WarriorsSami's picture
feat: scaffold model, training, inference and api
2d4eafe
raw
history blame
2.98 kB
#![recursion_limit = "256"]
#![allow(unexpected_cfgs)]
mod training;
fn main() {
#[cfg(feature = "training")]
training_runner::run();
#[cfg(feature = "inference")]
inference_runner::run();
}
#[cfg(not(any(feature = "f16", feature = "flex32")))]
#[allow(unused)]
type ElemType = f32;
#[cfg(feature = "f16")]
type ElemType = burn::tensor::f16;
#[cfg(feature = "flex32")]
type ElemType = burn::tensor::flex32;
#[cfg(feature = "training")]
mod training_runner {
use crate::{ElemType, training};
use burn::backend::{Autodiff, Metal};
use burn::nn::transformer::TransformerEncoderConfig;
use burn::optim::AdamConfig;
use burn::optim::decay::WeightDecayConfig;
use burn::tensor::backend::AutodiffBackend;
use model::data::TweetSentimentDataset;
use model::inference::ExperimentConfig;
pub fn launch<B: AutodiffBackend>(devices: Vec<B::Device>) {
let config = ExperimentConfig::new(
TransformerEncoderConfig::new(256, 1024, 8, 4)
.with_norm_first(true)
.with_quiet_softmax(true),
AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(5e-5))),
);
training::train::<B, TweetSentimentDataset>(
devices,
TweetSentimentDataset::train(),
TweetSentimentDataset::test(),
config,
"sam-artifacts",
);
}
pub fn run() {
launch::<Autodiff<Metal<ElemType, i32>>>(vec![Default::default()]);
}
}
#[cfg(feature = "inference")]
mod inference_runner {
use crate::ElemType;
use burn::backend::Metal;
use burn::prelude::Backend;
use model::data::TweetSentimentDataset;
use model::inference;
pub fn launch<B: Backend>(device: B::Device) {
inference::infer::<B, TweetSentimentDataset>(
device,
"sam-artifacts",
vec![
"2am feedings for the baby are fun when he is all smiles and coos".to_string(),
"I love the smell of fresh coffee in the morning".to_string(),
"The weather is terrible today, I hate the rain".to_string(),
"I just finished reading a great book, it was so inspiring".to_string(),
"I can't believe how much I enjoyed that movie, it was fantastic".to_string(),
"I am so excited for the weekend, I have so many plans".to_string(),
"I am feeling a bit under the weather today, I hope I get better soon".to_string(),
"I am so remorseful for the mistakes I made in the past, I have learned so much from them".to_string(),
"I am grateful for the support of my friends and family, they mean the world to me".to_string(),
"I am looking forward to the future, I have so many goals and dreams to achieve".to_string(),
],
);
}
pub fn run() {
launch::<Metal<ElemType, i32>>(Default::default());
}
}