|
|
#![recursion_limit = "256"] |
|
|
#![allow(unexpected_cfgs)] |
|
|
|
|
|
mod training; |
|
|
|
|
|
fn main() { |
|
|
#[cfg(feature = "training")] |
|
|
training_runner::run(); |
|
|
|
|
|
#[cfg(feature = "inference")] |
|
|
inference_runner::run(); |
|
|
} |
|
|
|
|
|
#[cfg(not(any(feature = "f16", feature = "flex32")))] |
|
|
#[allow(unused)] |
|
|
type ElemType = f32; |
|
|
#[cfg(feature = "f16")] |
|
|
type ElemType = burn::tensor::f16; |
|
|
#[cfg(feature = "flex32")] |
|
|
type ElemType = burn::tensor::flex32; |
|
|
|
|
|
#[cfg(feature = "training")] |
|
|
mod training_runner { |
|
|
use crate::{ElemType, training}; |
|
|
use burn::backend::{Autodiff, Metal}; |
|
|
use burn::nn::transformer::TransformerEncoderConfig; |
|
|
use burn::optim::AdamConfig; |
|
|
use burn::optim::decay::WeightDecayConfig; |
|
|
use burn::tensor::backend::AutodiffBackend; |
|
|
use model::data::TweetSentimentDataset; |
|
|
use model::inference::ExperimentConfig; |
|
|
|
|
|
pub fn launch<B: AutodiffBackend>(devices: Vec<B::Device>) { |
|
|
let config = ExperimentConfig::new( |
|
|
TransformerEncoderConfig::new(256, 1024, 8, 4) |
|
|
.with_norm_first(true) |
|
|
.with_quiet_softmax(true), |
|
|
AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(5e-5))), |
|
|
); |
|
|
|
|
|
training::train::<B, TweetSentimentDataset>( |
|
|
devices, |
|
|
TweetSentimentDataset::train(), |
|
|
TweetSentimentDataset::test(), |
|
|
config, |
|
|
"sam-artifacts", |
|
|
); |
|
|
} |
|
|
|
|
|
pub fn run() { |
|
|
launch::<Autodiff<Metal<ElemType, i32>>>(vec![Default::default()]); |
|
|
} |
|
|
} |
|
|
|
|
|
#[cfg(feature = "inference")] |
|
|
mod inference_runner { |
|
|
use crate::ElemType; |
|
|
use burn::backend::Metal; |
|
|
use burn::prelude::Backend; |
|
|
use model::data::TweetSentimentDataset; |
|
|
use model::inference; |
|
|
|
|
|
pub fn launch<B: Backend>(device: B::Device) { |
|
|
inference::infer::<B, TweetSentimentDataset>( |
|
|
device, |
|
|
"sam-artifacts", |
|
|
vec![ |
|
|
"2am feedings for the baby are fun when he is all smiles and coos".to_string(), |
|
|
"I love the smell of fresh coffee in the morning".to_string(), |
|
|
"The weather is terrible today, I hate the rain".to_string(), |
|
|
"I just finished reading a great book, it was so inspiring".to_string(), |
|
|
"I can't believe how much I enjoyed that movie, it was fantastic".to_string(), |
|
|
"I am so excited for the weekend, I have so many plans".to_string(), |
|
|
"I am feeling a bit under the weather today, I hope I get better soon".to_string(), |
|
|
"I am so remorseful for the mistakes I made in the past, I have learned so much from them".to_string(), |
|
|
"I am grateful for the support of my friends and family, they mean the world to me".to_string(), |
|
|
"I am looking forward to the future, I have so many goals and dreams to achieve".to_string(), |
|
|
], |
|
|
); |
|
|
} |
|
|
|
|
|
pub fn run() { |
|
|
launch::<Metal<ElemType, i32>>(Default::default()); |
|
|
} |
|
|
} |
|
|
|