Spaces:
Running
Running
| import copy | |
| import torch | |
| from torch import nn | |
| from transformers import AutoModel | |
| from torch.optim import AdamW | |
| from transformers import get_linear_schedule_with_warmup | |
| # from torchcrf import CRF | |
| class MyModel(nn.Module): | |
| def __init__(self, args, backbone): | |
| super().__init__() | |
| self.args = args | |
| self.backbone = backbone | |
| self.cls_id = 0 | |
| hidden_dim = self.backbone.config.hidden_size | |
| self.classifier = nn.Sequential( | |
| nn.Dropout(0.1), | |
| nn.Linear(hidden_dim, args.num_labels) | |
| ) | |
| if args.distil_att: | |
| self.distil_att = nn.Parameter(torch.ones(self.backbone.config.hidden_size)) | |
| def forward(self, x, mask): | |
| x = x.to(self.backbone.device) | |
| mask = mask.to(self.backbone.device) | |
| out = self.backbone(x, attention_mask = mask, output_attentions=True) | |
| return out, self.classifier(out.last_hidden_state) | |
| def decisions(self, x, mask): | |
| x = x.to(self.backbone.device) | |
| mask = mask.to(self.backbone.device) | |
| out = self.backbone(x, attention_mask = mask, output_attentions=False) | |
| return out, self.classifier(out.last_hidden_state) | |
| def phenos(self, x, mask): | |
| x = x.to(self.backbone.device) | |
| mask = mask.to(self.backbone.device) | |
| out = self.backbone(x, attention_mask = mask, output_attentions=True) | |
| return out, self.classifier(out.pooler_output) | |
| def generate(self, x, mask, choice=None): | |
| outs = [] | |
| if self.args.task == 'seq' or choice == 'seq': | |
| for i, offset in enumerate(range(0, x.shape[1], self.args.max_len-1)): | |
| if i == 0: | |
| segment = x[:, offset:offset + self.args.max_len-1] | |
| segment_mask = mask[:, offset:offset + self.args.max_len-1] | |
| else: | |
| segment = torch.cat((torch.ones((x.shape[0], 1), dtype=int).to(x.device)\ | |
| *self.cls_id, | |
| x[:, offset:offset + self.args.max_len-1]), axis=1) | |
| segment_mask = torch.cat((torch.ones((mask.shape[0], 1)).to(mask.device), | |
| mask[:, offset:offset + self.args.max_len-1]), axis=1) | |
| logits = self.phenos(segment, segment_mask)[1] | |
| outs.append(logits) | |
| return torch.max(torch.stack(outs, 1), 1).values | |
| elif self.args.task == 'token': | |
| for i, offset in enumerate(range(0, x.shape[1], self.args.max_len)): | |
| segment = x[:, offset:offset + self.args.max_len] | |
| segment_mask = mask[:, offset:offset + self.args.max_len] | |
| h = self.decisions(segment, segment_mask)[0].last_hidden_state | |
| outs.append(h) | |
| h = torch.cat(outs, 1) | |
| return self.classifier(h) | |
| class CNN(nn.Module): | |
| def __init__(self, args): | |
| super().__init__() | |
| self.emb = nn.Embedding(args.vocab_size, args.emb_size) | |
| self.model = nn.Sequential( | |
| nn.Conv1d(args.emb_size, args.hidden_size, args.kernels[0], | |
| padding='same' if args.task == 'token' else 'valid'), | |
| nn.ReLU(), | |
| nn.MaxPool1d(1), | |
| nn.Conv1d(args.hidden_size, args.hidden_size, args.kernels[1], | |
| padding='same' if args.task == 'token' else 'valid'), | |
| nn.ReLU(), | |
| nn.MaxPool1d(1), | |
| nn.Conv1d(args.hidden_size, args.hidden_size, args.kernels[2], | |
| padding='same' if args.task == 'token' else 'valid'), | |
| nn.ReLU(), | |
| nn.MaxPool1d(1), | |
| ) | |
| if args.task == 'seq': | |
| out_shape = 512 - args.kernels[0] - args.kernels[1] - args.kernels[2] + 3 | |
| elif args.task == 'token': | |
| out_shape = 1 | |
| self.classifier = nn.Linear(args.hidden_size*out_shape, args.num_labels) | |
| self.dropout = nn.Dropout() | |
| self.args = args | |
| self.device = None | |
| def forward(self, x, _): | |
| x = x.to(self.device) | |
| bs = x.shape[0] | |
| x = self.emb(x) | |
| x = x.transpose(1,2) | |
| x = self.model(x) | |
| x = self.dropout(x) | |
| if self.args.task == 'token': | |
| x = x.transpose(1,2) | |
| h = self.classifier(x) | |
| return x, h | |
| elif self.args.task == 'seq': | |
| x = x.reshape(bs, -1) | |
| x = self.classifier(x) | |
| return x | |
| def generate(self, x, _): | |
| outs = [] | |
| for i, offset in enumerate(range(0, x.shape[1], self.args.max_len)): | |
| segment = x[:, offset:offset + self.args.max_len] | |
| n = segment.shape[1] | |
| if n != self.args.max_len: | |
| segment = torch.nn.functional.pad(segment, (0, self.args.max_len - n)) | |
| if self.args.task == 'seq': | |
| logits = self(segment, None) | |
| outs.append(logits) | |
| elif self.args.task == 'token': | |
| h = self(segment, None)[0] | |
| h = h[:,:n] | |
| outs.append(h) | |
| if self.args.task == 'seq': | |
| return torch.max(torch.stack(outs, 1), 1).values | |
| elif self.args.task == 'token': | |
| h = torch.cat(outs, 1) | |
| return self.classifier(h) | |
| class LSTM(nn.Module): | |
| def __init__(self, args): | |
| super().__init__() | |
| self.emb = nn.Embedding(args.vocab_size, args.emb_size) | |
| self.model = nn.LSTM(args.emb_size, args.hidden_size, num_layers=args.num_layers, | |
| batch_first=True, bidirectional=True) | |
| dim = 2*args.num_layers*args.hidden_size if args.task == 'seq' else 2*args.hidden_size | |
| self.classifier = nn.Linear(dim, args.num_labels) | |
| self.dropout = nn.Dropout() | |
| self.args = args | |
| self.device = None | |
| def forward(self, x, _): | |
| x = x.to(self.device) | |
| x = self.emb(x) | |
| o, (x, _) = self.model(x) | |
| o_out = self.classifier(o) if self.args.task == 'token' else None | |
| if self.args.task == 'seq': | |
| x = torch.cat([h for h in x], 1) | |
| x = self.dropout(x) | |
| x = self.classifier(x) | |
| return (x, o), o_out | |
| def generate(self, x, _): | |
| outs = [] | |
| for i, offset in enumerate(range(0, x.shape[1], self.args.max_len)): | |
| segment = x[:, offset:offset + self.args.max_len] | |
| if self.args.task == 'seq': | |
| logits = self(segment, None)[0][0] | |
| outs.append(logits) | |
| elif self.args.task == 'token': | |
| h = self(segment, None)[0][1] | |
| outs.append(h) | |
| if self.args.task == 'seq': | |
| return torch.max(torch.stack(outs, 1), 1).values | |
| elif self.args.task == 'token': | |
| h = torch.cat(outs, 1) | |
| return self.classifier(h) | |
| def load_model(args, device): | |
| if args.model == 'lstm': | |
| model = LSTM(args).to(device) | |
| model.device = device | |
| elif args.model == 'cnn': | |
| model = CNN(args).to(device) | |
| model.device = device | |
| else: | |
| model = MyModel(args, AutoModel.from_pretrained(args.model_name)).to(device) | |
| if args.ckpt: | |
| model.load_state_dict(torch.load(args.ckpt, map_location=device), strict=False) | |
| if args.distil: | |
| args2 = copy.deepcopy(args) | |
| args2.task = 'token' | |
| # args2.num_labels = args.num_decs | |
| args2.num_labels = args.num_umls_tags | |
| model_B = MyModel(args2, AutoModel.from_pretrained(args.model_name)).to(device) | |
| model_B.load_state_dict(torch.load(args.distil_ckpt, map_location=device), strict=False) | |
| for p in model_B.parameters(): | |
| p.requires_grad = False | |
| else: | |
| model_B = None | |
| if args.label_encoding == 'multiclass': | |
| if args.use_crf: | |
| crit = CRF(args.num_labels, batch_first = True).to(device) | |
| else: | |
| crit = nn.CrossEntropyLoss(reduction='none') | |
| else: | |
| crit = nn.BCEWithLogitsLoss( | |
| pos_weight=torch.ones(args.num_labels).to(device)*args.pos_weight, | |
| reduction='none' | |
| ) | |
| optimizer = AdamW(model.parameters(), lr=args.lr) | |
| lr_scheduler = get_linear_schedule_with_warmup(optimizer, | |
| int(0.1*args.total_steps), args.total_steps) | |
| return model, crit, optimizer, lr_scheduler, model_B | |