Commit
·
caf4734
1
Parent(s):
1c9f6e9
perf: tweak hyper-params
Browse files- model/src/inference.rs +35 -4
- model/src/lib.rs +1 -1
- model/src/model.rs +32 -7
- trainer/src/main.rs +7 -1
- trainer/src/training.rs +17 -8
model/src/inference.rs
CHANGED
@@ -116,10 +116,41 @@ pub fn infer<B: Backend, D: TextClassificationDataset + 'static>(
|
|
116 |
let class_index = prediction.argmax(1).squeeze::<1>(1).into_scalar(); // Get class index with the highest value
|
117 |
let class = D::class_name(class_index.elem::<i32>() as usize); // Get class name
|
118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
// Print sample text, predicted logits and predicted class
|
120 |
-
println!(
|
121 |
-
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
}
|
125 |
}
|
|
|
116 |
let class_index = prediction.argmax(1).squeeze::<1>(1).into_scalar(); // Get class index with the highest value
|
117 |
let class = D::class_name(class_index.elem::<i32>() as usize); // Get class name
|
118 |
|
119 |
+
// Apply confidence threshold
|
120 |
+
let confidence_threshold = 0.6; // Define a confidence threshold
|
121 |
+
let max_logit = Tensor::<B, 2, Float>::from_data(logits.clone(), &device)
|
122 |
+
.max()
|
123 |
+
.to_data()
|
124 |
+
.iter()
|
125 |
+
.collect::<Vec<f64>>()[0];
|
126 |
+
|
127 |
// Print sample text, predicted logits and predicted class
|
128 |
+
println!("\n=== Item {i} ===\n- Text: {text}");
|
129 |
+
|
130 |
+
println!("- Prediction:");
|
131 |
+
if max_logit < confidence_threshold {
|
132 |
+
println!(
|
133 |
+
"🤔 Model is unsure about the sentiment (confidence {:.2}).",
|
134 |
+
max_logit
|
135 |
+
);
|
136 |
+
println!(
|
137 |
+
"Top prediction would have been class {} with low confidence.",
|
138 |
+
class
|
139 |
+
);
|
140 |
+
} else {
|
141 |
+
println!(
|
142 |
+
"Predicted sentiment = {} (confidence {:.1}%)",
|
143 |
+
class,
|
144 |
+
max_logit * 100.0
|
145 |
+
);
|
146 |
+
}
|
147 |
+
|
148 |
+
// Print logits for each class alongside their labels
|
149 |
+
print!("- Logits: [");
|
150 |
+
for (j, logit) in logits.iter::<f64>().enumerate() {
|
151 |
+
let class_label = D::class_name(j);
|
152 |
+
print!(" ({}: {:.2}) ", class_label, logit);
|
153 |
+
}
|
154 |
+
print!("]\n====================");
|
155 |
}
|
156 |
}
|
model/src/lib.rs
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
pub mod data;
|
2 |
pub mod inference;
|
3 |
-
pub mod model;
|
|
|
1 |
pub mod data;
|
2 |
pub mod inference;
|
3 |
+
pub mod model;
|
model/src/model.rs
CHANGED
@@ -8,6 +8,7 @@ use crate::data::{
|
|
8 |
TextClassificationInferenceBatch, TextClassificationTrainingBatch,
|
9 |
};
|
10 |
use crate::inference::ExperimentConfig;
|
|
|
11 |
use burn::{
|
12 |
nn::{
|
13 |
Embedding, EmbeddingConfig, Linear, LinearConfig,
|
@@ -24,6 +25,8 @@ use std::sync::Arc;
|
|
24 |
#[derive(Config)]
|
25 |
pub struct TextClassificationModelConfig {
|
26 |
transformer: TransformerEncoderConfig,
|
|
|
|
|
27 |
n_classes: usize,
|
28 |
vocab_size: usize,
|
29 |
max_seq_length: usize,
|
@@ -35,6 +38,8 @@ pub struct TextClassificationModel<B: Backend> {
|
|
35 |
transformer: TransformerEncoder<B>,
|
36 |
embedding_token: Embedding<B>,
|
37 |
embedding_pos: Embedding<B>,
|
|
|
|
|
38 |
output: Linear<B>,
|
39 |
n_classes: usize,
|
40 |
max_seq_length: usize,
|
@@ -44,17 +49,28 @@ pub struct TextClassificationModel<B: Backend> {
|
|
44 |
impl TextClassificationModelConfig {
|
45 |
/// Initializes a model with default weights
|
46 |
pub fn init<B: Backend>(&self, device: &B::Device) -> TextClassificationModel<B> {
|
47 |
-
let
|
48 |
-
|
|
|
|
|
|
|
|
|
49 |
let embedding_token =
|
50 |
EmbeddingConfig::new(self.vocab_size, self.transformer.d_model).init(device);
|
51 |
let embedding_pos =
|
52 |
EmbeddingConfig::new(self.max_seq_length, self.transformer.d_model).init(device);
|
53 |
|
|
|
|
|
|
|
|
|
|
|
54 |
TextClassificationModel {
|
55 |
transformer,
|
56 |
embedding_token,
|
57 |
embedding_pos,
|
|
|
|
|
58 |
output,
|
59 |
n_classes: self.n_classes,
|
60 |
max_seq_length: self.max_seq_length,
|
@@ -83,24 +99,33 @@ impl<B: Backend> TextClassificationModel<B> {
|
|
83 |
let embedding_tokens = self.embedding_token.forward(tokens);
|
84 |
let embedding = (embedding_positions + embedding_tokens) / 2;
|
85 |
|
86 |
-
//
|
|
|
|
|
|
|
87 |
let encoded = self
|
88 |
.transformer
|
89 |
.forward(TransformerEncoderInput::new(embedding).mask_pad(mask_pad));
|
|
|
|
|
|
|
|
|
|
|
90 |
let output = self.output.forward(encoded);
|
91 |
|
92 |
-
let
|
93 |
.slice([0..batch_size, 0..1])
|
94 |
.reshape([batch_size, self.n_classes]);
|
95 |
|
|
|
96 |
let loss = CrossEntropyLossConfig::new()
|
97 |
-
.init(&
|
98 |
-
.forward(
|
99 |
|
100 |
// Return the output and loss
|
101 |
ClassificationOutput {
|
102 |
loss,
|
103 |
-
output:
|
104 |
targets: labels,
|
105 |
}
|
106 |
}
|
|
|
8 |
TextClassificationInferenceBatch, TextClassificationTrainingBatch,
|
9 |
};
|
10 |
use crate::inference::ExperimentConfig;
|
11 |
+
use burn::nn::{Dropout, DropoutConfig};
|
12 |
use burn::{
|
13 |
nn::{
|
14 |
Embedding, EmbeddingConfig, Linear, LinearConfig,
|
|
|
25 |
#[derive(Config)]
|
26 |
pub struct TextClassificationModelConfig {
|
27 |
transformer: TransformerEncoderConfig,
|
28 |
+
#[config(default = 0.1)]
|
29 |
+
dropout_rate: f64,
|
30 |
n_classes: usize,
|
31 |
vocab_size: usize,
|
32 |
max_seq_length: usize,
|
|
|
38 |
transformer: TransformerEncoder<B>,
|
39 |
embedding_token: Embedding<B>,
|
40 |
embedding_pos: Embedding<B>,
|
41 |
+
embed_dropout: Dropout,
|
42 |
+
output_dropout: Dropout,
|
43 |
output: Linear<B>,
|
44 |
n_classes: usize,
|
45 |
max_seq_length: usize,
|
|
|
49 |
impl TextClassificationModelConfig {
|
50 |
/// Initializes a model with default weights
|
51 |
pub fn init<B: Backend>(&self, device: &B::Device) -> TextClassificationModel<B> {
|
52 |
+
let transformer = self
|
53 |
+
.transformer
|
54 |
+
.clone()
|
55 |
+
.with_dropout(self.dropout_rate)
|
56 |
+
.init(device);
|
57 |
+
|
58 |
let embedding_token =
|
59 |
EmbeddingConfig::new(self.vocab_size, self.transformer.d_model).init(device);
|
60 |
let embedding_pos =
|
61 |
EmbeddingConfig::new(self.max_seq_length, self.transformer.d_model).init(device);
|
62 |
|
63 |
+
let embed_dropout = DropoutConfig::new(self.dropout_rate).init();
|
64 |
+
let output_dropout = DropoutConfig::new(self.dropout_rate).init();
|
65 |
+
|
66 |
+
let output = LinearConfig::new(self.transformer.d_model, self.n_classes).init(device);
|
67 |
+
|
68 |
TextClassificationModel {
|
69 |
transformer,
|
70 |
embedding_token,
|
71 |
embedding_pos,
|
72 |
+
embed_dropout,
|
73 |
+
output_dropout,
|
74 |
output,
|
75 |
n_classes: self.n_classes,
|
76 |
max_seq_length: self.max_seq_length,
|
|
|
99 |
let embedding_tokens = self.embedding_token.forward(tokens);
|
100 |
let embedding = (embedding_positions + embedding_tokens) / 2;
|
101 |
|
102 |
+
// Apply dropout to the embeddings
|
103 |
+
let embedding = self.embed_dropout.forward(embedding);
|
104 |
+
|
105 |
+
// Perform transformer encoding
|
106 |
let encoded = self
|
107 |
.transformer
|
108 |
.forward(TransformerEncoderInput::new(embedding).mask_pad(mask_pad));
|
109 |
+
|
110 |
+
// Apply dropout to the output of the transformer
|
111 |
+
let encoded = self.output_dropout.forward(encoded);
|
112 |
+
|
113 |
+
// Calculate the output using the linear layer
|
114 |
let output = self.output.forward(encoded);
|
115 |
|
116 |
+
let logits = output
|
117 |
.slice([0..batch_size, 0..1])
|
118 |
.reshape([batch_size, self.n_classes]);
|
119 |
|
120 |
+
// Compute the loss using Cross-Entropy
|
121 |
let loss = CrossEntropyLossConfig::new()
|
122 |
+
.init(&logits.device())
|
123 |
+
.forward(logits.clone(), labels.clone());
|
124 |
|
125 |
// Return the output and loss
|
126 |
ClassificationOutput {
|
127 |
loss,
|
128 |
+
output: logits,
|
129 |
targets: labels,
|
130 |
}
|
131 |
}
|
trainer/src/main.rs
CHANGED
@@ -15,6 +15,7 @@ fn main() {
|
|
15 |
mod training_runner {
|
16 |
use crate::training;
|
17 |
use burn::backend::{Autodiff, Wgpu};
|
|
|
18 |
use burn::nn::transformer::TransformerEncoderConfig;
|
19 |
use burn::optim::AdamConfig;
|
20 |
use burn::optim::decay::WeightDecayConfig;
|
@@ -35,7 +36,12 @@ mod training_runner {
|
|
35 |
TransformerEncoderConfig::new(256, 1024, 8, 4)
|
36 |
.with_norm_first(true)
|
37 |
.with_quiet_softmax(true),
|
38 |
-
AdamConfig::new()
|
|
|
|
|
|
|
|
|
|
|
39 |
);
|
40 |
|
41 |
training::train::<B, TweetSentimentDataset>(
|
|
|
15 |
mod training_runner {
|
16 |
use crate::training;
|
17 |
use burn::backend::{Autodiff, Wgpu};
|
18 |
+
use burn::grad_clipping::GradientClippingConfig;
|
19 |
use burn::nn::transformer::TransformerEncoderConfig;
|
20 |
use burn::optim::AdamConfig;
|
21 |
use burn::optim::decay::WeightDecayConfig;
|
|
|
36 |
TransformerEncoderConfig::new(256, 1024, 8, 4)
|
37 |
.with_norm_first(true)
|
38 |
.with_quiet_softmax(true),
|
39 |
+
AdamConfig::new()
|
40 |
+
.with_weight_decay(Some(WeightDecayConfig::new(0.01)))
|
41 |
+
.with_grad_clipping(Some(GradientClippingConfig::Norm(1.0))) // clip gradients by L2 norm (max 1.0)
|
42 |
+
.with_beta_1(0.9)
|
43 |
+
.with_beta_2(0.999)
|
44 |
+
.with_epsilon(1e-8),
|
45 |
);
|
46 |
|
47 |
training::train::<B, TweetSentimentDataset>(
|
trainer/src/training.rs
CHANGED
@@ -5,6 +5,8 @@
|
|
5 |
// to build a learner, which is used to train the model. The trained model and the configuration are
|
6 |
// then saved to the specified directory.
|
7 |
|
|
|
|
|
8 |
use burn::{
|
9 |
data::{dataloader::DataLoaderBuilder, dataset::transform::SamplerDataset},
|
10 |
lr_scheduler::noam::NoamLrSchedulerConfig,
|
@@ -49,21 +51,21 @@ pub fn train<B: AutodiffBackend, D: TextClassificationDataset + 'static>(
|
|
49 |
.init::<B>(&devices[0]);
|
50 |
|
51 |
// Initialize data loaders for training and testing data
|
52 |
-
let
|
53 |
.batch_size(config.batch_size)
|
54 |
.num_workers(1)
|
55 |
.build(SamplerDataset::new(dataset_train, 50_000));
|
56 |
-
let
|
57 |
.batch_size(config.batch_size)
|
58 |
.num_workers(1)
|
59 |
.build(SamplerDataset::new(dataset_test, 5_000));
|
60 |
-
|
61 |
// Initialize optimizer
|
62 |
-
let
|
63 |
|
64 |
// Initialize learning rate scheduler
|
65 |
-
let lr_scheduler = NoamLrSchedulerConfig::new(1e-
|
66 |
-
.with_warmup_steps(
|
67 |
.with_model_size(config.transformer.d_model)
|
68 |
.init()
|
69 |
.unwrap();
|
@@ -79,13 +81,20 @@ pub fn train<B: AutodiffBackend, D: TextClassificationDataset + 'static>(
|
|
79 |
.metric_valid_numeric(AccuracyMetric::new())
|
80 |
.metric_train_numeric(LearningRateMetric::new())
|
81 |
.with_file_checkpointer(CompactRecorder::new())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
.devices(devices)
|
83 |
.num_epochs(config.num_epochs)
|
84 |
.summary()
|
85 |
-
.build(model,
|
86 |
|
87 |
// Train the model
|
88 |
-
let model_trained = learner.fit(
|
89 |
|
90 |
// Save the configuration and the trained model
|
91 |
config.save(format!("{artifact_dir}/config.json")).unwrap();
|
|
|
5 |
// to build a learner, which is used to train the model. The trained model and the configuration are
|
6 |
// then saved to the specified directory.
|
7 |
|
8 |
+
use burn::train::metric::store::{Aggregate, Direction, Split};
|
9 |
+
use burn::train::{MetricEarlyStoppingStrategy, StoppingCondition};
|
10 |
use burn::{
|
11 |
data::{dataloader::DataLoaderBuilder, dataset::transform::SamplerDataset},
|
12 |
lr_scheduler::noam::NoamLrSchedulerConfig,
|
|
|
51 |
.init::<B>(&devices[0]);
|
52 |
|
53 |
// Initialize data loaders for training and testing data
|
54 |
+
let train_loader = DataLoaderBuilder::new(batcher.clone())
|
55 |
.batch_size(config.batch_size)
|
56 |
.num_workers(1)
|
57 |
.build(SamplerDataset::new(dataset_train, 50_000));
|
58 |
+
let valid_loader = DataLoaderBuilder::new(batcher)
|
59 |
.batch_size(config.batch_size)
|
60 |
.num_workers(1)
|
61 |
.build(SamplerDataset::new(dataset_test, 5_000));
|
62 |
+
|
63 |
// Initialize optimizer
|
64 |
+
let optimizer = config.optimizer.init();
|
65 |
|
66 |
// Initialize learning rate scheduler
|
67 |
+
let lr_scheduler = NoamLrSchedulerConfig::new(1e-4)
|
68 |
+
.with_warmup_steps(8_000)
|
69 |
.with_model_size(config.transformer.d_model)
|
70 |
.init()
|
71 |
.unwrap();
|
|
|
81 |
.metric_valid_numeric(AccuracyMetric::new())
|
82 |
.metric_train_numeric(LearningRateMetric::new())
|
83 |
.with_file_checkpointer(CompactRecorder::new())
|
84 |
+
.early_stopping(MetricEarlyStoppingStrategy::new::<LossMetric<B>>(
|
85 |
+
&LossMetric::new(),
|
86 |
+
Aggregate::Mean,
|
87 |
+
Direction::Lowest,
|
88 |
+
Split::Valid,
|
89 |
+
StoppingCondition::NoImprovementSince { n_epochs: 2 },
|
90 |
+
)) // stop if no val loss improvement for 2 epochs
|
91 |
.devices(devices)
|
92 |
.num_epochs(config.num_epochs)
|
93 |
.summary()
|
94 |
+
.build(model, optimizer, lr_scheduler);
|
95 |
|
96 |
// Train the model
|
97 |
+
let model_trained = learner.fit(train_loader, valid_loader);
|
98 |
|
99 |
// Save the configuration and the trained model
|
100 |
config.save(format!("{artifact_dir}/config.json")).unwrap();
|