|
|
"""
|
|
|
HuggingFace Models API
|
|
|
Endpoints for AI model predictions
|
|
|
"""
|
|
|
|
|
|
from fastapi import APIRouter, HTTPException, Query, Body, Path, Header
|
|
|
from fastapi.responses import JSONResponse
|
|
|
from typing import Optional, Dict, Any, List
|
|
|
from datetime import datetime, timezone
|
|
|
import logging
|
|
|
import os
|
|
|
|
|
|
from backend.services.hf_client import run_sentiment
|
|
|
from backend.services.persistence_service import PersistenceService
|
|
|
from backend.routers.hf_service_api import create_response
|
|
|
from backend.services.models_adapter import (
|
|
|
predict_with_model,
|
|
|
get_model_info,
|
|
|
initialize_models as init_models
|
|
|
)
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
router = APIRouter(
|
|
|
prefix="/api/models",
|
|
|
tags=["Models API"]
|
|
|
)
|
|
|
|
|
|
persistence_service = PersistenceService()
|
|
|
|
|
|
|
|
|
MODEL_REGISTRY = {
|
|
|
"trade-signal-v1": {
|
|
|
"name": "CryptoTrader-LM",
|
|
|
"repo": "agarkovv/CryptoTrader-LM",
|
|
|
"type": "text-generation",
|
|
|
"description": "Trading signal generation model"
|
|
|
},
|
|
|
"sentiment-v1": {
|
|
|
"name": "CryptoBERT",
|
|
|
"repo": "kk08/CryptoBERT",
|
|
|
"type": "sentiment-analysis",
|
|
|
"description": "Cryptocurrency sentiment analysis"
|
|
|
},
|
|
|
"crypto-analyst": {
|
|
|
"name": "crypto-gpt-o3-mini",
|
|
|
"repo": "OpenC/crypto-gpt-o3-mini",
|
|
|
"type": "text-generation",
|
|
|
"description": "Crypto market analysis and insights"
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
@router.get("/list")
|
|
|
async def list_models():
|
|
|
"""List available models from ai_models.py"""
|
|
|
try:
|
|
|
model_info = get_model_info()
|
|
|
|
|
|
|
|
|
try:
|
|
|
from backend.services.models_adapter import _get_ai_models
|
|
|
ai_models = _get_ai_models()
|
|
|
|
|
|
if hasattr(ai_models, "MODEL_SPECS"):
|
|
|
models_list = []
|
|
|
for key, spec in ai_models.MODEL_SPECS.items():
|
|
|
models_list.append({
|
|
|
"key": key,
|
|
|
"name": spec.model_id if hasattr(spec, "model_id") else key,
|
|
|
"task": spec.task if hasattr(spec, "task") else "unknown",
|
|
|
"category": spec.category if hasattr(spec, "category") else "unknown",
|
|
|
"description": f"{spec.task} model for {spec.category}" if hasattr(spec, "category") else f"{spec.task} model"
|
|
|
})
|
|
|
|
|
|
return {
|
|
|
"models": models_list,
|
|
|
"data": models_list,
|
|
|
"meta": {
|
|
|
"source": "hf-model",
|
|
|
"generated_at": datetime.now(timezone.utc).isoformat(),
|
|
|
"total_count": len(models_list),
|
|
|
"initialized": model_info.get("models_initialized", False)
|
|
|
}
|
|
|
}
|
|
|
except Exception as e:
|
|
|
logger.warning(f"Could not get models from ai_models: {e}")
|
|
|
|
|
|
|
|
|
models_list = []
|
|
|
for key, model in MODEL_REGISTRY.items():
|
|
|
models_list.append({
|
|
|
"key": key,
|
|
|
"name": model.get("name", key),
|
|
|
"model_id": model.get("repo", key),
|
|
|
"task": model.get("type", "unknown"),
|
|
|
"category": "trading" if "trade" in key else "sentiment" if "sentiment" in key else "analysis",
|
|
|
"description": model.get("description", f"{model.get('type', 'unknown')} model"),
|
|
|
"status": "available"
|
|
|
})
|
|
|
|
|
|
return {
|
|
|
"models": models_list,
|
|
|
"data": models_list,
|
|
|
"meta": {
|
|
|
"source": "hf",
|
|
|
"generated_at": datetime.now(timezone.utc).isoformat(),
|
|
|
"total_count": len(models_list)
|
|
|
}
|
|
|
}
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error listing models: {e}")
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
@router.post("/{model_key}/predict")
|
|
|
async def predict(
|
|
|
model_key: str = Path(..., description="Model key"),
|
|
|
request: Dict[str, Any] = Body(...),
|
|
|
authorization: Optional[str] = Header(None, alias="Authorization")
|
|
|
):
|
|
|
"""
|
|
|
Get prediction from a model using ai_models.py
|
|
|
Token authentication is optional (uses server-side HF token from env)
|
|
|
"""
|
|
|
try:
|
|
|
|
|
|
symbol = request.get("symbol")
|
|
|
text = request.get("text", request.get("context", ""))
|
|
|
mode = request.get("mode", "crypto")
|
|
|
|
|
|
|
|
|
input_payload = {
|
|
|
"text": text,
|
|
|
"mode": mode,
|
|
|
"params": request.get("params", {})
|
|
|
}
|
|
|
|
|
|
|
|
|
normalized_output = await predict_with_model(
|
|
|
model_key=model_key,
|
|
|
input_payload=input_payload,
|
|
|
symbol=symbol
|
|
|
)
|
|
|
|
|
|
|
|
|
return {
|
|
|
"data": normalized_output.get("data"),
|
|
|
"meta": normalized_output.get("meta")
|
|
|
}
|
|
|
|
|
|
except ValueError as e:
|
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
except RuntimeError as e:
|
|
|
raise HTTPException(status_code=503, detail=f"Model unavailable: {str(e)}")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error in predict endpoint: {e}", exc_info=True)
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
@router.post("/batch/predict")
|
|
|
async def batch_predict(
|
|
|
request: Dict[str, Any] = Body(...),
|
|
|
authorization: Optional[str] = Header(None, alias="Authorization")
|
|
|
):
|
|
|
"""Batch prediction for multiple items using ai_models.py"""
|
|
|
try:
|
|
|
model_key = request.get("model_key", "sentiment-v1")
|
|
|
items = request.get("items", request.get("symbols", []))
|
|
|
mode = request.get("mode", "crypto")
|
|
|
|
|
|
if not items:
|
|
|
raise HTTPException(status_code=400, detail="items or symbols required")
|
|
|
|
|
|
results = []
|
|
|
for item in items:
|
|
|
try:
|
|
|
|
|
|
if isinstance(item, dict):
|
|
|
text = item.get("text", item.get("context", ""))
|
|
|
symbol = item.get("symbol")
|
|
|
else:
|
|
|
text = str(item)
|
|
|
symbol = None
|
|
|
|
|
|
input_payload = {
|
|
|
"text": text,
|
|
|
"mode": mode,
|
|
|
"params": request.get("params", {})
|
|
|
}
|
|
|
|
|
|
normalized_output = await predict_with_model(
|
|
|
model_key=model_key,
|
|
|
input_payload=input_payload,
|
|
|
symbol=symbol
|
|
|
)
|
|
|
|
|
|
results.append({
|
|
|
"data": normalized_output.get("data"),
|
|
|
"meta": normalized_output.get("meta")
|
|
|
})
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error predicting for item {item}: {e}")
|
|
|
results.append({
|
|
|
"data": None,
|
|
|
"meta": {
|
|
|
"source": "none",
|
|
|
"error": str(e),
|
|
|
"generated_at": datetime.now(timezone.utc).isoformat()
|
|
|
}
|
|
|
})
|
|
|
|
|
|
return {
|
|
|
"data": results,
|
|
|
"meta": {
|
|
|
"source": "batch",
|
|
|
"generated_at": datetime.now(timezone.utc).isoformat(),
|
|
|
"total_count": len(results)
|
|
|
}
|
|
|
}
|
|
|
except HTTPException:
|
|
|
raise
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error in batch_predict: {e}", exc_info=True)
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
@router.get("/status")
|
|
|
async def get_model_status():
|
|
|
"""Get model initialization status"""
|
|
|
try:
|
|
|
model_info = get_model_info()
|
|
|
return {
|
|
|
"data": model_info,
|
|
|
"meta": {
|
|
|
"source": "hf-model",
|
|
|
"generated_at": datetime.now(timezone.utc).isoformat()
|
|
|
}
|
|
|
}
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error getting model status: {e}")
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
@router.post("/initialize")
|
|
|
async def initialize_models_endpoint():
|
|
|
"""Initialize models from ai_models.py"""
|
|
|
try:
|
|
|
result = init_models()
|
|
|
return {
|
|
|
"data": result,
|
|
|
"meta": {
|
|
|
"source": "hf-model",
|
|
|
"generated_at": datetime.now(timezone.utc).isoformat()
|
|
|
}
|
|
|
}
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error initializing models: {e}")
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
@router.get("/{model_key}/info")
|
|
|
async def get_model_info_endpoint(
|
|
|
model_key: str = Path(..., description="Model key")
|
|
|
):
|
|
|
"""Get detailed information about a specific model"""
|
|
|
try:
|
|
|
from backend.services.models_adapter import _get_ai_models
|
|
|
|
|
|
ai_models = _get_ai_models()
|
|
|
|
|
|
|
|
|
if hasattr(ai_models, "MODEL_SPECS") and model_key in ai_models.MODEL_SPECS:
|
|
|
spec = ai_models.MODEL_SPECS[model_key]
|
|
|
|
|
|
|
|
|
is_loaded = False
|
|
|
if hasattr(ai_models, "_registry"):
|
|
|
is_loaded = model_key in ai_models._registry._pipelines
|
|
|
|
|
|
return {
|
|
|
"data": {
|
|
|
"key": model_key,
|
|
|
"name": spec.model_id if hasattr(spec, "model_id") else model_key,
|
|
|
"task": spec.task if hasattr(spec, "task") else "unknown",
|
|
|
"category": spec.category if hasattr(spec, "category") else "unknown",
|
|
|
"requires_auth": spec.requires_auth if hasattr(spec, "requires_auth") else False,
|
|
|
"loaded": is_loaded,
|
|
|
"description": f"{spec.task} model for {spec.category}" if hasattr(spec, "category") else f"{spec.task} model"
|
|
|
},
|
|
|
"meta": {
|
|
|
"source": "hf-model",
|
|
|
"generated_at": datetime.now(timezone.utc).isoformat()
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
if model_key in MODEL_REGISTRY:
|
|
|
model_info = MODEL_REGISTRY[model_key]
|
|
|
return {
|
|
|
"data": {
|
|
|
"key": model_key,
|
|
|
**model_info,
|
|
|
"loaded": False
|
|
|
},
|
|
|
"meta": {
|
|
|
"source": "registry",
|
|
|
"generated_at": datetime.now(timezone.utc).isoformat()
|
|
|
}
|
|
|
}
|
|
|
|
|
|
raise HTTPException(status_code=404, detail=f"Model {model_key} not found")
|
|
|
|
|
|
except HTTPException:
|
|
|
raise
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error getting model info for {model_key}: {e}")
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|