WarriorsSami commited on
Commit
17101fd
·
1 Parent(s): caf4734

perf: add class balancing and label smoothing to the loss function per batch

Browse files
model/src/data/batcher.rs CHANGED
@@ -28,6 +28,7 @@ pub struct TextClassificationTrainingBatch<B: Backend> {
28
  pub tokens: Tensor<B, 2, Int>, // Tokenized text
29
  pub labels: Tensor<B, 1, Int>, // Labels of the text
30
  pub mask_pad: Tensor<B, 2, Bool>, // Padding mask for the tokenized text
 
31
  }
32
 
33
  /// Struct for inference batch in text classification task
@@ -50,6 +51,9 @@ impl<B: Backend> Batcher<B, TextClassificationItem, TextClassificationTrainingBa
50
  let mut tokens_list = Vec::with_capacity(items.len());
51
  let mut labels_list = Vec::with_capacity(items.len());
52
 
 
 
 
53
  // Tokenize text and create label tensor for each item
54
  for item in items {
55
  tokens_list.push(self.tokenizer.encode(&item.text));
@@ -72,10 +76,27 @@ impl<B: Backend> Batcher<B, TextClassificationItem, TextClassificationTrainingBa
72
  tokens: mask.tensor,
73
  labels: Tensor::cat(labels_list, 0),
74
  mask_pad: mask.mask,
 
75
  }
76
  }
77
  }
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  /// Implement Batcher trait for TextClassificationBatcher struct for inference
80
  impl<B: Backend> Batcher<B, String, TextClassificationInferenceBatch<B>>
81
  for TextClassificationBatcher
 
28
  pub tokens: Tensor<B, 2, Int>, // Tokenized text
29
  pub labels: Tensor<B, 1, Int>, // Labels of the text
30
  pub mask_pad: Tensor<B, 2, Bool>, // Padding mask for the tokenized text
31
+ pub class_weights: Vec<f32>, // Class weights for handling class imbalance
32
  }
33
 
34
  /// Struct for inference batch in text classification task
 
51
  let mut tokens_list = Vec::with_capacity(items.len());
52
  let mut labels_list = Vec::with_capacity(items.len());
53
 
54
+ // Compute class weights based on the training dataset
55
+ let class_weights = compute_class_weights(&items);
56
+
57
  // Tokenize text and create label tensor for each item
58
  for item in items {
59
  tokens_list.push(self.tokenizer.encode(&item.text));
 
76
  tokens: mask.tensor,
77
  labels: Tensor::cat(labels_list, 0),
78
  mask_pad: mask.mask,
79
+ class_weights,
80
  }
81
  }
82
  }
83
 
84
+ // Function to compute class weights based on the training dataset
85
+ fn compute_class_weights(items: &[TextClassificationItem]) -> Vec<f32> {
86
+ let num_classes = items.iter().map(|item| item.label).max().unwrap_or(0) + 1;
87
+ let mut class_counts = vec![0; num_classes];
88
+
89
+ for item in items {
90
+ class_counts[item.label] += 1;
91
+ }
92
+
93
+ let total_count = class_counts.iter().sum::<usize>() as f32;
94
+ class_counts
95
+ .iter()
96
+ .map(|&count| total_count / count as f32)
97
+ .collect()
98
+ }
99
+
100
  /// Implement Batcher trait for TextClassificationBatcher struct for inference
101
  impl<B: Backend> Batcher<B, String, TextClassificationInferenceBatch<B>>
102
  for TextClassificationBatcher
model/src/model.rs CHANGED
@@ -119,7 +119,10 @@ impl<B: Backend> TextClassificationModel<B> {
119
 
120
  // Compute the loss using Cross-Entropy
121
  let loss = CrossEntropyLossConfig::new()
122
- .init(&logits.device())
 
 
 
123
  .forward(logits.clone(), labels.clone());
124
 
125
  // Return the output and loss
 
119
 
120
  // Compute the loss using Cross-Entropy
121
  let loss = CrossEntropyLossConfig::new()
122
+ .with_weights(Some(item.class_weights))
123
+ .with_smoothing(Some(0.1))
124
+ .with_logits(true)
125
+ .init(device)
126
  .forward(logits.clone(), labels.clone());
127
 
128
  // Return the output and loss
trainer/src/training.rs CHANGED
@@ -41,6 +41,10 @@ pub fn train<B: AutodiffBackend, D: TextClassificationDataset + 'static>(
41
  // Initialize batcher
42
  let batcher = TextClassificationBatcher::new(tokenizer.clone(), config.max_seq_length);
43
 
 
 
 
 
44
  // Initialize model
45
  let model = TextClassificationModelConfig::new(
46
  config.transformer.clone(),
@@ -54,12 +58,12 @@ pub fn train<B: AutodiffBackend, D: TextClassificationDataset + 'static>(
54
  let train_loader = DataLoaderBuilder::new(batcher.clone())
55
  .batch_size(config.batch_size)
56
  .num_workers(1)
57
- .build(SamplerDataset::new(dataset_train, 50_000));
58
  let valid_loader = DataLoaderBuilder::new(batcher)
59
  .batch_size(config.batch_size)
60
  .num_workers(1)
61
- .build(SamplerDataset::new(dataset_test, 5_000));
62
-
63
  // Initialize optimizer
64
  let optimizer = config.optimizer.init();
65
 
 
41
  // Initialize batcher
42
  let batcher = TextClassificationBatcher::new(tokenizer.clone(), config.max_seq_length);
43
 
44
+ // Create data samplers for training and testing datasets
45
+ let train_sampler = SamplerDataset::new(dataset_train, 50_000);
46
+ let test_sampler = SamplerDataset::new(dataset_test, 5_000);
47
+
48
  // Initialize model
49
  let model = TextClassificationModelConfig::new(
50
  config.transformer.clone(),
 
58
  let train_loader = DataLoaderBuilder::new(batcher.clone())
59
  .batch_size(config.batch_size)
60
  .num_workers(1)
61
+ .build(train_sampler);
62
  let valid_loader = DataLoaderBuilder::new(batcher)
63
  .batch_size(config.batch_size)
64
  .num_workers(1)
65
+ .build(test_sampler);
66
+
67
  // Initialize optimizer
68
  let optimizer = config.optimizer.init();
69