File size: 2,979 Bytes
2d4eafe
 
 
 
 
869c434
2d4eafe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
869c434
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#![recursion_limit = "256"]
#![allow(unexpected_cfgs)]

mod training;

fn main() {
    #[cfg(feature = "training")]
    training_runner::run();

    #[cfg(feature = "inference")]
    inference_runner::run();
}

#[cfg(not(any(feature = "f16", feature = "flex32")))]
#[allow(unused)]
type ElemType = f32;
#[cfg(feature = "f16")]
type ElemType = burn::tensor::f16;
#[cfg(feature = "flex32")]
type ElemType = burn::tensor::flex32;

#[cfg(feature = "training")]
mod training_runner {
    use crate::{ElemType, training};
    use burn::backend::{Autodiff, Metal};
    use burn::nn::transformer::TransformerEncoderConfig;
    use burn::optim::AdamConfig;
    use burn::optim::decay::WeightDecayConfig;
    use burn::tensor::backend::AutodiffBackend;
    use model::data::TweetSentimentDataset;
    use model::inference::ExperimentConfig;

    pub fn launch<B: AutodiffBackend>(devices: Vec<B::Device>) {
        let config = ExperimentConfig::new(
            TransformerEncoderConfig::new(256, 1024, 8, 4)
                .with_norm_first(true)
                .with_quiet_softmax(true),
            AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(5e-5))),
        );

        training::train::<B, TweetSentimentDataset>(
            devices,
            TweetSentimentDataset::train(),
            TweetSentimentDataset::test(),
            config,
            "sam-artifacts",
        );
    }

    pub fn run() {
        launch::<Autodiff<Metal<ElemType, i32>>>(vec![Default::default()]);
    }
}

#[cfg(feature = "inference")]
mod inference_runner {
    use crate::ElemType;
    use burn::backend::Metal;
    use burn::prelude::Backend;
    use model::data::TweetSentimentDataset;
    use model::inference;

    pub fn launch<B: Backend>(device: B::Device) {
        inference::infer::<B, TweetSentimentDataset>(
            device,
            "sam-artifacts",
            vec![
                "2am feedings for the baby are fun when he is all smiles and coos".to_string(),
                "I love the smell of fresh coffee in the morning".to_string(),
                "The weather is terrible today, I hate the rain".to_string(),
                "I just finished reading a great book, it was so inspiring".to_string(),
                "I can't believe how much I enjoyed that movie, it was fantastic".to_string(),
                "I am so excited for the weekend, I have so many plans".to_string(),
                "I am feeling a bit under the weather today, I hope I get better soon".to_string(),
                "I am so remorseful for the mistakes I made in the past, I have learned so much from them".to_string(),
                "I am grateful for the support of my friends and family, they mean the world to me".to_string(),
                "I am looking forward to the future, I have so many goals and dreams to achieve".to_string(),
            ],
        );
    }

    pub fn run() {
        launch::<Metal<ElemType, i32>>(Default::default());
    }
}