Dmitry Beresnev
		
	commited on
		
		
					Commit 
							
							·
						
						23c855e
	
1
								Parent(s):
							
							ad23307
								
add stock predictor, fix news iterator, etc
Browse files- .gitignore +17 -3
- pyproject.toml +4 -1
- src/core/risk_management/risk_analyzer.py +1 -1
- src/services/async_stock_price_predictor.py +42 -0
- src/services/news_iterator.py +9 -8
- src/services/stock_predictor.py +451 -0
- src/telegram_bot/telegram_bot_service.py +25 -2
    	
        .gitignore
    CHANGED
    
    | @@ -1,6 +1,20 @@ | |
| 1 | 
            -
             | 
|  | |
|  | |
| 2 | 
             
            .venv/
         | 
|  | |
|  | |
|  | |
| 3 | 
             
            uv.lock
         | 
| 4 | 
            -
             | 
| 5 | 
             
            src/your_project_name.egg-info/
         | 
| 6 | 
            -
            src/news_sentoment_analyzer.egg-info/
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Ignore Python cache files
         | 
| 2 | 
            +
            __pycache__/
         | 
| 3 | 
            +
            # Ignore virtual environment directories
         | 
| 4 | 
             
            .venv/
         | 
| 5 | 
            +
            # Ignore environment variable files
         | 
| 6 | 
            +
            .env
         | 
| 7 | 
            +
            # Ignore lock files
         | 
| 8 | 
             
            uv.lock
         | 
| 9 | 
            +
            # Ignore egg-info directories
         | 
| 10 | 
             
            src/your_project_name.egg-info/
         | 
| 11 | 
            +
            src/news_sentoment_analyzer.egg-info/
         | 
| 12 | 
            +
            # Ignore Python bytecode files
         | 
| 13 | 
            +
            *.pyc
         | 
| 14 | 
            +
            *.pyo
         | 
| 15 | 
            +
            # Ignore Jupyter Notebook checkpoints
         | 
| 16 | 
            +
            .ipynb_checkpoints/
         | 
| 17 | 
            +
            # Ignore IDE specific files
         | 
| 18 | 
            +
            .idea/
         | 
| 19 | 
            +
            # Ignore logs
         | 
| 20 | 
            +
            logs/
         | 
    	
        pyproject.toml
    CHANGED
    
    | @@ -1,5 +1,6 @@ | |
| 1 | 
             
            [project]
         | 
| 2 | 
            -
             | 
|  | |
| 3 | 
             
            version = "0.1.0"
         | 
| 4 | 
             
            description = "new sentiment analyzer using lexicon based and machine learning techniques"
         | 
| 5 | 
             
            authors = [
         | 
| @@ -63,6 +64,8 @@ dependencies = [ | |
| 63 | 
             
                "isort>=5.10.0",
         | 
| 64 | 
             
                "python-telegram-bot>=20.0",
         | 
| 65 | 
             
                "dotenv>=0.9.9",
         | 
|  | |
|  | |
| 66 | 
             
            ]
         | 
| 67 |  | 
| 68 | 
             
            [build-system]
         | 
|  | |
| 1 | 
             
            [project]
         | 
| 2 | 
            +
            requires-python = ">=3.11"
         | 
| 3 | 
            +
            name = "financial_news_bot"
         | 
| 4 | 
             
            version = "0.1.0"
         | 
| 5 | 
             
            description = "new sentiment analyzer using lexicon based and machine learning techniques"
         | 
| 6 | 
             
            authors = [
         | 
|  | |
| 64 | 
             
                "isort>=5.10.0",
         | 
| 65 | 
             
                "python-telegram-bot>=20.0",
         | 
| 66 | 
             
                "dotenv>=0.9.9",
         | 
| 67 | 
            +
                "keras>=3.11.2",
         | 
| 68 | 
            +
                "tensorflow>=2.20.0",
         | 
| 69 | 
             
            ]
         | 
| 70 |  | 
| 71 | 
             
            [build-system]
         | 
    	
        src/core/risk_management/risk_analyzer.py
    CHANGED
    
    | @@ -25,7 +25,7 @@ class RiskAnalyzer: | |
| 25 | 
             
                        http_options=types.HttpOptions(api_version='v1')
         | 
| 26 | 
             
                    )
         | 
| 27 |  | 
| 28 | 
            -
                def get_stock_data(self, ticker: str, period: str = " | 
| 29 | 
             
                    try:
         | 
| 30 | 
             
                        stock = yf.Ticker(ticker)
         | 
| 31 | 
             
                        data = stock.history(period=period)
         | 
|  | |
| 25 | 
             
                        http_options=types.HttpOptions(api_version='v1')
         | 
| 26 | 
             
                    )
         | 
| 27 |  | 
| 28 | 
            +
                def get_stock_data(self, ticker: str, period: str = "3mo") -> pd.DataFrame:
         | 
| 29 | 
             
                    try:
         | 
| 30 | 
             
                        stock = yf.Ticker(ticker)
         | 
| 31 | 
             
                        data = stock.history(period=period)
         | 
    	
        src/services/async_stock_price_predictor.py
    CHANGED
    
    | @@ -4,6 +4,7 @@ import asyncio | |
| 4 | 
             
            from datetime import datetime, timezone, timedelta
         | 
| 5 | 
             
            from collections import defaultdict
         | 
| 6 | 
             
            from typing import Any
         | 
|  | |
| 7 |  | 
| 8 | 
             
            import numpy as np
         | 
| 9 | 
             
            import pandas as pd
         | 
| @@ -14,6 +15,7 @@ from sklearn.preprocessing import MinMaxScaler | |
| 14 | 
             
            from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
         | 
| 15 | 
             
            from huggingface_hub import hf_hub_download
         | 
| 16 | 
             
            from concurrent.futures import ProcessPoolExecutor
         | 
|  | |
| 17 |  | 
| 18 | 
             
            from src.telegram_bot.logger import main_logger as logger
         | 
| 19 |  | 
| @@ -301,6 +303,46 @@ class AsyncStockPricePredictor: | |
| 301 | 
             
                        logger.error(f"Error fetching stock data for {ticker}: {e}")
         | 
| 302 | 
             
                        raise RuntimeError(f"Failed to fetch stock data: {e}")
         | 
| 303 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 304 | 
             
                async def fetch_news(self, ticker: str) -> list[dict[str, Any]]:
         | 
| 305 | 
             
                    """Fetch recent news for a stock ticker."""
         | 
| 306 | 
             
                    url = f"https://query1.finance.yahoo.com/v6/finance/news?symbols={ticker.upper()}"
         | 
|  | |
| 4 | 
             
            from datetime import datetime, timezone, timedelta
         | 
| 5 | 
             
            from collections import defaultdict
         | 
| 6 | 
             
            from typing import Any
         | 
| 7 | 
            +
            import warnings
         | 
| 8 |  | 
| 9 | 
             
            import numpy as np
         | 
| 10 | 
             
            import pandas as pd
         | 
|  | |
| 15 | 
             
            from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
         | 
| 16 | 
             
            from huggingface_hub import hf_hub_download
         | 
| 17 | 
             
            from concurrent.futures import ProcessPoolExecutor
         | 
| 18 | 
            +
            import yfinance as yf
         | 
| 19 |  | 
| 20 | 
             
            from src.telegram_bot.logger import main_logger as logger
         | 
| 21 |  | 
|  | |
| 303 | 
             
                        logger.error(f"Error fetching stock data for {ticker}: {e}")
         | 
| 304 | 
             
                        raise RuntimeError(f"Failed to fetch stock data: {e}")
         | 
| 305 |  | 
| 306 | 
            +
                @staticmethod
         | 
| 307 | 
            +
                async def fetch_prices(ticker: str, period: str = "6mo", interval: str = "1d") -> pd.DataFrame | None:
         | 
| 308 | 
            +
                    """
         | 
| 309 | 
            +
                    Fetch historical stock price data from Yahoo Finance.
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                    Args:
         | 
| 312 | 
            +
                        ticker: Stock ticker symbol (e.g., 'AAPL')
         | 
| 313 | 
            +
                        period: Time period for data (1d, 5d, 1mo, 3mo, 6mo, 1y, 2y, 5y, 10y, ytd, max)
         | 
| 314 | 
            +
                        interval: Data interval (1m, 2m, 5m, 15m, 30m, 60m, 90m, 1h, 1d, 5d, 1wk, 1mo, 3mo)
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                    Returns:
         | 
| 317 | 
            +
                        DataFrame with OHLCV data or None if error occurs
         | 
| 318 | 
            +
                    """
         | 
| 319 | 
            +
                    try:
         | 
| 320 | 
            +
                        logger.info(f"Fetching data for {ticker}")
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                        # Suppress yfinance warnings
         | 
| 323 | 
            +
                        with warnings.catch_warnings():
         | 
| 324 | 
            +
                            warnings.simplefilter("ignore")
         | 
| 325 | 
            +
                            df = yf.download(ticker, period=period, interval=interval, progress=False)
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                        if df.empty:
         | 
| 328 | 
            +
                            logger.error(f"No data found for ticker {ticker}")
         | 
| 329 | 
            +
                            return None
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                        # Select relevant columns
         | 
| 332 | 
            +
                        df = df[["Open", "High", "Low", "Close", "Volume"]].copy()
         | 
| 333 | 
            +
                        df.dropna(inplace=True)
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                        if len(df) < 60:
         | 
| 336 | 
            +
                            logger.warning(f"Insufficient data for {ticker}. Got {len(df)} days, need at least 60")
         | 
| 337 | 
            +
                            return None
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                        logger.info(f"Successfully fetched {len(df)} data points for {ticker}")
         | 
| 340 | 
            +
                        return df
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                    except Exception as e:
         | 
| 343 | 
            +
                        logger.error(f"Error fetching data for {ticker}: {e}")
         | 
| 344 | 
            +
                        return None
         | 
| 345 | 
            +
             | 
| 346 | 
             
                async def fetch_news(self, ticker: str) -> list[dict[str, Any]]:
         | 
| 347 | 
             
                    """Fetch recent news for a stock ticker."""
         | 
| 348 | 
             
                    url = f"https://query1.finance.yahoo.com/v6/finance/news?symbols={ticker.upper()}"
         | 
    	
        src/services/news_iterator.py
    CHANGED
    
    | @@ -1,8 +1,8 @@ | |
| 1 | 
             
            import os
         | 
| 2 | 
             
            import json
         | 
| 3 | 
            -
            import time
         | 
| 4 | 
             
            import re
         | 
| 5 | 
            -
            from  | 
|  | |
| 6 | 
             
            import asyncio
         | 
| 7 | 
             
            from dataclasses import dataclass
         | 
| 8 |  | 
| @@ -121,12 +121,12 @@ class CompanyNewsPostsIterator: | |
| 121 | 
             
                NEWS_PING_PERIOD = 120 # seconds
         | 
| 122 |  | 
| 123 | 
             
                def __init__(self, finhub_api_key: str | None, google_api_key: str | None, sentiments_only: list[str] | None = None,
         | 
| 124 | 
            -
                             company_symbol: str = "NVDA",  | 
| 125 | 
             
                    self.finhub_api_key = finhub_api_key
         | 
| 126 | 
             
                    self.finnhub_client = self.build_finnhub_client()
         | 
| 127 | 
             
                    self.sentiment_analyzer = SentimentAnalyzer(google_api_key=google_api_key)
         | 
| 128 | 
             
                    self.company_symbol = company_symbol
         | 
| 129 | 
            -
                    self. | 
| 130 | 
             
                    self.sentiments_only = sentiments_only or ['positive', 'negative']
         | 
| 131 | 
             
                    self.news_posts = []
         | 
| 132 | 
             
                    self.latest_timestamp = None
         | 
| @@ -161,9 +161,10 @@ class CompanyNewsPostsIterator: | |
| 161 |  | 
| 162 | 
             
                async def read_news_posts(self):
         | 
| 163 | 
             
                    loop = asyncio.get_event_loop()
         | 
|  | |
| 164 | 
             
                    news_posts = await loop.run_in_executor(
         | 
| 165 | 
             
                        None,
         | 
| 166 | 
            -
                        lambda: self.finnhub_client.company_news(self.company_symbol, _from= | 
| 167 | 
             
                    )
         | 
| 168 | 
             
                    decorated_news_posts = [NewsPostDecorator.from_dict(news_post) for news_post in news_posts]
         | 
| 169 | 
             
                    self.news_posts = self.unread_posts(decorated_news_posts)
         | 
| @@ -171,15 +172,15 @@ class CompanyNewsPostsIterator: | |
| 171 | 
             
                def build_finnhub_client(self):
         | 
| 172 | 
             
                    return finnhub.Client(api_key=self.finhub_api_key)
         | 
| 173 |  | 
| 174 | 
            -
                def default_date(self):
         | 
| 175 | 
            -
                    return datetime.now().strftime('%Y-%m-%d')
         | 
| 176 | 
            -
             | 
| 177 | 
             
                def unread_posts(self, posts):
         | 
| 178 | 
             
                    if not self.latest_timestamp:
         | 
| 179 | 
             
                        return posts
         | 
| 180 | 
             
                    idx = next((i for i, post in enumerate(posts) if post.datetime == self.latest_timestamp), 0)
         | 
| 181 | 
             
                    return posts[:idx]
         | 
| 182 |  | 
|  | |
|  | |
|  | |
| 183 |  | 
| 184 | 
             
            if __name__ == '__main__':
         | 
| 185 | 
             
                iterator = CompanyNewsPostsIterator(
         | 
|  | |
| 1 | 
             
            import os
         | 
| 2 | 
             
            import json
         | 
|  | |
| 3 | 
             
            import re
         | 
| 4 | 
            +
            from zoneinfo import ZoneInfo
         | 
| 5 | 
            +
            from datetime import datetime
         | 
| 6 | 
             
            import asyncio
         | 
| 7 | 
             
            from dataclasses import dataclass
         | 
| 8 |  | 
|  | |
| 121 | 
             
                NEWS_PING_PERIOD = 120 # seconds
         | 
| 122 |  | 
| 123 | 
             
                def __init__(self, finhub_api_key: str | None, google_api_key: str | None, sentiments_only: list[str] | None = None,
         | 
| 124 | 
            +
                             company_symbol: str = "NVDA", time_zone: str = "America/New_York"):
         | 
| 125 | 
             
                    self.finhub_api_key = finhub_api_key
         | 
| 126 | 
             
                    self.finnhub_client = self.build_finnhub_client()
         | 
| 127 | 
             
                    self.sentiment_analyzer = SentimentAnalyzer(google_api_key=google_api_key)
         | 
| 128 | 
             
                    self.company_symbol = company_symbol
         | 
| 129 | 
            +
                    self.time_zone = time_zone
         | 
| 130 | 
             
                    self.sentiments_only = sentiments_only or ['positive', 'negative']
         | 
| 131 | 
             
                    self.news_posts = []
         | 
| 132 | 
             
                    self.latest_timestamp = None
         | 
|  | |
| 161 |  | 
| 162 | 
             
                async def read_news_posts(self):
         | 
| 163 | 
             
                    loop = asyncio.get_event_loop()
         | 
| 164 | 
            +
                    date = self.news_date()
         | 
| 165 | 
             
                    news_posts = await loop.run_in_executor(
         | 
| 166 | 
             
                        None,
         | 
| 167 | 
            +
                        lambda: self.finnhub_client.company_news(self.company_symbol, _from=date, to=date)
         | 
| 168 | 
             
                    )
         | 
| 169 | 
             
                    decorated_news_posts = [NewsPostDecorator.from_dict(news_post) for news_post in news_posts]
         | 
| 170 | 
             
                    self.news_posts = self.unread_posts(decorated_news_posts)
         | 
|  | |
| 172 | 
             
                def build_finnhub_client(self):
         | 
| 173 | 
             
                    return finnhub.Client(api_key=self.finhub_api_key)
         | 
| 174 |  | 
|  | |
|  | |
|  | |
| 175 | 
             
                def unread_posts(self, posts):
         | 
| 176 | 
             
                    if not self.latest_timestamp:
         | 
| 177 | 
             
                        return posts
         | 
| 178 | 
             
                    idx = next((i for i, post in enumerate(posts) if post.datetime == self.latest_timestamp), 0)
         | 
| 179 | 
             
                    return posts[:idx]
         | 
| 180 |  | 
| 181 | 
            +
                def news_date(self):
         | 
| 182 | 
            +
                    return datetime.now(ZoneInfo(self.time_zone)).strftime('%Y-%m-%d')
         | 
| 183 | 
            +
             | 
| 184 |  | 
| 185 | 
             
            if __name__ == '__main__':
         | 
| 186 | 
             
                iterator = CompanyNewsPostsIterator(
         | 
    	
        src/services/stock_predictor.py
    ADDED
    
    | @@ -0,0 +1,451 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Async Stock Price Predictor using Amazon Chronos T5-Small Time Series Model
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            Required installations:
         | 
| 5 | 
            +
            pip install chronos-forecasting yfinance torch numpy pandas aiohttp asyncio
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            Usage:
         | 
| 8 | 
            +
            python stock_predictor.py
         | 
| 9 | 
            +
            """
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import yfinance as yf
         | 
| 12 | 
            +
            import torch
         | 
| 13 | 
            +
            import numpy as np
         | 
| 14 | 
            +
            from chronos import ChronosPipeline
         | 
| 15 | 
            +
            import pandas as pd
         | 
| 16 | 
            +
            import logging
         | 
| 17 | 
            +
            import asyncio
         | 
| 18 | 
            +
            import aiohttp
         | 
| 19 | 
            +
            from concurrent.futures import ThreadPoolExecutor
         | 
| 20 | 
            +
            from typing import Optional, Tuple, List, Dict
         | 
| 21 | 
            +
            from datetime import datetime
         | 
| 22 | 
            +
            import warnings
         | 
| 23 | 
            +
            import time
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            # Configure logging
         | 
| 26 | 
            +
            logging.basicConfig(level=logging.INFO)
         | 
| 27 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            class AsyncStockPredictor:
         | 
| 31 | 
            +
                """
         | 
| 32 | 
            +
                An async stock price predictor using Amazon Chronos T5 time series model.
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                This class fetches historical stock data asynchronously and uses the Chronos model
         | 
| 35 | 
            +
                to predict future stock prices and movement trends with concurrent processing.
         | 
| 36 | 
            +
                """
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                def __init__(self, model_name: str = "amazon/chronos-t5-small", max_workers: int = 4):
         | 
| 39 | 
            +
                    """
         | 
| 40 | 
            +
                    Initialize the async stock predictor with Chronos model.
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    Args:
         | 
| 43 | 
            +
                        model_name: Name of the Chronos model to use
         | 
| 44 | 
            +
                        max_workers: Maximum number of worker threads for CPU-intensive tasks
         | 
| 45 | 
            +
                    """
         | 
| 46 | 
            +
                    self.model_name = model_name
         | 
| 47 | 
            +
                    self.max_workers = max_workers
         | 
| 48 | 
            +
                    self.executor = ThreadPoolExecutor(max_workers=max_workers)
         | 
| 49 | 
            +
                    self.pipeline = None
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                async def initialize(self):
         | 
| 52 | 
            +
                    """Initialize the model asynchronously."""
         | 
| 53 | 
            +
                    try:
         | 
| 54 | 
            +
                        logger.info(f"Loading Chronos model: {self.model_name}")
         | 
| 55 | 
            +
                        # Run model loading in thread pool to avoid blocking
         | 
| 56 | 
            +
                        self.pipeline = await asyncio.get_event_loop().run_in_executor(
         | 
| 57 | 
            +
                            self.executor, self._load_model
         | 
| 58 | 
            +
                        )
         | 
| 59 | 
            +
                        logger.info("Chronos model loaded successfully")
         | 
| 60 | 
            +
                    except Exception as e:
         | 
| 61 | 
            +
                        logger.error(f"Error loading model: {e}")
         | 
| 62 | 
            +
                        raise
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                def _load_model(self):
         | 
| 65 | 
            +
                    """Load the Chronos model (CPU intensive, runs in thread pool)."""
         | 
| 66 | 
            +
                    try:
         | 
| 67 | 
            +
                        return ChronosPipeline.from_pretrained(
         | 
| 68 | 
            +
                            self.model_name,
         | 
| 69 | 
            +
                            device_map="auto",
         | 
| 70 | 
            +
                            torch_dtype=torch.bfloat16,
         | 
| 71 | 
            +
                        )
         | 
| 72 | 
            +
                    except Exception as e:
         | 
| 73 | 
            +
                        logger.warning(f"Failed to load with optimized settings: {e}")
         | 
| 74 | 
            +
                        logger.info("Attempting to load with default settings...")
         | 
| 75 | 
            +
                        return ChronosPipeline.from_pretrained(self.model_name)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                async def fetch_prices_async(self, ticker: str, period: str = "6mo", interval: str = "1d") -> Optional[
         | 
| 78 | 
            +
                    pd.DataFrame]:
         | 
| 79 | 
            +
                    """
         | 
| 80 | 
            +
                    Fetch historical stock price data asynchronously.
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    Args:
         | 
| 83 | 
            +
                        ticker: Stock ticker symbol (e.g., 'AAPL')
         | 
| 84 | 
            +
                        period: Time period for data (1d, 5d, 1mo, 3mo, 6mo, 1y, 2y, 5y, 10y, ytd, max)
         | 
| 85 | 
            +
                        interval: Data interval (1m, 2m, 5m, 15m, 30m, 60m, 90m, 1h, 1d, 5d, 1wk, 1mo, 3mo)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    Returns:
         | 
| 88 | 
            +
                        DataFrame with OHLCV data or None if error occurs
         | 
| 89 | 
            +
                    """
         | 
| 90 | 
            +
                    try:
         | 
| 91 | 
            +
                        logger.info(f"Fetching data for {ticker}")
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                        # Run yfinance download in thread pool to avoid blocking
         | 
| 94 | 
            +
                        df = await asyncio.get_event_loop().run_in_executor(
         | 
| 95 | 
            +
                            self.executor, self._fetch_data_sync, ticker, period, interval
         | 
| 96 | 
            +
                        )
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                        if df is None or df.empty:
         | 
| 99 | 
            +
                            logger.error(f"No data found for ticker {ticker}")
         | 
| 100 | 
            +
                            return None
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                        # Select relevant columns
         | 
| 103 | 
            +
                        df = df[["Open", "High", "Low", "Close", "Volume"]].copy()
         | 
| 104 | 
            +
                        df.dropna(inplace=True)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                        if len(df) < 30:
         | 
| 107 | 
            +
                            logger.warning(f"Insufficient data for {ticker}. Got {len(df)} days, need at least 30")
         | 
| 108 | 
            +
                            return None
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                        logger.info(f"Successfully fetched {len(df)} data points for {ticker}")
         | 
| 111 | 
            +
                        return df
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    except Exception as e:
         | 
| 114 | 
            +
                        logger.error(f"Error fetching data for {ticker}: {e}")
         | 
| 115 | 
            +
                        return None
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                def _fetch_data_sync(self, ticker: str, period: str, interval: str) -> Optional[pd.DataFrame]:
         | 
| 118 | 
            +
                    """Synchronous data fetching (runs in thread pool)."""
         | 
| 119 | 
            +
                    try:
         | 
| 120 | 
            +
                        with warnings.catch_warnings():
         | 
| 121 | 
            +
                            warnings.simplefilter("ignore")
         | 
| 122 | 
            +
                            df = yf.download(ticker, period=period, interval=interval, progress=False)
         | 
| 123 | 
            +
                        return df
         | 
| 124 | 
            +
                    except Exception as e:
         | 
| 125 | 
            +
                        logger.error(f"Error in sync data fetch for {ticker}: {e}")
         | 
| 126 | 
            +
                        return None
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                async def predict_next_day_async(self, prices: pd.DataFrame, prediction_length: int = 1, num_samples: int = 20) -> \
         | 
| 129 | 
            +
                Tuple[str, float, List[float]]:
         | 
| 130 | 
            +
                    """
         | 
| 131 | 
            +
                    Predict next day's price using Chronos time series model asynchronously.
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    Args:
         | 
| 134 | 
            +
                        prices: DataFrame with historical price data
         | 
| 135 | 
            +
                        prediction_length: Number of future periods to predict
         | 
| 136 | 
            +
                        num_samples: Number of sample predictions to generate
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    Returns:
         | 
| 139 | 
            +
                        Tuple of (trend_description, confidence_score, predicted_prices)
         | 
| 140 | 
            +
                    """
         | 
| 141 | 
            +
                    if self.pipeline is None:
         | 
| 142 | 
            +
                        return "❌ Model not initialized", 0.0, []
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    if prices is None or len(prices) < 30:
         | 
| 145 | 
            +
                        return "❓ Insufficient data", 0.0, []
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    try:
         | 
| 148 | 
            +
                        # Run prediction in thread pool as it's CPU intensive
         | 
| 149 | 
            +
                        result = await asyncio.get_event_loop().run_in_executor(
         | 
| 150 | 
            +
                            self.executor, self._predict_sync, prices, prediction_length, num_samples
         | 
| 151 | 
            +
                        )
         | 
| 152 | 
            +
                        return result
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    except Exception as e:
         | 
| 155 | 
            +
                        logger.error(f"Error during async prediction: {e}")
         | 
| 156 | 
            +
                        return "❌ Prediction error", 0.0, []
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                def _predict_sync(self, prices: pd.DataFrame, prediction_length: int, num_samples: int) -> Tuple[
         | 
| 159 | 
            +
                    str, float, List[float]]:
         | 
| 160 | 
            +
                    """Synchronous prediction (runs in thread pool)."""
         | 
| 161 | 
            +
                    try:
         | 
| 162 | 
            +
                        # Use closing prices as the time series
         | 
| 163 | 
            +
                        closes = prices["Close"].values
         | 
| 164 | 
            +
                        context_length = min(len(closes), 512)  # Chronos context limit
         | 
| 165 | 
            +
                        context = closes[-context_length:]
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                        logger.info(f"Using {context_length} data points for prediction")
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                        # Convert to tensor
         | 
| 170 | 
            +
                        context_tensor = torch.tensor(context, dtype=torch.float32).unsqueeze(0)
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                        # Generate predictions
         | 
| 173 | 
            +
                        with torch.no_grad():
         | 
| 174 | 
            +
                            forecast = self.pipeline.predict(
         | 
| 175 | 
            +
                                context=context_tensor,
         | 
| 176 | 
            +
                                prediction_length=prediction_length,
         | 
| 177 | 
            +
                                num_samples=num_samples
         | 
| 178 | 
            +
                            )
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                        # Extract predictions
         | 
| 181 | 
            +
                        predictions = forecast[0, :, 0].numpy()  # First batch, all samples, first prediction step
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                        # Calculate statistics
         | 
| 184 | 
            +
                        mean_prediction = np.mean(predictions)
         | 
| 185 | 
            +
                        std_prediction = np.std(predictions)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                        current_price = float(closes[-1])
         | 
| 188 | 
            +
                        price_change_pct = ((mean_prediction - current_price) / current_price) * 100
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                        # Determine trend based on prediction
         | 
| 191 | 
            +
                        if price_change_pct > 2.0:
         | 
| 192 | 
            +
                            trend = "🚀 Strong Growth Expected"
         | 
| 193 | 
            +
                            confidence = min(0.9, abs(price_change_pct) / 10.0)
         | 
| 194 | 
            +
                        elif price_change_pct > 0.5:
         | 
| 195 | 
            +
                            trend = "📈 Moderate Growth Expected"
         | 
| 196 | 
            +
                            confidence = min(0.7, abs(price_change_pct) / 5.0)
         | 
| 197 | 
            +
                        elif price_change_pct < -2.0:
         | 
| 198 | 
            +
                            trend = "📉 Strong Decline Expected"
         | 
| 199 | 
            +
                            confidence = min(0.9, abs(price_change_pct) / 10.0)
         | 
| 200 | 
            +
                        elif price_change_pct < -0.5:
         | 
| 201 | 
            +
                            trend = "📉 Moderate Decline Expected"
         | 
| 202 | 
            +
                            confidence = min(0.7, abs(price_change_pct) / 5.0)
         | 
| 203 | 
            +
                        else:
         | 
| 204 | 
            +
                            trend = "➡️ Sideways Movement Expected"
         | 
| 205 | 
            +
                            confidence = 0.5
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                        # Adjust confidence based on prediction variance
         | 
| 208 | 
            +
                        variance_factor = min(1.0, std_prediction / current_price)
         | 
| 209 | 
            +
                        confidence = max(0.1, confidence * (1 - variance_factor))
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                        logger.info(f"Prediction: ${mean_prediction:.2f} ({price_change_pct:+.2f}%) - {trend}")
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                        return trend, confidence, predictions.tolist()
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    except Exception as e:
         | 
| 216 | 
            +
                        logger.error(f"Error in sync prediction: {e}")
         | 
| 217 | 
            +
                        return "❌ Prediction error", 0.0, []
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                async def calculate_technical_indicators_async(self, prices: pd.DataFrame) -> dict:
         | 
| 220 | 
            +
                    """
         | 
| 221 | 
            +
                    Calculate basic technical indicators asynchronously.
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                    Args:
         | 
| 224 | 
            +
                        prices: DataFrame with historical price data
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                    Returns:
         | 
| 227 | 
            +
                        Dictionary with technical indicators
         | 
| 228 | 
            +
                    """
         | 
| 229 | 
            +
                    try:
         | 
| 230 | 
            +
                        # Run calculations in thread pool for consistency
         | 
| 231 | 
            +
                        indicators = await asyncio.get_event_loop().run_in_executor(
         | 
| 232 | 
            +
                            self.executor, self._calculate_indicators_sync, prices
         | 
| 233 | 
            +
                        )
         | 
| 234 | 
            +
                        return indicators
         | 
| 235 | 
            +
                    except Exception as e:
         | 
| 236 | 
            +
                        logger.error(f"Error calculating technical indicators: {e}")
         | 
| 237 | 
            +
                        return {}
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                def _calculate_indicators_sync(self, prices: pd.DataFrame) -> dict:
         | 
| 240 | 
            +
                    """Synchronous indicator calculation."""
         | 
| 241 | 
            +
                    try:
         | 
| 242 | 
            +
                        # Simple moving averages
         | 
| 243 | 
            +
                        sma_20 = prices['Close'].rolling(window=20).mean().iloc[-1]
         | 
| 244 | 
            +
                        sma_50 = prices['Close'].rolling(window=50).mean().iloc[-1]
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                        # Price change
         | 
| 247 | 
            +
                        price_change = ((prices['Close'].iloc[-1] - prices['Close'].iloc[-2]) / prices['Close'].iloc[-2]) * 100
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                        # Volume analysis
         | 
| 250 | 
            +
                        avg_volume = prices['Volume'].rolling(window=20).mean().iloc[-1]
         | 
| 251 | 
            +
                        current_volume = prices['Volume'].iloc[-1]
         | 
| 252 | 
            +
                        volume_ratio = current_volume / avg_volume if avg_volume > 0 else 1.0
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                        return {
         | 
| 255 | 
            +
                            'sma_20': sma_20,
         | 
| 256 | 
            +
                            'sma_50': sma_50,
         | 
| 257 | 
            +
                            'price_change': price_change,
         | 
| 258 | 
            +
                            'volume_ratio': volume_ratio
         | 
| 259 | 
            +
                        }
         | 
| 260 | 
            +
                    except Exception as e:
         | 
| 261 | 
            +
                        logger.error(f"Error in sync indicator calculation: {e}")
         | 
| 262 | 
            +
                        return {}
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                async def analyze_stock_async(self, ticker: str) -> str:
         | 
| 265 | 
            +
                    """
         | 
| 266 | 
            +
                    Perform complete stock analysis asynchronously.
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                    Args:
         | 
| 269 | 
            +
                        ticker: Stock ticker symbol
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                    Returns:
         | 
| 272 | 
            +
                        Formatted analysis message
         | 
| 273 | 
            +
                    """
         | 
| 274 | 
            +
                    try:
         | 
| 275 | 
            +
                        # Fetch price data
         | 
| 276 | 
            +
                        prices = await self.fetch_prices_async(ticker)
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                        if prices is None:
         | 
| 279 | 
            +
                            return f"❌ Could not fetch data for {ticker}"
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                        # Run prediction and technical analysis concurrently
         | 
| 282 | 
            +
                        prediction_task = self.predict_next_day_async(prices)
         | 
| 283 | 
            +
                        indicators_task = self.calculate_technical_indicators_async(prices)
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                        # Wait for both tasks to complete
         | 
| 286 | 
            +
                        (trend, confidence, predictions), indicators = await asyncio.gather(
         | 
| 287 | 
            +
                            prediction_task, indicators_task
         | 
| 288 | 
            +
                        )
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                        # Create analysis message
         | 
| 291 | 
            +
                        message = await self.create_analysis_message_async(
         | 
| 292 | 
            +
                            ticker, prices, trend, confidence, predictions, indicators
         | 
| 293 | 
            +
                        )
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                        return message
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                    except Exception as e:
         | 
| 298 | 
            +
                        logger.error(f"Error analyzing {ticker}: {e}")
         | 
| 299 | 
            +
                        return f"❌ Error analyzing {ticker}: {e}"
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                async def create_analysis_message_async(self, ticker: str, prices: pd.DataFrame, trend: str,
         | 
| 302 | 
            +
                                                        confidence: float, predictions: List[float] = None,
         | 
| 303 | 
            +
                                                        indicators: dict = None) -> str:
         | 
| 304 | 
            +
                    """
         | 
| 305 | 
            +
                    Create a comprehensive analysis message asynchronously.
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                    Args:
         | 
| 308 | 
            +
                        ticker: Stock ticker symbol
         | 
| 309 | 
            +
                        prices: DataFrame with price data
         | 
| 310 | 
            +
                        trend: Predicted trend
         | 
| 311 | 
            +
                        confidence: Prediction confidence score
         | 
| 312 | 
            +
                        predictions: List of predicted prices
         | 
| 313 | 
            +
                        indicators: Technical indicators dictionary
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                    Returns:
         | 
| 316 | 
            +
                        Formatted analysis message
         | 
| 317 | 
            +
                    """
         | 
| 318 | 
            +
                    if prices is None or prices.empty:
         | 
| 319 | 
            +
                        return f"❌ Unable to analyze {ticker} - no data available"
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                    try:
         | 
| 322 | 
            +
                        last_close = float(prices["Close"].iloc[-1])
         | 
| 323 | 
            +
                        last_date = prices.index[-1].strftime('%Y-%m-%d')
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                        message_parts = [
         | 
| 326 | 
            +
                            f"📊 **Stock Analysis: {ticker}**",
         | 
| 327 | 
            +
                            f"📅 Date: {last_date}",
         | 
| 328 | 
            +
                            f"💰 Current Price: ${last_close:.2f}",
         | 
| 329 | 
            +
                            f"🔮 Prediction: {trend}",
         | 
| 330 | 
            +
                            f"🎯 Confidence: {confidence:.1%}",
         | 
| 331 | 
            +
                            ""
         | 
| 332 | 
            +
                        ]
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                        # Add price prediction if available
         | 
| 335 | 
            +
                        if predictions and len(predictions) > 0:
         | 
| 336 | 
            +
                            mean_pred = np.mean(predictions)
         | 
| 337 | 
            +
                            min_pred = np.min(predictions)
         | 
| 338 | 
            +
                            max_pred = np.max(predictions)
         | 
| 339 | 
            +
                            price_change = ((mean_pred - last_close) / last_close) * 100
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                            message_parts.extend([
         | 
| 342 | 
            +
                                "🎲 **Price Predictions:**",
         | 
| 343 | 
            +
                                f"• Expected Price: ${mean_pred:.2f} ({price_change:+.2f}%)",
         | 
| 344 | 
            +
                                f"• Price Range: ${min_pred:.2f} - ${max_pred:.2f}",
         | 
| 345 | 
            +
                                f"• Prediction Samples: {len(predictions)}",
         | 
| 346 | 
            +
                                ""
         | 
| 347 | 
            +
                            ])
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                        # Add technical indicators if available
         | 
| 350 | 
            +
                        if indicators:
         | 
| 351 | 
            +
                            message_parts.extend([
         | 
| 352 | 
            +
                                "📈 **Technical Indicators:**",
         | 
| 353 | 
            +
                                f"• 20-day SMA: ${indicators.get('sma_20', 0):.2f}",
         | 
| 354 | 
            +
                                f"• 50-day SMA: ${indicators.get('sma_50', 0):.2f}",
         | 
| 355 | 
            +
                                f"• Daily Change: {indicators.get('price_change', 0):.2f}%",
         | 
| 356 | 
            +
                                f"• Volume Ratio: {indicators.get('volume_ratio', 0):.2f}x",
         | 
| 357 | 
            +
                                ""
         | 
| 358 | 
            +
                            ])
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                        message_parts.extend([
         | 
| 361 | 
            +
                            "⚠️ **Disclaimer:** This is AI-generated analysis, not financial advice.",
         | 
| 362 | 
            +
                            "Predictions are based on historical patterns and may not reflect future performance.",
         | 
| 363 | 
            +
                            "Always do your own research and consult financial advisors before investing."
         | 
| 364 | 
            +
                        ])
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                        return "\n".join(message_parts)
         | 
| 367 | 
            +
             | 
| 368 | 
            +
                    except Exception as e:
         | 
| 369 | 
            +
                        logger.error(f"Error creating message: {e}")
         | 
| 370 | 
            +
                        return f"❌ Error creating analysis for {ticker}"
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                async def analyze_multiple_stocks(self, tickers: List[str]) -> Dict[str, str]:
         | 
| 373 | 
            +
                    """
         | 
| 374 | 
            +
                    Analyze multiple stocks concurrently.
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                    Args:
         | 
| 377 | 
            +
                        tickers: List of stock ticker symbols
         | 
| 378 | 
            +
             | 
| 379 | 
            +
                    Returns:
         | 
| 380 | 
            +
                        Dictionary mapping tickers to analysis messages
         | 
| 381 | 
            +
                    """
         | 
| 382 | 
            +
                    tasks = [self.analyze_stock_async(ticker) for ticker in tickers]
         | 
| 383 | 
            +
                    results = await asyncio.gather(*tasks, return_exceptions=True)
         | 
| 384 | 
            +
             | 
| 385 | 
            +
                    analysis_results = {}
         | 
| 386 | 
            +
                    for ticker, result in zip(tickers, results):
         | 
| 387 | 
            +
                        if isinstance(result, Exception):
         | 
| 388 | 
            +
                            analysis_results[ticker] = f"❌ Error analyzing {ticker}: {result}"
         | 
| 389 | 
            +
                        else:
         | 
| 390 | 
            +
                            analysis_results[ticker] = result
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                    return analysis_results
         | 
| 393 | 
            +
             | 
| 394 | 
            +
                async def close(self):
         | 
| 395 | 
            +
                    """Clean up resources."""
         | 
| 396 | 
            +
                    if hasattr(self, 'executor'):
         | 
| 397 | 
            +
                        self.executor.shutdown(wait=True)
         | 
| 398 | 
            +
                    logger.info("AsyncStockPredictor resources cleaned up")
         | 
| 399 | 
            +
             | 
| 400 | 
            +
             | 
| 401 | 
            +
            async def main():
         | 
| 402 | 
            +
                """Main async function to demonstrate the stock predictor."""
         | 
| 403 | 
            +
                predictor = AsyncStockPredictor()
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                try:
         | 
| 406 | 
            +
                    # Initialize the model
         | 
| 407 | 
            +
                    await predictor.initialize()
         | 
| 408 | 
            +
             | 
| 409 | 
            +
                    # List of tickers to analyze
         | 
| 410 | 
            +
                    tickers = ["AAPL", "GOOGL", "MSFT", "TSLA", "NVDA", "AMD"]
         | 
| 411 | 
            +
             | 
| 412 | 
            +
                    print(f"\n🚀 Starting concurrent analysis of {len(tickers)} stocks...")
         | 
| 413 | 
            +
                    start_time = time.time()
         | 
| 414 | 
            +
             | 
| 415 | 
            +
                    # Analyze all stocks concurrently
         | 
| 416 | 
            +
                    results = await predictor.analyze_multiple_stocks(tickers)
         | 
| 417 | 
            +
             | 
| 418 | 
            +
                    end_time = time.time()
         | 
| 419 | 
            +
                    total_time = end_time - start_time
         | 
| 420 | 
            +
             | 
| 421 | 
            +
                    # Display results
         | 
| 422 | 
            +
                    for ticker, analysis in results.items():
         | 
| 423 | 
            +
                        print(f"\n{'=' * 60}")
         | 
| 424 | 
            +
                        print(f"Analysis for {ticker}")
         | 
| 425 | 
            +
                        print('=' * 60)
         | 
| 426 | 
            +
                        print(analysis)
         | 
| 427 | 
            +
             | 
| 428 | 
            +
                    print(f"\n🏁 Analysis completed in {total_time:.2f} seconds")
         | 
| 429 | 
            +
                    print(f"⚡ Average time per stock: {total_time / len(tickers):.2f} seconds")
         | 
| 430 | 
            +
             | 
| 431 | 
            +
                except Exception as e:
         | 
| 432 | 
            +
                    logger.error(f"Error in main execution: {e}")
         | 
| 433 | 
            +
                    print(f"❌ Application error: {e}")
         | 
| 434 | 
            +
             | 
| 435 | 
            +
                finally:
         | 
| 436 | 
            +
                    # Clean up resources
         | 
| 437 | 
            +
                    await predictor.close()
         | 
| 438 | 
            +
             | 
| 439 | 
            +
             | 
| 440 | 
            +
            def run_async_analysis():
         | 
| 441 | 
            +
                """Entry point for running the async analysis."""
         | 
| 442 | 
            +
                try:
         | 
| 443 | 
            +
                    asyncio.run(main())
         | 
| 444 | 
            +
                except KeyboardInterrupt:
         | 
| 445 | 
            +
                    print("\n🛑 Analysis interrupted by user")
         | 
| 446 | 
            +
                except Exception as e:
         | 
| 447 | 
            +
                    print(f"❌ Fatal error: {e}")
         | 
| 448 | 
            +
             | 
| 449 | 
            +
             | 
| 450 | 
            +
            if __name__ == "__main__":
         | 
| 451 | 
            +
                run_async_analysis()
         | 
    	
        src/telegram_bot/telegram_bot_service.py
    CHANGED
    
    | @@ -19,6 +19,7 @@ from src.services.news_pooling_service import NewsPollingService | |
| 19 | 
             
            from src.core.risk_management.risk_analyzer import RiskAnalyzer
         | 
| 20 | 
             
            from src.services.news_iterator import CompanyNewsPostsIterator
         | 
| 21 | 
             
            from src.services.async_stock_price_predictor import AsyncStockPricePredictor, handle_stock_prediction
         | 
|  | |
| 22 |  | 
| 23 |  | 
| 24 | 
             
            class TelegramBotService:
         | 
| @@ -522,11 +523,16 @@ class TelegramBotService: | |
| 522 | 
             
                ) -> None:
         | 
| 523 | 
             
                    ticker = 'NVDA'  # Default ticker if not specified
         | 
| 524 | 
             
                    # Show available backends
         | 
|  | |
| 525 | 
             
                    available_backends = AsyncStockPricePredictor.get_available_backends()
         | 
| 526 | 
             
                    main_logger.info(f"Available Keras backends: {available_backends}")
         | 
| 527 | 
             
                    main_logger.info(f"\nUsing Keras backend: {os.environ.get('KERAS_BACKEND')}")
         | 
|  | |
|  | |
|  | |
| 528 | 
             
                    ticker = command_parts[1].upper() if len(command_parts) == 2 else ticker
         | 
| 529 | 
             
                    await self.send_message_via_proxy(chat_id, f"Predicting the price for the ticker {ticker} ...")
         | 
|  | |
| 530 | 
             
                    # Initialize with Keras 3.0 and JAX backend
         | 
| 531 | 
             
                    predictor = AsyncStockPricePredictor(
         | 
| 532 | 
             
                        lstm_model_repo="jengyang/lstm-stock-prediction-model",
         | 
| @@ -535,6 +541,23 @@ class TelegramBotService: | |
| 535 | 
             
                        max_workers=2,
         | 
| 536 | 
             
                        keras_backend="jax"  # Can also use "torch" or "tensorflow"
         | 
| 537 | 
             
                    )
         | 
| 538 | 
            -
                     | 
| 539 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 540 | 
             
                    await self.send_message_via_proxy(chat_id, result)
         | 
|  | |
| 19 | 
             
            from src.core.risk_management.risk_analyzer import RiskAnalyzer
         | 
| 20 | 
             
            from src.services.news_iterator import CompanyNewsPostsIterator
         | 
| 21 | 
             
            from src.services.async_stock_price_predictor import AsyncStockPricePredictor, handle_stock_prediction
         | 
| 22 | 
            +
            from src.services.stock_predictor import AsyncStockPredictor
         | 
| 23 |  | 
| 24 |  | 
| 25 | 
             
            class TelegramBotService:
         | 
|  | |
| 523 | 
             
                ) -> None:
         | 
| 524 | 
             
                    ticker = 'NVDA'  # Default ticker if not specified
         | 
| 525 | 
             
                    # Show available backends
         | 
| 526 | 
            +
                    '''
         | 
| 527 | 
             
                    available_backends = AsyncStockPricePredictor.get_available_backends()
         | 
| 528 | 
             
                    main_logger.info(f"Available Keras backends: {available_backends}")
         | 
| 529 | 
             
                    main_logger.info(f"\nUsing Keras backend: {os.environ.get('KERAS_BACKEND')}")
         | 
| 530 | 
            +
                    '''
         | 
| 531 | 
            +
                    predictor = AsyncStockPredictor()
         | 
| 532 | 
            +
                    await predictor.initialize()
         | 
| 533 | 
             
                    ticker = command_parts[1].upper() if len(command_parts) == 2 else ticker
         | 
| 534 | 
             
                    await self.send_message_via_proxy(chat_id, f"Predicting the price for the ticker {ticker} ...")
         | 
| 535 | 
            +
                    '''
         | 
| 536 | 
             
                    # Initialize with Keras 3.0 and JAX backend
         | 
| 537 | 
             
                    predictor = AsyncStockPricePredictor(
         | 
| 538 | 
             
                        lstm_model_repo="jengyang/lstm-stock-prediction-model",
         | 
|  | |
| 541 | 
             
                        max_workers=2,
         | 
| 542 | 
             
                        keras_backend="jax"  # Can also use "torch" or "tensorflow"
         | 
| 543 | 
             
                    )
         | 
| 544 | 
            +
                    '''
         | 
| 545 | 
            +
                    # List of tickers to analyze
         | 
| 546 | 
            +
                    tickers = [ticker]
         | 
| 547 | 
            +
                    main_logger.info(f"\n🚀 Starting concurrent analysis of {len(tickers)} stocks...")
         | 
| 548 | 
            +
                    start_time = time.time()
         | 
| 549 | 
            +
                    # Analyze all stocks concurrently
         | 
| 550 | 
            +
                    results = await predictor.analyze_multiple_stocks(tickers)
         | 
| 551 | 
            +
                    end_time = time.time()
         | 
| 552 | 
            +
                    total_time = end_time - start_time
         | 
| 553 | 
            +
                    # Display results
         | 
| 554 | 
            +
                    for ticker, analysis in results.items():
         | 
| 555 | 
            +
                        main_logger.info(f"\n{'=' * 60}")
         | 
| 556 | 
            +
                        main_logger.info(f"Analysis for {ticker}")
         | 
| 557 | 
            +
                        main_logger.info('=' * 60)
         | 
| 558 | 
            +
                        main_logger.info(analysis)
         | 
| 559 | 
            +
                    main_logger.info(f"\n🏁 Analysis completed in {total_time:.2f} seconds")
         | 
| 560 | 
            +
                    main_logger.info(f"⚡ Average time per stock: {total_time / len(tickers):.2f} seconds")
         | 
| 561 | 
            +
                    #main_logger.info(f'predicted price: {result}')
         | 
| 562 | 
            +
                    result = results.get(ticker)
         | 
| 563 | 
             
                    await self.send_message_via_proxy(chat_id, result)
         | 
