""" HuggingFace Models API Endpoints for AI model predictions """ from fastapi import APIRouter, HTTPException, Query, Body, Path, Header from fastapi.responses import JSONResponse from typing import Optional, Dict, Any, List from datetime import datetime, timezone import logging import os from backend.services.hf_client import run_sentiment from backend.services.persistence_service import PersistenceService from backend.routers.hf_service_api import create_response from backend.services.models_adapter import ( predict_with_model, get_model_info, initialize_models as init_models ) logger = logging.getLogger(__name__) router = APIRouter( prefix="/api/models", tags=["Models API"] ) persistence_service = PersistenceService() # Model registry MODEL_REGISTRY = { "trade-signal-v1": { "name": "CryptoTrader-LM", "repo": "agarkovv/CryptoTrader-LM", "type": "text-generation", "description": "Trading signal generation model" }, "sentiment-v1": { "name": "CryptoBERT", "repo": "kk08/CryptoBERT", "type": "sentiment-analysis", "description": "Cryptocurrency sentiment analysis" }, "crypto-analyst": { "name": "crypto-gpt-o3-mini", "repo": "OpenC/crypto-gpt-o3-mini", "type": "text-generation", "description": "Crypto market analysis and insights" } } @router.get("/list") async def list_models(): """List available models from ai_models.py""" try: model_info = get_model_info() # Get model registry from ai_models if available try: from backend.services.models_adapter import _get_ai_models ai_models = _get_ai_models() if hasattr(ai_models, "MODEL_SPECS"): models_list = [] for key, spec in ai_models.MODEL_SPECS.items(): models_list.append({ "key": key, "name": spec.model_id if hasattr(spec, "model_id") else key, "task": spec.task if hasattr(spec, "task") else "unknown", "category": spec.category if hasattr(spec, "category") else "unknown", "description": f"{spec.task} model for {spec.category}" if hasattr(spec, "category") else f"{spec.task} model" }) return { "models": models_list, # Frontend expects "models" key "data": models_list, # Also include "data" for compatibility "meta": { "source": "hf-model", "generated_at": datetime.now(timezone.utc).isoformat(), "total_count": len(models_list), "initialized": model_info.get("models_initialized", False) } } except Exception as e: logger.warning(f"Could not get models from ai_models: {e}") # Fallback to registry models_list = [] for key, model in MODEL_REGISTRY.items(): models_list.append({ "key": key, "name": model.get("name", key), "model_id": model.get("repo", key), "task": model.get("type", "unknown"), "category": "trading" if "trade" in key else "sentiment" if "sentiment" in key else "analysis", "description": model.get("description", f"{model.get('type', 'unknown')} model"), "status": "available" }) return { "models": models_list, # Frontend expects "models" key "data": models_list, # Also include "data" for compatibility "meta": { "source": "hf", "generated_at": datetime.now(timezone.utc).isoformat(), "total_count": len(models_list) } } except Exception as e: logger.error(f"Error listing models: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/{model_key}/predict") async def predict( model_key: str = Path(..., description="Model key"), request: Dict[str, Any] = Body(...), authorization: Optional[str] = Header(None, alias="Authorization") ): """ Get prediction from a model using ai_models.py Token authentication is optional (uses server-side HF token from env) """ try: # Extract request parameters symbol = request.get("symbol") text = request.get("text", request.get("context", "")) mode = request.get("mode", "crypto") # Build input payload for adapter input_payload = { "text": text, "mode": mode, "params": request.get("params", {}) } # Call adapter (uses ai_models.py internally) normalized_output = await predict_with_model( model_key=model_key, input_payload=input_payload, symbol=symbol ) # Return standardized response return { "data": normalized_output.get("data"), "meta": normalized_output.get("meta") } except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except RuntimeError as e: raise HTTPException(status_code=503, detail=f"Model unavailable: {str(e)}") except Exception as e: logger.error(f"Error in predict endpoint: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @router.post("/batch/predict") async def batch_predict( request: Dict[str, Any] = Body(...), authorization: Optional[str] = Header(None, alias="Authorization") ): """Batch prediction for multiple items using ai_models.py""" try: model_key = request.get("model_key", "sentiment-v1") items = request.get("items", request.get("symbols", [])) mode = request.get("mode", "crypto") if not items: raise HTTPException(status_code=400, detail="items or symbols required") results = [] for item in items: try: # Handle both dict and string items if isinstance(item, dict): text = item.get("text", item.get("context", "")) symbol = item.get("symbol") else: text = str(item) symbol = None input_payload = { "text": text, "mode": mode, "params": request.get("params", {}) } normalized_output = await predict_with_model( model_key=model_key, input_payload=input_payload, symbol=symbol ) results.append({ "data": normalized_output.get("data"), "meta": normalized_output.get("meta") }) except Exception as e: logger.error(f"Error predicting for item {item}: {e}") results.append({ "data": None, "meta": { "source": "none", "error": str(e), "generated_at": datetime.now(timezone.utc).isoformat() } }) return { "data": results, "meta": { "source": "batch", "generated_at": datetime.now(timezone.utc).isoformat(), "total_count": len(results) } } except HTTPException: raise except Exception as e: logger.error(f"Error in batch_predict: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @router.get("/status") async def get_model_status(): """Get model initialization status""" try: model_info = get_model_info() return { "data": model_info, "meta": { "source": "hf-model", "generated_at": datetime.now(timezone.utc).isoformat() } } except Exception as e: logger.error(f"Error getting model status: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/initialize") async def initialize_models_endpoint(): """Initialize models from ai_models.py""" try: result = init_models() return { "data": result, "meta": { "source": "hf-model", "generated_at": datetime.now(timezone.utc).isoformat() } } except Exception as e: logger.error(f"Error initializing models: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/{model_key}/info") async def get_model_info_endpoint( model_key: str = Path(..., description="Model key") ): """Get detailed information about a specific model""" try: from backend.services.models_adapter import _get_ai_models ai_models = _get_ai_models() # Try to get model spec from ai_models if hasattr(ai_models, "MODEL_SPECS") and model_key in ai_models.MODEL_SPECS: spec = ai_models.MODEL_SPECS[model_key] # Check if model is loaded is_loaded = False if hasattr(ai_models, "_registry"): is_loaded = model_key in ai_models._registry._pipelines return { "data": { "key": model_key, "name": spec.model_id if hasattr(spec, "model_id") else model_key, "task": spec.task if hasattr(spec, "task") else "unknown", "category": spec.category if hasattr(spec, "category") else "unknown", "requires_auth": spec.requires_auth if hasattr(spec, "requires_auth") else False, "loaded": is_loaded, "description": f"{spec.task} model for {spec.category}" if hasattr(spec, "category") else f"{spec.task} model" }, "meta": { "source": "hf-model", "generated_at": datetime.now(timezone.utc).isoformat() } } # Fallback to MODEL_REGISTRY if model_key in MODEL_REGISTRY: model_info = MODEL_REGISTRY[model_key] return { "data": { "key": model_key, **model_info, "loaded": False # Unknown status }, "meta": { "source": "registry", "generated_at": datetime.now(timezone.utc).isoformat() } } raise HTTPException(status_code=404, detail=f"Model {model_key} not found") except HTTPException: raise except Exception as e: logger.error(f"Error getting model info for {model_key}: {e}") raise HTTPException(status_code=500, detail=str(e))