|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
use super::{dataset::TextClassificationItem, tokenizer::Tokenizer}; |
|
use burn::{data::dataloader::batcher::Batcher, nn::attention::generate_padding_mask, prelude::*}; |
|
use derive_new::new; |
|
use std::sync::Arc; |
|
|
|
|
|
#[derive(Clone, new)] |
|
pub struct TextClassificationBatcher { |
|
tokenizer: Arc<dyn Tokenizer>, |
|
max_seq_length: usize, |
|
} |
|
|
|
|
|
#[derive(Debug, Clone, new)] |
|
pub struct TextClassificationTrainingBatch<B: Backend> { |
|
pub tokens: Tensor<B, 2, Int>, |
|
pub labels: Tensor<B, 1, Int>, |
|
pub mask_pad: Tensor<B, 2, Bool>, |
|
pub class_weights: Vec<f32>, |
|
} |
|
|
|
|
|
#[derive(Debug, Clone, new)] |
|
pub struct TextClassificationInferenceBatch<B: Backend> { |
|
pub tokens: Tensor<B, 2, Int>, |
|
pub mask_pad: Tensor<B, 2, Bool>, |
|
} |
|
|
|
|
|
impl<B: Backend> Batcher<B, TextClassificationItem, TextClassificationTrainingBatch<B>> |
|
for TextClassificationBatcher |
|
{ |
|
|
|
fn batch( |
|
&self, |
|
items: Vec<TextClassificationItem>, |
|
device: &B::Device, |
|
) -> TextClassificationTrainingBatch<B> { |
|
let mut tokens_list = Vec::with_capacity(items.len()); |
|
let mut labels_list = Vec::with_capacity(items.len()); |
|
|
|
|
|
let class_weights = compute_class_weights(&items); |
|
|
|
|
|
for item in items { |
|
tokens_list.push(self.tokenizer.encode(&item.text)); |
|
labels_list.push(Tensor::from_data( |
|
TensorData::from([(item.label as i64).elem::<B::IntElem>()]), |
|
device, |
|
)); |
|
} |
|
|
|
|
|
let mask = generate_padding_mask( |
|
self.tokenizer.pad_token(), |
|
tokens_list, |
|
Some(self.max_seq_length), |
|
device, |
|
); |
|
|
|
|
|
TextClassificationTrainingBatch { |
|
tokens: mask.tensor, |
|
labels: Tensor::cat(labels_list, 0), |
|
mask_pad: mask.mask, |
|
class_weights, |
|
} |
|
} |
|
} |
|
|
|
|
|
fn compute_class_weights(items: &[TextClassificationItem]) -> Vec<f32> { |
|
let num_classes = items.iter().map(|item| item.label).max().unwrap_or(0) + 1; |
|
let mut class_counts = vec![0; num_classes]; |
|
|
|
for item in items { |
|
class_counts[item.label] += 1; |
|
} |
|
|
|
let total_count = class_counts.iter().sum::<usize>() as f32; |
|
class_counts |
|
.iter() |
|
.map(|&count| total_count / count as f32) |
|
.collect() |
|
} |
|
|
|
|
|
impl<B: Backend> Batcher<B, String, TextClassificationInferenceBatch<B>> |
|
for TextClassificationBatcher |
|
{ |
|
|
|
fn batch(&self, items: Vec<String>, device: &B::Device) -> TextClassificationInferenceBatch<B> { |
|
let mut tokens_list = Vec::with_capacity(items.len()); |
|
|
|
|
|
for item in items { |
|
tokens_list.push(self.tokenizer.encode(&item)); |
|
} |
|
|
|
|
|
let mask = generate_padding_mask( |
|
self.tokenizer.pad_token(), |
|
tokens_list, |
|
Some(self.max_seq_length), |
|
device, |
|
); |
|
|
|
|
|
TextClassificationInferenceBatch { |
|
tokens: mask.tensor.to_device(device), |
|
mask_pad: mask.mask.to_device(device), |
|
} |
|
} |
|
} |
|
|