|
from typing import List |
|
|
|
from transformers import PretrainedConfig |
|
|
|
|
|
class WangchanbertaEncoderConfig(PretrainedConfig): |
|
def __init__( |
|
self, |
|
th_model_base: str = "airesearch/wangchanberta-base-att-spm-uncased", |
|
input_text_embedding_dim: int = 768, |
|
output_embedding_dim: int = 512, |
|
dropout: float = 0.2, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.th_model_base = th_model_base |
|
self.input_text_embedding_dim = input_text_embedding_dim |
|
self.output_embedding_dim = output_embedding_dim |
|
self.dropout = dropout |