WarriorsSami's picture
feat: scaffold model, training, inference and api
2d4eafe
use burn::data::dataset::{Dataset, SqliteDataset, source::huggingface::HuggingfaceDatasetLoader};
use derive_new::new;
use serde::{Deserialize, Serialize};
use std::fmt::Display;
// Define a struct for text classification items
#[derive(new, Clone, Debug)]
pub struct TextClassificationItem {
pub text: String, // The text for classification
pub label: usize, // The label of the text (classification category)
}
// Trait for text classification datasets
pub trait TextClassificationDataset: Dataset<TextClassificationItem> {
fn num_classes() -> usize; // Returns the number of unique classes in the dataset
fn class_name(label: usize) -> String; // Returns the name of the class given its label
}
// Struct for items in the Tweet Sentiment Extraction dataset
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TweetSentimentItem {
pub text: String, // The tweet text
pub label: usize, // The sentiment label: 0 (Negative), 1 (Neutral), 2 (Positive)
}
// Struct for the Tweet Sentiment Extraction dataset
pub struct TweetSentimentDataset {
dataset: SqliteDataset<TweetSentimentItem>, // Underlying SQLite dataset
}
// Implement the Dataset trait for the Tweet Sentiment Extraction dataset
impl Dataset<TextClassificationItem> for TweetSentimentDataset {
fn get(&self, index: usize) -> Option<TextClassificationItem> {
self.dataset
.get(index)
.map(|item| TextClassificationItem::new(item.text, item.label))
}
fn len(&self) -> usize {
self.dataset.len()
}
}
// Implement methods for constructing the Tweet Sentiment Extraction dataset
impl TweetSentimentDataset {
/// Returns the training portion of the dataset
pub fn train() -> Self {
Self::new("train")
}
/// Returns the testing portion of the dataset
pub fn test() -> Self {
Self::new("test")
}
/// Constructs the dataset from a split (either "train" or "test")
pub fn new(split: &str) -> Self {
let dataset: SqliteDataset<TweetSentimentItem> =
HuggingfaceDatasetLoader::new("mteb/tweet_sentiment_extraction")
.dataset(split)
.unwrap();
Self { dataset }
}
}
// Implement the TextClassificationDataset trait for the Tweet Sentiment Extraction dataset
impl TextClassificationDataset for TweetSentimentDataset {
fn num_classes() -> usize {
3
}
fn class_name(label: usize) -> String {
match label {
0 => SentimentLabel::Negative.into(),
1 => SentimentLabel::Neutral.into(),
2 => SentimentLabel::Positive.into(),
_ => panic!("Invalid class label"),
}
}
}
#[derive(Debug)]
pub enum SentimentLabel {
Negative,
Neutral,
Positive,
}
impl From<usize> for SentimentLabel {
fn from(value: usize) -> Self {
match value {
0 => SentimentLabel::Negative,
1 => SentimentLabel::Neutral,
2 => SentimentLabel::Positive,
_ => panic!("Invalid sentiment label"),
}
}
}
impl From<String> for SentimentLabel {
fn from(value: String) -> Self {
match value.as_str() {
"Negative" => SentimentLabel::Negative,
"Neutral" => SentimentLabel::Neutral,
"Positive" => SentimentLabel::Positive,
_ => panic!("Invalid sentiment label"),
}
}
}
impl From<SentimentLabel> for String {
fn from(value: SentimentLabel) -> Self {
match value {
SentimentLabel::Negative => "Negative".to_string(),
SentimentLabel::Neutral => "Neutral".to_string(),
SentimentLabel::Positive => "Positive".to_string(),
}
}
}
impl Display for SentimentLabel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let str = match self {
SentimentLabel::Negative => "Negative".to_string(),
SentimentLabel::Neutral => "Neutral".to_string(),
SentimentLabel::Positive => "Positive".to_string(),
};
write!(f, "{}", str)
}
}