Spaces:
Running
Running
""" | |
Google API Client - Handles all communication with Google's Gemini API. | |
This module is used by both OpenAI compatibility layer and native Gemini endpoints. | |
""" | |
import json | |
import logging | |
import requests | |
from fastapi import Response | |
from fastapi.responses import StreamingResponse | |
from google.auth.transport.requests import Request as GoogleAuthRequest | |
from .auth import get_credentials, save_credentials, get_user_project_id, onboard_user | |
from .utils import get_user_agent | |
from .config import CODE_ASSIST_ENDPOINT, DEFAULT_SAFETY_SETTINGS | |
import asyncio | |
def send_gemini_request(payload: dict, is_streaming: bool = False) -> Response: | |
""" | |
Send a request to Google's Gemini API. | |
Args: | |
payload: The request payload in Gemini format | |
is_streaming: Whether this is a streaming request | |
Returns: | |
FastAPI Response object | |
""" | |
# Get and validate credentials | |
creds = get_credentials() | |
if not creds: | |
return Response( | |
content="Authentication failed. Please restart the proxy to log in.", | |
status_code=500 | |
) | |
# Refresh credentials if needed | |
if creds.expired and creds.refresh_token: | |
try: | |
creds.refresh(GoogleAuthRequest()) | |
save_credentials(creds) | |
except Exception as e: | |
return Response( | |
content="Token refresh failed. Please restart the proxy to re-authenticate.", | |
status_code=500 | |
) | |
elif not creds.token: | |
return Response( | |
content="No access token. Please restart the proxy to re-authenticate.", | |
status_code=500 | |
) | |
# Get project ID and onboard user | |
proj_id = get_user_project_id(creds) | |
if not proj_id: | |
return Response(content="Failed to get user project ID.", status_code=500) | |
onboard_user(creds, proj_id) | |
# Build the final payload with project info | |
final_payload = { | |
"model": payload.get("model"), | |
"project": proj_id, | |
"request": payload.get("request", {}) | |
} | |
# Determine the action and URL | |
action = "streamGenerateContent" if is_streaming else "generateContent" | |
target_url = f"{CODE_ASSIST_ENDPOINT}/v1internal:{action}" | |
if is_streaming: | |
target_url += "?alt=sse" | |
# Build request headers | |
request_headers = { | |
"Authorization": f"Bearer {creds.token}", | |
"Content-Type": "application/json", | |
"User-Agent": get_user_agent(), | |
} | |
final_post_data = json.dumps(final_payload) | |
# Send the request | |
try: | |
if is_streaming: | |
resp = requests.post(target_url, data=final_post_data, headers=request_headers, stream=True) | |
return _handle_streaming_response(resp) | |
else: | |
resp = requests.post(target_url, data=final_post_data, headers=request_headers) | |
return _handle_non_streaming_response(resp) | |
except requests.exceptions.RequestException as e: | |
logging.error(f"Request to Google API failed: {str(e)}") | |
return Response( | |
content=json.dumps({"error": {"message": f"Request failed: {str(e)}"}}), | |
status_code=500, | |
media_type="application/json" | |
) | |
except Exception as e: | |
logging.error(f"Unexpected error during Google API request: {str(e)}") | |
return Response( | |
content=json.dumps({"error": {"message": f"Unexpected error: {str(e)}"}}), | |
status_code=500, | |
media_type="application/json" | |
) | |
def _handle_streaming_response(resp) -> StreamingResponse: | |
"""Handle streaming response from Google API.""" | |
# Check for HTTP errors before starting to stream | |
if resp.status_code != 200: | |
logging.error(f"Google API returned status {resp.status_code}: {resp.text}") | |
error_message = f"Google API error: {resp.status_code}" | |
try: | |
error_data = resp.json() | |
if "error" in error_data: | |
error_message = error_data["error"].get("message", error_message) | |
except: | |
pass | |
# Return error as a streaming response | |
async def error_generator(): | |
error_response = { | |
"error": { | |
"message": error_message, | |
"type": "invalid_request_error" if resp.status_code == 404 else "api_error", | |
"code": resp.status_code | |
} | |
} | |
yield f'data: {json.dumps(error_response)}\n\n'.encode('utf-8') | |
response_headers = { | |
"Content-Type": "text/event-stream", | |
"Content-Disposition": "attachment", | |
"Vary": "Origin, X-Origin, Referer", | |
"X-XSS-Protection": "0", | |
"X-Frame-Options": "SAMEORIGIN", | |
"X-Content-Type-Options": "nosniff", | |
"Server": "ESF" | |
} | |
return StreamingResponse( | |
error_generator(), | |
media_type="text/event-stream", | |
headers=response_headers, | |
status_code=resp.status_code | |
) | |
async def stream_generator(): | |
try: | |
with resp: | |
for chunk in resp.iter_lines(): | |
if chunk: | |
if not isinstance(chunk, str): | |
chunk = chunk.decode('utf-8') | |
if chunk.startswith('data: '): | |
chunk = chunk[len('data: '):] | |
try: | |
obj = json.loads(chunk) | |
if "response" in obj: | |
response_chunk = obj["response"] | |
response_json = json.dumps(response_chunk, separators=(',', ':')) | |
response_line = f"data: {response_json}\n\n" | |
yield response_line.encode('utf-8') | |
await asyncio.sleep(0) | |
else: | |
obj_json = json.dumps(obj, separators=(',', ':')) | |
yield f"data: {obj_json}\n\n".encode('utf-8') | |
except json.JSONDecodeError: | |
continue | |
except requests.exceptions.RequestException as e: | |
logging.error(f"Streaming request failed: {str(e)}") | |
error_response = { | |
"error": { | |
"message": f"Upstream request failed: {str(e)}", | |
"type": "api_error", | |
"code": 502 | |
} | |
} | |
yield f'data: {json.dumps(error_response)}\n\n'.encode('utf-8') | |
except Exception as e: | |
logging.error(f"Unexpected error during streaming: {str(e)}") | |
error_response = { | |
"error": { | |
"message": f"An unexpected error occurred: {str(e)}", | |
"type": "api_error", | |
"code": 500 | |
} | |
} | |
yield f'data: {json.dumps(error_response)}\n\n'.encode('utf-8') | |
response_headers = { | |
"Content-Type": "text/event-stream", | |
"Content-Disposition": "attachment", | |
"Vary": "Origin, X-Origin, Referer", | |
"X-XSS-Protection": "0", | |
"X-Frame-Options": "SAMEORIGIN", | |
"X-Content-Type-Options": "nosniff", | |
"Server": "ESF" | |
} | |
return StreamingResponse( | |
stream_generator(), | |
media_type="text/event-stream", | |
headers=response_headers | |
) | |
def _handle_non_streaming_response(resp) -> Response: | |
"""Handle non-streaming response from Google API.""" | |
if resp.status_code == 200: | |
try: | |
google_api_response = resp.text | |
if google_api_response.startswith('data: '): | |
google_api_response = google_api_response[len('data: '):] | |
google_api_response = json.loads(google_api_response) | |
standard_gemini_response = google_api_response.get("response") | |
return Response( | |
content=json.dumps(standard_gemini_response), | |
status_code=200, | |
media_type="application/json; charset=utf-8" | |
) | |
except (json.JSONDecodeError, AttributeError) as e: | |
logging.error(f"Failed to parse Google API response: {str(e)}") | |
return Response( | |
content=resp.content, | |
status_code=resp.status_code, | |
media_type=resp.headers.get("Content-Type") | |
) | |
else: | |
# Log the error details | |
logging.error(f"Google API returned status {resp.status_code}: {resp.text}") | |
# Try to parse error response and provide meaningful error message | |
try: | |
error_data = resp.json() | |
if "error" in error_data: | |
error_message = error_data["error"].get("message", f"API error: {resp.status_code}") | |
error_response = { | |
"error": { | |
"message": error_message, | |
"type": "invalid_request_error" if resp.status_code == 404 else "api_error", | |
"code": resp.status_code | |
} | |
} | |
return Response( | |
content=json.dumps(error_response), | |
status_code=resp.status_code, | |
media_type="application/json" | |
) | |
except (json.JSONDecodeError, KeyError): | |
pass | |
# Fallback to original response if we can't parse the error | |
return Response( | |
content=resp.content, | |
status_code=resp.status_code, | |
media_type=resp.headers.get("Content-Type") | |
) | |
def build_gemini_payload_from_openai(openai_payload: dict) -> dict: | |
""" | |
Build a Gemini API payload from an OpenAI-transformed request. | |
This is used when OpenAI requests are converted to Gemini format. | |
""" | |
# Extract model from the payload | |
model = openai_payload.get("model") | |
# Get safety settings or use defaults | |
safety_settings = openai_payload.get("safetySettings", DEFAULT_SAFETY_SETTINGS) | |
# Build the request portion | |
request_data = { | |
"contents": openai_payload.get("contents"), | |
"systemInstruction": openai_payload.get("systemInstruction"), | |
"cachedContent": openai_payload.get("cachedContent"), | |
"tools": openai_payload.get("tools"), | |
"toolConfig": openai_payload.get("toolConfig"), | |
"safetySettings": safety_settings, | |
"generationConfig": openai_payload.get("generationConfig", {}), | |
} | |
# Remove any keys with None values | |
request_data = {k: v for k, v in request_data.items() if v is not None} | |
return { | |
"model": model, | |
"request": request_data | |
} | |
def build_gemini_payload_from_native(native_request: dict, model_from_path: str) -> dict: | |
""" | |
Build a Gemini API payload from a native Gemini request. | |
This is used for direct Gemini API calls. | |
""" | |
# Add default safety settings if not provided | |
if "safetySettings" not in native_request: | |
native_request["safetySettings"] = DEFAULT_SAFETY_SETTINGS | |
return { | |
"model": model_from_path, | |
"request": native_request | |
} |