malt666's picture
Upload 12 files
ad9a66f verified
raw
history blame
13.5 kB
package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/google/generative-ai-go/genai"
"google.golang.org/api/option"
)
// 配置结构
type Config struct {
AnthropicKey string
GoogleKey string
ServiceURL string
DeepseekURL string
OpenAIURL string
}
var (
config Config
configOnce sync.Once
)
// 请求结构
type TokenCountRequest struct {
Model string `json:"model" binding:"required"`
Messages []Message `json:"messages" binding:"required"`
System *string `json:"system,omitempty"`
}
type Message struct {
Role string `json:"role" binding:"required"`
Content string `json:"content" binding:"required"`
}
// 响应结构
type TokenCountResponse struct {
InputTokens int `json:"input_tokens"`
}
// 错误响应结构
type ErrorResponse struct {
Error string `json:"error"`
}
// 模型映射规则
type ModelRule struct {
Keywords []string
Target string
}
var modelRules = []ModelRule{
{
Keywords: []string{"gpt"},
Target: "gpt-3.5-turbo",
},
{
Keywords: []string{"openai"},
Target: "gpt-3.5-turbo",
},
{
Keywords: []string{"deepseek"},
Target: "deepseek-v3",
},
{
Keywords: []string{"claude", "3", "sonnet"},
Target: "claude-3-sonnet-20240229",
},
{
Keywords: []string{"claude", "3", "7"},
Target: "claude-3-7-sonnet-latest",
},
{
Keywords: []string{"claude", "3", "5", "sonnet"},
Target: "claude-3-5-sonnet-latest",
},
{
Keywords: []string{"claude", "3", "5", "haiku"},
Target: "claude-3-5-haiku-latest",
},
{
Keywords: []string{"claude", "3", "opus"},
Target: "claude-3-opus-latest",
},
{
Keywords: []string{"claude", "3", "haiku"},
Target: "claude-3-haiku-20240307",
},
{
Keywords: []string{"gemini", "2.0"},
Target: "gemini-2.0-flash",
},
{
Keywords: []string{"gemini", "2.5"},
Target: "gemini-2.0-flash", // 目前使用2.0-flash作为2.5的替代
},
{
Keywords: []string{"gemini", "1.5"},
Target: "gemini-1.5-flash",
},
}
// 智能匹配模型名称
func matchModelName(input string) string {
// 转换为小写进行匹配
input = strings.ToLower(input)
// 特殊规则:OpenAI GPT-4o
if (strings.Contains(input, "gpt") && strings.Contains(input, "4o")) ||
strings.Contains(input, "o1") ||
strings.Contains(input, "o3") {
return "gpt-4o"
}
// 特殊规则:OpenAI GPT-4
if (strings.Contains(input, "gpt") && strings.Contains(input, "3") && strings.Contains(input, "5")) ||
(strings.Contains(input, "gpt") && strings.Contains(input, "4") && !strings.Contains(input, "4o")) {
return "gpt-4"
}
// 遍历所有规则
for _, rule := range modelRules {
matches := true
for _, keyword := range rule.Keywords {
if !strings.Contains(input, strings.ToLower(keyword)) {
matches = false
break
}
}
if matches {
return rule.Target
}
}
// 如果没有匹配到,返回原始输入
return input
}
// 加载配置
func loadConfig() Config {
configOnce.Do(func() {
config.AnthropicKey = os.Getenv("ANTHROPIC_API_KEY")
if config.AnthropicKey == "" {
log.Println("警告: ANTHROPIC_API_KEY 环境变量未设置,Claude模型将无法使用")
}
config.GoogleKey = os.Getenv("GOOGLE_API_KEY")
if config.GoogleKey == "" {
log.Println("警告: GOOGLE_API_KEY 环境变量未设置,Gemini模型将无法使用")
}
// 获取Deepseek服务URL
config.DeepseekURL = os.Getenv("DEEPSEEK_URL")
if config.DeepseekURL == "" {
config.DeepseekURL = "http://127.0.0.1:7861" // 默认本地地址
log.Println("使用默认Deepseek服务地址:", config.DeepseekURL)
}
// 获取OpenAI服务URL
config.OpenAIURL = os.Getenv("OPENAI_URL")
if config.OpenAIURL == "" {
config.OpenAIURL = "http://127.0.0.1:7862" // 默认本地地址
log.Println("使用默认OpenAI服务地址:", config.OpenAIURL)
}
// 获取服务URL,用于防休眠
config.ServiceURL = os.Getenv("SERVICE_URL")
if config.ServiceURL == "" {
log.Println("SERVICE_URL 未设置,防休眠功能将被禁用")
}
})
return config
}
// 使用Claude API计算token
func countTokensWithClaude(req TokenCountRequest) (TokenCountResponse, error) {
// 准备请求Anthropic API
client := &http.Client{}
data, err := json.Marshal(req)
if err != nil {
return TokenCountResponse{}, fmt.Errorf("序列化请求失败: %v", err)
}
// 创建请求
request, err := http.NewRequest("POST", "https://api.anthropic.com/v1/messages/count_tokens", bytes.NewBuffer(data))
if err != nil {
return TokenCountResponse{}, fmt.Errorf("创建请求失败: %v", err)
}
// 设置请求头
request.Header.Set("x-api-key", config.AnthropicKey)
request.Header.Set("anthropic-version", "2023-06-01")
request.Header.Set("content-type", "application/json")
// 发送请求
response, err := client.Do(request)
if err != nil {
return TokenCountResponse{}, fmt.Errorf("发送请求到Anthropic API失败: %v", err)
}
defer response.Body.Close()
// 读取响应
var result TokenCountResponse
if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
return TokenCountResponse{}, fmt.Errorf("解码响应失败: %v", err)
}
return result, nil
}
// 使用Gemini API计算token
func countTokensWithGemini(req TokenCountRequest) (TokenCountResponse, error) {
// 检查API密钥
if config.GoogleKey == "" {
return TokenCountResponse{}, fmt.Errorf("GOOGLE_API_KEY 未设置")
}
// 创建Gemini客户端
ctx := context.Background()
client, err := genai.NewClient(ctx, option.WithAPIKey(config.GoogleKey))
if err != nil {
return TokenCountResponse{}, fmt.Errorf("创建Gemini客户端失败: %v", err)
}
defer client.Close()
// 使用已经匹配好的模型名称
modelName := req.Model
// 创建Gemini模型
model := client.GenerativeModel(modelName)
// 构建提示内容
var content string
if req.System != nil && *req.System != "" {
content += *req.System + "\n\n"
}
for _, msg := range req.Messages {
if msg.Role == "user" {
content += "用户: " + msg.Content + "\n"
} else if msg.Role == "assistant" {
content += "助手: " + msg.Content + "\n"
} else {
content += msg.Role + ": " + msg.Content + "\n"
}
}
// 计算token
tokResp, err := model.CountTokens(ctx, genai.Text(content))
if err != nil {
return TokenCountResponse{}, fmt.Errorf("计算Gemini token失败: %v", err)
}
return TokenCountResponse{InputTokens: int(tokResp.TotalTokens)}, nil
}
// 使用Deepseek API计算token
func countTokensWithDeepseek(req TokenCountRequest) (TokenCountResponse, error) {
// 准备请求
client := &http.Client{}
data, err := json.Marshal(req)
if err != nil {
return TokenCountResponse{}, fmt.Errorf("序列化请求失败: %v", err)
}
// 创建请求
request, err := http.NewRequest("POST", config.DeepseekURL+"/count_tokens", bytes.NewBuffer(data))
if err != nil {
return TokenCountResponse{}, fmt.Errorf("创建请求失败: %v", err)
}
// 设置请求头
request.Header.Set("Content-Type", "application/json")
// 发送请求
response, err := client.Do(request)
if err != nil {
return TokenCountResponse{}, fmt.Errorf("发送请求到Deepseek服务失败: %v", err)
}
defer response.Body.Close()
// 读取响应
var result TokenCountResponse
if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
return TokenCountResponse{}, fmt.Errorf("解码响应失败: %v", err)
}
return result, nil
}
// 使用OpenAI API计算token
func countTokensWithOpenAI(req TokenCountRequest) (TokenCountResponse, error) {
// 准备请求
client := &http.Client{}
data, err := json.Marshal(req)
if err != nil {
return TokenCountResponse{}, fmt.Errorf("序列化请求失败: %v", err)
}
// 创建请求
request, err := http.NewRequest("POST", config.OpenAIURL+"/count_tokens", bytes.NewBuffer(data))
if err != nil {
return TokenCountResponse{}, fmt.Errorf("创建请求失败: %v", err)
}
// 设置请求头
request.Header.Set("Content-Type", "application/json")
// 发送请求
response, err := client.Do(request)
if err != nil {
return TokenCountResponse{}, fmt.Errorf("发送请求到OpenAI服务失败: %v", err)
}
defer response.Body.Close()
// 读取响应
var result struct {
InputTokens int `json:"input_tokens"`
Model string `json:"model"`
Encoding string `json:"encoding"`
}
if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
return TokenCountResponse{}, fmt.Errorf("解码响应失败: %v", err)
}
return TokenCountResponse{InputTokens: result.InputTokens}, nil
}
// 计算token
func countTokens(c *gin.Context) {
var req TokenCountRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse{Error: err.Error()})
return
}
// 保存原始模型名称
originalModel := req.Model
// 检查是否为不支持的模型
isUnsupportedModel := true
// 检查是否为支持的模型类型
modelLower := strings.ToLower(req.Model)
if strings.Contains(modelLower, "gpt") || strings.Contains(modelLower, "openai") ||
strings.Contains(modelLower, "o1") || strings.Contains(modelLower, "o3") ||
strings.HasPrefix(modelLower, "claude") ||
strings.Contains(modelLower, "gemini") ||
strings.Contains(modelLower, "deepseek") {
isUnsupportedModel = false
}
// 智能匹配模型名称
req.Model = matchModelName(req.Model)
var result TokenCountResponse
var err error
// 优先检查是否为Deepseek模型
if strings.Contains(strings.ToLower(req.Model), "deepseek") {
// 使用Deepseek API
result, err = countTokensWithDeepseek(req)
} else if strings.Contains(strings.ToLower(req.Model), "gpt") || strings.Contains(strings.ToLower(req.Model), "openai") {
// 使用OpenAI API
result, err = countTokensWithOpenAI(req)
} else if strings.HasPrefix(strings.ToLower(req.Model), "claude") {
// 使用Claude API
if config.AnthropicKey == "" {
c.JSON(http.StatusBadRequest, ErrorResponse{Error: "ANTHROPIC_API_KEY 未设置,无法使用Claude模型"})
return
}
result, err = countTokensWithClaude(req)
} else if strings.Contains(strings.ToLower(req.Model), "gemini") {
// 使用Gemini API
if config.GoogleKey == "" {
c.JSON(http.StatusBadRequest, ErrorResponse{Error: "GOOGLE_API_KEY 未设置,无法使用Gemini模型"})
return
}
result, err = countTokensWithGemini(req)
} else if isUnsupportedModel {
// 不支持的模型,使用GPT-4o估算
// 创建新的请求,使用GPT-4o
gptReq := req
gptReq.Model = "gpt-4o"
// 使用OpenAI API
result, err = countTokensWithOpenAI(gptReq)
if err == nil {
// 返回估算值,但添加警告信息
c.JSON(http.StatusOK, gin.H{
"input_tokens": result.InputTokens,
"warning": fmt.Sprintf("The tokenizer for model '%s' is not supported yet. This is an estimation based on gpt-4o and may not be accurate.", originalModel),
"estimated_with": "gpt-4o",
})
return
}
} else {
// 完全不支持的情况,返回错误但仍提供估算值
// 使用GPT-4o进行估算
gptReq := req
gptReq.Model = "gpt-4o"
estimatedResult, estimateErr := countTokensWithOpenAI(gptReq)
if estimateErr == nil {
c.JSON(http.StatusOK, gin.H{
"input_tokens": estimatedResult.InputTokens,
"warning": fmt.Sprintf("The tokenizer for model '%s' is not supported yet. This is an estimation based on gpt-4o and may not be accurate.", originalModel),
"estimated_with": "gpt-4o",
})
} else {
c.JSON(http.StatusBadRequest, ErrorResponse{Error: fmt.Sprintf("The tokenizer for model '%s' is not supported yet.", originalModel)})
}
return
}
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse{Error: err.Error()})
return
}
// 返回结果
c.JSON(http.StatusOK, result)
}
// 健康检查
func healthCheck(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"status": "healthy",
"time": time.Now().Format(time.RFC3339),
})
}
// 防休眠任务
func startKeepAlive() {
if config.ServiceURL == "" {
return
}
healthURL := fmt.Sprintf("%s/health", config.ServiceURL)
ticker := time.NewTicker(10 * time.Hour)
// 立即执行一次检查
go func() {
log.Printf("Starting keep-alive checks to %s", healthURL)
for {
resp, err := http.Get(healthURL)
if err != nil {
log.Printf("Keep-alive check failed: %v", err)
} else {
resp.Body.Close()
log.Printf("Keep-alive check successful")
}
// 等待下一次触发
<-ticker.C
}
}()
}
func main() {
// 加载配置
loadConfig()
// 设置gin模式
gin.SetMode(gin.ReleaseMode)
// 创建路由
r := gin.Default()
// 添加中间件
r.Use(gin.Recovery())
r.Use(func(c *gin.Context) {
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}
c.Next()
})
// 路由
r.GET("/health", healthCheck)
r.POST("/count_tokens", countTokens)
// 获取端口
port := os.Getenv("PORT")
if port == "" {
port = "7860" // Hugging Face默认端口
}
// 启动防休眠任务
startKeepAlive()
// 启动服务器
log.Printf("Server starting on port %s", port)
if err := r.Run(":" + port); err != nil {
log.Fatal(err)
}
}