Really-amin's picture
Upload 602 files
89ead5a verified
"""
HuggingFace Models API
Endpoints for AI model predictions
"""
from fastapi import APIRouter, HTTPException, Query, Body, Path, Header
from fastapi.responses import JSONResponse
from typing import Optional, Dict, Any, List
from datetime import datetime, timezone
import logging
import os
from backend.services.hf_client import run_sentiment
from backend.services.persistence_service import PersistenceService
from backend.routers.hf_service_api import create_response
from backend.services.models_adapter import (
predict_with_model,
get_model_info,
initialize_models as init_models
)
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/api/models",
tags=["Models API"]
)
persistence_service = PersistenceService()
# Model registry
MODEL_REGISTRY = {
"trade-signal-v1": {
"name": "CryptoTrader-LM",
"repo": "agarkovv/CryptoTrader-LM",
"type": "text-generation",
"description": "Trading signal generation model"
},
"sentiment-v1": {
"name": "CryptoBERT",
"repo": "kk08/CryptoBERT",
"type": "sentiment-analysis",
"description": "Cryptocurrency sentiment analysis"
},
"crypto-analyst": {
"name": "crypto-gpt-o3-mini",
"repo": "OpenC/crypto-gpt-o3-mini",
"type": "text-generation",
"description": "Crypto market analysis and insights"
}
}
@router.get("/list")
async def list_models():
"""List available models from ai_models.py"""
try:
model_info = get_model_info()
# Get model registry from ai_models if available
try:
from backend.services.models_adapter import _get_ai_models
ai_models = _get_ai_models()
if hasattr(ai_models, "MODEL_SPECS"):
models_list = []
for key, spec in ai_models.MODEL_SPECS.items():
models_list.append({
"key": key,
"name": spec.model_id if hasattr(spec, "model_id") else key,
"task": spec.task if hasattr(spec, "task") else "unknown",
"category": spec.category if hasattr(spec, "category") else "unknown",
"description": f"{spec.task} model for {spec.category}" if hasattr(spec, "category") else f"{spec.task} model"
})
return {
"models": models_list, # Frontend expects "models" key
"data": models_list, # Also include "data" for compatibility
"meta": {
"source": "hf-model",
"generated_at": datetime.now(timezone.utc).isoformat(),
"total_count": len(models_list),
"initialized": model_info.get("models_initialized", False)
}
}
except Exception as e:
logger.warning(f"Could not get models from ai_models: {e}")
# Fallback to registry
models_list = []
for key, model in MODEL_REGISTRY.items():
models_list.append({
"key": key,
"name": model.get("name", key),
"model_id": model.get("repo", key),
"task": model.get("type", "unknown"),
"category": "trading" if "trade" in key else "sentiment" if "sentiment" in key else "analysis",
"description": model.get("description", f"{model.get('type', 'unknown')} model"),
"status": "available"
})
return {
"models": models_list, # Frontend expects "models" key
"data": models_list, # Also include "data" for compatibility
"meta": {
"source": "hf",
"generated_at": datetime.now(timezone.utc).isoformat(),
"total_count": len(models_list)
}
}
except Exception as e:
logger.error(f"Error listing models: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/{model_key}/predict")
async def predict(
model_key: str = Path(..., description="Model key"),
request: Dict[str, Any] = Body(...),
authorization: Optional[str] = Header(None, alias="Authorization")
):
"""
Get prediction from a model using ai_models.py
Token authentication is optional (uses server-side HF token from env)
"""
try:
# Extract request parameters
symbol = request.get("symbol")
text = request.get("text", request.get("context", ""))
mode = request.get("mode", "crypto")
# Build input payload for adapter
input_payload = {
"text": text,
"mode": mode,
"params": request.get("params", {})
}
# Call adapter (uses ai_models.py internally)
normalized_output = await predict_with_model(
model_key=model_key,
input_payload=input_payload,
symbol=symbol
)
# Return standardized response
return {
"data": normalized_output.get("data"),
"meta": normalized_output.get("meta")
}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except RuntimeError as e:
raise HTTPException(status_code=503, detail=f"Model unavailable: {str(e)}")
except Exception as e:
logger.error(f"Error in predict endpoint: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/batch/predict")
async def batch_predict(
request: Dict[str, Any] = Body(...),
authorization: Optional[str] = Header(None, alias="Authorization")
):
"""Batch prediction for multiple items using ai_models.py"""
try:
model_key = request.get("model_key", "sentiment-v1")
items = request.get("items", request.get("symbols", []))
mode = request.get("mode", "crypto")
if not items:
raise HTTPException(status_code=400, detail="items or symbols required")
results = []
for item in items:
try:
# Handle both dict and string items
if isinstance(item, dict):
text = item.get("text", item.get("context", ""))
symbol = item.get("symbol")
else:
text = str(item)
symbol = None
input_payload = {
"text": text,
"mode": mode,
"params": request.get("params", {})
}
normalized_output = await predict_with_model(
model_key=model_key,
input_payload=input_payload,
symbol=symbol
)
results.append({
"data": normalized_output.get("data"),
"meta": normalized_output.get("meta")
})
except Exception as e:
logger.error(f"Error predicting for item {item}: {e}")
results.append({
"data": None,
"meta": {
"source": "none",
"error": str(e),
"generated_at": datetime.now(timezone.utc).isoformat()
}
})
return {
"data": results,
"meta": {
"source": "batch",
"generated_at": datetime.now(timezone.utc).isoformat(),
"total_count": len(results)
}
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in batch_predict: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/status")
async def get_model_status():
"""Get model initialization status"""
try:
model_info = get_model_info()
return {
"data": model_info,
"meta": {
"source": "hf-model",
"generated_at": datetime.now(timezone.utc).isoformat()
}
}
except Exception as e:
logger.error(f"Error getting model status: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/initialize")
async def initialize_models_endpoint():
"""Initialize models from ai_models.py"""
try:
result = init_models()
return {
"data": result,
"meta": {
"source": "hf-model",
"generated_at": datetime.now(timezone.utc).isoformat()
}
}
except Exception as e:
logger.error(f"Error initializing models: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{model_key}/info")
async def get_model_info_endpoint(
model_key: str = Path(..., description="Model key")
):
"""Get detailed information about a specific model"""
try:
from backend.services.models_adapter import _get_ai_models
ai_models = _get_ai_models()
# Try to get model spec from ai_models
if hasattr(ai_models, "MODEL_SPECS") and model_key in ai_models.MODEL_SPECS:
spec = ai_models.MODEL_SPECS[model_key]
# Check if model is loaded
is_loaded = False
if hasattr(ai_models, "_registry"):
is_loaded = model_key in ai_models._registry._pipelines
return {
"data": {
"key": model_key,
"name": spec.model_id if hasattr(spec, "model_id") else model_key,
"task": spec.task if hasattr(spec, "task") else "unknown",
"category": spec.category if hasattr(spec, "category") else "unknown",
"requires_auth": spec.requires_auth if hasattr(spec, "requires_auth") else False,
"loaded": is_loaded,
"description": f"{spec.task} model for {spec.category}" if hasattr(spec, "category") else f"{spec.task} model"
},
"meta": {
"source": "hf-model",
"generated_at": datetime.now(timezone.utc).isoformat()
}
}
# Fallback to MODEL_REGISTRY
if model_key in MODEL_REGISTRY:
model_info = MODEL_REGISTRY[model_key]
return {
"data": {
"key": model_key,
**model_info,
"loaded": False # Unknown status
},
"meta": {
"source": "registry",
"generated_at": datetime.now(timezone.utc).isoformat()
}
}
raise HTTPException(status_code=404, detail=f"Model {model_key} not found")
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting model info for {model_key}: {e}")
raise HTTPException(status_code=500, detail=str(e))