File size: 9,753 Bytes
8b7b267
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
#!/usr/bin/env python3
"""
AI & ML API Router
==================
API endpoints for AI predictions, backtesting, and ML training
"""

from fastapi import APIRouter, HTTPException, Depends, Body, Query, Path
from fastapi.responses import JSONResponse
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field
from datetime import datetime
from sqlalchemy.orm import Session
import logging

from backend.services.backtesting_service import BacktestingService
from backend.services.ml_training_service import MLTrainingService
from database.db_manager import db_manager

logger = logging.getLogger(__name__)

router = APIRouter(
    prefix="/api/ai",
    tags=["AI & ML"]
)


# ============================================================================
# Pydantic Models
# ============================================================================

class BacktestRequest(BaseModel):
    """Request model for starting a backtest."""
    strategy: str = Field(..., description="Strategy name (e.g., 'simple_moving_average', 'rsi_strategy', 'macd_strategy')")
    symbol: str = Field(..., description="Trading pair (e.g., 'BTC/USDT')")
    start_date: datetime = Field(..., description="Backtest start date")
    end_date: datetime = Field(..., description="Backtest end date")
    initial_capital: float = Field(..., gt=0, description="Starting capital for backtest")


class TrainingRequest(BaseModel):
    """Request model for starting ML training."""
    model_name: str = Field(..., description="Name of the model to train")
    training_data_start: datetime = Field(..., description="Start date for training data")
    training_data_end: datetime = Field(..., description="End date for training data")
    batch_size: int = Field(32, gt=0, description="Training batch size")
    learning_rate: Optional[float] = Field(None, gt=0, description="Learning rate")
    config: Optional[Dict[str, Any]] = Field(None, description="Additional training configuration")


class TrainingStepRequest(BaseModel):
    """Request model for executing a training step."""
    step_number: int = Field(..., ge=1, description="Step number")
    loss: Optional[float] = Field(None, description="Training loss")
    accuracy: Optional[float] = Field(None, ge=0, le=1, description="Training accuracy")
    learning_rate: Optional[float] = Field(None, gt=0, description="Current learning rate")
    metrics: Optional[Dict[str, Any]] = Field(None, description="Additional metrics")


# ============================================================================
# Dependency Injection
# ============================================================================

def get_db() -> Session:
    """Get database session."""
    db = db_manager.SessionLocal()
    try:
        yield db
    finally:
        db.close()


def get_backtesting_service(db: Session = Depends(get_db)) -> BacktestingService:
    """Get backtesting service instance."""
    return BacktestingService(db)


def get_ml_training_service(db: Session = Depends(get_db)) -> MLTrainingService:
    """Get ML training service instance."""
    return MLTrainingService(db)


# ============================================================================
# API Endpoints
# ============================================================================

@router.post("/backtest")
async def start_backtest(
    backtest_request: BacktestRequest,
    service: BacktestingService = Depends(get_backtesting_service)
) -> JSONResponse:
    """
    Start a backtest for a specific strategy.
    
    Runs a backtest simulation using historical data and returns comprehensive
    performance metrics including total return, Sharpe ratio, max drawdown, and win rate.
    
    Args:
        backtest_request: Backtest configuration
        service: Backtesting service instance
    
    Returns:
        JSON response with backtest results
    """
    try:
        # Validate dates
        if backtest_request.end_date <= backtest_request.start_date:
            raise ValueError("end_date must be after start_date")

        # Run backtest
        results = service.start_backtest(
            strategy=backtest_request.strategy,
            symbol=backtest_request.symbol,
            start_date=backtest_request.start_date,
            end_date=backtest_request.end_date,
            initial_capital=backtest_request.initial_capital
        )

        return JSONResponse(
            status_code=200,
            content={
                "success": True,
                "message": "Backtest completed successfully",
                "data": results
            }
        )

    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))
    except Exception as e:
        logger.error(f"Error running backtest: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")


@router.post("/train")
async def start_training(
    training_request: TrainingRequest,
    service: MLTrainingService = Depends(get_ml_training_service)
) -> JSONResponse:
    """
    Start training a model.
    
    Initiates the model training process with specified configuration.
    
    Args:
        training_request: Training configuration
        service: ML training service instance
    
    Returns:
        JSON response with training job details
    """
    try:
        job = service.start_training(
            model_name=training_request.model_name,
            training_data_start=training_request.training_data_start,
            training_data_end=training_request.training_data_end,
            batch_size=training_request.batch_size,
            learning_rate=training_request.learning_rate,
            config=training_request.config
        )

        return JSONResponse(
            status_code=201,
            content={
                "success": True,
                "message": "Training job created successfully",
                "data": job
            }
        )

    except Exception as e:
        logger.error(f"Error starting training: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")


@router.post("/train-step")
async def execute_training_step(
    job_id: str = Query(..., description="Training job ID"),
    step_request: TrainingStepRequest = Body(...),
    service: MLTrainingService = Depends(get_ml_training_service)
) -> JSONResponse:
    """
    Execute a training step.
    
    Records a single training step with metrics.
    
    Args:
        job_id: Training job ID
        step_request: Training step data
        service: ML training service instance
    
    Returns:
        JSON response with step details
    """
    try:
        step = service.execute_training_step(
            job_id=job_id,
            step_number=step_request.step_number,
            loss=step_request.loss,
            accuracy=step_request.accuracy,
            learning_rate=step_request.learning_rate,
            metrics=step_request.metrics
        )

        return JSONResponse(
            status_code=200,
            content={
                "success": True,
                "message": "Training step executed successfully",
                "data": step
            }
        )

    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))
    except Exception as e:
        logger.error(f"Error executing training step: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")


@router.get("/train/status")
async def get_training_status(
    job_id: str = Query(..., description="Training job ID"),
    service: MLTrainingService = Depends(get_ml_training_service)
) -> JSONResponse:
    """
    Get the current training status.
    
    Retrieves the current status and metrics for a training job.
    
    Args:
        job_id: Training job ID
        service: ML training service instance
    
    Returns:
        JSON response with training status
    """
    try:
        status = service.get_training_status(job_id)

        return JSONResponse(
            status_code=200,
            content={
                "success": True,
                "data": status
            }
        )

    except ValueError as e:
        raise HTTPException(status_code=404, detail=str(e))
    except Exception as e:
        logger.error(f"Error getting training status: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")


@router.get("/train/history")
async def get_training_history(
    model_name: Optional[str] = Query(None, description="Filter by model name"),
    limit: int = Query(100, ge=1, le=1000, description="Maximum number of jobs to return"),
    service: MLTrainingService = Depends(get_ml_training_service)
) -> JSONResponse:
    """
    Get training history.
    
    Retrieves the training history for all models or a specific model.
    
    Args:
        model_name: Optional model name filter
        limit: Maximum number of jobs to return
        service: ML training service instance
    
    Returns:
        JSON response with training history
    """
    try:
        history = service.get_training_history(
            model_name=model_name,
            limit=limit
        )

        return JSONResponse(
            status_code=200,
            content={
                "success": True,
                "count": len(history),
                "data": history
            }
        )

    except Exception as e:
        logger.error(f"Error retrieving training history: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")