|
import os |
|
import pickle |
|
import asyncio |
|
from datetime import datetime, timezone, timedelta |
|
from collections import defaultdict |
|
from typing import Any |
|
import warnings |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import aiohttp |
|
import tensorflow as tf |
|
import keras |
|
from sklearn.preprocessing import MinMaxScaler |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline |
|
from huggingface_hub import hf_hub_download |
|
from concurrent.futures import ProcessPoolExecutor |
|
import yfinance as yf |
|
|
|
from src.telegram_bot.logger import main_logger as logger |
|
|
|
|
|
os.environ["KERAS_BACKEND"] = "jax" |
|
|
|
|
|
class AsyncStockPricePredictor: |
|
""" |
|
Asynchronous stock price predictor using Keras 3.0 models from Hugging Face. |
|
|
|
This class loads LSTM models and sentiment analysis models directly from |
|
Hugging Face Hub using the new Keras 3.0 model loading API. |
|
""" |
|
|
|
REQUIRED_COLUMNS = ['Open', 'High', 'Low', 'Close', 'Volume', 'Sentiment'] |
|
DEFAULT_HEADERS = { |
|
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' |
|
} |
|
|
|
def __init__( |
|
self, |
|
lstm_model_repo: str = "jengyang/lstm-stock-prediction-model", |
|
scaler_repo: str = "jengyang/lstm-stock-prediction-model", |
|
sentiment_model: str = "TLOB/roberta-base-finetuned-financial-text-classification", |
|
sequence_length: int = 60, |
|
news_lookback_days: int = 7, |
|
device: int = -1, |
|
max_workers: int = 4, |
|
timeout: int = 30, |
|
keras_backend: str = "jax", |
|
use_auth_token: str | None = None |
|
): |
|
""" |
|
Initialize the async stock predictor with Keras 3.0 and HuggingFace models. |
|
|
|
Args: |
|
lstm_model_repo: HF repository for LSTM model (Keras 3.0 compatible) |
|
scaler_repo: HF repository for scalers |
|
sentiment_model: HF repository for sentiment analysis |
|
sequence_length: Number of days for LSTM sequence |
|
news_lookback_days: Days of news to analyze |
|
device: Device for transformers (-1 for CPU, 0+ for GPU) |
|
max_workers: Max threads for CPU-bound operations |
|
timeout: HTTP request timeout |
|
keras_backend: Keras backend ("jax", "torch", "tensorflow") |
|
use_auth_token: HF token for private repos |
|
""" |
|
|
|
if "KERAS_BACKEND" not in os.environ: |
|
os.environ["KERAS_BACKEND"] = keras_backend |
|
self.sequence_length = sequence_length |
|
self.news_lookback_days = news_lookback_days |
|
self.timeout = aiohttp.ClientTimeout(total=timeout) |
|
self.use_auth_token = use_auth_token or os.getenv("HF_TOKEN") |
|
|
|
self._load_keras_models(lstm_model_repo, scaler_repo, sentiment_model, device) |
|
|
|
self.executor = ProcessPoolExecutor(max_workers=max_workers) |
|
|
|
def _load_keras_models( |
|
self, |
|
lstm_repo: str, |
|
scaler_repo: str, |
|
sentiment_repo: str, |
|
device: int |
|
) -> None: |
|
"""Load models from Hugging Face Hub using multiple fallback approaches.""" |
|
try: |
|
|
|
model_loaded = False |
|
|
|
|
|
try: |
|
logger.info(f"Attempting to load Keras 3.0 model from hf://{lstm_repo}") |
|
self.model = keras.saving.load_model(f"hf://{lstm_repo}") |
|
logger.info( |
|
f"Keras 3.0 model loaded successfully with {os.environ.get('KERAS_BACKEND', 'default')} backend") |
|
model_loaded = True |
|
except Exception as e: |
|
logger.warning(f"Keras 3.0 loading failed: {e}") |
|
|
|
|
|
if not model_loaded: |
|
logger.info(f"Trying to download model files from {lstm_repo}") |
|
model_files = [ |
|
"model.keras", |
|
"model.h5", |
|
"lstm_model.keras", |
|
"lstm_model.h5", |
|
"saved_model.pb", |
|
"pytorch_model.bin" |
|
] |
|
|
|
for filename in model_files: |
|
try: |
|
model_path = hf_hub_download( |
|
repo_id=lstm_repo, |
|
filename=filename, |
|
token=self.use_auth_token |
|
) |
|
logger.info(f"Found model file: {filename}") |
|
|
|
if filename.endswith('.keras') or filename.endswith('.h5'): |
|
|
|
if os.environ.get("KERAS_BACKEND") != "tensorflow": |
|
|
|
tf_model = tf.keras.models.load_model(model_path) |
|
|
|
self.model = keras.Model.from_config(tf_model.get_config()) |
|
self.model.set_weights(tf_model.get_weights()) |
|
else: |
|
self.model = keras.saving.load_model(model_path) |
|
model_loaded = True |
|
break |
|
elif filename == 'saved_model.pb': |
|
|
|
tf_model = tf.keras.models.load_model(os.path.dirname(model_path)) |
|
self.model = keras.Model.from_config(tf_model.get_config()) |
|
self.model.set_weights(tf_model.get_weights()) |
|
model_loaded = True |
|
break |
|
|
|
except Exception as e: |
|
logger.debug(f"Model file {filename} not found or failed to load: {e}") |
|
continue |
|
|
|
|
|
if not model_loaded: |
|
logger.warning(f"Could not load model from {lstm_repo}, trying alternative approaches") |
|
|
|
|
|
alternative_repos = [ |
|
"microsoft/DialoGPT-medium", |
|
"huggingface/CodeBERTa-small-v1" |
|
] |
|
|
|
for alt_repo in alternative_repos: |
|
try: |
|
logger.info(f"Trying alternative repo: {alt_repo}") |
|
|
|
break |
|
except: |
|
continue |
|
|
|
|
|
logger.warning("Creating a simple LSTM model as fallback") |
|
self.model = self._create_fallback_lstm_model() |
|
model_loaded = True |
|
|
|
if not model_loaded: |
|
raise RuntimeError(f"Could not load any model from {lstm_repo}") |
|
|
|
logger.info("LSTM model loaded successfully") |
|
|
|
|
|
logger.info(f"Downloading scalers from {scaler_repo}") |
|
|
|
scaler_files = [ |
|
"scalers.pkl", |
|
"scaler.pkl", |
|
"preprocessing.pkl", |
|
"feature_scalers.pkl", |
|
"minmax_scalers.pkl" |
|
] |
|
|
|
scaler_path = None |
|
for filename in scaler_files: |
|
try: |
|
scaler_path = hf_hub_download( |
|
repo_id=scaler_repo, |
|
filename=filename, |
|
token=self.use_auth_token |
|
) |
|
logger.info(f"Found scaler file: {filename}") |
|
break |
|
except Exception as e: |
|
logger.debug(f"Scaler file {filename} not found: {e}") |
|
continue |
|
|
|
if scaler_path: |
|
with open(scaler_path, 'rb') as f: |
|
self.scalers = pickle.load(f) |
|
logger.info("Scalers loaded successfully") |
|
|
|
|
|
missing_scalers = set(self.REQUIRED_COLUMNS) - set(self.scalers.keys()) |
|
if missing_scalers: |
|
logger.warning(f"Missing scalers for columns: {missing_scalers}") |
|
|
|
for col in missing_scalers: |
|
self.scalers[col] = MinMaxScaler() |
|
logger.info(f"Created dummy scaler for {col}") |
|
else: |
|
logger.warning("No scaler file found, will use manual normalization") |
|
self.scalers = {} |
|
|
|
|
|
logger.info(f"Loading sentiment model: {sentiment_repo}") |
|
self.tokenizer = AutoTokenizer.from_pretrained(sentiment_repo) |
|
sentiment_model = AutoModelForSequenceClassification.from_pretrained(sentiment_repo) |
|
self.sentiment_pipe = pipeline( |
|
"sentiment-analysis", |
|
model=sentiment_model, |
|
tokenizer=self.tokenizer, |
|
device=device |
|
) |
|
logger.info("Sentiment analysis pipeline initialized") |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to load models from Hugging Face: {e}") |
|
raise |
|
|
|
def _create_fallback_lstm_model(self): |
|
"""Create a simple LSTM model as fallback.""" |
|
try: |
|
logger.info("Creating fallback LSTM model") |
|
|
|
|
|
model = keras.Sequential([ |
|
keras.layers.LSTM(50, return_sequences=True, |
|
input_shape=(self.sequence_length, len(self.REQUIRED_COLUMNS))), |
|
keras.layers.Dropout(0.2), |
|
keras.layers.LSTM(50, return_sequences=True), |
|
keras.layers.Dropout(0.2), |
|
keras.layers.LSTM(50), |
|
keras.layers.Dropout(0.2), |
|
keras.layers.Dense(1) |
|
]) |
|
|
|
model.compile(optimizer='adam', loss='mean_squared_error') |
|
|
|
|
|
dummy_input = np.random.random((1, self.sequence_length, len(self.REQUIRED_COLUMNS))) |
|
model.predict(dummy_input, verbose=0) |
|
|
|
logger.warning("Using fallback LSTM model - predictions may not be accurate") |
|
return model |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to create fallback model: {e}") |
|
raise |
|
|
|
async def fetch_stock_data( |
|
self, |
|
ticker: str, |
|
period: str = "1y", |
|
interval: str = "1d" |
|
) -> pd.DataFrame: |
|
""" |
|
Asynchronously fetch historical stock data from Yahoo Finance. |
|
""" |
|
url = ( |
|
f"https://query1.finance.yahoo.com/v8/finance/chart/{ticker.upper()}" |
|
f"?range={period}&interval={interval}&includePrePost=false" |
|
) |
|
try: |
|
async with aiohttp.ClientSession( |
|
timeout=self.timeout, |
|
headers=self.DEFAULT_HEADERS |
|
) as session: |
|
async with session.get(url) as response: |
|
response.raise_for_status() |
|
data = await response.json() |
|
result = data['chart']['result'][0] |
|
timestamps = pd.to_datetime(result['timestamp'], unit='s') |
|
quotes = result['indicators']['quote'][0] |
|
|
|
ohlcv_data = { |
|
'Open': quotes.get('open', []), |
|
'High': quotes.get('high', []), |
|
'Low': quotes.get('low', []), |
|
'Close': quotes.get('close', []), |
|
'Volume': quotes.get('volume', []) |
|
} |
|
df = pd.DataFrame(ohlcv_data, index=timestamps) |
|
|
|
df.index = df.index.tz_localize(None) |
|
df = df.dropna() |
|
if df.empty: |
|
raise ValueError(f"No valid data found for ticker {ticker}") |
|
logger.info(f"Fetched {len(df)} data points for {ticker}") |
|
return df |
|
except Exception as e: |
|
logger.error(f"Error fetching stock data for {ticker}: {e}") |
|
raise RuntimeError(f"Failed to fetch stock data: {e}") |
|
|
|
@staticmethod |
|
async def fetch_prices(ticker: str, period: str = "6mo", interval: str = "1d") -> pd.DataFrame | None: |
|
""" |
|
Fetch historical stock price data from Yahoo Finance. |
|
|
|
Args: |
|
ticker: Stock ticker symbol (e.g., 'AAPL') |
|
period: Time period for data (1d, 5d, 1mo, 3mo, 6mo, 1y, 2y, 5y, 10y, ytd, max) |
|
interval: Data interval (1m, 2m, 5m, 15m, 30m, 60m, 90m, 1h, 1d, 5d, 1wk, 1mo, 3mo) |
|
|
|
Returns: |
|
DataFrame with OHLCV data or None if error occurs |
|
""" |
|
try: |
|
logger.info(f"Fetching data for {ticker}") |
|
|
|
|
|
with warnings.catch_warnings(): |
|
warnings.simplefilter("ignore") |
|
df = yf.download(ticker, period=period, interval=interval, progress=False) |
|
|
|
if df.empty: |
|
logger.error(f"No data found for ticker {ticker}") |
|
return None |
|
|
|
|
|
df = df[["Open", "High", "Low", "Close", "Volume"]].copy() |
|
df.dropna(inplace=True) |
|
|
|
if len(df) < 60: |
|
logger.warning(f"Insufficient data for {ticker}. Got {len(df)} days, need at least 60") |
|
return None |
|
|
|
logger.info(f"Successfully fetched {len(df)} data points for {ticker}") |
|
return df |
|
|
|
except Exception as e: |
|
logger.error(f"Error fetching data for {ticker}: {e}") |
|
return None |
|
|
|
async def fetch_news(self, ticker: str) -> list[dict[str, Any]]: |
|
"""Fetch recent news for a stock ticker.""" |
|
url = f"https://query1.finance.yahoo.com/v6/finance/news?symbols={ticker.upper()}" |
|
try: |
|
async with aiohttp.ClientSession( |
|
timeout=self.timeout, |
|
headers=self.DEFAULT_HEADERS |
|
) as session: |
|
async with session.get(url) as response: |
|
response.raise_for_status() |
|
data = await response.json() |
|
news_items = data.get('items', {}).get('result', []) |
|
logger.info(f"Fetched {len(news_items)} news items for {ticker}") |
|
return news_items |
|
except Exception as e: |
|
logger.warning(f"Error fetching news for {ticker}: {e}") |
|
return [] |
|
|
|
async def analyze_sentiment_batch( |
|
self, |
|
texts: list[str], |
|
batch_size: int = 10 |
|
) -> list[float]: |
|
"""Analyze sentiment for a batch of texts asynchronously.""" |
|
if not texts: |
|
return [] |
|
async def process_batch(batch: list[str]) -> list[float]: |
|
"""Process a single batch of texts.""" |
|
try: |
|
predictions = await asyncio.get_event_loop().run_in_executor( |
|
self.executor, |
|
lambda: self.sentiment_pipe(batch, truncation=True, max_length=512) |
|
) |
|
scores = [] |
|
for pred in predictions: |
|
label = pred.get('label', '').lower() |
|
confidence = float(pred.get('score', 0.0)) |
|
|
|
if any(pos in label for pos in ['positive', 'pos', 'bullish', 'label_2']): |
|
scores.append(confidence) |
|
elif any(neg in label for neg in ['negative', 'neg', 'bearish', 'label_0']): |
|
scores.append(-confidence) |
|
else: |
|
scores.append(0.0) |
|
return scores |
|
except Exception as e: |
|
logger.warning(f"Error in sentiment analysis batch: {e}") |
|
return [0.0] * len(batch) |
|
|
|
batches = [texts[i:i + batch_size] for i in range(0, len(texts), batch_size)] |
|
|
|
batch_results = await asyncio.gather(*[process_batch(batch) for batch in batches]) |
|
|
|
all_scores = [] |
|
for batch_scores in batch_results: |
|
all_scores.extend(batch_scores) |
|
return all_scores |
|
|
|
async def compute_daily_sentiment( |
|
self, |
|
news_items: list[dict[str, Any]], |
|
since_date: datetime.date |
|
) -> dict[datetime.date, float]: |
|
"""Compute daily sentiment scores from news items.""" |
|
news_by_date = defaultdict(list) |
|
for item in news_items: |
|
|
|
timestamp = ( |
|
item.get('providerPublishTime') or |
|
item.get('pubDate') or |
|
item.get('published') |
|
) |
|
if isinstance(timestamp, (int, float)): |
|
date = datetime.fromtimestamp(int(timestamp), tz=timezone.utc).date() |
|
else: |
|
try: |
|
date = pd.to_datetime(timestamp).tz_convert(None).date() |
|
except Exception: |
|
date = datetime.now().date() |
|
if date < since_date: |
|
continue |
|
|
|
title = item.get('title', '').strip() |
|
summary = item.get('summary', '').strip() |
|
text = f"{title}. {summary}".strip() |
|
if text and len(text) > 10: |
|
news_by_date[date].append(text) |
|
|
|
daily_sentiment = {} |
|
for date, texts in news_by_date.items(): |
|
if texts: |
|
sentiment_scores = await self.analyze_sentiment_batch(texts) |
|
daily_sentiment[date] = float(np.mean(sentiment_scores)) |
|
else: |
|
daily_sentiment[date] = 0.0 |
|
logger.info(f"Computed sentiment for {len(daily_sentiment)} days") |
|
return daily_sentiment |
|
|
|
def align_sentiment_to_prices( |
|
self, |
|
price_df: pd.DataFrame, |
|
daily_sentiment: dict[datetime.date, float] |
|
) -> pd.Series: |
|
"""Align daily sentiment scores to price DataFrame index.""" |
|
sentiment_values = [] |
|
for date in price_df.index.date: |
|
sentiment_values.append(daily_sentiment.get(date, 0.0)) |
|
|
|
return pd.Series(sentiment_values, index=price_df.index, name='Sentiment') |
|
|
|
def prepare_sequences(self, df: pd.DataFrame) -> np.ndarray: |
|
"""Prepare input sequences for the LSTM model.""" |
|
|
|
available_columns = [col for col in self.REQUIRED_COLUMNS if col in df.columns] |
|
df = df[available_columns].copy() |
|
|
|
|
|
scaled_data = {} |
|
for column in df.columns: |
|
if column in self.scalers: |
|
|
|
scaled_data[column] = self.scalers[column].transform( |
|
df[[column]] |
|
).flatten() |
|
else: |
|
|
|
col_values = df[column].values |
|
min_val, max_val = col_values.min(), col_values.max() |
|
if max_val > min_val: |
|
scaled_data[column] = (col_values - min_val) / (max_val - min_val) |
|
else: |
|
scaled_data[column] = np.zeros_like(col_values) |
|
|
|
df_scaled = pd.DataFrame(scaled_data, index=df.index) |
|
|
|
|
|
while len(df_scaled.columns) < len(self.REQUIRED_COLUMNS): |
|
missing_col = f"feature_{len(df_scaled.columns)}" |
|
df_scaled[missing_col] = 0.0 |
|
|
|
|
|
sequences = [] |
|
for i in range(len(df_scaled) - self.sequence_length + 1): |
|
sequence = df_scaled.iloc[i:i + self.sequence_length].values |
|
sequences.append(sequence) |
|
|
|
return np.array(sequences) |
|
|
|
async def predict_next_day_price( |
|
self, |
|
ticker: str, |
|
period: str = "1y", |
|
interval: str = "1d" |
|
) -> dict[str, Any]: |
|
"""Predict the next day's stock price using Keras 3.0 model.""" |
|
try: |
|
|
|
logger.info(f"Starting prediction for {ticker}") |
|
stock_data, news_items = await asyncio.gather( |
|
self.fetch_stock_data(ticker, period, interval), |
|
self.fetch_news(ticker) |
|
) |
|
|
|
|
|
since_date = stock_data.index[-1].date() - timedelta(days=self.news_lookback_days) |
|
daily_sentiment = await self.compute_daily_sentiment(news_items, since_date) |
|
|
|
|
|
sentiment_series = self.align_sentiment_to_prices(stock_data, daily_sentiment) |
|
combined_data = stock_data.copy() |
|
combined_data['Sentiment'] = sentiment_series |
|
|
|
|
|
if len(combined_data) < self.sequence_length: |
|
raise RuntimeError( |
|
f"Insufficient data: {len(combined_data)} days, " |
|
f"need at least {self.sequence_length} days" |
|
) |
|
|
|
|
|
sequences = self.prepare_sequences(combined_data) |
|
last_sequence = sequences[-1:] |
|
|
|
|
|
|
|
if os.environ.get("KERAS_BACKEND") == "jax": |
|
import jax.numpy as jnp |
|
last_sequence = jnp.array(last_sequence) |
|
|
|
scaled_prediction = await asyncio.get_event_loop().run_in_executor( |
|
self.executor, |
|
lambda: self.model.predict(last_sequence, verbose=0) |
|
) |
|
|
|
|
|
if hasattr(scaled_prediction, 'numpy'): |
|
scaled_prediction = scaled_prediction.numpy() |
|
elif hasattr(scaled_prediction, '__array__'): |
|
scaled_prediction = np.array(scaled_prediction) |
|
|
|
|
|
if 'Close' in self.scalers: |
|
predicted_price = self.scalers['Close'].inverse_transform( |
|
scaled_prediction.reshape(-1, 1) |
|
)[0][0] |
|
else: |
|
|
|
close_data = combined_data['Close'].values |
|
min_val, max_val = close_data.min(), close_data.max() |
|
predicted_price = scaled_prediction[0][0] * (max_val - min_val) + min_val |
|
|
|
|
|
last_close = float(combined_data['Close'].iloc[-1]) |
|
change_percent = ((predicted_price - last_close) / last_close) * 100 |
|
|
|
|
|
if daily_sentiment: |
|
avg_sentiment = np.mean(list(daily_sentiment.values())) |
|
sentiment_label = ( |
|
"Positive" if avg_sentiment > 0.1 else |
|
"Negative" if avg_sentiment < -0.1 else |
|
"Neutral" |
|
) |
|
else: |
|
avg_sentiment = 0.0 |
|
sentiment_label = "No recent news" |
|
|
|
|
|
trend_emoji = ( |
|
"π" if change_percent > 1 else |
|
"π" if change_percent < -1 else |
|
"β‘οΈ" |
|
) |
|
|
|
result = { |
|
'ticker': ticker.upper(), |
|
'predicted_price': round(float(predicted_price), 2), |
|
'last_price': round(last_close, 2), |
|
'change_percent': round(change_percent, 2), |
|
'last_date': str(stock_data.index[-1].date()), |
|
'trend_emoji': trend_emoji, |
|
'sentiment_score': round(avg_sentiment, 3), |
|
'sentiment_label': sentiment_label, |
|
'data_points': len(combined_data), |
|
'news_items': len(news_items), |
|
'backend': os.environ.get("KERAS_BACKEND", "tensorflow") |
|
} |
|
|
|
logger.info( |
|
f"Prediction completed for {ticker}: ${result['predicted_price']} (backend: {result['backend']})") |
|
return result |
|
|
|
except Exception as e: |
|
logger.error(f"Prediction failed for {ticker}: {e}") |
|
raise |
|
|
|
def format_telegram_message(self, prediction: dict[str, Any]) -> str: |
|
"""Format prediction result as a Telegram message.""" |
|
return ( |
|
f"π *Stock Prediction for {prediction['ticker']}*\n" |
|
f"Date: {prediction['last_date']}\n\n" |
|
f"Last closing price: `${prediction['last_price']:.2f}`\n" |
|
f"Predicted next price: *${prediction['predicted_price']:.2f}* {prediction['trend_emoji']}\n\n" |
|
f"Expected change: {prediction['change_percent']:+.2f}%\n\n" |
|
f"π° News sentiment: {prediction['sentiment_label']}\n" |
|
f"π Data points used: {prediction['data_points']}\n" |
|
f"π News articles: {prediction['news_items']}\n\n" |
|
f"π€ Powered by Keras 3.0 ({prediction.get('backend', 'JAX')}) + HuggingFace" |
|
) |
|
|
|
async def cleanup(self) -> None: |
|
"""Cleanup resources.""" |
|
if hasattr(self, 'executor'): |
|
self.executor.shutdown(wait=True) |
|
logger.info("Thread executor shut down") |
|
|
|
@staticmethod |
|
def get_available_backends() -> list[str]: |
|
"""Get list of available Keras backends.""" |
|
backends = ["tensorflow", "jax", "torch"] |
|
available = [] |
|
|
|
for backend in backends: |
|
try: |
|
if backend == "jax": |
|
import jax |
|
available.append("jax") |
|
elif backend == "torch": |
|
import torch |
|
available.append("torch") |
|
elif backend == "tensorflow": |
|
import tensorflow |
|
available.append("tensorflow") |
|
except ImportError: |
|
continue |
|
|
|
return available |
|
|
|
@classmethod |
|
def create_with_backend(cls, backend: str = "jax", **kwargs): |
|
"""Create predictor with specific Keras backend.""" |
|
os.environ["KERAS_BACKEND"] = backend |
|
return cls(**kwargs) |
|
|
|
|
|
|
|
async def handle_stock_prediction(ticker: str, predictor: AsyncStockPricePredictor) -> str: |
|
"""Handle stock prediction request for Telegram bot.""" |
|
try: |
|
prediction = await predictor.predict_next_day_price(ticker.upper()) |
|
return predictor.format_telegram_message(prediction) |
|
except Exception as e: |
|
return f"β Error predicting {ticker.upper()}: {str(e)}" |
|
|
|
|
|
|
|
async def main(): |
|
"""Example usage of the async stock predictor with Keras 3.0.""" |
|
|
|
|
|
available_backends = AsyncStockPricePredictor.get_available_backends() |
|
print(f"Available Keras backends: {available_backends}") |
|
|
|
|
|
predictor = AsyncStockPricePredictor( |
|
lstm_model_repo="jengyang/lstm-stock-prediction-model", |
|
scaler_repo="jengyang/lstm-stock-prediction-model", |
|
sentiment_model="TLOB/roberta-base-finetuned-financial-text-classification", |
|
max_workers=2, |
|
keras_backend="jax" |
|
) |
|
|
|
try: |
|
|
|
result = await handle_stock_prediction("AAPL", predictor) |
|
print(result) |
|
|
|
print(f"\nUsing Keras backend: {os.environ.get('KERAS_BACKEND')}") |
|
|
|
finally: |
|
await predictor.cleanup() |
|
|
|
|
|
if __name__ == "__main__": |
|
asyncio.run(main()) |
|
|