|
""" |
|
Async Stock Price Predictor using Amazon Chronos T5-Small Time Series Model |
|
|
|
Required installations: |
|
pip install chronos-forecasting yfinance torch numpy pandas aiohttp asyncio |
|
|
|
Usage: |
|
python stock_predictor.py |
|
""" |
|
|
|
import yfinance as yf |
|
import torch |
|
import numpy as np |
|
from chronos import ChronosPipeline |
|
import pandas as pd |
|
import logging |
|
import asyncio |
|
import aiohttp |
|
from concurrent.futures import ThreadPoolExecutor |
|
from typing import Optional, Tuple, List, Dict |
|
from datetime import datetime |
|
import warnings |
|
import time |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class AsyncStockPredictor: |
|
""" |
|
An async stock price predictor using Amazon Chronos T5 time series model. |
|
|
|
This class fetches historical stock data asynchronously and uses the Chronos model |
|
to predict future stock prices and movement trends with concurrent processing. |
|
""" |
|
|
|
def __init__(self, model_name: str = "amazon/chronos-t5-small", max_workers: int = 4): |
|
""" |
|
Initialize the async stock predictor with Chronos model. |
|
|
|
Args: |
|
model_name: Name of the Chronos model to use |
|
max_workers: Maximum number of worker threads for CPU-intensive tasks |
|
""" |
|
self.model_name = model_name |
|
self.max_workers = max_workers |
|
self.executor = ThreadPoolExecutor(max_workers=max_workers) |
|
self.pipeline = None |
|
|
|
async def initialize(self): |
|
"""Initialize the model asynchronously.""" |
|
try: |
|
logger.info(f"Loading Chronos model: {self.model_name}") |
|
|
|
self.pipeline = await asyncio.get_event_loop().run_in_executor( |
|
self.executor, self._load_model |
|
) |
|
logger.info("Chronos model loaded successfully") |
|
except Exception as e: |
|
logger.error(f"Error loading model: {e}") |
|
raise |
|
|
|
def _load_model(self): |
|
"""Load the Chronos model (CPU intensive, runs in thread pool).""" |
|
try: |
|
return ChronosPipeline.from_pretrained( |
|
self.model_name, |
|
device_map="auto", |
|
torch_dtype=torch.bfloat16, |
|
) |
|
except Exception as e: |
|
logger.warning(f"Failed to load with optimized settings: {e}") |
|
logger.info("Attempting to load with default settings...") |
|
return ChronosPipeline.from_pretrained(self.model_name) |
|
|
|
async def fetch_prices_async(self, ticker: str, period: str = "6mo", interval: str = "1d") -> Optional[ |
|
pd.DataFrame]: |
|
""" |
|
Fetch historical stock price data asynchronously. |
|
|
|
Args: |
|
ticker: Stock ticker symbol (e.g., 'AAPL') |
|
period: Time period for data (1d, 5d, 1mo, 3mo, 6mo, 1y, 2y, 5y, 10y, ytd, max) |
|
interval: Data interval (1m, 2m, 5m, 15m, 30m, 60m, 90m, 1h, 1d, 5d, 1wk, 1mo, 3mo) |
|
|
|
Returns: |
|
DataFrame with OHLCV data or None if error occurs |
|
""" |
|
try: |
|
logger.info(f"Fetching data for {ticker}") |
|
|
|
|
|
df = await asyncio.get_event_loop().run_in_executor( |
|
self.executor, self._fetch_data_sync, ticker, period, interval |
|
) |
|
|
|
if df is None or df.empty: |
|
logger.error(f"No data found for ticker {ticker}") |
|
return None |
|
|
|
|
|
df = df[["Open", "High", "Low", "Close", "Volume"]].copy() |
|
df.dropna(inplace=True) |
|
|
|
if len(df) < 30: |
|
logger.warning(f"Insufficient data for {ticker}. Got {len(df)} days, need at least 30") |
|
return None |
|
|
|
logger.info(f"Successfully fetched {len(df)} data points for {ticker}") |
|
return df |
|
|
|
except Exception as e: |
|
logger.error(f"Error fetching data for {ticker}: {e}") |
|
return None |
|
|
|
def _fetch_data_sync(self, ticker: str, period: str, interval: str) -> Optional[pd.DataFrame]: |
|
"""Synchronous data fetching (runs in thread pool).""" |
|
try: |
|
with warnings.catch_warnings(): |
|
warnings.simplefilter("ignore") |
|
df = yf.download(ticker, period=period, interval=interval, progress=False) |
|
return df |
|
except Exception as e: |
|
logger.error(f"Error in sync data fetch for {ticker}: {e}") |
|
return None |
|
|
|
async def predict_next_day_async(self, prices: pd.DataFrame, prediction_length: int = 1, num_samples: int = 20) -> \ |
|
Tuple[str, float, List[float]]: |
|
""" |
|
Predict next day's price using Chronos time series model asynchronously. |
|
|
|
Args: |
|
prices: DataFrame with historical price data |
|
prediction_length: Number of future periods to predict |
|
num_samples: Number of sample predictions to generate |
|
|
|
Returns: |
|
Tuple of (trend_description, confidence_score, predicted_prices) |
|
""" |
|
if self.pipeline is None: |
|
return "❌ Model not initialized", 0.0, [] |
|
|
|
if prices is None or len(prices) < 30: |
|
return "❓ Insufficient data", 0.0, [] |
|
|
|
try: |
|
|
|
result = await asyncio.get_event_loop().run_in_executor( |
|
self.executor, self._predict_sync, prices, prediction_length, num_samples |
|
) |
|
return result |
|
|
|
except Exception as e: |
|
logger.error(f"Error during async prediction: {e}") |
|
return "❌ Prediction error", 0.0, [] |
|
|
|
def _predict_sync(self, prices: pd.DataFrame, prediction_length: int, num_samples: int) -> Tuple[ |
|
str, float, List[float]]: |
|
"""Synchronous prediction (runs in thread pool).""" |
|
try: |
|
|
|
closes = prices["Close"].values |
|
context_length = min(len(closes), 512) |
|
context = closes[-context_length:] |
|
|
|
logger.info(f"Using {context_length} data points for prediction") |
|
|
|
|
|
|
|
|
|
context_tensor = torch.tensor(context, dtype=torch.float32).reshape(1, -1) |
|
|
|
|
|
with torch.no_grad(): |
|
forecast = self.pipeline.predict( |
|
context=context_tensor, |
|
prediction_length=prediction_length, |
|
num_samples=num_samples |
|
) |
|
|
|
|
|
predictions = forecast[0, :, 0].numpy() |
|
|
|
|
|
mean_prediction = np.mean(predictions) |
|
std_prediction = np.std(predictions) |
|
|
|
current_price = float(closes[-1]) |
|
price_change_pct = ((mean_prediction - current_price) / current_price) * 100 |
|
|
|
|
|
if price_change_pct > 2.0: |
|
trend = "🚀 Strong Growth Expected" |
|
confidence = min(0.9, abs(price_change_pct) / 10.0) |
|
elif price_change_pct > 0.5: |
|
trend = "📈 Moderate Growth Expected" |
|
confidence = min(0.7, abs(price_change_pct) / 5.0) |
|
elif price_change_pct < -2.0: |
|
trend = "📉 Strong Decline Expected" |
|
confidence = min(0.9, abs(price_change_pct) / 10.0) |
|
elif price_change_pct < -0.5: |
|
trend = "📉 Moderate Decline Expected" |
|
confidence = min(0.7, abs(price_change_pct) / 5.0) |
|
else: |
|
trend = "➡️ Sideways Movement Expected" |
|
confidence = 0.5 |
|
|
|
|
|
variance_factor = min(1.0, std_prediction / current_price) |
|
confidence = max(0.1, confidence * (1 - variance_factor)) |
|
|
|
logger.info(f"Prediction: ${mean_prediction:.2f} ({price_change_pct:+.2f}%) - {trend}") |
|
|
|
return trend, confidence, predictions.tolist() |
|
|
|
except Exception as e: |
|
logger.error(f"Error in sync prediction: {e}", exc_info=True) |
|
return "❌ Prediction error", 0.0, [] |
|
|
|
async def calculate_technical_indicators_async(self, prices: pd.DataFrame) -> dict: |
|
""" |
|
Calculate basic technical indicators asynchronously. |
|
|
|
Args: |
|
prices: DataFrame with historical price data |
|
|
|
Returns: |
|
Dictionary with technical indicators |
|
""" |
|
try: |
|
|
|
indicators = await asyncio.get_event_loop().run_in_executor( |
|
self.executor, self._calculate_indicators_sync, prices |
|
) |
|
return indicators |
|
except Exception as e: |
|
logger.error(f"Error calculating technical indicators: {e}") |
|
return {} |
|
|
|
|
|
def _safe_float(self, val) -> float: |
|
"""Convert a value to float, safely handling NaN and single-element Series.""" |
|
if isinstance(val, pd.Series): |
|
|
|
if len(val) == 1: |
|
val = val.iloc[0] |
|
else: |
|
|
|
val = val.iloc[-1] |
|
if pd.isna(val): |
|
return 0.0 |
|
return float(val) |
|
|
|
def _calculate_indicators_sync(self, prices: pd.DataFrame) -> dict[str, float]: |
|
"""Synchronous indicator calculation - alternative approach.""" |
|
try: |
|
|
|
sma_20 = self._safe_float(prices['Close'].rolling(window=20).mean().iloc[-1]) |
|
sma_50 = self._safe_float(prices['Close'].rolling(window=50).mean().iloc[-1]) |
|
|
|
|
|
current_price = self._safe_float(prices['Close'].iloc[-1]) |
|
previous_price = self._safe_float(prices['Close'].iloc[-2]) |
|
price_change = ((current_price - previous_price) / previous_price) * 100 if previous_price != 0 else 0.0 |
|
|
|
|
|
avg_volume = self._safe_float(prices['Volume'].rolling(window=20).mean().iloc[-1]) |
|
current_volume = self._safe_float(prices['Volume'].iloc[-1]) |
|
volume_ratio = current_volume / avg_volume if avg_volume != 0 else 1.0 |
|
|
|
tech_indicators = { |
|
'sma_20': sma_20, |
|
'sma_50': sma_50, |
|
'price_change': price_change, |
|
'volume_ratio': volume_ratio |
|
} |
|
logger.info(f"Calculated indicators: {tech_indicators}") |
|
return tech_indicators |
|
except Exception as e: |
|
logger.error(f"Error in sync indicator calculation: {e}", exc_info=True) |
|
return {} |
|
|
|
async def analyze_stock_async(self, ticker: str) -> str: |
|
""" |
|
Perform complete stock analysis asynchronously. |
|
|
|
Args: |
|
ticker: Stock ticker symbol |
|
|
|
Returns: |
|
Formatted analysis message |
|
""" |
|
try: |
|
|
|
prices = await self.fetch_prices_async(ticker) |
|
|
|
if prices is None: |
|
return f"❌ Could not fetch data for {ticker}" |
|
|
|
|
|
prediction_task = self.predict_next_day_async(prices) |
|
indicators_task = self.calculate_technical_indicators_async(prices) |
|
|
|
|
|
(trend, confidence, predictions), indicators = await asyncio.gather( |
|
prediction_task, indicators_task |
|
) |
|
|
|
|
|
message = await self.create_analysis_message_async( |
|
ticker, prices, trend, confidence, predictions, indicators |
|
) |
|
|
|
return message |
|
|
|
except Exception as e: |
|
logger.error(f"Error analyzing {ticker}: {e}") |
|
return f"❌ Error analyzing {ticker}: {e}" |
|
|
|
async def create_analysis_message_async(self, ticker: str, prices: pd.DataFrame, trend: str, |
|
confidence: float, predictions: List[float] = None, |
|
indicators: dict = None) -> str: |
|
""" |
|
Create a comprehensive analysis message asynchronously. |
|
|
|
Args: |
|
ticker: Stock ticker symbol |
|
prices: DataFrame with price data |
|
trend: Predicted trend |
|
confidence: Prediction confidence score |
|
predictions: List of predicted prices |
|
indicators: Technical indicators dictionary |
|
|
|
Returns: |
|
Formatted analysis message |
|
""" |
|
if prices is None or prices.empty: |
|
return f"❌ Unable to analyze {ticker} - no data available" |
|
|
|
try: |
|
last_close = float(prices["Close"].iloc[-1]) |
|
last_date = prices.index[-1].strftime('%Y-%m-%d') |
|
|
|
message_parts = [ |
|
f"📊 **Stock Analysis: {ticker}**", |
|
f"📅 Date: {last_date}", |
|
f"💰 Current Price: ${last_close:.2f}", |
|
f"🔮 Prediction: {trend}", |
|
f"🎯 Confidence: {confidence:.1%}", |
|
"" |
|
] |
|
|
|
|
|
if predictions and len(predictions) > 0: |
|
mean_pred = np.mean(predictions) |
|
min_pred = np.min(predictions) |
|
max_pred = np.max(predictions) |
|
price_change = ((mean_pred - last_close) / last_close) * 100 |
|
|
|
''' |
|
message_parts.extend([ |
|
"🎲 **Price Predictions:**", |
|
f"• Expected Price: ${mean_pred:.2f} ({price_change:+.2f}%)", |
|
f"• Price Range: ${min_pred:.2f} - ${max_pred:.2f}", |
|
f"• Prediction Samples: {len(predictions)}", |
|
"" |
|
]) |
|
''' |
|
message_parts.extend([ |
|
"🎲 **Price Predictions:**", |
|
f"• Expected Price: ${mean_pred:.2f} ({price_change:+.2f}%)", |
|
f"• Price Range: ${min_pred:.2f} - ${max_pred:.2f}", |
|
"" |
|
]) |
|
|
|
|
|
if indicators: |
|
message_parts.extend([ |
|
"📈 **Technical Indicators:**", |
|
f"• 20-day SMA: ${indicators.get('sma_20', 0):.2f}", |
|
f"• 50-day SMA: ${indicators.get('sma_50', 0):.2f}", |
|
f"• Daily Change: {indicators.get('price_change', 0):.2f}%", |
|
f"• Volume Ratio: {indicators.get('volume_ratio', 0):.2f}x", |
|
"" |
|
]) |
|
|
|
message_parts.extend([ |
|
"⚠️ **Disclaimer:** This is AI-generated analysis, not financial advice.", |
|
"Predictions are based on historical patterns and may not reflect future performance.", |
|
"Always do your own research and consult financial advisors before investing." |
|
]) |
|
|
|
return "\n".join(message_parts) |
|
|
|
except Exception as e: |
|
logger.error(f"Error creating message: {e}") |
|
return f"❌ Error creating analysis for {ticker}" |
|
|
|
async def analyze_multiple_stocks(self, tickers: List[str]) -> Dict[str, str]: |
|
""" |
|
Analyze multiple stocks concurrently. |
|
|
|
Args: |
|
tickers: List of stock ticker symbols |
|
|
|
Returns: |
|
Dictionary mapping tickers to analysis messages |
|
""" |
|
tasks = [self.analyze_stock_async(ticker) for ticker in tickers] |
|
results = await asyncio.gather(*tasks, return_exceptions=True) |
|
|
|
analysis_results = {} |
|
for ticker, result in zip(tickers, results): |
|
if isinstance(result, Exception): |
|
analysis_results[ticker] = f"❌ Error analyzing {ticker}: {result}" |
|
else: |
|
analysis_results[ticker] = result |
|
|
|
return analysis_results |
|
|
|
async def close(self): |
|
"""Clean up resources.""" |
|
if hasattr(self, 'executor'): |
|
self.executor.shutdown(wait=True) |
|
logger.info("AsyncStockPredictor resources cleaned up") |
|
|
|
|
|
async def main(): |
|
"""Main async function to demonstrate the stock predictor.""" |
|
predictor = AsyncStockPredictor() |
|
|
|
try: |
|
|
|
await predictor.initialize() |
|
|
|
|
|
tickers = ["AAPL", "GOOGL", "MSFT", "TSLA", "NVDA", "AMD"] |
|
|
|
print(f"\n🚀 Starting concurrent analysis of {len(tickers)} stocks...") |
|
start_time = time.time() |
|
|
|
|
|
results = await predictor.analyze_multiple_stocks(tickers) |
|
|
|
end_time = time.time() |
|
total_time = end_time - start_time |
|
|
|
|
|
for ticker, analysis in results.items(): |
|
print(f"\n{'=' * 60}") |
|
print(f"Analysis for {ticker}") |
|
print('=' * 60) |
|
print(analysis) |
|
|
|
print(f"\n🏁 Analysis completed in {total_time:.2f} seconds") |
|
print(f"⚡ Average time per stock: {total_time / len(tickers):.2f} seconds") |
|
|
|
except Exception as e: |
|
logger.error(f"Error in main execution: {e}") |
|
print(f"❌ Application error: {e}") |
|
|
|
finally: |
|
|
|
await predictor.close() |
|
|
|
|
|
def run_async_analysis(): |
|
"""Entry point for running the async analysis.""" |
|
try: |
|
asyncio.run(main()) |
|
except KeyboardInterrupt: |
|
print("\n🛑 Analysis interrupted by user") |
|
except Exception as e: |
|
print(f"❌ Fatal error: {e}") |
|
|
|
|
|
if __name__ == "__main__": |
|
run_async_analysis() |
|
|