Commit
·
17101fd
1
Parent(s):
caf4734
perf: add class balancing and label smoothing to the loss function per batch
Browse files- model/src/data/batcher.rs +21 -0
- model/src/model.rs +4 -1
- trainer/src/training.rs +7 -3
model/src/data/batcher.rs
CHANGED
@@ -28,6 +28,7 @@ pub struct TextClassificationTrainingBatch<B: Backend> {
|
|
28 |
pub tokens: Tensor<B, 2, Int>, // Tokenized text
|
29 |
pub labels: Tensor<B, 1, Int>, // Labels of the text
|
30 |
pub mask_pad: Tensor<B, 2, Bool>, // Padding mask for the tokenized text
|
|
|
31 |
}
|
32 |
|
33 |
/// Struct for inference batch in text classification task
|
@@ -50,6 +51,9 @@ impl<B: Backend> Batcher<B, TextClassificationItem, TextClassificationTrainingBa
|
|
50 |
let mut tokens_list = Vec::with_capacity(items.len());
|
51 |
let mut labels_list = Vec::with_capacity(items.len());
|
52 |
|
|
|
|
|
|
|
53 |
// Tokenize text and create label tensor for each item
|
54 |
for item in items {
|
55 |
tokens_list.push(self.tokenizer.encode(&item.text));
|
@@ -72,10 +76,27 @@ impl<B: Backend> Batcher<B, TextClassificationItem, TextClassificationTrainingBa
|
|
72 |
tokens: mask.tensor,
|
73 |
labels: Tensor::cat(labels_list, 0),
|
74 |
mask_pad: mask.mask,
|
|
|
75 |
}
|
76 |
}
|
77 |
}
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
/// Implement Batcher trait for TextClassificationBatcher struct for inference
|
80 |
impl<B: Backend> Batcher<B, String, TextClassificationInferenceBatch<B>>
|
81 |
for TextClassificationBatcher
|
|
|
28 |
pub tokens: Tensor<B, 2, Int>, // Tokenized text
|
29 |
pub labels: Tensor<B, 1, Int>, // Labels of the text
|
30 |
pub mask_pad: Tensor<B, 2, Bool>, // Padding mask for the tokenized text
|
31 |
+
pub class_weights: Vec<f32>, // Class weights for handling class imbalance
|
32 |
}
|
33 |
|
34 |
/// Struct for inference batch in text classification task
|
|
|
51 |
let mut tokens_list = Vec::with_capacity(items.len());
|
52 |
let mut labels_list = Vec::with_capacity(items.len());
|
53 |
|
54 |
+
// Compute class weights based on the training dataset
|
55 |
+
let class_weights = compute_class_weights(&items);
|
56 |
+
|
57 |
// Tokenize text and create label tensor for each item
|
58 |
for item in items {
|
59 |
tokens_list.push(self.tokenizer.encode(&item.text));
|
|
|
76 |
tokens: mask.tensor,
|
77 |
labels: Tensor::cat(labels_list, 0),
|
78 |
mask_pad: mask.mask,
|
79 |
+
class_weights,
|
80 |
}
|
81 |
}
|
82 |
}
|
83 |
|
84 |
+
// Function to compute class weights based on the training dataset
|
85 |
+
fn compute_class_weights(items: &[TextClassificationItem]) -> Vec<f32> {
|
86 |
+
let num_classes = items.iter().map(|item| item.label).max().unwrap_or(0) + 1;
|
87 |
+
let mut class_counts = vec![0; num_classes];
|
88 |
+
|
89 |
+
for item in items {
|
90 |
+
class_counts[item.label] += 1;
|
91 |
+
}
|
92 |
+
|
93 |
+
let total_count = class_counts.iter().sum::<usize>() as f32;
|
94 |
+
class_counts
|
95 |
+
.iter()
|
96 |
+
.map(|&count| total_count / count as f32)
|
97 |
+
.collect()
|
98 |
+
}
|
99 |
+
|
100 |
/// Implement Batcher trait for TextClassificationBatcher struct for inference
|
101 |
impl<B: Backend> Batcher<B, String, TextClassificationInferenceBatch<B>>
|
102 |
for TextClassificationBatcher
|
model/src/model.rs
CHANGED
@@ -119,7 +119,10 @@ impl<B: Backend> TextClassificationModel<B> {
|
|
119 |
|
120 |
// Compute the loss using Cross-Entropy
|
121 |
let loss = CrossEntropyLossConfig::new()
|
122 |
-
.
|
|
|
|
|
|
|
123 |
.forward(logits.clone(), labels.clone());
|
124 |
|
125 |
// Return the output and loss
|
|
|
119 |
|
120 |
// Compute the loss using Cross-Entropy
|
121 |
let loss = CrossEntropyLossConfig::new()
|
122 |
+
.with_weights(Some(item.class_weights))
|
123 |
+
.with_smoothing(Some(0.1))
|
124 |
+
.with_logits(true)
|
125 |
+
.init(device)
|
126 |
.forward(logits.clone(), labels.clone());
|
127 |
|
128 |
// Return the output and loss
|
trainer/src/training.rs
CHANGED
@@ -41,6 +41,10 @@ pub fn train<B: AutodiffBackend, D: TextClassificationDataset + 'static>(
|
|
41 |
// Initialize batcher
|
42 |
let batcher = TextClassificationBatcher::new(tokenizer.clone(), config.max_seq_length);
|
43 |
|
|
|
|
|
|
|
|
|
44 |
// Initialize model
|
45 |
let model = TextClassificationModelConfig::new(
|
46 |
config.transformer.clone(),
|
@@ -54,12 +58,12 @@ pub fn train<B: AutodiffBackend, D: TextClassificationDataset + 'static>(
|
|
54 |
let train_loader = DataLoaderBuilder::new(batcher.clone())
|
55 |
.batch_size(config.batch_size)
|
56 |
.num_workers(1)
|
57 |
-
.build(
|
58 |
let valid_loader = DataLoaderBuilder::new(batcher)
|
59 |
.batch_size(config.batch_size)
|
60 |
.num_workers(1)
|
61 |
-
.build(
|
62 |
-
|
63 |
// Initialize optimizer
|
64 |
let optimizer = config.optimizer.init();
|
65 |
|
|
|
41 |
// Initialize batcher
|
42 |
let batcher = TextClassificationBatcher::new(tokenizer.clone(), config.max_seq_length);
|
43 |
|
44 |
+
// Create data samplers for training and testing datasets
|
45 |
+
let train_sampler = SamplerDataset::new(dataset_train, 50_000);
|
46 |
+
let test_sampler = SamplerDataset::new(dataset_test, 5_000);
|
47 |
+
|
48 |
// Initialize model
|
49 |
let model = TextClassificationModelConfig::new(
|
50 |
config.transformer.clone(),
|
|
|
58 |
let train_loader = DataLoaderBuilder::new(batcher.clone())
|
59 |
.batch_size(config.batch_size)
|
60 |
.num_workers(1)
|
61 |
+
.build(train_sampler);
|
62 |
let valid_loader = DataLoaderBuilder::new(batcher)
|
63 |
.batch_size(config.batch_size)
|
64 |
.num_workers(1)
|
65 |
+
.build(test_sampler);
|
66 |
+
|
67 |
// Initialize optimizer
|
68 |
let optimizer = config.optimizer.init();
|
69 |
|