|
use burn::data::dataset::{Dataset, SqliteDataset, source::huggingface::HuggingfaceDatasetLoader}; |
|
use derive_new::new; |
|
use serde::{Deserialize, Serialize}; |
|
use std::fmt::Display; |
|
|
|
|
|
#[derive(new, Clone, Debug)] |
|
pub struct TextClassificationItem { |
|
pub text: String, |
|
pub label: usize, |
|
} |
|
|
|
|
|
pub trait TextClassificationDataset: Dataset<TextClassificationItem> { |
|
fn num_classes() -> usize; |
|
fn class_name(label: usize) -> String; |
|
} |
|
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize)] |
|
pub struct TweetSentimentItem { |
|
pub text: String, |
|
pub label: usize, |
|
} |
|
|
|
|
|
pub struct TweetSentimentDataset { |
|
dataset: SqliteDataset<TweetSentimentItem>, |
|
} |
|
|
|
|
|
impl Dataset<TextClassificationItem> for TweetSentimentDataset { |
|
fn get(&self, index: usize) -> Option<TextClassificationItem> { |
|
self.dataset |
|
.get(index) |
|
.map(|item| TextClassificationItem::new(item.text, item.label)) |
|
} |
|
|
|
fn len(&self) -> usize { |
|
self.dataset.len() |
|
} |
|
} |
|
|
|
|
|
impl TweetSentimentDataset { |
|
|
|
pub fn train() -> Self { |
|
Self::new("train") |
|
} |
|
|
|
|
|
pub fn test() -> Self { |
|
Self::new("test") |
|
} |
|
|
|
|
|
pub fn new(split: &str) -> Self { |
|
let dataset: SqliteDataset<TweetSentimentItem> = |
|
HuggingfaceDatasetLoader::new("mteb/tweet_sentiment_extraction") |
|
.dataset(split) |
|
.unwrap(); |
|
Self { dataset } |
|
} |
|
} |
|
|
|
|
|
impl TextClassificationDataset for TweetSentimentDataset { |
|
fn num_classes() -> usize { |
|
3 |
|
} |
|
|
|
fn class_name(label: usize) -> String { |
|
match label { |
|
0 => SentimentLabel::Negative.into(), |
|
1 => SentimentLabel::Neutral.into(), |
|
2 => SentimentLabel::Positive.into(), |
|
_ => panic!("Invalid class label"), |
|
} |
|
} |
|
} |
|
|
|
#[derive(Debug)] |
|
pub enum SentimentLabel { |
|
Negative, |
|
Neutral, |
|
Positive, |
|
} |
|
|
|
impl From<usize> for SentimentLabel { |
|
fn from(value: usize) -> Self { |
|
match value { |
|
0 => SentimentLabel::Negative, |
|
1 => SentimentLabel::Neutral, |
|
2 => SentimentLabel::Positive, |
|
_ => panic!("Invalid sentiment label"), |
|
} |
|
} |
|
} |
|
|
|
impl From<String> for SentimentLabel { |
|
fn from(value: String) -> Self { |
|
match value.as_str() { |
|
"Negative" => SentimentLabel::Negative, |
|
"Neutral" => SentimentLabel::Neutral, |
|
"Positive" => SentimentLabel::Positive, |
|
_ => panic!("Invalid sentiment label"), |
|
} |
|
} |
|
} |
|
|
|
impl From<SentimentLabel> for String { |
|
fn from(value: SentimentLabel) -> Self { |
|
match value { |
|
SentimentLabel::Negative => "Negative".to_string(), |
|
SentimentLabel::Neutral => "Neutral".to_string(), |
|
SentimentLabel::Positive => "Positive".to_string(), |
|
} |
|
} |
|
} |
|
|
|
impl Display for SentimentLabel { |
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
|
let str = match self { |
|
SentimentLabel::Negative => "Negative".to_string(), |
|
SentimentLabel::Neutral => "Neutral".to_string(), |
|
SentimentLabel::Positive => "Positive".to_string(), |
|
}; |
|
write!(f, "{}", str) |
|
} |
|
} |
|
|