| from transformers.models.bert import BertConfig | |
| from transformers.models.xlm_roberta import XLMRobertaConfig | |
| def _init_function( | |
| self, | |
| entity_vocab_size: int | None = 10000, | |
| entity_embedding_size: int = 768, | |
| entity_fusion_method: str = "multihead_attention", | |
| use_entity_position_embeddings: bool = True, | |
| entity_fusion_activation: str = "softmax", | |
| num_entity_fusion_attention_heads: int = 12, | |
| similarity_function: str = "dot", | |
| similarity_temperature: float = 1.0, | |
| *args, | |
| **kwargs, | |
| ): | |
| self.entity_vocab_size = entity_vocab_size | |
| self.entity_embedding_size = entity_embedding_size | |
| self.entity_fusion_method = entity_fusion_method | |
| self.use_entity_position_embeddings = use_entity_position_embeddings | |
| self.entity_fusion_activation = entity_fusion_activation | |
| self.num_entity_fusion_attention_heads = num_entity_fusion_attention_heads | |
| self.similarity_function = similarity_function | |
| self.similarity_temperature = similarity_temperature | |
| super(self.__class__, self).__init__(*args, **kwargs) | |
| class KPRConfigForBert(BertConfig): | |
| __init__ = _init_function | |
| model_type = "kpr-bert" | |
| class KPRConfigForXLMRoberta(XLMRobertaConfig): | |
| __init__ = _init_function | |
| model_type = "kpr-xlm-roberta" | |
