File size: 1,205 Bytes
0144345 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import torch
from transformers import PreTrainedModel, AutoTokenizer
from transformers.configuration_utils import PretrainedConfig
from .th_encoder import ThaiEncoder
from .projector_residual import HeadProjectorResidual
class WangchanbertaEncoderModel(PreTrainedModel):
def __init__(self, config: PretrainedConfig):
super().__init__(config)
self.text_tokenizer = AutoTokenizer.from_pretrained(config.th_model_base)
self.text_encoder = ThaiEncoder(model_name=config.th_model_base)
self.text_projector = HeadProjectorResidual(
input_embedding_dim=config.input_text_embedding_dim,
output_embedding_dim=config.output_embedding_dim,
dropout=config.dropout
)
self.max_length = 200
def forward(self, text: str):
tokened_word = self.text_tokenizer(text, padding='max_length', truncation=True, max_length=self.max_length)
text_vector = self.text_encoder(
input_ids=torch.tensor([tokened_word["input_ids"]]),
attention_mask=torch.tensor([tokened_word["attention_mask"]])
)
text_projected = self.text_projector(text_vector)
return text_projected |