from transformers import RoFormerConfig, RoFormerForMaskedLM import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP import torch class Roformer(nn.Module): def __init__(self, config, tokenizer): super(Roformer, self).__init__() self.tokenizer = tokenizer self.vocab_size = self.tokenizer.vocab_size device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = device roformer_config = RoFormerConfig( vocab_size=self.tokenizer.vocab_size, embedding_size=config.roformer.hidden_size, hidden_size=config.roformer.hidden_size, num_hidden_layers=config.roformer.n_layers, num_attention_heads=config.roformer.n_heads, intermediate_size=config.roformer.hidden_size * 4, max_position_embeddings=config.roformer.max_position_embeddings, hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, pad_token_id=0, rotary_value=False ) self.model = RoFormerForMaskedLM(roformer_config).to(self.device) def freeze_model(self): for param in self.model.parameters(): param.requires_grad = False def unfreeze_all_layers(self): for param in self.model.parameters(): param.requires_grad = True def unfreeze_n_layers(self, n): num_layers = 8 for i, layer in enumerate(self.model.roformer.encoder.layer): # finetune final n layers if i >= num_layers - n: # unfreeze query weights for module in layer.attention.self.query.modules(): for param in module.parameters(): param.requires_grad = True # unfreeze key weights for module in layer.attention.self.key.modules(): for param in module.parameters(): param.requires_grad = True def forward(self, input_ids, attn_mask): input_ids = input_ids.to(self.device) attn_mask = attn_mask.to(self.device) # get logits embeddings logits = self.model(input_ids=input_ids, attention_mask=attn_mask) # return logits #print(logits.logits) return logits.logits def save_model(self, save_dir): self.model.save_pretrained(save_dir) self.tokenizer.save_pretrained(save_dir) @classmethod def load_model(cls, save_dir, config, tokenizer): roformer = cls(config, tokenizer) roformer.model = RoFormerForMaskedLM.from_pretrained(save_dir) return roformer