Spaces:
Sleeping
Sleeping
import json | |
import os | |
import time | |
import uuid | |
import threading | |
from typing import Any, Dict, List, Optional, TypedDict, Union | |
import requests | |
from fastapi import FastAPI, HTTPException, Depends, Query | |
from fastapi.responses import StreamingResponse | |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
from pydantic import BaseModel, Field | |
# CodeGeeX Token Management | |
class CodeGeeXToken(TypedDict): | |
token: str | |
is_valid: bool | |
last_used: float | |
error_count: int | |
# Global variables | |
VALID_CLIENT_KEYS: set = set() | |
CODEGEEX_TOKENS: List[CodeGeeXToken] = [] | |
CODEGEEX_MODELS: List[str] = ["claude-3-7-sonnet", "claude-sonnet-4"] | |
token_rotation_lock = threading.Lock() | |
MAX_ERROR_COUNT = 3 | |
ERROR_COOLDOWN = 300 # 5 minutes cooldown for tokens with errors | |
DEBUG_MODE = os.environ.get("DEBUG_MODE", "false").lower() == "true" | |
# Pydantic Models | |
class ChatMessage(BaseModel): | |
role: str | |
content: Union[str, List[Dict[str, Any]]] | |
reasoning_content: Optional[str] = None | |
class ChatCompletionRequest(BaseModel): | |
model: str | |
messages: List[ChatMessage] | |
stream: bool = True | |
temperature: Optional[float] = None | |
max_tokens: Optional[int] = None | |
top_p: Optional[float] = None | |
class ModelInfo(BaseModel): | |
id: str | |
object: str = "model" | |
created: int | |
owned_by: str | |
class ModelList(BaseModel): | |
object: str = "list" | |
data: List[ModelInfo] | |
class ChatCompletionChoice(BaseModel): | |
message: ChatMessage | |
index: int = 0 | |
finish_reason: str = "stop" | |
class ChatCompletionResponse(BaseModel): | |
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex}") | |
object: str = "chat.completion" | |
created: int = Field(default_factory=lambda: int(time.time())) | |
model: str | |
choices: List[ChatCompletionChoice] | |
usage: Dict[str, int] = Field( | |
default_factory=lambda: { | |
"prompt_tokens": 0, | |
"completion_tokens": 0, | |
"total_tokens": 0, | |
} | |
) | |
class StreamChoice(BaseModel): | |
delta: Dict[str, Any] = Field(default_factory=dict) | |
index: int = 0 | |
finish_reason: Optional[str] = None | |
class StreamResponse(BaseModel): | |
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex}") | |
object: str = "chat.completion.chunk" | |
created: int = Field(default_factory=lambda: int(time.time())) | |
model: str | |
choices: List[StreamChoice] | |
# FastAPI App | |
app = FastAPI(title="CodeGeeX OpenAI API Adapter") | |
security = HTTPBearer(auto_error=False) | |
def log_debug(message: str): | |
"""Debug日志函数""" | |
if DEBUG_MODE: | |
print(f"[DEBUG] {message}") | |
def load_client_api_keys(): | |
"""Load client API keys from client_api_keys.json""" | |
global VALID_CLIENT_KEYS | |
try: | |
with open("client_api_keys.json", "r", encoding="utf-8") as f: | |
keys = json.load(f) | |
VALID_CLIENT_KEYS = set(keys) if isinstance(keys, list) else set() | |
print(f"Successfully loaded {len(VALID_CLIENT_KEYS)} client API keys.") | |
except FileNotFoundError: | |
print("Error: client_api_keys.json not found. Client authentication will fail.") | |
VALID_CLIENT_KEYS = set() | |
except Exception as e: | |
print(f"Error loading client_api_keys.json: {e}") | |
VALID_CLIENT_KEYS = set() | |
def load_codegeex_tokens(): | |
"""Load CodeGeeX tokens from codegeex.txt""" | |
global CODEGEEX_TOKENS | |
CODEGEEX_TOKENS = [] | |
try: | |
with open("codegeex.txt", "r", encoding="utf-8") as f: | |
for line in f: | |
token = line.strip() | |
if token: | |
CODEGEEX_TOKENS.append({ | |
"token": token, | |
"is_valid": True, | |
"last_used": 0, | |
"error_count": 0 | |
}) | |
print(f"Successfully loaded {len(CODEGEEX_TOKENS)} CodeGeeX tokens.") | |
except FileNotFoundError: | |
print("Error: codegeex.txt not found. API calls will fail.") | |
except Exception as e: | |
print(f"Error loading codegeex.txt: {e}") | |
def get_best_codegeex_token() -> Optional[CodeGeeXToken]: | |
"""Get the best available CodeGeeX token using a smart selection algorithm.""" | |
with token_rotation_lock: | |
now = time.time() | |
valid_tokens = [ | |
token for token in CODEGEEX_TOKENS | |
if token["is_valid"] and ( | |
token["error_count"] < MAX_ERROR_COUNT or | |
now - token["last_used"] > ERROR_COOLDOWN | |
) | |
] | |
if not valid_tokens: | |
return None | |
# Reset error count for tokens that have been in cooldown | |
for token in valid_tokens: | |
if token["error_count"] >= MAX_ERROR_COUNT and now - token["last_used"] > ERROR_COOLDOWN: | |
token["error_count"] = 0 | |
# Sort by last used (oldest first) and error count (lowest first) | |
valid_tokens.sort(key=lambda x: (x["last_used"], x["error_count"])) | |
token = valid_tokens[0] | |
token["last_used"] = now | |
return token | |
def _convert_messages_to_codegeex_format(messages: List[ChatMessage]): | |
"""Convert OpenAI messages format to CodeGeeX prompt and history format.""" | |
if not messages: | |
return "", [] | |
# Extract the last user message as prompt | |
last_user_msg = None | |
for msg in reversed(messages): | |
if msg.role == "user": | |
last_user_msg = msg | |
break | |
if not last_user_msg: | |
raise HTTPException(status_code=400, detail="No user message found in the conversation.") | |
prompt = last_user_msg.content if isinstance(last_user_msg.content, str) else "" | |
# Build history from previous messages (excluding the last user message) | |
history = [] | |
user_content = "" | |
assistant_content = "" | |
for i, msg in enumerate(messages[:-1] if messages[-1].role == "user" else messages): | |
if msg == last_user_msg: | |
continue | |
if msg.role == "user": | |
# If we have a complete pair, add it to history | |
if user_content and assistant_content: | |
history.append({ | |
"query": user_content, | |
"answer": assistant_content, | |
"id": f"{uuid.uuid4()}" | |
}) | |
user_content = "" | |
assistant_content = "" | |
# Start a new pair with this user message | |
content = msg.content if isinstance(msg.content, str) else "" | |
user_content = content | |
elif msg.role == "assistant": | |
content = msg.content if isinstance(msg.content, str) else "" | |
assistant_content = content | |
# If we have a complete pair, add it to history | |
if user_content: | |
history.append({ | |
"query": user_content, | |
"answer": assistant_content, | |
"id": f"{uuid.uuid4()}" | |
}) | |
user_content = "" | |
assistant_content = "" | |
# Handle any remaining unpaired messages | |
if user_content and not assistant_content: | |
# Unpaired user message - treat as part of the prompt | |
prompt = user_content + "\n" + prompt | |
return prompt, history | |
async def authenticate_client( | |
auth: Optional[HTTPAuthorizationCredentials] = Depends(security), | |
): | |
"""Authenticate client based on API key in Authorization header""" | |
if not VALID_CLIENT_KEYS: | |
raise HTTPException( | |
status_code=503, | |
detail="Service unavailable: Client API keys not configured on server.", | |
) | |
if not auth or not auth.credentials: | |
raise HTTPException( | |
status_code=401, | |
detail="API key required in Authorization header.", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
if auth.credentials not in VALID_CLIENT_KEYS: | |
raise HTTPException(status_code=403, detail="Invalid client API key.") | |
async def startup(): | |
"""应用启动时初始化配置""" | |
print("Starting CodeGeeX OpenAI API Adapter server...") | |
load_client_api_keys() | |
load_codegeex_tokens() | |
print("Server initialization completed.") | |
def get_models_list_response() -> ModelList: | |
"""Helper to construct ModelList response from cached models.""" | |
model_infos = [ | |
ModelInfo( | |
id=model, | |
created=int(time.time()), | |
owned_by="anthropic" | |
) | |
for model in CODEGEEX_MODELS | |
] | |
return ModelList(data=model_infos) | |
async def list_v1_models(_: None = Depends(authenticate_client)): | |
"""List available models - authenticated""" | |
return get_models_list_response() | |
async def list_models_no_auth(): | |
"""List available models without authentication - for client compatibility""" | |
return get_models_list_response() | |
async def toggle_debug(enable: bool = Query(None)): | |
"""切换调试模式""" | |
global DEBUG_MODE | |
if enable is not None: | |
DEBUG_MODE = enable | |
return {"debug_mode": DEBUG_MODE} | |
def _codegeex_stream_generator(response, model: str): | |
"""Real-time streaming with format conversion - CodeGeeX to OpenAI""" | |
stream_id = f"chatcmpl-{uuid.uuid4().hex}" | |
created_time = int(time.time()) | |
# 发送初始角色增量 | |
yield f"data: {StreamResponse(id=stream_id, created=created_time, model=model, choices=[StreamChoice(delta={'role': 'assistant'})]).json()}\n\n" | |
buffer = "" | |
try: | |
for chunk in response.iter_content(chunk_size=1024): | |
if not chunk: | |
continue | |
chunk_text = chunk.decode("utf-8") | |
log_debug(f"Received chunk: {chunk_text[:100]}..." if len(chunk_text) > 100 else chunk_text) | |
buffer += chunk_text | |
# 处理缓冲区中的完整事件块 | |
while "\n\n" in buffer: | |
event_data, buffer = buffer.split("\n\n", 1) | |
event_data = event_data.strip() | |
if not event_data: | |
continue | |
# 解析事件 | |
event_type = None | |
data_json = None | |
for line in event_data.split("\n"): | |
line = line.strip() | |
if line.startswith("event:"): | |
event_type = line[6:].strip() | |
elif line.startswith("data:"): | |
try: | |
data_json = json.loads(line[5:].strip()) | |
except json.JSONDecodeError: | |
log_debug(f"Failed to parse JSON: {line[5:].strip()}") | |
if not event_type or not data_json: | |
continue | |
if event_type == "add": | |
# 'text' 字段本身就是增量内容 | |
delta = data_json.get("text", "") | |
if delta: | |
openai_response = StreamResponse( | |
id=stream_id, | |
created=created_time, | |
model=model, | |
choices=[StreamChoice(delta={"content": delta})], | |
) | |
yield f"data: {openai_response.json()}\n\n" | |
elif event_type == "finish": | |
# 'finish' 事件标志着流的结束 | |
log_debug("Received finish event.") | |
openai_response = StreamResponse( | |
id=stream_id, | |
created=created_time, | |
model=model, | |
choices=[StreamChoice(delta={}, finish_reason="stop")], | |
) | |
yield f"data: {openai_response.json()}\n\n" | |
yield "data: [DONE]\n\n" | |
return # 终止生成器 | |
except Exception as e: | |
log_debug(f"Stream processing error: {e}") | |
yield f"data: {json.dumps({'error': str(e)})}\n\n" | |
# 如果流意外中断,也发送终止信号 | |
log_debug("Stream finished unexpectedly, sending completion signal.") | |
yield f"data: {StreamResponse(id=stream_id, created=created_time, model=model, choices=[StreamChoice(delta={}, finish_reason='stop')]).json()}\n\n" | |
yield "data: [DONE]\n\n" | |
def _build_codegeex_non_stream_response(response, model: str) -> ChatCompletionResponse: | |
"""Build non-streaming response by accumulating stream data.""" | |
full_content = "" | |
buffer = "" | |
for chunk in response.iter_content(chunk_size=1024): | |
if not chunk: | |
continue | |
buffer += chunk.decode("utf-8") | |
# 处理缓冲区中的完整事件块 | |
while "\n\n" in buffer: | |
event_data, buffer = buffer.split("\n\n", 1) | |
event_data = event_data.strip() | |
if not event_data: | |
continue | |
# 解析事件 | |
event_type = None | |
data_json = None | |
for line in event_data.split("\n"): | |
line = line.strip() | |
if line.startswith("event:"): | |
event_type = line[6:].strip() | |
elif line.startswith("data:"): | |
try: | |
data_json = json.loads(line[5:].strip()) | |
except json.JSONDecodeError: | |
continue | |
if not event_type or not data_json: | |
continue | |
if event_type == "add": | |
# 正确地累积增量文本 | |
full_content += data_json.get("text", "") | |
elif event_type == "finish": | |
# finish事件中的text是最终的完整文本,以此为准 | |
finish_text = data_json.get("text", "") | |
if finish_text: | |
full_content = finish_text | |
# 收到finish事件,可以提前结束解析 | |
return ChatCompletionResponse( | |
model=model, | |
choices=[ | |
ChatCompletionChoice( | |
message=ChatMessage( | |
role="assistant", | |
content=full_content | |
) | |
) | |
], | |
) | |
# 如果循环结束仍未返回(例如没有finish事件),则使用累积的内容 | |
return ChatCompletionResponse( | |
model=model, | |
choices=[ | |
ChatCompletionChoice( | |
message=ChatMessage( | |
role="assistant", | |
content=full_content | |
) | |
) | |
], | |
) | |
async def chat_completions( | |
request: ChatCompletionRequest, _: None = Depends(authenticate_client) | |
): | |
"""Create chat completion using CodeGeeX backend""" | |
if request.model not in CODEGEEX_MODELS: | |
raise HTTPException(status_code=404, detail=f"Model '{request.model}' not found.") | |
if not request.messages: | |
raise HTTPException(status_code=400, detail="No messages provided in the request.") | |
log_debug(f"Processing request for model: {request.model}") | |
# 转换消息格式 | |
try: | |
prompt, history = _convert_messages_to_codegeex_format(request.messages) | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=f"Failed to process messages: {str(e)}") | |
# 尝试所有令牌 | |
for attempt in range(len(CODEGEEX_TOKENS) + 1): # +1 to handle the case of no tokens | |
if attempt == len(CODEGEEX_TOKENS): | |
raise HTTPException( | |
status_code=503, | |
detail="All attempts to contact CodeGeeX API failed." | |
) | |
token = get_best_codegeex_token() | |
if not token: | |
raise HTTPException( | |
status_code=503, | |
detail="No valid CodeGeeX tokens available." | |
) | |
try: | |
# 构建请求 | |
payload = { | |
"user_role": 0, | |
"ide": "VSCode", | |
"ide_version": "", | |
"plugin_version": "", | |
"prompt": prompt, | |
"machineId": "", | |
"talkId": f"{uuid.uuid4()}", | |
"locale": "", | |
"model": request.model, | |
"agent": None, | |
"candidates": { | |
"candidate_msg_id": "", | |
"candidate_type": "", | |
"selected_candidate": "", | |
}, | |
"history": history, | |
} | |
headers = { | |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Code/1.100.3 Chrome/132.0.6834.210 Electron/34.5.1 Safari/537.36", | |
"Accept": "text/event-stream", | |
"Accept-Encoding": "gzip, deflate, br, zstd", | |
"Content-Type": "application/json", | |
"code-token": token["token"], | |
} | |
log_debug(f"Sending request to CodeGeeX API with token ending in ...{token['token'][-4:]}") | |
response = requests.post( | |
"https://codegeex.cn/prod/code/chatCodeSseV3/chat", | |
data=json.dumps(payload), | |
headers=headers, | |
stream=True, | |
timeout=300.0, | |
) | |
response.raise_for_status() | |
if request.stream: | |
log_debug("Returning stream response") | |
return StreamingResponse( | |
_codegeex_stream_generator(response, request.model), | |
media_type="text/event-stream", | |
headers={ | |
"Cache-Control": "no-cache", | |
"Connection": "keep-alive", | |
"X-Accel-Buffering": "no", | |
}, | |
) | |
else: | |
log_debug("Building non-stream response") | |
return _build_codegeex_non_stream_response(response, request.model) | |
except requests.HTTPError as e: | |
status_code = getattr(e.response, "status_code", 500) | |
error_detail = getattr(e.response, "text", str(e)) | |
log_debug(f"CodeGeeX API error ({status_code}): {error_detail}") | |
with token_rotation_lock: | |
if status_code in [401, 403]: | |
# 标记令牌为无效 | |
token["is_valid"] = False | |
print(f"Token ...{token['token'][-4:]} marked as invalid due to auth error.") | |
elif status_code in [429, 500, 502, 503, 504]: | |
# 增加错误计数 | |
token["error_count"] += 1 | |
print(f"Token ...{token['token'][-4:]} error count: {token['error_count']}") | |
except Exception as e: | |
log_debug(f"Request error: {e}") | |
with token_rotation_lock: | |
token["error_count"] += 1 | |
async def error_stream_generator(error_detail: str, status_code: int): | |
"""Generate error stream response""" | |
yield f'data: {json.dumps({"error": {"message": error_detail, "type": "codegeex_api_error", "code": status_code}})}\n\n' | |
yield "data: [DONE]\n\n" | |
if __name__ == "__main__": | |
import uvicorn | |
# 设置环境变量以启用调试模式 | |
if os.environ.get("DEBUG_MODE", "").lower() == "true": | |
DEBUG_MODE = True | |
print("Debug mode enabled via environment variable") | |
if not os.path.exists("codegeex.txt"): | |
print("Warning: codegeex.txt not found. Creating a dummy file.") | |
with open("codegeex.txt", "w", encoding="utf-8") as f: | |
f.write(f"your-codegeex-token-here\n") | |
print("Created dummy codegeex.txt. Please replace with valid CodeGeeX token.") | |
if not os.path.exists("client_api_keys.json"): | |
print("Warning: client_api_keys.json not found. Creating a dummy file.") | |
dummy_key = f"sk-dummy-{uuid.uuid4().hex}" | |
with open("client_api_keys.json", "w", encoding="utf-8") as f: | |
json.dump([dummy_key], f, indent=2) | |
print(f"Created dummy client_api_keys.json with key: {dummy_key}") | |
load_client_api_keys() | |
load_codegeex_tokens() | |
print("\n--- CodeGeeX OpenAI API Adapter ---") | |
print(f"Debug Mode: {DEBUG_MODE}") | |
print("Endpoints:") | |
print(" GET /v1/models (Client API Key Auth)") | |
print(" GET /models (No Auth)") | |
print(" POST /v1/chat/completions (Client API Key Auth)") | |
print(" GET /debug?enable=[true|false] (Toggle Debug Mode)") | |
print(f"\nClient API Keys: {len(VALID_CLIENT_KEYS)}") | |
if CODEGEEX_TOKENS: | |
print(f"CodeGeeX Tokens: {len(CODEGEEX_TOKENS)}") | |
else: | |
print("CodeGeeX Tokens: None loaded. Check codegeex.txt.") | |
print(f"Available models: {', '.join(CODEGEEX_MODELS)}") | |
print("------------------------------------") | |
uvicorn.run(app, host="0.0.0.0", port=8000) |