Spaces:
Sleeping
Sleeping
import os | |
import uuid | |
import json | |
import time | |
import asyncio | |
import random | |
import threading | |
from curl_cffi.requests import AsyncSession | |
from fastapi import FastAPI, Request, HTTPException, Depends, status | |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
from fastapi.responses import StreamingResponse | |
from dotenv import load_dotenv | |
import secrets | |
from pydantic import BaseModel, Field | |
from typing import List, Optional, Dict, Any, Literal, Union | |
from contextlib import asynccontextmanager | |
# Load environment variables from .env file | |
load_dotenv() | |
# --- 并发请求配置 --- | |
CONCURRENT_REQUESTS = 1 # 可自定义并发请求数量 | |
# --- 重试配置 --- | |
MAX_RETRIES = 3 | |
RETRY_DELAY = 1 # 秒 | |
# --- Models (Integrated from models.py) --- | |
# Input Models (OpenAI-like) | |
class ChatMessage(BaseModel): | |
role: Literal["system", "user", "assistant"] | |
content: str | |
class ChatCompletionRequest(BaseModel): | |
messages: List[ChatMessage] | |
model: str = "notion-proxy" | |
stream: bool = False | |
notion_model: str = "anthropic-opus-4" | |
# Notion Models | |
class NotionTranscriptConfigValue(BaseModel): | |
type: str = "markdown-chat" | |
model: str # e.g., "anthropic-opus-4" | |
class NotionTranscriptItem(BaseModel): | |
type: Literal["config", "user", "markdown-chat"] | |
value: Union[List[List[str]], str, NotionTranscriptConfigValue] | |
class NotionDebugOverrides(BaseModel): | |
cachedInferences: Dict = Field(default_factory=dict) | |
annotationInferences: Dict = Field(default_factory=dict) | |
emitInferences: bool = False | |
class NotionRequestBody(BaseModel): | |
traceId: str = Field(default_factory=lambda: str(uuid.uuid4())) | |
spaceId: str | |
transcript: List[NotionTranscriptItem] | |
# threadId is removed, createThread will be set to true | |
createThread: bool = True | |
debugOverrides: NotionDebugOverrides = Field(default_factory=NotionDebugOverrides) | |
generateTitle: bool = False | |
saveAllThreadOperations: bool = True | |
# Output Models (OpenAI SSE) | |
class ChoiceDelta(BaseModel): | |
content: Optional[str] = None | |
class Choice(BaseModel): | |
index: int = 0 | |
delta: ChoiceDelta | |
finish_reason: Optional[Literal["stop", "length"]] = None | |
class ChatCompletionChunk(BaseModel): | |
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4()}") | |
object: str = "chat.completion.chunk" | |
created: int = Field(default_factory=lambda: int(time.time())) | |
model: str = "notion-proxy" # Or could reflect the underlying Notion model | |
choices: List[Choice] | |
# Models for /v1/models Endpoint | |
class Model(BaseModel): | |
id: str | |
object: str = "model" | |
created: int = Field(default_factory=lambda: int(time.time())) | |
owned_by: str = "notion" # Or specify based on actual model origin if needed | |
class ModelList(BaseModel): | |
object: str = "list" | |
data: List[Model] | |
# --- Configuration --- | |
NOTION_API_URL = "https://www.notion.so/api/v3/runInferenceTranscript" | |
# IMPORTANT: Load the Notion cookie securely from environment variables | |
NOTION_COOKIE = os.getenv("NOTION_COOKIE") | |
NOTION_SPACE_ID = os.getenv("NOTION_SPACE_ID") | |
if not NOTION_COOKIE: | |
print("Error: NOTION_COOKIE environment variable not set.") | |
# Consider raising HTTPException or exiting in a real app | |
if not NOTION_SPACE_ID: | |
print("Warning: NOTION_SPACE_ID environment variable not set. Using a default UUID.") | |
# Using a default might not be ideal, depends on Notion's behavior | |
# Consider raising an error instead: raise ValueError("NOTION_SPACE_ID not set") | |
NOTION_SPACE_ID = str(uuid.uuid4()) # Default or raise error | |
# --- Cookie Management --- | |
browser_cookies = "" | |
cookie_lock = threading.Lock() | |
last_cookie_update = 0 | |
COOKIE_UPDATE_INTERVAL = 30 * 60 # 30 minutes in seconds | |
async def get_browser_cookies(): | |
"""获取Notion网站的浏览器cookie""" | |
global browser_cookies, last_cookie_update | |
try: | |
print("正在获取Notion浏览器cookie...") | |
async with AsyncSession(impersonate="chrome136") as session: | |
response = await session.get("https://www.notion.so") | |
if response.status_code == 200: | |
# 获取所有cookie | |
cookies = response.cookies | |
notion_so_cookies = [] | |
# 处理CookieConflict问题,只获取.notion.so域名的cookie | |
try: | |
# 尝试通过域名过滤来避免冲突 | |
if hasattr(cookies, 'get_dict'): | |
# 使用get_dict方法并指定域名 | |
notion_so_dict = cookies.get_dict(domain='.notion.so') | |
for name, value in notion_so_dict.items(): | |
notion_so_cookies.append(f"{name}={value}") | |
elif hasattr(cookies, 'jar'): | |
# 如果cookies有jar属性,遍历并过滤域名 | |
for cookie in cookies.jar: | |
if hasattr(cookie, 'domain') and cookie.domain: | |
if '.notion.so' in cookie.domain and '.notion.com' not in cookie.domain: | |
notion_so_cookies.append(f"{cookie.name}={cookie.value}") | |
else: | |
# 尝试手动构建cookie字符串,避免冲突 | |
# 直接从响应头中提取Set-Cookie信息 | |
set_cookie_headers = response.headers.get_list('Set-Cookie') if hasattr(response.headers, 'get_list') else [] | |
if not set_cookie_headers and 'Set-Cookie' in response.headers: | |
set_cookie_headers = [response.headers['Set-Cookie']] | |
for cookie_header in set_cookie_headers: | |
if 'domain=.notion.so' in cookie_header or ('notion.so' in cookie_header and 'notion.com' not in cookie_header): | |
# 提取cookie名称和值 | |
cookie_parts = cookie_header.split(';')[0].strip() | |
if '=' in cookie_parts: | |
notion_so_cookies.append(cookie_parts) | |
# 如果还是没有获取到,尝试使用requests-like的方式 | |
if not notion_so_cookies and hasattr(response, 'cookies'): | |
try: | |
# 遍历所有cookie,手动过滤 | |
for cookie in response.cookies: | |
if hasattr(cookie, 'domain') and cookie.domain and '.notion.so' in cookie.domain: | |
notion_so_cookies.append(f"{cookie.name}={cookie.value}") | |
except Exception as inner_e: | |
print(f"内部cookie处理错误: {inner_e}") | |
except Exception as cookie_error: | |
print(f"处理cookie时出现错误: {cookie_error}") | |
# 如果所有方法都失败,尝试从session获取 | |
if hasattr(session, 'cookies'): | |
try: | |
for name, value in session.cookies.items(): | |
notion_so_cookies.append(f"{name}={value}") | |
except: | |
pass | |
# 添加环境变量中的cookie,加上token_v2前缀 | |
if NOTION_COOKIE: | |
notion_so_cookies.append(f"token_v2={NOTION_COOKIE}") | |
# 如果没有获取到任何cookie,至少使用环境变量的 | |
if not notion_so_cookies and NOTION_COOKIE: | |
notion_so_cookies = [f"token_v2={NOTION_COOKIE}"] | |
with cookie_lock: | |
browser_cookies = "; ".join(notion_so_cookies) | |
last_cookie_update = time.time() | |
# 提取cookie名称用于日志显示 | |
cookie_names = [] | |
for cookie_str in notion_so_cookies: | |
if '=' in cookie_str: | |
name = cookie_str.split('=')[0] | |
cookie_names.append(name) | |
print(f"成功获取到 {len(notion_so_cookies)} 个cookie") | |
print(f"Cookie名称列表: {', '.join(cookie_names)}") | |
return True | |
else: | |
print(f"获取cookie失败,HTTP状态码: {response.status_code}") | |
return False | |
except Exception as e: | |
print(f"获取browser cookie时出错: {e}") | |
print(f"错误详情: {type(e).__name__}: {str(e)}") | |
# 如果完全失败,至少使用环境变量的cookie | |
if NOTION_COOKIE: | |
with cookie_lock: | |
browser_cookies = f"token_v2={NOTION_COOKIE}" | |
last_cookie_update = time.time() | |
print("使用环境变量cookie作为备用") | |
return True | |
return False | |
def should_update_cookies(): | |
"""检查是否需要更新cookie""" | |
return time.time() - last_cookie_update > COOKIE_UPDATE_INTERVAL | |
async def ensure_cookies_available(): | |
"""确保cookie可用,如果需要则更新""" | |
global browser_cookies | |
if not browser_cookies or should_update_cookies(): | |
success = await get_browser_cookies() | |
if not success and not browser_cookies: | |
# 如果获取失败且没有备用cookie,使用环境变量的cookie | |
if NOTION_COOKIE: | |
with cookie_lock: | |
browser_cookies = f"token_v2={NOTION_COOKIE}" | |
print("使用环境变量cookie作为备用") | |
else: | |
raise HTTPException(status_code=500, detail="无法获取Notion cookie") | |
def start_cookie_updater(): | |
"""启动cookie定时更新器""" | |
def cookie_updater(): | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
while True: | |
try: | |
if should_update_cookies(): | |
print("开始定时更新cookie...") | |
loop.run_until_complete(get_browser_cookies()) | |
time.sleep(60) # 每分钟检查一次 | |
except Exception as e: | |
print(f"定时更新cookie时出错: {e}") | |
time.sleep(60) | |
thread = threading.Thread(target=cookie_updater, daemon=True) | |
thread.start() | |
print("cookie定时更新器已启动") | |
# --- Authentication --- | |
EXPECTED_TOKEN = os.getenv("PROXY_AUTH_TOKEN", "default_token") # Default token | |
security = HTTPBearer() | |
def authenticate(credentials: HTTPAuthorizationCredentials = Depends(security)): | |
"""Compares provided token with the expected token.""" | |
correct_token = secrets.compare_digest(credentials.credentials, EXPECTED_TOKEN) | |
if not correct_token: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Invalid authentication credentials", | |
# WWW-Authenticate header removed for Bearer | |
) | |
return True # Indicate successful authentication | |
# --- Lifespan Event Handler --- | |
async def lifespan(app: FastAPI): | |
"""应用生命周期管理""" | |
# 启动时的初始化 | |
print("正在初始化Notion浏览器cookie...") | |
await get_browser_cookies() | |
# 启动cookie定时更新器 | |
start_cookie_updater() | |
yield | |
# 关闭时的清理(如果需要) | |
# --- FastAPI App --- | |
app = FastAPI(lifespan=lifespan) | |
# --- Helper Functions --- | |
def build_notion_request(request_data: ChatCompletionRequest) -> NotionRequestBody: | |
"""Transforms OpenAI-style messages to Notion transcript format.""" | |
transcript = [ | |
NotionTranscriptItem( | |
type="config", | |
value=NotionTranscriptConfigValue(model=request_data.notion_model) | |
) | |
] | |
for message in request_data.messages: | |
# Map 'assistant' role to 'markdown-chat', all others to 'user' | |
if message.role == "assistant": | |
# Notion uses "markdown-chat" for assistant replies in the transcript history | |
transcript.append(NotionTranscriptItem(type="markdown-chat", value=message.content)) | |
else: | |
# Map user, system, and any other potential roles to 'user' | |
transcript.append(NotionTranscriptItem(type="user", value=[[message.content]])) | |
# Use globally configured spaceId, set createThread=True | |
return NotionRequestBody( | |
spaceId=NOTION_SPACE_ID, # From environment variable | |
transcript=transcript, | |
createThread=True, # Always create a new thread | |
# Generate a new traceId for each request | |
traceId=str(uuid.uuid4()), | |
# Explicitly set debugOverrides, generateTitle, and saveAllThreadOperations | |
debugOverrides=NotionDebugOverrides( | |
cachedInferences={}, | |
annotationInferences={}, | |
emitInferences=False | |
), | |
generateTitle=False, | |
saveAllThreadOperations=False | |
) | |
async def check_first_response_line(session: AsyncSession, notion_request_body: NotionRequestBody, headers: dict, request_id: int): | |
"""检查响应的第一行,判断是否为500错误""" | |
try: | |
# 当并发请求数大于1时,添加随机延迟以避免同时到达 | |
if CONCURRENT_REQUESTS > 1: | |
delay = random.uniform(0, 1.0) | |
print(f"并发请求 {request_id} 延迟 {delay:.2f}秒") | |
await asyncio.sleep(delay) | |
# 为每个并发请求创建独立的请求体,生成新的traceId | |
request_body_copy = notion_request_body.model_copy() | |
request_body_copy.traceId = str(uuid.uuid4()) | |
response = await session.post( | |
NOTION_API_URL, | |
json=request_body_copy.model_dump(), | |
headers=headers, | |
stream=True | |
) | |
if response.status_code != 200: | |
return None, response, f"HTTP {response.status_code}" | |
# 读取第一行来检查是否是错误 | |
buffer = "" | |
async for chunk in response.aiter_content(): | |
if isinstance(chunk, bytes): | |
chunk = chunk.decode('utf-8') | |
buffer += chunk | |
# 尝试解析第一个完整的JSON行 | |
lines = buffer.split('\n') | |
for line in lines: | |
line = line.strip() | |
if line: | |
try: | |
data = json.loads(line) | |
if (data.get("type") == "error" and | |
data.get("message") and | |
"error code 500" in data.get("message", "")): | |
print(f"并发请求 {request_id} 检测到500错误: {data}") | |
return None, response, "500 error" | |
else: | |
# 正常响应,返回response和已读取的buffer | |
print(f"并发请求 {request_id} 响应正常") | |
return (response, buffer), None, None | |
except json.JSONDecodeError: | |
continue | |
return None, response, "No valid response" | |
except Exception as e: | |
print(f"并发请求 {request_id} 发生异常: {e}") | |
return None, None, str(e) | |
async def stream_notion_response_single(session: AsyncSession, response, initial_buffer: str, chunk_id: str, created_time: int): | |
"""处理单个响应的流式输出""" | |
buffer = initial_buffer | |
# 首先处理已经读取的buffer中的内容 | |
lines = buffer.split('\n') | |
buffer = lines[-1] | |
for line in lines[:-1]: | |
line = line.strip() | |
if not line: | |
continue | |
try: | |
data = json.loads(line) | |
if data.get("type") == "markdown-chat" and isinstance(data.get("value"), str): | |
content_chunk = data["value"] | |
if content_chunk: | |
chunk_obj = ChatCompletionChunk( | |
id=chunk_id, | |
created=created_time, | |
choices=[Choice(delta=ChoiceDelta(content=content_chunk))] | |
) | |
yield f"data: {chunk_obj.model_dump_json()}\n\n" | |
elif "recordMap" in data: | |
print("Detected recordMap, stopping stream.") | |
# 继续处理剩余的buffer | |
if buffer.strip(): | |
try: | |
last_data = json.loads(buffer.strip()) | |
if last_data.get("type") == "markdown-chat" and isinstance(last_data.get("value"), str): | |
if last_data["value"]: | |
last_chunk = ChatCompletionChunk( | |
id=chunk_id, | |
created=created_time, | |
choices=[Choice(delta=ChoiceDelta(content=last_data["value"]))] | |
) | |
yield f"data: {last_chunk.model_dump_json()}\n\n" | |
except: | |
pass | |
return | |
except json.JSONDecodeError as e: | |
print(f"Warning: Could not decode JSON line: {line[:100]}... Error: {str(e)}") | |
except Exception as e: | |
print(f"Error processing line: {str(e)}") | |
# 继续读取剩余的响应 | |
async for chunk in response.aiter_content(): | |
if isinstance(chunk, bytes): | |
chunk = chunk.decode('utf-8') | |
buffer += chunk | |
lines = buffer.split('\n') | |
buffer = lines[-1] | |
for line in lines[:-1]: | |
line = line.strip() | |
if not line: | |
continue | |
try: | |
data = json.loads(line) | |
if data.get("type") == "markdown-chat" and isinstance(data.get("value"), str): | |
content_chunk = data["value"] | |
if content_chunk: | |
chunk_obj = ChatCompletionChunk( | |
id=chunk_id, | |
created=created_time, | |
choices=[Choice(delta=ChoiceDelta(content=content_chunk))] | |
) | |
yield f"data: {chunk_obj.model_dump_json()}\n\n" | |
elif "recordMap" in data: | |
print("Detected recordMap, stopping stream.") | |
if buffer.strip(): | |
try: | |
last_data = json.loads(buffer.strip()) | |
if last_data.get("type") == "markdown-chat" and isinstance(last_data.get("value"), str): | |
if last_data["value"]: | |
last_chunk = ChatCompletionChunk( | |
id=chunk_id, | |
created=created_time, | |
choices=[Choice(delta=ChoiceDelta(content=last_data["value"]))] | |
) | |
yield f"data: {last_chunk.model_dump_json()}\n\n" | |
except: | |
pass | |
return | |
except json.JSONDecodeError as e: | |
print(f"Warning: Could not decode JSON line: {line[:100]}... Error: {str(e)}") | |
except Exception as e: | |
print(f"Error processing line: {str(e)}") | |
async def stream_notion_response(notion_request_body: NotionRequestBody): | |
"""Streams the request to Notion and yields OpenAI-compatible SSE chunks.""" | |
# 确保cookie可用 | |
await ensure_cookies_available() | |
with cookie_lock: | |
current_cookies = browser_cookies | |
headers = { | |
'accept': 'application/x-ndjson', | |
'accept-encoding': 'gzip, deflate, br, zstd', | |
'accept-language': 'en-US,zh;q=0.9', | |
'content-type': 'application/json', | |
'dnt': '1', | |
'notion-audit-log-platform': 'web', | |
'notion-client-version': '23.13.0.3661', | |
'origin': 'https://www.notion.so', | |
'referer': 'https://www.notion.so/', | |
'priority': 'u=1, i', | |
'sec-ch-ua-mobile': '?0', | |
'sec-ch-ua-platform': '"Windows"', | |
'sec-fetch-dest': 'empty', | |
'sec-fetch-mode': 'cors', | |
'sec-fetch-site': 'same-origin', | |
'cookie': current_cookies, | |
'x-notion-space-id': NOTION_SPACE_ID | |
} | |
# Conditionally add the active user header | |
notion_active_user = os.getenv("NOTION_ACTIVE_USER_HEADER") | |
if notion_active_user: # Checks for None and empty string implicitly | |
headers['x-notion-active-user-header'] = notion_active_user | |
chunk_id = f"chatcmpl-{uuid.uuid4()}" | |
created_time = int(time.time()) | |
# 使用全局重试配置 | |
max_retries = MAX_RETRIES | |
retry_delay = RETRY_DELAY | |
# 首先尝试并发请求 | |
print(f"同时发起 {CONCURRENT_REQUESTS} 个并发请求...") | |
async with AsyncSession(impersonate="chrome136") as session: | |
# 同时创建并发任务(每个都是独立的异步任务) | |
tasks = [] | |
for i in range(CONCURRENT_REQUESTS): | |
task = asyncio.create_task( | |
check_first_response_line(session, notion_request_body, headers, i + 1) | |
) | |
tasks.append(task) | |
# 等待所有任务完成或找到第一个成功的响应 | |
successful_response = None | |
failed_count = 0 | |
completed_tasks = set() | |
while len(completed_tasks) < CONCURRENT_REQUESTS and not successful_response: | |
# 等待任意一个任务完成 | |
done, pending = await asyncio.wait( | |
[t for t in tasks if t not in completed_tasks], | |
return_when=asyncio.FIRST_COMPLETED | |
) | |
for task in done: | |
completed_tasks.add(task) | |
result, response, error = await task | |
if result: | |
# 找到成功的响应,立即使用 | |
successful_response = result | |
print(f"找到成功的并发响应,立即使用") | |
# 取消其他还在运行的任务 | |
for t in tasks: | |
if t not in completed_tasks: | |
t.cancel() | |
break | |
else: | |
# 记录失败 | |
failed_count += 1 | |
if error: | |
print(f"并发请求失败: {error}") | |
# 如果有成功的响应,使用它进行流式传输 | |
if successful_response: | |
response, initial_buffer = successful_response | |
print("使用成功的并发响应进行流式传输") | |
# 流式输出响应 | |
async for data in stream_notion_response_single(session, response, initial_buffer, chunk_id, created_time): | |
yield data | |
# Send the final chunk indicating stop | |
final_chunk = ChatCompletionChunk( | |
id=chunk_id, | |
created=created_time, | |
choices=[Choice(delta=ChoiceDelta(), finish_reason="stop")] | |
) | |
yield f"data: {final_chunk.model_dump_json()}\n\n" | |
yield "data: [DONE]\n\n" | |
return | |
# 只有当所有并发请求都失败时,才进入重试流程 | |
print(f"所有 {CONCURRENT_REQUESTS} 个并发请求都失败,开始单请求重试流程...") | |
# 进入原有的重试逻辑(不使用并发) | |
for attempt in range(max_retries): | |
try: | |
# Using curl_cffi with chrome136 impersonation for better anti-bot bypass | |
async with AsyncSession(impersonate="chrome136") as session: | |
# Stream the response | |
response = await session.post( | |
NOTION_API_URL, | |
json=notion_request_body.model_dump(), | |
headers=headers, | |
stream=True | |
) | |
if response.status_code != 200: | |
error_content = await response.atext() | |
print(f"Error from Notion API: {response.status_code}") | |
print(f"Response: {error_content}") | |
raise HTTPException(status_code=response.status_code, detail=f"Notion API Error: {error_content}") | |
# Process streaming response | |
# curl_cffi streaming works differently - we need to read the content in chunks | |
buffer = "" | |
first_line_checked = False | |
is_error_response = False | |
async for chunk in response.aiter_content(): | |
# Decode chunk if it's bytes | |
if isinstance(chunk, bytes): | |
chunk = chunk.decode('utf-8') | |
buffer += chunk | |
# Split by newlines and process complete lines | |
lines = buffer.split('\n') | |
# Keep the last incomplete line in the buffer | |
buffer = lines[-1] | |
for line in lines[:-1]: | |
line = line.strip() | |
if not line: | |
continue | |
try: | |
data = json.loads(line) | |
# 检查第一行是否是500错误响应 | |
if not first_line_checked: | |
first_line_checked = True | |
if (data.get("type") == "error" and | |
data.get("message") and | |
"error code 500" in data.get("message", "")): | |
print(f"检测到Notion API 500错误 (重试 {attempt + 1}/{max_retries}): {data}") | |
is_error_response = True | |
break | |
# 如果不是错误响应,实时流式转发 | |
# Check if it's the type of message containing text chunks | |
if data.get("type") == "markdown-chat" and isinstance(data.get("value"), str): | |
content_chunk = data["value"] | |
if content_chunk: # Only send if there's content | |
chunk_obj = ChatCompletionChunk( | |
id=chunk_id, | |
created=created_time, | |
choices=[Choice(delta=ChoiceDelta(content=content_chunk))] | |
) | |
yield f"data: {chunk_obj.model_dump_json()}\n\n" | |
# Add logic here to detect the end of the stream if Notion has a specific marker | |
# For now, we assume markdown-chat stops when the main content is done. | |
# If we see a recordMap, it's definitely past the text stream. | |
elif "recordMap" in data: | |
print("Detected recordMap, stopping stream.") | |
# Process any remaining buffer | |
if buffer.strip(): | |
try: | |
last_data = json.loads(buffer.strip()) | |
if last_data.get("type") == "markdown-chat" and isinstance(last_data.get("value"), str): | |
if last_data["value"]: | |
last_chunk = ChatCompletionChunk( | |
id=chunk_id, | |
created=created_time, | |
choices=[Choice(delta=ChoiceDelta(content=last_data["value"]))] | |
) | |
yield f"data: {last_chunk.model_dump_json()}\n\n" | |
except: | |
pass | |
# Exit the loop | |
break | |
except json.JSONDecodeError as e: | |
print(f"Warning: Could not decode JSON line: {line[:100]}... Error: {str(e)}") | |
except Exception as e: | |
print(f"Error processing line: {str(e)}") | |
# Continue processing other lines | |
if is_error_response: | |
break | |
# 如果检测到错误,进行重试 | |
if is_error_response: | |
if attempt < max_retries - 1: | |
print(f"等待 {retry_delay} 秒后重试...") | |
await asyncio.sleep(retry_delay) | |
continue # 重试 | |
else: | |
# 所有重试都失败了,通过流式响应返回错误信息 | |
print("所有重试都失败,返回500错误给客户端") | |
error_chunk = ChatCompletionChunk( | |
id=chunk_id, | |
created=created_time, | |
choices=[Choice(delta=ChoiceDelta(content="Error: Notion API returned error code 500 after all retries"), finish_reason="stop")] | |
) | |
yield f"data: {error_chunk.model_dump_json()}\n\n" | |
yield "data: [DONE]\n\n" | |
return | |
# 如果没有错误,发送最终的停止信号 | |
# Send the final chunk indicating stop | |
final_chunk = ChatCompletionChunk( | |
id=chunk_id, | |
created=created_time, | |
choices=[Choice(delta=ChoiceDelta(), finish_reason="stop")] | |
) | |
yield f"data: {final_chunk.model_dump_json()}\n\n" | |
yield "data: [DONE]\n\n" | |
# 成功完成,退出重试循环 | |
break | |
except HTTPException: | |
# 在流式响应中不能抛出HTTPException,通过流式响应返回错误 | |
if attempt < max_retries - 1: | |
print(f"HTTP异常,等待 {retry_delay} 秒后重试...") | |
await asyncio.sleep(retry_delay) | |
continue | |
else: | |
print("HTTP异常且无更多重试,返回错误信息") | |
error_chunk = ChatCompletionChunk( | |
id=chunk_id, | |
created=created_time, | |
choices=[Choice(delta=ChoiceDelta(content="Error: HTTP exception occurred after all retries"), finish_reason="stop")] | |
) | |
yield f"data: {error_chunk.model_dump_json()}\n\n" | |
yield "data: [DONE]\n\n" | |
return | |
except Exception as e: | |
print(f"Unexpected error during streaming (attempt {attempt + 1}/{max_retries}): {e}") | |
if attempt < max_retries - 1: | |
print(f"等待 {retry_delay} 秒后重试...") | |
await asyncio.sleep(retry_delay) | |
continue | |
else: | |
print("意外错误且无更多重试,返回错误信息") | |
error_chunk = ChatCompletionChunk( | |
id=chunk_id, | |
created=created_time, | |
choices=[Choice(delta=ChoiceDelta(content=f"Error: Internal server error during streaming: {e}"), finish_reason="stop")] | |
) | |
yield f"data: {error_chunk.model_dump_json()}\n\n" | |
yield "data: [DONE]\n\n" | |
return | |
# --- API Endpoints --- | |
async def list_models(authenticated: bool = Depends(authenticate)): | |
""" | |
Endpoint to list available Notion models, mimicking OpenAI's /v1/models. | |
""" | |
available_models = [ | |
"openai-gpt-4.1", | |
"anthropic-opus-4", | |
"anthropic-sonnet-4" | |
] | |
model_list = [ | |
Model(id=model_id, owned_by="notion") # created uses default_factory | |
for model_id in available_models | |
] | |
return ModelList(data=model_list) | |
async def chat_completions(request_data: ChatCompletionRequest, request: Request, authenticated: bool = Depends(authenticate)): | |
""" | |
Endpoint to mimic OpenAI's chat completions, proxying to Notion. | |
""" | |
if not NOTION_COOKIE: | |
raise HTTPException(status_code=500, detail="Server configuration error: Notion cookie not set.") | |
notion_request_body = build_notion_request(request_data) | |
if request_data.stream: | |
return StreamingResponse( | |
stream_notion_response(notion_request_body), | |
media_type="text/event-stream" | |
) | |
else: | |
# --- Non-Streaming Logic (Optional - Collects stream internally) --- | |
# Note: The primary goal is streaming, but a non-streaming version | |
# might be useful for testing or simpler clients. | |
# This requires collecting all chunks from the async generator. | |
full_response_content = "" | |
final_finish_reason = None | |
chunk_id = f"chatcmpl-{uuid.uuid4()}" # Generate ID for the non-streamed response | |
created_time = int(time.time()) | |
try: | |
async for line in stream_notion_response(notion_request_body): | |
if line.startswith("data: ") and "[DONE]" not in line: | |
try: | |
data_json = line[len("data: "):].strip() | |
if data_json: | |
chunk_data = json.loads(data_json) | |
if chunk_data.get("choices"): | |
delta = chunk_data["choices"][0].get("delta", {}) | |
content = delta.get("content") | |
if content: | |
full_response_content += content | |
finish_reason = chunk_data["choices"][0].get("finish_reason") | |
if finish_reason: | |
final_finish_reason = finish_reason | |
except json.JSONDecodeError: | |
print(f"Warning: Could not decode JSON line in non-streaming mode: {line}") | |
# Construct the final OpenAI-compatible non-streaming response | |
return { | |
"id": chunk_id, | |
"object": "chat.completion", | |
"created": created_time, | |
"model": request_data.model, # Return the model requested by the client | |
"choices": [ | |
{ | |
"index": 0, | |
"message": { | |
"role": "assistant", | |
"content": full_response_content, | |
}, | |
"finish_reason": final_finish_reason or "stop", # Default to stop if not explicitly set | |
} | |
], | |
"usage": { # Note: Token usage is not available from Notion | |
"prompt_tokens": None, | |
"completion_tokens": None, | |
"total_tokens": None, | |
}, | |
} | |
except HTTPException as e: | |
# Re-raise HTTP exceptions from the streaming function | |
raise e | |
except Exception as e: | |
print(f"Error during non-streaming processing: {e}") | |
raise HTTPException(status_code=500, detail="Internal server error processing Notion response") | |
if __name__ == "__main__": | |
import uvicorn | |
print("Starting server. Access at http://localhost:7860") | |
print("Ensure NOTION_COOKIE is set in your .env file or environment.") | |
print("Cookie管理系统已启用,将自动获取和更新Notion浏览器cookie") | |
# 运行服务器 | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |