Spaces:
Running
Running
package main | |
import ( | |
"bytes" | |
"context" | |
"encoding/json" | |
"fmt" | |
"log" | |
"net/http" | |
"os" | |
"strings" | |
"sync" | |
"time" | |
"github.com/gin-gonic/gin" | |
"github.com/google/generative-ai-go/genai" | |
"google.golang.org/api/option" | |
) | |
// 配置结构 | |
type Config struct { | |
AnthropicKey string | |
GoogleKey string | |
ServiceURL string | |
DeepseekURL string | |
OpenAIURL string | |
} | |
var ( | |
config Config | |
configOnce sync.Once | |
) | |
// 请求结构 | |
type TokenCountRequest struct { | |
Model string `json:"model" binding:"required"` | |
Messages []Message `json:"messages" binding:"required"` | |
System *string `json:"system,omitempty"` | |
} | |
type Message struct { | |
Role string `json:"role" binding:"required"` | |
Content string `json:"content" binding:"required"` | |
} | |
// 响应结构 | |
type TokenCountResponse struct { | |
InputTokens int `json:"input_tokens"` | |
} | |
// 错误响应结构 | |
type ErrorResponse struct { | |
Error string `json:"error"` | |
} | |
// 模型映射规则 | |
type ModelRule struct { | |
Keywords []string | |
Target string | |
} | |
var modelRules = []ModelRule{ | |
{ | |
Keywords: []string{"gpt"}, | |
Target: "gpt-3.5-turbo", | |
}, | |
{ | |
Keywords: []string{"openai"}, | |
Target: "gpt-3.5-turbo", | |
}, | |
{ | |
Keywords: []string{"deepseek"}, | |
Target: "deepseek-v3", | |
}, | |
{ | |
Keywords: []string{"claude", "3", "sonnet"}, | |
Target: "claude-3-sonnet-20240229", | |
}, | |
{ | |
Keywords: []string{"claude", "3", "7"}, | |
Target: "claude-3-7-sonnet-latest", | |
}, | |
{ | |
Keywords: []string{"claude", "3", "5", "sonnet"}, | |
Target: "claude-3-5-sonnet-latest", | |
}, | |
{ | |
Keywords: []string{"claude", "3", "5", "haiku"}, | |
Target: "claude-3-5-haiku-latest", | |
}, | |
{ | |
Keywords: []string{"claude", "3", "opus"}, | |
Target: "claude-3-opus-latest", | |
}, | |
{ | |
Keywords: []string{"claude", "3", "haiku"}, | |
Target: "claude-3-haiku-20240307", | |
}, | |
{ | |
Keywords: []string{"gemini", "2.0"}, | |
Target: "gemini-2.0-flash", | |
}, | |
{ | |
Keywords: []string{"gemini", "2.5"}, | |
Target: "gemini-2.0-flash", // 目前使用2.0-flash作为2.5的替代 | |
}, | |
{ | |
Keywords: []string{"gemini", "1.5"}, | |
Target: "gemini-1.5-flash", | |
}, | |
} | |
// 智能匹配模型名称 | |
func matchModelName(input string) string { | |
// 转换为小写进行匹配 | |
input = strings.ToLower(input) | |
// 特殊规则:OpenAI GPT-4o | |
if (strings.Contains(input, "gpt") && strings.Contains(input, "4o")) || | |
strings.Contains(input, "o1") || | |
strings.Contains(input, "o3") { | |
return "gpt-4o" | |
} | |
// 特殊规则:OpenAI GPT-4 | |
if (strings.Contains(input, "gpt") && strings.Contains(input, "3") && strings.Contains(input, "5")) || | |
(strings.Contains(input, "gpt") && strings.Contains(input, "4") && !strings.Contains(input, "4o")) { | |
return "gpt-4" | |
} | |
// 遍历所有规则 | |
for _, rule := range modelRules { | |
matches := true | |
for _, keyword := range rule.Keywords { | |
if !strings.Contains(input, strings.ToLower(keyword)) { | |
matches = false | |
break | |
} | |
} | |
if matches { | |
return rule.Target | |
} | |
} | |
// 如果没有匹配到,返回原始输入 | |
return input | |
} | |
// 加载配置 | |
func loadConfig() Config { | |
configOnce.Do(func() { | |
config.AnthropicKey = os.Getenv("ANTHROPIC_API_KEY") | |
if config.AnthropicKey == "" { | |
log.Println("警告: ANTHROPIC_API_KEY 环境变量未设置,Claude模型将无法使用") | |
} | |
config.GoogleKey = os.Getenv("GOOGLE_API_KEY") | |
if config.GoogleKey == "" { | |
log.Println("警告: GOOGLE_API_KEY 环境变量未设置,Gemini模型将无法使用") | |
} | |
// 获取Deepseek服务URL | |
config.DeepseekURL = os.Getenv("DEEPSEEK_URL") | |
if config.DeepseekURL == "" { | |
config.DeepseekURL = "http://127.0.0.1:7861" // 默认本地地址 | |
log.Println("使用默认Deepseek服务地址:", config.DeepseekURL) | |
} | |
// 获取OpenAI服务URL | |
config.OpenAIURL = os.Getenv("OPENAI_URL") | |
if config.OpenAIURL == "" { | |
config.OpenAIURL = "http://127.0.0.1:7862" // 默认本地地址 | |
log.Println("使用默认OpenAI服务地址:", config.OpenAIURL) | |
} | |
// 获取服务URL,用于防休眠 | |
config.ServiceURL = os.Getenv("SERVICE_URL") | |
if config.ServiceURL == "" { | |
log.Println("SERVICE_URL 未设置,防休眠功能将被禁用") | |
} | |
}) | |
return config | |
} | |
// 使用Claude API计算token | |
func countTokensWithClaude(req TokenCountRequest) (TokenCountResponse, error) { | |
// 准备请求Anthropic API | |
client := &http.Client{} | |
data, err := json.Marshal(req) | |
if err != nil { | |
return TokenCountResponse{}, fmt.Errorf("序列化请求失败: %v", err) | |
} | |
// 创建请求 | |
request, err := http.NewRequest("POST", "https://api.anthropic.com/v1/messages/count_tokens", bytes.NewBuffer(data)) | |
if err != nil { | |
return TokenCountResponse{}, fmt.Errorf("创建请求失败: %v", err) | |
} | |
// 设置请求头 | |
request.Header.Set("x-api-key", config.AnthropicKey) | |
request.Header.Set("anthropic-version", "2023-06-01") | |
request.Header.Set("content-type", "application/json") | |
// 发送请求 | |
response, err := client.Do(request) | |
if err != nil { | |
return TokenCountResponse{}, fmt.Errorf("发送请求到Anthropic API失败: %v", err) | |
} | |
defer response.Body.Close() | |
// 读取响应 | |
var result TokenCountResponse | |
if err := json.NewDecoder(response.Body).Decode(&result); err != nil { | |
return TokenCountResponse{}, fmt.Errorf("解码响应失败: %v", err) | |
} | |
return result, nil | |
} | |
// 使用Gemini API计算token | |
func countTokensWithGemini(req TokenCountRequest) (TokenCountResponse, error) { | |
// 检查API密钥 | |
if config.GoogleKey == "" { | |
return TokenCountResponse{}, fmt.Errorf("GOOGLE_API_KEY 未设置") | |
} | |
// 创建Gemini客户端 | |
ctx := context.Background() | |
client, err := genai.NewClient(ctx, option.WithAPIKey(config.GoogleKey)) | |
if err != nil { | |
return TokenCountResponse{}, fmt.Errorf("创建Gemini客户端失败: %v", err) | |
} | |
defer client.Close() | |
// 使用已经匹配好的模型名称 | |
modelName := req.Model | |
// 创建Gemini模型 | |
model := client.GenerativeModel(modelName) | |
// 构建提示内容 | |
var content string | |
if req.System != nil && *req.System != "" { | |
content += *req.System + "\n\n" | |
} | |
for _, msg := range req.Messages { | |
if msg.Role == "user" { | |
content += "用户: " + msg.Content + "\n" | |
} else if msg.Role == "assistant" { | |
content += "助手: " + msg.Content + "\n" | |
} else { | |
content += msg.Role + ": " + msg.Content + "\n" | |
} | |
} | |
// 计算token | |
tokResp, err := model.CountTokens(ctx, genai.Text(content)) | |
if err != nil { | |
return TokenCountResponse{}, fmt.Errorf("计算Gemini token失败: %v", err) | |
} | |
return TokenCountResponse{InputTokens: int(tokResp.TotalTokens)}, nil | |
} | |
// 使用Deepseek API计算token | |
func countTokensWithDeepseek(req TokenCountRequest) (TokenCountResponse, error) { | |
// 准备请求 | |
client := &http.Client{} | |
data, err := json.Marshal(req) | |
if err != nil { | |
return TokenCountResponse{}, fmt.Errorf("序列化请求失败: %v", err) | |
} | |
// 创建请求 | |
request, err := http.NewRequest("POST", config.DeepseekURL+"/count_tokens", bytes.NewBuffer(data)) | |
if err != nil { | |
return TokenCountResponse{}, fmt.Errorf("创建请求失败: %v", err) | |
} | |
// 设置请求头 | |
request.Header.Set("Content-Type", "application/json") | |
// 发送请求 | |
response, err := client.Do(request) | |
if err != nil { | |
return TokenCountResponse{}, fmt.Errorf("发送请求到Deepseek服务失败: %v", err) | |
} | |
defer response.Body.Close() | |
// 读取响应 | |
var result TokenCountResponse | |
if err := json.NewDecoder(response.Body).Decode(&result); err != nil { | |
return TokenCountResponse{}, fmt.Errorf("解码响应失败: %v", err) | |
} | |
return result, nil | |
} | |
// 使用OpenAI API计算token | |
func countTokensWithOpenAI(req TokenCountRequest) (TokenCountResponse, error) { | |
// 准备请求 | |
client := &http.Client{} | |
data, err := json.Marshal(req) | |
if err != nil { | |
return TokenCountResponse{}, fmt.Errorf("序列化请求失败: %v", err) | |
} | |
// 创建请求 | |
request, err := http.NewRequest("POST", config.OpenAIURL+"/count_tokens", bytes.NewBuffer(data)) | |
if err != nil { | |
return TokenCountResponse{}, fmt.Errorf("创建请求失败: %v", err) | |
} | |
// 设置请求头 | |
request.Header.Set("Content-Type", "application/json") | |
// 发送请求 | |
response, err := client.Do(request) | |
if err != nil { | |
return TokenCountResponse{}, fmt.Errorf("发送请求到OpenAI服务失败: %v", err) | |
} | |
defer response.Body.Close() | |
// 读取响应 | |
var result struct { | |
InputTokens int `json:"input_tokens"` | |
Model string `json:"model"` | |
Encoding string `json:"encoding"` | |
} | |
if err := json.NewDecoder(response.Body).Decode(&result); err != nil { | |
return TokenCountResponse{}, fmt.Errorf("解码响应失败: %v", err) | |
} | |
return TokenCountResponse{InputTokens: result.InputTokens}, nil | |
} | |
// 计算token | |
func countTokens(c *gin.Context) { | |
var req TokenCountRequest | |
if err := c.ShouldBindJSON(&req); err != nil { | |
c.JSON(http.StatusBadRequest, ErrorResponse{Error: err.Error()}) | |
return | |
} | |
// 保存原始模型名称 | |
originalModel := req.Model | |
// 检查是否为不支持的模型 | |
isUnsupportedModel := true | |
// 检查是否为支持的模型类型 | |
modelLower := strings.ToLower(req.Model) | |
if strings.Contains(modelLower, "gpt") || strings.Contains(modelLower, "openai") || | |
strings.Contains(modelLower, "o1") || strings.Contains(modelLower, "o3") || | |
strings.HasPrefix(modelLower, "claude") || | |
strings.Contains(modelLower, "gemini") || | |
strings.Contains(modelLower, "deepseek") { | |
isUnsupportedModel = false | |
} | |
// 智能匹配模型名称 | |
req.Model = matchModelName(req.Model) | |
var result TokenCountResponse | |
var err error | |
// 优先检查是否为Deepseek模型 | |
if strings.Contains(strings.ToLower(req.Model), "deepseek") { | |
// 使用Deepseek API | |
result, err = countTokensWithDeepseek(req) | |
} else if strings.Contains(strings.ToLower(req.Model), "gpt") || strings.Contains(strings.ToLower(req.Model), "openai") { | |
// 使用OpenAI API | |
result, err = countTokensWithOpenAI(req) | |
} else if strings.HasPrefix(strings.ToLower(req.Model), "claude") { | |
// 使用Claude API | |
if config.AnthropicKey == "" { | |
c.JSON(http.StatusBadRequest, ErrorResponse{Error: "ANTHROPIC_API_KEY 未设置,无法使用Claude模型"}) | |
return | |
} | |
result, err = countTokensWithClaude(req) | |
} else if strings.Contains(strings.ToLower(req.Model), "gemini") { | |
// 使用Gemini API | |
if config.GoogleKey == "" { | |
c.JSON(http.StatusBadRequest, ErrorResponse{Error: "GOOGLE_API_KEY 未设置,无法使用Gemini模型"}) | |
return | |
} | |
result, err = countTokensWithGemini(req) | |
} else if isUnsupportedModel { | |
// 不支持的模型,使用GPT-4o估算 | |
// 创建新的请求,使用GPT-4o | |
gptReq := req | |
gptReq.Model = "gpt-4o" | |
// 使用OpenAI API | |
result, err = countTokensWithOpenAI(gptReq) | |
if err == nil { | |
// 返回估算值,但添加警告信息 | |
c.JSON(http.StatusOK, gin.H{ | |
"input_tokens": result.InputTokens, | |
"warning": fmt.Sprintf("The tokenizer for model '%s' is not supported yet. This is an estimation based on gpt-4o and may not be accurate.", originalModel), | |
"estimated_with": "gpt-4o", | |
}) | |
return | |
} | |
} else { | |
// 完全不支持的情况,返回错误但仍提供估算值 | |
// 使用GPT-4o进行估算 | |
gptReq := req | |
gptReq.Model = "gpt-4o" | |
estimatedResult, estimateErr := countTokensWithOpenAI(gptReq) | |
if estimateErr == nil { | |
c.JSON(http.StatusOK, gin.H{ | |
"input_tokens": estimatedResult.InputTokens, | |
"warning": fmt.Sprintf("The tokenizer for model '%s' is not supported yet. This is an estimation based on gpt-4o and may not be accurate.", originalModel), | |
"estimated_with": "gpt-4o", | |
}) | |
} else { | |
c.JSON(http.StatusBadRequest, ErrorResponse{Error: fmt.Sprintf("The tokenizer for model '%s' is not supported yet.", originalModel)}) | |
} | |
return | |
} | |
if err != nil { | |
c.JSON(http.StatusInternalServerError, ErrorResponse{Error: err.Error()}) | |
return | |
} | |
// 返回结果 | |
c.JSON(http.StatusOK, result) | |
} | |
// 健康检查 | |
func healthCheck(c *gin.Context) { | |
c.JSON(http.StatusOK, gin.H{ | |
"status": "healthy", | |
"time": time.Now().Format(time.RFC3339), | |
}) | |
} | |
// 防休眠任务 | |
func startKeepAlive() { | |
if config.ServiceURL == "" { | |
return | |
} | |
healthURL := fmt.Sprintf("%s/health", config.ServiceURL) | |
ticker := time.NewTicker(10 * time.Hour) | |
// 立即执行一次检查 | |
go func() { | |
log.Printf("Starting keep-alive checks to %s", healthURL) | |
for { | |
resp, err := http.Get(healthURL) | |
if err != nil { | |
log.Printf("Keep-alive check failed: %v", err) | |
} else { | |
resp.Body.Close() | |
log.Printf("Keep-alive check successful") | |
} | |
// 等待下一次触发 | |
<-ticker.C | |
} | |
}() | |
} | |
func main() { | |
// 加载配置 | |
loadConfig() | |
// 设置gin模式 | |
gin.SetMode(gin.ReleaseMode) | |
// 创建路由 | |
r := gin.Default() | |
// 添加中间件 | |
r.Use(gin.Recovery()) | |
r.Use(func(c *gin.Context) { | |
c.Writer.Header().Set("Access-Control-Allow-Origin", "*") | |
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS") | |
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type") | |
if c.Request.Method == "OPTIONS" { | |
c.AbortWithStatus(204) | |
return | |
} | |
c.Next() | |
}) | |
// 路由 | |
r.GET("/health", healthCheck) | |
r.POST("/count_tokens", countTokens) | |
// 获取端口 | |
port := os.Getenv("PORT") | |
if port == "" { | |
port = "7860" // Hugging Face默认端口 | |
} | |
// 启动防休眠任务 | |
startKeepAlive() | |
// 启动服务器 | |
log.Printf("Server starting on port %s", port) | |
if err := r.Run(":" + port); err != nil { | |
log.Fatal(err) | |
} | |
} | |