WarriorsSami commited on
Commit
caf4734
·
1 Parent(s): 1c9f6e9

perf: tweak hyper-params

Browse files
model/src/inference.rs CHANGED
@@ -116,10 +116,41 @@ pub fn infer<B: Backend, D: TextClassificationDataset + 'static>(
116
  let class_index = prediction.argmax(1).squeeze::<1>(1).into_scalar(); // Get class index with the highest value
117
  let class = D::class_name(class_index.elem::<i32>() as usize); // Get class name
118
 
 
 
 
 
 
 
 
 
119
  // Print sample text, predicted logits and predicted class
120
- println!(
121
- "\n=== Item {i} ===\n- Text: {text}\n- Logits: {logits}\n- Prediction: \
122
- {class}\n================"
123
- );
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  }
125
  }
 
116
  let class_index = prediction.argmax(1).squeeze::<1>(1).into_scalar(); // Get class index with the highest value
117
  let class = D::class_name(class_index.elem::<i32>() as usize); // Get class name
118
 
119
+ // Apply confidence threshold
120
+ let confidence_threshold = 0.6; // Define a confidence threshold
121
+ let max_logit = Tensor::<B, 2, Float>::from_data(logits.clone(), &device)
122
+ .max()
123
+ .to_data()
124
+ .iter()
125
+ .collect::<Vec<f64>>()[0];
126
+
127
  // Print sample text, predicted logits and predicted class
128
+ println!("\n=== Item {i} ===\n- Text: {text}");
129
+
130
+ println!("- Prediction:");
131
+ if max_logit < confidence_threshold {
132
+ println!(
133
+ "🤔 Model is unsure about the sentiment (confidence {:.2}).",
134
+ max_logit
135
+ );
136
+ println!(
137
+ "Top prediction would have been class {} with low confidence.",
138
+ class
139
+ );
140
+ } else {
141
+ println!(
142
+ "Predicted sentiment = {} (confidence {:.1}%)",
143
+ class,
144
+ max_logit * 100.0
145
+ );
146
+ }
147
+
148
+ // Print logits for each class alongside their labels
149
+ print!("- Logits: [");
150
+ for (j, logit) in logits.iter::<f64>().enumerate() {
151
+ let class_label = D::class_name(j);
152
+ print!(" ({}: {:.2}) ", class_label, logit);
153
+ }
154
+ print!("]\n====================");
155
  }
156
  }
model/src/lib.rs CHANGED
@@ -1,3 +1,3 @@
1
  pub mod data;
2
  pub mod inference;
3
- pub mod model;
 
1
  pub mod data;
2
  pub mod inference;
3
+ pub mod model;
model/src/model.rs CHANGED
@@ -8,6 +8,7 @@ use crate::data::{
8
  TextClassificationInferenceBatch, TextClassificationTrainingBatch,
9
  };
10
  use crate::inference::ExperimentConfig;
 
11
  use burn::{
12
  nn::{
13
  Embedding, EmbeddingConfig, Linear, LinearConfig,
@@ -24,6 +25,8 @@ use std::sync::Arc;
24
  #[derive(Config)]
25
  pub struct TextClassificationModelConfig {
26
  transformer: TransformerEncoderConfig,
 
 
27
  n_classes: usize,
28
  vocab_size: usize,
29
  max_seq_length: usize,
@@ -35,6 +38,8 @@ pub struct TextClassificationModel<B: Backend> {
35
  transformer: TransformerEncoder<B>,
36
  embedding_token: Embedding<B>,
37
  embedding_pos: Embedding<B>,
 
 
38
  output: Linear<B>,
39
  n_classes: usize,
40
  max_seq_length: usize,
@@ -44,17 +49,28 @@ pub struct TextClassificationModel<B: Backend> {
44
  impl TextClassificationModelConfig {
45
  /// Initializes a model with default weights
46
  pub fn init<B: Backend>(&self, device: &B::Device) -> TextClassificationModel<B> {
47
- let output = LinearConfig::new(self.transformer.d_model, self.n_classes).init(device);
48
- let transformer = self.transformer.init(device);
 
 
 
 
49
  let embedding_token =
50
  EmbeddingConfig::new(self.vocab_size, self.transformer.d_model).init(device);
51
  let embedding_pos =
52
  EmbeddingConfig::new(self.max_seq_length, self.transformer.d_model).init(device);
53
 
 
 
 
 
 
54
  TextClassificationModel {
55
  transformer,
56
  embedding_token,
57
  embedding_pos,
 
 
58
  output,
59
  n_classes: self.n_classes,
60
  max_seq_length: self.max_seq_length,
@@ -83,24 +99,33 @@ impl<B: Backend> TextClassificationModel<B> {
83
  let embedding_tokens = self.embedding_token.forward(tokens);
84
  let embedding = (embedding_positions + embedding_tokens) / 2;
85
 
86
- // Perform transformer encoding, calculate output and loss
 
 
 
87
  let encoded = self
88
  .transformer
89
  .forward(TransformerEncoderInput::new(embedding).mask_pad(mask_pad));
 
 
 
 
 
90
  let output = self.output.forward(encoded);
91
 
92
- let output_classification = output
93
  .slice([0..batch_size, 0..1])
94
  .reshape([batch_size, self.n_classes]);
95
 
 
96
  let loss = CrossEntropyLossConfig::new()
97
- .init(&output_classification.device())
98
- .forward(output_classification.clone(), labels.clone());
99
 
100
  // Return the output and loss
101
  ClassificationOutput {
102
  loss,
103
- output: output_classification,
104
  targets: labels,
105
  }
106
  }
 
8
  TextClassificationInferenceBatch, TextClassificationTrainingBatch,
9
  };
10
  use crate::inference::ExperimentConfig;
11
+ use burn::nn::{Dropout, DropoutConfig};
12
  use burn::{
13
  nn::{
14
  Embedding, EmbeddingConfig, Linear, LinearConfig,
 
25
  #[derive(Config)]
26
  pub struct TextClassificationModelConfig {
27
  transformer: TransformerEncoderConfig,
28
+ #[config(default = 0.1)]
29
+ dropout_rate: f64,
30
  n_classes: usize,
31
  vocab_size: usize,
32
  max_seq_length: usize,
 
38
  transformer: TransformerEncoder<B>,
39
  embedding_token: Embedding<B>,
40
  embedding_pos: Embedding<B>,
41
+ embed_dropout: Dropout,
42
+ output_dropout: Dropout,
43
  output: Linear<B>,
44
  n_classes: usize,
45
  max_seq_length: usize,
 
49
  impl TextClassificationModelConfig {
50
  /// Initializes a model with default weights
51
  pub fn init<B: Backend>(&self, device: &B::Device) -> TextClassificationModel<B> {
52
+ let transformer = self
53
+ .transformer
54
+ .clone()
55
+ .with_dropout(self.dropout_rate)
56
+ .init(device);
57
+
58
  let embedding_token =
59
  EmbeddingConfig::new(self.vocab_size, self.transformer.d_model).init(device);
60
  let embedding_pos =
61
  EmbeddingConfig::new(self.max_seq_length, self.transformer.d_model).init(device);
62
 
63
+ let embed_dropout = DropoutConfig::new(self.dropout_rate).init();
64
+ let output_dropout = DropoutConfig::new(self.dropout_rate).init();
65
+
66
+ let output = LinearConfig::new(self.transformer.d_model, self.n_classes).init(device);
67
+
68
  TextClassificationModel {
69
  transformer,
70
  embedding_token,
71
  embedding_pos,
72
+ embed_dropout,
73
+ output_dropout,
74
  output,
75
  n_classes: self.n_classes,
76
  max_seq_length: self.max_seq_length,
 
99
  let embedding_tokens = self.embedding_token.forward(tokens);
100
  let embedding = (embedding_positions + embedding_tokens) / 2;
101
 
102
+ // Apply dropout to the embeddings
103
+ let embedding = self.embed_dropout.forward(embedding);
104
+
105
+ // Perform transformer encoding
106
  let encoded = self
107
  .transformer
108
  .forward(TransformerEncoderInput::new(embedding).mask_pad(mask_pad));
109
+
110
+ // Apply dropout to the output of the transformer
111
+ let encoded = self.output_dropout.forward(encoded);
112
+
113
+ // Calculate the output using the linear layer
114
  let output = self.output.forward(encoded);
115
 
116
+ let logits = output
117
  .slice([0..batch_size, 0..1])
118
  .reshape([batch_size, self.n_classes]);
119
 
120
+ // Compute the loss using Cross-Entropy
121
  let loss = CrossEntropyLossConfig::new()
122
+ .init(&logits.device())
123
+ .forward(logits.clone(), labels.clone());
124
 
125
  // Return the output and loss
126
  ClassificationOutput {
127
  loss,
128
+ output: logits,
129
  targets: labels,
130
  }
131
  }
trainer/src/main.rs CHANGED
@@ -15,6 +15,7 @@ fn main() {
15
  mod training_runner {
16
  use crate::training;
17
  use burn::backend::{Autodiff, Wgpu};
 
18
  use burn::nn::transformer::TransformerEncoderConfig;
19
  use burn::optim::AdamConfig;
20
  use burn::optim::decay::WeightDecayConfig;
@@ -35,7 +36,12 @@ mod training_runner {
35
  TransformerEncoderConfig::new(256, 1024, 8, 4)
36
  .with_norm_first(true)
37
  .with_quiet_softmax(true),
38
- AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(5e-5))),
 
 
 
 
 
39
  );
40
 
41
  training::train::<B, TweetSentimentDataset>(
 
15
  mod training_runner {
16
  use crate::training;
17
  use burn::backend::{Autodiff, Wgpu};
18
+ use burn::grad_clipping::GradientClippingConfig;
19
  use burn::nn::transformer::TransformerEncoderConfig;
20
  use burn::optim::AdamConfig;
21
  use burn::optim::decay::WeightDecayConfig;
 
36
  TransformerEncoderConfig::new(256, 1024, 8, 4)
37
  .with_norm_first(true)
38
  .with_quiet_softmax(true),
39
+ AdamConfig::new()
40
+ .with_weight_decay(Some(WeightDecayConfig::new(0.01)))
41
+ .with_grad_clipping(Some(GradientClippingConfig::Norm(1.0))) // clip gradients by L2 norm (max 1.0)
42
+ .with_beta_1(0.9)
43
+ .with_beta_2(0.999)
44
+ .with_epsilon(1e-8),
45
  );
46
 
47
  training::train::<B, TweetSentimentDataset>(
trainer/src/training.rs CHANGED
@@ -5,6 +5,8 @@
5
  // to build a learner, which is used to train the model. The trained model and the configuration are
6
  // then saved to the specified directory.
7
 
 
 
8
  use burn::{
9
  data::{dataloader::DataLoaderBuilder, dataset::transform::SamplerDataset},
10
  lr_scheduler::noam::NoamLrSchedulerConfig,
@@ -49,21 +51,21 @@ pub fn train<B: AutodiffBackend, D: TextClassificationDataset + 'static>(
49
  .init::<B>(&devices[0]);
50
 
51
  // Initialize data loaders for training and testing data
52
- let dataloader_train = DataLoaderBuilder::new(batcher.clone())
53
  .batch_size(config.batch_size)
54
  .num_workers(1)
55
  .build(SamplerDataset::new(dataset_train, 50_000));
56
- let dataloader_test = DataLoaderBuilder::new(batcher)
57
  .batch_size(config.batch_size)
58
  .num_workers(1)
59
  .build(SamplerDataset::new(dataset_test, 5_000));
60
-
61
  // Initialize optimizer
62
- let optim = config.optimizer.init();
63
 
64
  // Initialize learning rate scheduler
65
- let lr_scheduler = NoamLrSchedulerConfig::new(1e-2)
66
- .with_warmup_steps(1000)
67
  .with_model_size(config.transformer.d_model)
68
  .init()
69
  .unwrap();
@@ -79,13 +81,20 @@ pub fn train<B: AutodiffBackend, D: TextClassificationDataset + 'static>(
79
  .metric_valid_numeric(AccuracyMetric::new())
80
  .metric_train_numeric(LearningRateMetric::new())
81
  .with_file_checkpointer(CompactRecorder::new())
 
 
 
 
 
 
 
82
  .devices(devices)
83
  .num_epochs(config.num_epochs)
84
  .summary()
85
- .build(model, optim, lr_scheduler);
86
 
87
  // Train the model
88
- let model_trained = learner.fit(dataloader_train, dataloader_test);
89
 
90
  // Save the configuration and the trained model
91
  config.save(format!("{artifact_dir}/config.json")).unwrap();
 
5
  // to build a learner, which is used to train the model. The trained model and the configuration are
6
  // then saved to the specified directory.
7
 
8
+ use burn::train::metric::store::{Aggregate, Direction, Split};
9
+ use burn::train::{MetricEarlyStoppingStrategy, StoppingCondition};
10
  use burn::{
11
  data::{dataloader::DataLoaderBuilder, dataset::transform::SamplerDataset},
12
  lr_scheduler::noam::NoamLrSchedulerConfig,
 
51
  .init::<B>(&devices[0]);
52
 
53
  // Initialize data loaders for training and testing data
54
+ let train_loader = DataLoaderBuilder::new(batcher.clone())
55
  .batch_size(config.batch_size)
56
  .num_workers(1)
57
  .build(SamplerDataset::new(dataset_train, 50_000));
58
+ let valid_loader = DataLoaderBuilder::new(batcher)
59
  .batch_size(config.batch_size)
60
  .num_workers(1)
61
  .build(SamplerDataset::new(dataset_test, 5_000));
62
+
63
  // Initialize optimizer
64
+ let optimizer = config.optimizer.init();
65
 
66
  // Initialize learning rate scheduler
67
+ let lr_scheduler = NoamLrSchedulerConfig::new(1e-4)
68
+ .with_warmup_steps(8_000)
69
  .with_model_size(config.transformer.d_model)
70
  .init()
71
  .unwrap();
 
81
  .metric_valid_numeric(AccuracyMetric::new())
82
  .metric_train_numeric(LearningRateMetric::new())
83
  .with_file_checkpointer(CompactRecorder::new())
84
+ .early_stopping(MetricEarlyStoppingStrategy::new::<LossMetric<B>>(
85
+ &LossMetric::new(),
86
+ Aggregate::Mean,
87
+ Direction::Lowest,
88
+ Split::Valid,
89
+ StoppingCondition::NoImprovementSince { n_epochs: 2 },
90
+ )) // stop if no val loss improvement for 2 epochs
91
  .devices(devices)
92
  .num_epochs(config.num_epochs)
93
  .summary()
94
+ .build(model, optimizer, lr_scheduler);
95
 
96
  // Train the model
97
+ let model_trained = learner.fit(train_loader, valid_loader);
98
 
99
  // Save the configuration and the trained model
100
  config.save(format!("{artifact_dir}/config.json")).unwrap();