File size: 5,077 Bytes
2d4eafe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17101fd
2d4eafe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17101fd
 
 
2d4eafe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17101fd
2d4eafe
 
 
 
17101fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d4eafe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
// The module defines two structs TextClassificationTrainingBatch and TextClassificationInferenceBatch
// to handle batches of data during training and inference respectively. The TextClassificationBatcher
// struct is implemented for creating these batches. It is parameterized on the type B: Backend to
// support different computation backends (e.g., CPU, CUDA).

// Two implementations of the Batcher trait are provided for TextClassificationBatcher, one for creating
// training batches and one for creating inference batches. In each implementation, the batch function is
// defined to convert a vector of items into a batch. For training, the items are instances of
// TextClassificationItem and include both the text and the corresponding label.
// For inference, the items are simply strings without labels. The function tokenizes the text,
// generates a padding mask, and returns a batch object.

use super::{dataset::TextClassificationItem, tokenizer::Tokenizer};
use burn::{data::dataloader::batcher::Batcher, nn::attention::generate_padding_mask, prelude::*};
use derive_new::new;
use std::sync::Arc;

/// Struct for batching text classification items
#[derive(Clone, new)]
pub struct TextClassificationBatcher {
    tokenizer: Arc<dyn Tokenizer>, // Tokenizer for converting text to token IDs
    max_seq_length: usize,         // Maximum sequence length for tokenized text
}

/// Struct for training batch in text classification task
#[derive(Debug, Clone, new)]
pub struct TextClassificationTrainingBatch<B: Backend> {
    pub tokens: Tensor<B, 2, Int>,    // Tokenized text
    pub labels: Tensor<B, 1, Int>,    // Labels of the text
    pub mask_pad: Tensor<B, 2, Bool>, // Padding mask for the tokenized text
    pub class_weights: Vec<f32>,      // Class weights for handling class imbalance
}

/// Struct for inference batch in text classification task
#[derive(Debug, Clone, new)]
pub struct TextClassificationInferenceBatch<B: Backend> {
    pub tokens: Tensor<B, 2, Int>,    // Tokenized text
    pub mask_pad: Tensor<B, 2, Bool>, // Padding mask for the tokenized text
}

/// Implement Batcher trait for TextClassificationBatcher struct for training
impl<B: Backend> Batcher<B, TextClassificationItem, TextClassificationTrainingBatch<B>>
    for TextClassificationBatcher
{
    /// Batches a vector of text classification items into a training batch
    fn batch(
        &self,
        items: Vec<TextClassificationItem>,
        device: &B::Device,
    ) -> TextClassificationTrainingBatch<B> {
        let mut tokens_list = Vec::with_capacity(items.len());
        let mut labels_list = Vec::with_capacity(items.len());

        // Compute class weights based on the training dataset
        let class_weights = compute_class_weights(&items);

        // Tokenize text and create label tensor for each item
        for item in items {
            tokens_list.push(self.tokenizer.encode(&item.text));
            labels_list.push(Tensor::from_data(
                TensorData::from([(item.label as i64).elem::<B::IntElem>()]),
                device,
            ));
        }

        // Generate padding mask for tokenized text
        let mask = generate_padding_mask(
            self.tokenizer.pad_token(),
            tokens_list,
            Some(self.max_seq_length),
            device,
        );

        // Create and return training batch
        TextClassificationTrainingBatch {
            tokens: mask.tensor,
            labels: Tensor::cat(labels_list, 0),
            mask_pad: mask.mask,
            class_weights,
        }
    }
}

// Function to compute class weights based on the training dataset
fn compute_class_weights(items: &[TextClassificationItem]) -> Vec<f32> {
    let num_classes = items.iter().map(|item| item.label).max().unwrap_or(0) + 1;
    let mut class_counts = vec![0; num_classes];

    for item in items {
        class_counts[item.label] += 1;
    }

    let total_count = class_counts.iter().sum::<usize>() as f32;
    class_counts
        .iter()
        .map(|&count| total_count / count as f32)
        .collect()
}

/// Implement Batcher trait for TextClassificationBatcher struct for inference
impl<B: Backend> Batcher<B, String, TextClassificationInferenceBatch<B>>
    for TextClassificationBatcher
{
    /// Batches a vector of strings into an inference batch
    fn batch(&self, items: Vec<String>, device: &B::Device) -> TextClassificationInferenceBatch<B> {
        let mut tokens_list = Vec::with_capacity(items.len());

        // Tokenize each string
        for item in items {
            tokens_list.push(self.tokenizer.encode(&item));
        }

        // Generate padding mask for tokenized text
        let mask = generate_padding_mask(
            self.tokenizer.pad_token(),
            tokens_list,
            Some(self.max_seq_length),
            device,
        );

        // Create and return inference batch
        TextClassificationInferenceBatch {
            tokens: mask.tensor.to_device(device),
            mask_pad: mask.mask.to_device(device),
        }
    }
}