import torch from transformers import PreTrainedModel, AutoTokenizer from transformers.configuration_utils import PretrainedConfig from .th_encoder import ThaiEncoder from .projector_residual import HeadProjectorResidual class WangchanbertaEncoderModel(PreTrainedModel): def __init__(self, config: PretrainedConfig): super().__init__(config) self.text_tokenizer = AutoTokenizer.from_pretrained(config.th_model_base) self.text_encoder = ThaiEncoder(model_name=config.th_model_base) self.text_projector = HeadProjectorResidual( input_embedding_dim=config.input_text_embedding_dim, output_embedding_dim=config.output_embedding_dim, dropout=config.dropout ) self.max_length = 200 def forward(self, text: str): tokened_word = self.text_tokenizer(text, padding='max_length', truncation=True, max_length=self.max_length) text_vector = self.text_encoder( input_ids=torch.tensor([tokened_word["input_ids"]]), attention_mask=torch.tensor([tokened_word["attention_mask"]]) ) text_projected = self.text_projector(text_vector) return text_projected