|
import torch |
|
|
|
from transformers import PreTrainedModel, AutoTokenizer |
|
from transformers.configuration_utils import PretrainedConfig |
|
|
|
from .th_encoder import ThaiEncoder |
|
from .projector_residual import HeadProjectorResidual |
|
|
|
|
|
class WangchanbertaEncoderModel(PreTrainedModel): |
|
def __init__(self, config: PretrainedConfig): |
|
super().__init__(config) |
|
self.text_tokenizer = AutoTokenizer.from_pretrained(config.th_model_base) |
|
self.text_encoder = ThaiEncoder(model_name=config.th_model_base) |
|
self.text_projector = HeadProjectorResidual( |
|
input_embedding_dim=config.input_text_embedding_dim, |
|
output_embedding_dim=config.output_embedding_dim, |
|
dropout=config.dropout |
|
) |
|
self.max_length = 200 |
|
|
|
def forward(self, text: str): |
|
tokened_word = self.text_tokenizer(text, padding='max_length', truncation=True, max_length=self.max_length) |
|
text_vector = self.text_encoder( |
|
input_ids=torch.tensor([tokened_word["input_ids"]]), |
|
attention_mask=torch.tensor([tokened_word["attention_mask"]]) |
|
) |
|
text_projected = self.text_projector(text_vector) |
|
return text_projected |