File size: 1,205 Bytes
0144345
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch

from transformers import PreTrainedModel, AutoTokenizer
from transformers.configuration_utils import PretrainedConfig

from .th_encoder import ThaiEncoder
from .projector_residual import HeadProjectorResidual


class WangchanbertaEncoderModel(PreTrainedModel):
    def __init__(self, config: PretrainedConfig):
        super().__init__(config)
        self.text_tokenizer = AutoTokenizer.from_pretrained(config.th_model_base)
        self.text_encoder = ThaiEncoder(model_name=config.th_model_base)
        self.text_projector = HeadProjectorResidual(
            input_embedding_dim=config.input_text_embedding_dim, 
            output_embedding_dim=config.output_embedding_dim,
            dropout=config.dropout
        )
        self.max_length = 200
        
    def forward(self, text: str):
        tokened_word = self.text_tokenizer(text, padding='max_length', truncation=True, max_length=self.max_length)
        text_vector = self.text_encoder(
            input_ids=torch.tensor([tokened_word["input_ids"]]),
            attention_mask=torch.tensor([tokened_word["attention_mask"]])
        )
        text_projected = self.text_projector(text_vector)
        return text_projected