package main import ( "bufio" "bytes" "context" "encoding/json" "fmt" "log" "net/http" "os" "strconv" "strings" "sync" "time" "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/joho/godotenv" ) // 配置常量 const ( BadKeyRetryInterval = 600 * time.Second // 10分钟 SessionTimeout = 600 * time.Second // 10分钟 DefaultPort = 7860 ) // 全局变量 var ( privateKey string ondemandAPIKeys []string safeHeaders = []string{"Authorization", "X-API-KEY"} ondemandAPIBase = "https://api.on-demand.io/chat/v1" defaultModel = "predefined-openai-gpt4o" ) // 模型映射 var modelMap = map[string]string{ "o3-mini": "predefined-openai-gpto3-mini", "o4-mini": "predefined-openai-gpto4-mini", "gpt-4o": "predefined-openai-gpt4o", "gpt-4.1": "predefined-openai-gpt4.1", "gpt-4.1-mini": "predefined-openai-gpt4.1-mini", "gpt-4o-mini": "predefined-openai-gpt4o-mini", "deepseek-v3": "predefined-deepseek-v3", "deepseek-r1": "predefined-deepseek-r1", "claude-4-sonnet": "predefined-claude-4-sonnet", "claude-4-opus": "predefined-claude-4-opus", } // KeyStatus 表示API密钥的状态 type KeyStatus struct { Bad bool `json:"bad"` BadTS time.Time `json:"bad_ts"` } // KeyManager 管理API密钥的轮换和状态 type KeyManager struct { keyList []string mu sync.RWMutex keyStatus map[string]*KeyStatus idx int currentKey string currentSession string lastUsedTime time.Time } // NewKeyManager 创建新的密钥管理器 func NewKeyManager(keys []string) *KeyManager { km := &KeyManager{ keyList: make([]string, len(keys)), keyStatus: make(map[string]*KeyStatus), } copy(km.keyList, keys) for _, key := range keys { km.keyStatus[key] = &KeyStatus{} } return km } // displayKey 显示密钥的简化版本 func (km *KeyManager) displayKey(key string) string { if len(key) <= 10 { return key } return fmt.Sprintf("%s...%s", key[:6], key[len(key)-4:]) } // Get 获取可用的API密钥 func (km *KeyManager) Get() string { km.mu.Lock() defer km.mu.Unlock() now := time.Now() // 检查会话是否超时 if km.currentKey != "" && !km.lastUsedTime.IsZero() && now.Sub(km.lastUsedTime) > SessionTimeout { log.Printf("【对话超时】上次使用时间: %s", km.lastUsedTime.Format("2006-01-02 15:04:05")) log.Printf("【对话超时】当前时间: %s", now.Format("2006-01-02 15:04:05")) log.Printf("【对话超时】超时%d分钟,切换新会话", int(SessionTimeout.Minutes())) km.currentKey = "" km.currentSession = "" } // 如果已有正在使用的key,继续使用 if km.currentKey != "" { if !km.keyStatus[km.currentKey].Bad { log.Printf("【对话请求】【继续使用API KEY: %s】【状态:正常】", km.displayKey(km.currentKey)) km.lastUsedTime = now return km.currentKey } else { // 当前key已标记为异常,需要切换 km.currentKey = "" km.currentSession = "" } } // 选择新的key total := len(km.keyList) for i := 0; i < total; i++ { key := km.keyList[km.idx] km.idx = (km.idx + 1) % total status := km.keyStatus[key] if !status.Bad { log.Printf("【对话请求】【使用新API KEY: %s】【状态:正常】", km.displayKey(key)) km.currentKey = key km.currentSession = "" km.lastUsedTime = now return key } if status.Bad && !status.BadTS.IsZero() { if now.Sub(status.BadTS) >= BadKeyRetryInterval { log.Printf("【KEY自动尝试恢复】API KEY: %s 满足重试周期,标记为正常", km.displayKey(key)) status.Bad = false status.BadTS = time.Time{} km.currentKey = key km.currentSession = "" km.lastUsedTime = now return key } } } // 所有密钥都不可用,强制重置 log.Printf("【警告】全部KEY已被禁用,强制选用第一个KEY继续尝试: %s", km.displayKey(km.keyList[0])) for _, key := range km.keyList { km.keyStatus[key].Bad = false km.keyStatus[key].BadTS = time.Time{} } km.idx = 0 km.currentKey = km.keyList[0] km.currentSession = "" km.lastUsedTime = now log.Printf("【对话请求】【使用API KEY: %s】【状态:强制尝试(全部异常)】", km.displayKey(km.currentKey)) return km.currentKey } // MarkBad 标记密钥为不可用 func (km *KeyManager) MarkBad(key string) { km.mu.Lock() defer km.mu.Unlock() if status, exists := km.keyStatus[key]; exists && !status.Bad { log.Printf("【禁用KEY】API KEY: %s,接口返回无效(将在%d分钟后自动重试)", km.displayKey(key), int(BadKeyRetryInterval.Minutes())) status.Bad = true status.BadTS = time.Now() if km.currentKey == key { km.currentKey = "" km.currentSession = "" } } } // GetSession 获取或创建会话 func (km *KeyManager) GetSession(ctx context.Context, apikey string) (string, error) { km.mu.Lock() defer km.mu.Unlock() if km.currentSession == "" { session, err := createSession(ctx, apikey, "", nil) if err != nil { log.Printf("【创建会话失败】错误: %v", err) return "", err } km.currentSession = session log.Printf("【创建新会话】SESSION ID: %s", km.currentSession) } km.lastUsedTime = time.Now() return km.currentSession, nil } var keyManager *KeyManager // HTTP请求结构 type ChatCompletionRequest struct { Messages []Message `json:"messages"` Model string `json:"model"` Stream bool `json:"stream"` } type Message struct { Role string `json:"role"` Content string `json:"content"` } type ChatCompletionResponse struct { ID string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` Model string `json:"model"` Choices []Choice `json:"choices"` Usage Usage `json:"usage"` } type Choice struct { Index int `json:"index"` Message *Message `json:"message,omitempty"` Delta *Message `json:"delta,omitempty"` FinishReason *string `json:"finish_reason"` } type Usage struct{} type ModelsResponse struct { Object string `json:"object"` Data []Model `json:"data"` } type Model struct { ID string `json:"id"` Object string `json:"object"` OwnedBy string `json:"owned_by"` } // OnDemand API 结构 type CreateSessionRequest struct { ExternalUserID string `json:"externalUserId"` PluginIds []string `json:"pluginIds,omitempty"` } type CreateSessionResponse struct { Data struct { ID string `json:"id"` } `json:"data"` } type QueryRequest struct { Query string `json:"query"` EndpointID string `json:"endpointId"` PluginIds []string `json:"pluginIds"` ResponseMode string `json:"responseMode"` } type QueryResponse struct { Data struct { Answer string `json:"answer"` } `json:"data"` } // 初始化配置 func init() { // 加载 .env 文件 err := godotenv.Load() if err != nil { log.Println("警告:没有找到 .env 文件,将仅使用系统环境变量") } initConfig() } func initConfig() { privateKey = getEnv("PRIVATE_KEY", "testofli") apiKeysStr := os.Getenv("ONDEMAND_APIKEYS") if apiKeysStr != "" { ondemandAPIKeys = strings.Split(apiKeysStr, ",") } if len(ondemandAPIKeys) == 0 && !isTestMode() { log.Fatal("ONDEMAND_APIKEYS 环境变量为空,请设置API密钥") } if len(ondemandAPIKeys) > 0 { keyManager = NewKeyManager(ondemandAPIKeys) } } func isTestMode() bool { for _, arg := range os.Args { if strings.Contains(arg, "test") { return true } } return os.Getenv("GIN_MODE") == "test" } func getEnv(key, defaultValue string) string { if value := os.Getenv(key); value != "" { return value } return defaultValue } // 权限检查中间件 func checkPrivateKey() gin.HandlerFunc { return func(c *gin.Context) { // 放宽部分接口 if c.Request.URL.Path == "/" || c.Request.URL.Path == "/favicon.ico" { c.Next() return } var key string for _, header := range safeHeaders { if value := c.GetHeader(header); value != "" { key = value if header == "Authorization" && strings.HasPrefix(value, "Bearer ") { key = strings.TrimSpace(value[7:]) } break } } if key == "" || key != privateKey { c.JSON(http.StatusUnauthorized, gin.H{ "error": "Unauthorized, must provide correct Authorization or X-API-KEY", "headers": c.Request.Header, }) c.Abort() return } c.Next() } } // 获取端点ID func getEndpointID(openaiModel string) string { model := strings.ToLower(strings.ReplaceAll(openaiModel, " ", "")) if endpoint, exists := modelMap[model]; exists { return endpoint } return "" } // 创建会话 func createSession(ctx context.Context, apikey, externalUserID string, pluginIds []string) (string, error) { if externalUserID == "" { externalUserID = uuid.New().String() } payload := CreateSessionRequest{ ExternalUserID: externalUserID, PluginIds: pluginIds, } jsonData, err := json.Marshal(payload) if err != nil { return "", err } req, err := http.NewRequestWithContext(ctx, "POST", ondemandAPIBase+"/sessions", bytes.NewBuffer(jsonData)) if err != nil { return "", err } req.Header.Set("apikey", apikey) req.Header.Set("Content-Type", "application/json") client := &http.Client{Timeout: 20 * time.Second} resp, err := client.Do(req) if err != nil { return "", err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { return "", fmt.Errorf("create session failed with status: %d", resp.StatusCode) } var sessionResp CreateSessionResponse if err := json.NewDecoder(resp.Body).Decode(&sessionResp); err != nil { return "", err } return sessionResp.Data.ID, nil } // 执行带重试的操作 func withValidKey(ctx context.Context, fn func(ctx context.Context, key string) (interface{}, error)) (interface{}, error) { badCount := 0 maxRetry := len(keyManager.keyList) * 2 for badCount < maxRetry { key := keyManager.Get() result, err := fn(ctx, key) if err != nil { // 检查是否是需要标记密钥为坏的错误 if isAuthError(err) { keyManager.MarkBad(key) badCount++ continue } return nil, err } return result, nil } return nil, fmt.Errorf("没有可用API KEY,请补充新KEY或联系技术支持") } // 检查是否是认证相关错误 func isAuthError(err error) bool { errStr := err.Error() return strings.Contains(errStr, "401") || strings.Contains(errStr, "403") || strings.Contains(errStr, "429") || strings.Contains(errStr, "500") } // 聊天完成接口 func chatCompletions(c *gin.Context) { var req ChatCompletionRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "请求缺少messages字段"}) return } if len(req.Messages) == 0 { c.JSON(http.StatusBadRequest, gin.H{"error": "请求缺少messages字段"}) return } // 获取用户消息 var userMsg string for i := len(req.Messages) - 1; i >= 0; i-- { if req.Messages[i].Role == "user" { userMsg = req.Messages[i].Content break } } if userMsg == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "未找到用户消息"}) return } endpointID := getEndpointID(req.Model) if endpointID == "" { c.JSON(http.StatusBadRequest, gin.H{ "error": map[string]interface{}{ "message": fmt.Sprintf("The model '%s' does not exist", req.Model), "type": "invalid_request_error", "param": "model", "code": "model_not_found", }, }) return } // 添加模型和端点的日志记录 log.Printf("【模型请求】模型: %s, 端点: %s, 流式: %t", req.Model, endpointID, req.Stream) if req.Stream { handleStreamResponse(c, userMsg, endpointID, req.Model) } else { handleNonStreamResponse(c, userMsg, endpointID, req.Model) } } // 处理流式响应 func handleStreamResponse(c *gin.Context, userMsg, endpointID, model string) { c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") // 使用channel进行异步处理 resultChan := make(chan string, 100) errorChan := make(chan error, 1) go func() { defer close(resultChan) defer close(errorChan) ctx := context.Background() result, err := withValidKey(ctx, func(ctx context.Context, apikey string) (interface{}, error) { return streamQuery(ctx, apikey, userMsg, endpointID, model, resultChan) }) if err != nil { errorChan <- err return } _ = result // 流式响应的结果通过channel传递 }() // 处理响应流 for { select { case chunk, ok := <-resultChan: if !ok { return } if chunk == "data: [DONE]" { _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n") c.Writer.Flush() return } _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", chunk) c.Writer.Flush() case err := <-errorChan: if err != nil { errorData := map[string]any{"error": err.Error()} errorJSON, _ := json.Marshal(errorData) _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(errorJSON)) c.Writer.Flush() } return case <-c.Request.Context().Done(): return } } } // 流式查询 func streamQuery(ctx context.Context, apikey, userMsg, endpointID, model string, resultChan chan<- string) (interface{}, error) { sessionID, err := keyManager.GetSession(ctx, apikey) if err != nil { return nil, err } payload := QueryRequest{ Query: userMsg, EndpointID: endpointID, PluginIds: []string{}, ResponseMode: "stream", } jsonData, err := json.Marshal(payload) if err != nil { return nil, err } req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("%s/sessions/%s/query", ondemandAPIBase, sessionID), bytes.NewBuffer(jsonData)) if err != nil { return nil, err } req.Header.Set("apikey", apikey) req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "text/event-stream") client := &http.Client{Timeout: 300 * time.Second} resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("stream query failed with status: %d", resp.StatusCode) } scanner := bufio.NewScanner(resp.Body) firstChunk := true for scanner.Scan() { line := scanner.Text() if !strings.HasPrefix(line, "data:") { continue } dataPart := strings.TrimSpace(line[5:]) if dataPart == "[DONE]" { resultChan <- "data: [DONE]" break } if strings.HasPrefix(dataPart, "[ERROR]:") { errJSON := strings.TrimSpace(dataPart[8:]) resultChan <- fmt.Sprintf(`{"error": "%s"}`, errJSON) break } var eventData map[string]any if err := json.Unmarshal([]byte(dataPart), &eventData); err != nil { continue } // 处理不同类型的事件 if eventType, ok := eventData["eventType"].(string); ok { var content string var hasContent bool switch eventType { case "fulfillment": if answer, ok := eventData["answer"].(string); ok { content = answer hasContent = true } case "stream", "thinking", "reasoning", "thoughts": // 可能的思考过程事件类型 if answer, ok := eventData["answer"].(string); ok { content = answer hasContent = true } else if text, ok := eventData["text"].(string); ok { content = text hasContent = true } else if data, ok := eventData["data"].(string); ok { content = data hasContent = true } else if thoughts, ok := eventData["thoughts"].(string); ok { content = thoughts hasContent = true } default: // 对于未知事件类型,尝试提取任何文本内容 if answer, ok := eventData["answer"].(string); ok { content = answer hasContent = true } else if text, ok := eventData["text"].(string); ok { content = text hasContent = true } else if thoughts, ok := eventData["thoughts"].(string); ok { content = thoughts hasContent = true } } if hasContent { chunk := ChatCompletionResponse{ ID: "chatcmpl-" + uuid.New().String()[:8], Object: "chat.completion.chunk", Created: time.Now().Unix(), Model: model, Choices: []Choice{{ Index: 0, Delta: &Message{ Role: func() string { if firstChunk { return "assistant" } else { return "" } }(), Content: content, }, FinishReason: nil, }}, } chunkJSON, _ := json.Marshal(chunk) resultChan <- string(chunkJSON) firstChunk = false } } } if err := scanner.Err(); err != nil { return nil, err } return nil, nil } // 处理非流式响应 func handleNonStreamResponse(c *gin.Context, userMsg, endpointID, model string) { ctx := c.Request.Context() result, err := withValidKey(ctx, func(ctx context.Context, apikey string) (any, error) { return nonStreamQuery(ctx, apikey, userMsg, endpointID, model) }) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, result) } // 非流式查询 func nonStreamQuery(ctx context.Context, apikey, userMsg, endpointID, model string) (any, error) { sessionID, err := keyManager.GetSession(ctx, apikey) if err != nil { return nil, err } payload := QueryRequest{ Query: userMsg, EndpointID: endpointID, PluginIds: []string{}, ResponseMode: "sync", } jsonData, err := json.Marshal(payload) if err != nil { return nil, err } req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("%s/sessions/%s/query", ondemandAPIBase, sessionID), bytes.NewBuffer(jsonData)) if err != nil { return nil, err } req.Header.Set("apikey", apikey) req.Header.Set("Content-Type", "application/json") client := &http.Client{Timeout: 300 * time.Second} resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("non-stream query failed with status: %d", resp.StatusCode) } var queryResp QueryResponse if err := json.NewDecoder(resp.Body).Decode(&queryResp); err != nil { return nil, err } content := queryResp.Data.Answer response := ChatCompletionResponse{ ID: "chatcmpl-" + uuid.New().String()[:8], Object: "chat.completion", Created: time.Now().Unix(), Model: model, Choices: []Choice{{ Index: 0, Message: &Message{ Role: "assistant", Content: content, }, FinishReason: func() *string { s := "stop"; return &s }(), }}, Usage: Usage{}, } return response, nil } // 模型列表接口 func models(c *gin.Context) { var modelList []Model for modelID := range modelMap { modelList = append(modelList, Model{ ID: modelID, Object: "model", OwnedBy: "ondemand-proxy", }) } response := ModelsResponse{ Object: "list", Data: modelList, } c.JSON(http.StatusOK, response) } // 健康检查接口 func health(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ "status": "ok", "keys": len(ondemandAPIKeys), }) } func main() { // 设置日志格式 log.SetFlags(log.LstdFlags | log.Lshortfile) // 设置Gin模式 if os.Getenv("GIN_MODE") == "" { gin.SetMode(gin.ReleaseMode) } router := gin.New() // 中间件 router.Use(gin.Logger()) router.Use(gin.Recovery()) router.Use(checkPrivateKey()) // 路由 router.GET("/", health) router.POST("/v1/chat/completions", chatCompletions) router.GET("/v1/models", models) // 获取端口 port := DefaultPort if portStr := os.Getenv("PORT"); portStr != "" { if p, err := strconv.Atoi(portStr); err == nil { port = p } } log.Printf("======== OnDemand KEY池数量:%d ========", len(ondemandAPIKeys)) log.Printf("服务器启动在端口:%d", port) // 启动服务器 if err := router.Run(fmt.Sprintf(":%d", port)); err != nil { log.Fatal("启动服务器失败:", err) } }