use crate::dtos::{SentimentRequest, SentimentResponse}; use burn::backend::Wgpu; use model::data::TweetSentimentDataset; use model::model::SamModel; use rocket::serde::json::Json; use rocket::tokio::sync::Mutex; use rocket::{State, post}; use std::sync::Arc; #[cfg(not(feature = "f16"))] #[allow(dead_code)] type ElemType = f32; #[cfg(feature = "f16")] type ElemType = burn::tensor::f16; pub type BackendImpl = Wgpu; // ---- API Route ---- // #[post("/predict", format = "json", data = "")] pub async fn predict( input: Json, state: &State, ) -> Json { let model = state.model.lock().await; let label = model.predict(&input.text); Json(SentimentResponse { label: label.into(), }) } // ---- App State ---- // pub struct AppState { pub(crate) model: Arc>>, }