import torch.nn as nn from transformers import AutoModel class ThaiEncoder(nn.Module): def __init__(self, model_name: str, trainable: bool = False) -> None: super().__init__() self.model = AutoModel.from_pretrained(model_name) for p in self.model.parameters(): p.requires_grad = trainable self.target_token_idx = 0 def forward(self, input_ids, attention_mask): output = self.model(input_ids=input_ids, attention_mask=attention_mask) last_hidden_state = output.last_hidden_state return last_hidden_state[:, self.target_token_idx, :]