const express = require('express') const axios = require('axios') const WebSocket = require('ws') const router = express.Router() const { v4: uuidv4 } = require('uuid') const { uploadFileBuffer } = require('../lib/upload') const verify = require('./verify') const modelMap = require('../lib/model-map') async function parseMessages(req, res, next) { const messages = req.body.messages if (!Array.isArray(messages)) { req.processedMessages = [] return next() } try { const transformedMessages = await Promise.all(messages.map(async (msg) => { // Determine provider to conditionally convert system role for Anthropic const modelName = req.body.model; const modelData = modelMap[modelName]; // modelMap is required at the top const provider = modelData?.provider; // Use optional chaining for safety let message = { role: (msg.role === "system" && provider === "anthropic") ? "user" : msg.role, tool_calls: [], template_format: "jinja2" } if (Array.isArray(msg.content)) { const contentItems = await Promise.all(msg.content.map(async (item) => { if (item.type === "text") { return { type: "text", text: item.text } } else if (item.type === "image_url") { try { const base64Match = item.image_url.url.match(/^data:image\/\w+;base64,(.+)$/) if (base64Match) { const base64 = base64Match[1] const data = Buffer.from(base64, 'base64') const uploadResult = await uploadFileBuffer(data) return { type: "media", media: { "type": "image", "url": uploadResult.file_url, "title": `image_${Date.now()}.png` } } } else { return { type: "media", media: { "type": "image", "url": item.image_url.url, "title": "external_image" } } } } catch (error) { console.error("处理图像时出错:", error) return { type: "text", text: "[图像处理失败]" } } } else { return { type: "text", text: JSON.stringify(item) } } })) message.content = contentItems } else { message.content = [ { type: "text", text: msg.content || "" } ] } return message })) req.body.messages = transformedMessages return next() } catch (error) { console.error("处理消息时出错:", error.status) req.body.messages = [] return next(error) } } async function getChatID(req, res) { try { const url = 'https://api.promptlayer.com/api/dashboard/v2/workspaces/' + req.account.workspaceId + '/playground_sessions' const headers = { Authorization: "Bearer " + req.account.token } const model_data = modelMap[req.body.model] ? modelMap[req.body.model] : modelMap["claude-3-7-sonnet-20250219"] let data = { "id": uuidv4(), "name": "Not implemented", "prompt_blueprint": { "inference_client_name": null, "metadata": { "model": model_data }, "prompt_template": { "type": "chat", "messages": req.body.messages, "tools": req.body.tools || [], "tool_choice": req.body.tool_choice || "none", "input_variables": [], "functions": [], "function_call": null }, "provider_base_url_name": null }, "input_variables": [] } for (const item in req.body) { if (item === "messages" || item === "model" || item === "stream") { continue } else if (model_data.parameters[item]) { model_data.parameters[item] = req.body[item] } } data.prompt_blueprint.metadata.model = model_data console.log(`模型参数: ${data.prompt_blueprint.metadata.model}`) const response = await axios.put(url, data, { headers }) if (response.data.success) { console.log(`生成会话ID成功: ${response.data.playground_session.id}`) req.chatID = response.data.playground_session.id return response.data.playground_session.id } else { return false } } catch (error) { // console.error("错误:", error.response?.data) res.status(500).json({ "error": { "message": error.message || "服务器内部错误", "type": "server_error", "param": null, "code": "server_error" } }) return false } } async function sentRequest(req, res) { try { const url = 'https://api.promptlayer.com/api/dashboard/v2/workspaces/' + req.account.workspaceId + '/run_groups' const headers = { Authorization: "Bearer " + req.account.token } const model_data = modelMap[req.body.model] ? modelMap[req.body.model] : modelMap["claude-3-7-sonnet-20250219"]; const provider = model_data?.provider; // Get provider // Base prompt template structure let prompt_template = { "type": "chat", "messages": req.body.messages, "tools": req.body.tools || [], // Default value "tool_choice": req.body.tool_choice || "none", // Default value "input_variables": [], "functions": [], "function_call": null }; // Conditionally modify for Mistral/Cohere if (provider === 'mistral' || provider === 'cohere') { prompt_template.tools = null; delete prompt_template.tool_choice; // Remove tool_choice entirely delete prompt_template.function_call; } let data = { "id": uuidv4(), "playground_session_id": req.chatID, "shared_prompt_blueprint": { "inference_client_name": null, "metadata": { "model": model_data // Keep original model_data here for metadata }, "prompt_template": prompt_template, // Use the adjusted template "provider_base_url_name": null }, "individual_run_requests": [ { "input_variables": {}, "run_group_position": 1 } ] }; console.log(JSON.stringify(data)) // Update model parameters (this loop remains the same) for (const item in req.body) { if (item === "messages" || item === "model" || item === "stream") { continue } else if (model_data.parameters && model_data.parameters.hasOwnProperty(item)) { // Check if parameters exist and has the property model_data.parameters[item] = req.body[item] } } // Ensure the potentially modified model_data (with updated parameters) is in metadata data.shared_prompt_blueprint.metadata.model = model_data; const response = await axios.post(url, data, { headers }); if (response.data.success) { return response.data.run_group.individual_run_requests[0].id } else { return false } } catch (error) { // console.error("错误:", error.response?.data) res.status(500).json({ "error": { "message": error.message || "服务器内部错误", "type": "server_error", "param": null, "code": "server_error" } }) } } // 聊天完成路由 router.post('/v1/chat/completions', verify, parseMessages, async (req, res) => { // console.log(JSON.stringify(req.body)) try { const setHeader = () => { try { if (req.body.stream === true) { res.setHeader('Content-Type', 'text/event-stream') res.setHeader('Cache-Control', 'no-cache') res.setHeader('Connection', 'keep-alive') } else { res.setHeader('Content-Type', 'application/json') } } catch (error) { // console.error("设置响应头时出错:", error) } } const { access_token, clientId } = req.account // 生成会话ID await getChatID(req, res) // 发送的数据 const sendAction = `{"action":10,"channel":"user:${clientId}","params":{"agent":"react-hooks/2.0.2"}}` // 构建 WebSocket URL const wsUrl = `wss://realtime.ably.io/?access_token=${encodeURIComponent(access_token)}&clientId=${clientId}&format=json&heartbeats=true&v=3&agent=ably-js%2F2.0.2%20browser` // 创建 WebSocket 连接 const ws = new WebSocket(wsUrl) // 状态详细 let ThinkingLastContent = "" let TextLastContent = "" let ThinkingStart = false let ThinkingEnd = false let RequestID = "" let MessageID = "chatcmpl-" + uuidv4() let streamChunk = { "id": MessageID, "object": "chat.completion.chunk", "system_fingerprint": "fp_44709d6fcb", "created": Math.floor(Date.now() / 1000), "model": req.body.model, "choices": [ { "index": 0, "delta": { "content": null }, "finish_reason": null } ] } let pingInterval; ws.on('open', async () => { ws.send(sendAction) RequestID = await sentRequest(req, res) setHeader() // Start sending pings every 30 seconds to keep the connection alive pingInterval = setInterval(() => { if (ws.readyState === WebSocket.OPEN) { ws.ping(() => {}); // Empty callback, just to send the ping } }, 15000); }) ws.on('message', async (data) => { try { data = data.toString() console.log("here!!!") console.log(data) let ContentText = JSON.parse(data)?.messages?.[0] let ContentData = JSON.parse(ContentText?.data) const isRequestID = ContentData?.individual_run_request_id if (isRequestID != RequestID || !isRequestID) return let output = "" if (ContentText?.name === "UPDATE_LAST_MESSAGE") { const MessageArray = ContentData?.payload?.message?.content for (const item of MessageArray) { if (item.type === "text") { output = item.text.replace(TextLastContent, "") if (ThinkingStart && !ThinkingEnd) { ThinkingEnd = true output = `${output}\n\n` } TextLastContent = item.text } else if (item.type === "thinking" && MessageArray.length === 1) { output = item.thinking.replace(ThinkingLastContent, "") if (!ThinkingStart) { ThinkingStart = true output = `\n\n${output}` } ThinkingLastContent = item.thinking } } if (req.body.stream === true) { streamChunk.choices[0].delta.content = output res.write(`data: ${JSON.stringify(streamChunk)}\n\n`) } } else if (ContentText?.name === "INDIVIDUAL_RUN_COMPLETE") { if (req.body.stream !== true) { output = ThinkingLastContent ? `\n\n${ThinkingLastContent}\n\n\n\n${TextLastContent}` : TextLastContent } if (ThinkingLastContent === "" && TextLastContent === "") { const modelName = req.body.model; const modelData = modelMap[modelName] || modelMap["claude-3-7-sonnet-20250219"]; // Fallback to default if model not found const provider = modelData.provider || "anthropic"; // Default to anthropic if provider not found const providerUpperCase = provider.charAt(0).toUpperCase() + provider.slice(1); output = `${provider}.BadRequestError: Error code: 400 - {'type': 'error', 'error': {'type': 'invalid_request_error', 'message': 'Your credit balance is too low to access the ${providerUpperCase} API. Please go to Plans & Billing to upgrade or purchase credits.'}}` streamChunk.choices[0].delta.content = output res.write(`data: ${JSON.stringify(streamChunk)}\n\n`) } if (!req.body.stream || req.body.stream !== true) { let responseJson = { "id": MessageID, "object": "chat.completion", "created": Math.floor(Date.now() / 1000), "system_fingerprint": "fp_44709d6fcb", "model": req.body.model, "choices": [ { "index": 0, "message": { "role": "assistant", "content": output }, "finish_reason": "stop" } ], "usage": { "prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0 } } res.json(responseJson) ws.close() return } else { // 流式响应:发送结束标记 let finalChunk = { "id": MessageID, "object": "chat.completion.chunk", "system_fingerprint": "fp_44709d6fcb", "created": Math.floor(Date.now() / 1000), "model": req.body.model, "choices": [ { "index": 0, "delta": {}, "finish_reason": "stop" } ] } res.write(`data: ${JSON.stringify(finalChunk)}\n\n`) res.write(`data: [DONE]\n\n`) res.end() } ws.close() } } catch (err) { // console.error("处理WebSocket消息出错:", err) } }) ws.on('error', (err) => { clearInterval(pingInterval); // Stop sending pings // 标准OpenAI错误响应格式 res.status(500).json({ "error": { "message": err.message, "type": "server_error", "param": null, "code": "server_error" } }) }) const oSeriesModels = ["o4-mini", "o4-mini-high", "o3-mini", "o3-mini-high", "o1", "o3", "o3-2025-04-16", "o4-mini-2025-04-16"]; let timeoutDuration = 300 * 1000; // 默认5分钟 if (oSeriesModels.includes(req.body.model)) { timeoutDuration = 1200 * 1000; // o系列模型20分钟 } const requestTimeout = setTimeout(() => { clearInterval(pingInterval); // Stop sending pings if (ws.readyState === WebSocket.OPEN) { ws.close() if (!res.headersSent) { // 标准OpenAI超时错误响应格式 res.status(504).json({ "error": { "message": "请求超时", "type": "timeout", "param": null, "code": "timeout_error" } }) } } }, timeoutDuration) ws.on('close', () => { clearInterval(pingInterval); // Stop sending pings clearTimeout(requestTimeout); // Clear the main request timeout as well if ws closes first }); } catch (error) { console.error("错误:", error) // 标准OpenAI通用错误响应格式 res.status(500).json({ "error": { "message": error.message || "服务器内部错误", "type": "server_error", "param": null, "code": "server_error" } }) } }) module.exports = router