whangchanberta-multimodal-distill / wangchanberta_cross_clip.py
Nachaphat's picture
Upload model
0144345
import torch
from transformers import PreTrainedModel, AutoTokenizer
from transformers.configuration_utils import PretrainedConfig
from .th_encoder import ThaiEncoder
from .projector_residual import HeadProjectorResidual
class WangchanbertaEncoderModel(PreTrainedModel):
def __init__(self, config: PretrainedConfig):
super().__init__(config)
self.text_tokenizer = AutoTokenizer.from_pretrained(config.th_model_base)
self.text_encoder = ThaiEncoder(model_name=config.th_model_base)
self.text_projector = HeadProjectorResidual(
input_embedding_dim=config.input_text_embedding_dim,
output_embedding_dim=config.output_embedding_dim,
dropout=config.dropout
)
self.max_length = 200
def forward(self, text: str):
tokened_word = self.text_tokenizer(text, padding='max_length', truncation=True, max_length=self.max_length)
text_vector = self.text_encoder(
input_ids=torch.tensor([tokened_word["input_ids"]]),
attention_mask=torch.tensor([tokened_word["attention_mask"]])
)
text_projected = self.text_projector(text_vector)
return text_projected