2-adapter-tuning-initial-impl (#30)
Browse files- 2 adapter tuning (3fd28cf83a7aeb3b39b4da99337ae29c84f1b424)
Co-authored-by: Jack Min Ong <[email protected]>
- block.py +11 -1
- embedding.py +26 -4
- mha.py +37 -5
- mlp.py +21 -3
- modeling_lora.py +0 -1
- modeling_xlm_roberta.py +18 -5
block.py
CHANGED
|
@@ -233,7 +233,17 @@ class Block(nn.Module):
|
|
| 233 |
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
| 234 |
)
|
| 235 |
if not isinstance(self.mlp, nn.Identity):
|
| 236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
if self.return_residual: # mlp out is actually a pair here
|
| 238 |
mlp_out, hidden_states = mlp_out
|
| 239 |
if not self.fused_dropout_add_ln:
|
|
|
|
| 233 |
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
| 234 |
)
|
| 235 |
if not isinstance(self.mlp, nn.Identity):
|
| 236 |
+
task_type = mixer_kwargs.get('task_type')
|
| 237 |
+
if task_type:
|
| 238 |
+
if isinstance(task_type, tuple):
|
| 239 |
+
assert mixer_kwargs['cu_seqlens'].shape[0] % 9 == 1
|
| 240 |
+
split_index = int((mixer_kwargs['cu_seqlens'].shape[0] - 1) / 9)
|
| 241 |
+
split = mixer_kwargs['cu_seqlens'][split_index]
|
| 242 |
+
mlp_out = self.mlp(hidden_states, task_type=mixer_kwargs.get('task_type'), split=split)
|
| 243 |
+
else:
|
| 244 |
+
mlp_out = self.mlp(hidden_states, task_type=task_type)
|
| 245 |
+
else:
|
| 246 |
+
mlp_out = self.mlp(hidden_states)
|
| 247 |
if self.return_residual: # mlp out is actually a pair here
|
| 248 |
mlp_out, hidden_states = mlp_out
|
| 249 |
if not self.fused_dropout_add_ln:
|
embedding.py
CHANGED
|
@@ -47,8 +47,18 @@ class XLMRobertaEmbeddings(nn.Module):
|
|
| 47 |
token_type_ids: (batch, seqlen)
|
| 48 |
"""
|
| 49 |
batch_size, seqlen = input_ids.shape
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
if self.max_position_embeddings > 0:
|
| 53 |
if position_ids is None:
|
| 54 |
position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
|
|
@@ -58,6 +68,18 @@ class XLMRobertaEmbeddings(nn.Module):
|
|
| 58 |
if self.type_vocab_size > 0:
|
| 59 |
if token_type_ids is None:
|
| 60 |
token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
return embeddings
|
|
|
|
| 47 |
token_type_ids: (batch, seqlen)
|
| 48 |
"""
|
| 49 |
batch_size, seqlen = input_ids.shape
|
| 50 |
+
if isinstance(task_type, tuple):
|
| 51 |
+
assert input_ids.shape[0] % 9 == 0
|
| 52 |
+
split = int(input_ids.shape[0] / 9)
|
| 53 |
+
tensor1 = input_ids[:split, :]
|
| 54 |
+
tensor2 = input_ids[split:, :]
|
| 55 |
+
emb1 = self.word_embeddings(tensor1, task_type=task_type[0])
|
| 56 |
+
emb2 = self.word_embeddings(tensor2, task_type=task_type[1])
|
| 57 |
+
embeddings = torch.cat((emb1, emb2), dim=0)
|
| 58 |
+
else:
|
| 59 |
+
lora_kwargs = {'task_type': task_type} if task_type is not None else {}
|
| 60 |
+
embeddings = self.word_embeddings(input_ids, **lora_kwargs)
|
| 61 |
+
|
| 62 |
if self.max_position_embeddings > 0:
|
| 63 |
if position_ids is None:
|
| 64 |
position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
|
|
|
|
| 68 |
if self.type_vocab_size > 0:
|
| 69 |
if token_type_ids is None:
|
| 70 |
token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
|
| 71 |
+
if isinstance(task_type, tuple):
|
| 72 |
+
assert embeddings.shape[0] % 9 == 0
|
| 73 |
+
split = int(embeddings.shape[0] / 9)
|
| 74 |
+
emb1 = embeddings[:split, :, :]
|
| 75 |
+
emb2 = embeddings[split:, :, :]
|
| 76 |
+
token_type_embs1 = self.token_type_embeddings(token_type_ids, task_type=task_type[0])
|
| 77 |
+
token_type_embs2 = self.token_type_embeddings(token_type_ids, task_type=task_type[1])
|
| 78 |
+
emb1 = emb1 + token_type_embs1
|
| 79 |
+
emb2 = emb2 + token_type_embs2
|
| 80 |
+
embeddings = torch.cat((emb1, emb2), dim=0)
|
| 81 |
+
else:
|
| 82 |
+
lora_kwargs = {'task_type': task_type} if task_type is not None else {}
|
| 83 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids, **lora_kwargs)
|
| 84 |
+
embeddings = embeddings + token_type_embeddings
|
| 85 |
return embeddings
|
mha.py
CHANGED
|
@@ -643,15 +643,39 @@ class MHA(nn.Module):
|
|
| 643 |
inference_params.max_sequence_len if inference_params is not None else max_seqlen
|
| 644 |
)
|
| 645 |
batch, seqlen = x.shape[:2]
|
|
|
|
| 646 |
if not self.cross_attn and self.num_heads_kv == self.num_heads:
|
| 647 |
assert x_kv is None and mixer_subset is None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 648 |
lora_kwargs = {'task_type': task_type} if task_type is not None else {}
|
|
|
|
| 649 |
if not self.return_residual:
|
| 650 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 651 |
else:
|
| 652 |
-
if
|
| 653 |
-
|
| 654 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 655 |
|
| 656 |
if self.dwconv:
|
| 657 |
qkv = rearrange(
|
|
@@ -739,5 +763,13 @@ class MHA(nn.Module):
|
|
| 739 |
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
| 740 |
|
| 741 |
lora_kwargs.pop('residual', None)
|
| 742 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 743 |
return out if not self.return_residual else (out, x)
|
|
|
|
| 643 |
inference_params.max_sequence_len if inference_params is not None else max_seqlen
|
| 644 |
)
|
| 645 |
batch, seqlen = x.shape[:2]
|
| 646 |
+
lora_kwargs = {}
|
| 647 |
if not self.cross_attn and self.num_heads_kv == self.num_heads:
|
| 648 |
assert x_kv is None and mixer_subset is None
|
| 649 |
+
|
| 650 |
+
split = None
|
| 651 |
+
if isinstance(task_type, tuple):
|
| 652 |
+
assert cu_seqlens.shape[0] % 9 == 1
|
| 653 |
+
split_index = int((cu_seqlens.shape[0] - 1) / 9)
|
| 654 |
+
split = cu_seqlens[split_index]
|
| 655 |
+
|
| 656 |
lora_kwargs = {'task_type': task_type} if task_type is not None else {}
|
| 657 |
+
|
| 658 |
if not self.return_residual:
|
| 659 |
+
if isinstance(task_type, tuple):
|
| 660 |
+
tensor1 = x[:split, :]
|
| 661 |
+
tensor2 = x[split:, :]
|
| 662 |
+
qkv1 = self.Wqkv(tensor1, task_type=task_type[0])
|
| 663 |
+
qkv2 = self.Wqkv(tensor2, task_type=task_type[1])
|
| 664 |
+
qkv = torch.cat((qkv1, qkv2), dim=0)
|
| 665 |
+
else:
|
| 666 |
+
qkv = self.Wqkv(x, **lora_kwargs)
|
| 667 |
else:
|
| 668 |
+
if isinstance(task_type, tuple):
|
| 669 |
+
tensor1 = x[:split, :]
|
| 670 |
+
tensor2 = x[split:, :]
|
| 671 |
+
qkv1, tensor1 = self.Wqkv(tensor1, task_type=task_type[0], residual=True)
|
| 672 |
+
qkv2, tensor2 = self.Wqkv(tensor2, task_type=task_type[1], residual=True)
|
| 673 |
+
qkv = torch.cat((qkv1, qkv2), dim=0)
|
| 674 |
+
x = torch.cat((tensor1, tensor2), dim=0)
|
| 675 |
+
else:
|
| 676 |
+
if lora_kwargs:
|
| 677 |
+
lora_kwargs['residual'] = True
|
| 678 |
+
qkv, x = self.Wqkv(x, **lora_kwargs)
|
| 679 |
|
| 680 |
if self.dwconv:
|
| 681 |
qkv = rearrange(
|
|
|
|
| 763 |
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
| 764 |
|
| 765 |
lora_kwargs.pop('residual', None)
|
| 766 |
+
inp = rearrange(context, "... h d -> ... (h d)")
|
| 767 |
+
if isinstance(task_type, tuple):
|
| 768 |
+
tensor1 = inp[:split, :]
|
| 769 |
+
tensor2 = inp[split:, :]
|
| 770 |
+
out1 = self.out_proj(tensor1, task_type=task_type[0])
|
| 771 |
+
out2 = self.out_proj(tensor2, task_type=task_type[1])
|
| 772 |
+
out = torch.cat((out1, out2), dim=0)
|
| 773 |
+
else:
|
| 774 |
+
out = self.out_proj(inp, **lora_kwargs)
|
| 775 |
return out if not self.return_residual else (out, x)
|
mlp.py
CHANGED
|
@@ -47,11 +47,29 @@ class Mlp(nn.Module):
|
|
| 47 |
self.activation = activation
|
| 48 |
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
| 49 |
|
| 50 |
-
def forward(self, x, task_type=None):
|
| 51 |
lora_kwargs = {'task_type': task_type} if task_type is not None else {}
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
y = self.activation(y)
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
return y if not self.return_residual else (y, x)
|
| 56 |
|
| 57 |
|
|
|
|
| 47 |
self.activation = activation
|
| 48 |
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
| 49 |
|
| 50 |
+
def forward(self, x, task_type=None, split=None):
|
| 51 |
lora_kwargs = {'task_type': task_type} if task_type is not None else {}
|
| 52 |
+
if split:
|
| 53 |
+
assert isinstance(task_type, tuple)
|
| 54 |
+
tensor1 = x[:split, :]
|
| 55 |
+
tensor2 = x[split:, :]
|
| 56 |
+
y1 = self.fc1(tensor1, task_type=task_type[0])
|
| 57 |
+
y2 = self.fc1(tensor2, task_type=task_type[1])
|
| 58 |
+
y = torch.cat((y1, y2), dim=0)
|
| 59 |
+
else:
|
| 60 |
+
y = self.fc1(x, **lora_kwargs)
|
| 61 |
+
|
| 62 |
y = self.activation(y)
|
| 63 |
+
|
| 64 |
+
if split:
|
| 65 |
+
assert isinstance(task_type, tuple)
|
| 66 |
+
tensor1 = y[:split, :]
|
| 67 |
+
tensor2 = y[split:, :]
|
| 68 |
+
y1 = self.fc2(tensor1, task_type=task_type[0])
|
| 69 |
+
y2 = self.fc2(tensor2, task_type=task_type[1])
|
| 70 |
+
y = torch.cat((y1, y2), dim=0)
|
| 71 |
+
else:
|
| 72 |
+
y = self.fc2(y, **lora_kwargs)
|
| 73 |
return y if not self.return_residual else (y, x)
|
| 74 |
|
| 75 |
|
modeling_lora.py
CHANGED
|
@@ -227,7 +227,6 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
| 227 |
roberta: Optional[XLMRobertaModel] = None
|
| 228 |
):
|
| 229 |
super().__init__(config)
|
| 230 |
-
|
| 231 |
if roberta is None:
|
| 232 |
self.roberta = XLMRobertaModel(config)
|
| 233 |
else:
|
|
|
|
| 227 |
roberta: Optional[XLMRobertaModel] = None
|
| 228 |
):
|
| 229 |
super().__init__(config)
|
|
|
|
| 230 |
if roberta is None:
|
| 231 |
self.roberta = XLMRobertaModel(config)
|
| 232 |
else:
|
modeling_xlm_roberta.py
CHANGED
|
@@ -210,10 +210,12 @@ class XLMRobertaEncoder(nn.Module):
|
|
| 210 |
subset_mask: (batch, seqlen), dtype=torch.bool
|
| 211 |
"""
|
| 212 |
if key_padding_mask is None or not self.use_flash_attn:
|
| 213 |
-
mixer_kwargs =
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
|
|
|
|
|
|
| 217 |
for layer in self.layers:
|
| 218 |
if self._grad_checkpointing:
|
| 219 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
|
@@ -314,7 +316,18 @@ class XLMRobertaPooler(nn.Module):
|
|
| 314 |
lora_kwargs = {'task_type': task_type} if task_type is not None else {}
|
| 315 |
|
| 316 |
first_token_tensor = hidden_states[:, 0] if pool else hidden_states
|
| 317 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
pooled_output = self.activation(pooled_output)
|
| 319 |
return pooled_output
|
| 320 |
|
|
|
|
| 210 |
subset_mask: (batch, seqlen), dtype=torch.bool
|
| 211 |
"""
|
| 212 |
if key_padding_mask is None or not self.use_flash_attn:
|
| 213 |
+
mixer_kwargs = (
|
| 214 |
+
{"key_padding_mask": key_padding_mask.bool()}
|
| 215 |
+
if key_padding_mask is not None
|
| 216 |
+
else None
|
| 217 |
+
)
|
| 218 |
+
mixer_kwargs['task_type'] = task_type
|
| 219 |
for layer in self.layers:
|
| 220 |
if self._grad_checkpointing:
|
| 221 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
|
|
|
| 316 |
lora_kwargs = {'task_type': task_type} if task_type is not None else {}
|
| 317 |
|
| 318 |
first_token_tensor = hidden_states[:, 0] if pool else hidden_states
|
| 319 |
+
|
| 320 |
+
if isinstance(task_type, tuple):
|
| 321 |
+
assert first_token_tensor.shape[0] % 9 == 0
|
| 322 |
+
split = int(first_token_tensor.shape[0] / 9)
|
| 323 |
+
tensor1 = first_token_tensor[:split, :]
|
| 324 |
+
tensor2 = first_token_tensor[split:, :]
|
| 325 |
+
pooled_out1 = self.dense(tensor1, task_type=task_type[0])
|
| 326 |
+
pooled_out2 = self.dense(tensor2, task_type=task_type[0])
|
| 327 |
+
pooled_output = torch.cat((pooled_out1, pooled_out2), dim=0)
|
| 328 |
+
else:
|
| 329 |
+
pooled_output = self.dense(first_token_tensor, **lora_kwargs)
|
| 330 |
+
|
| 331 |
pooled_output = self.activation(pooled_output)
|
| 332 |
return pooled_output
|
| 333 |
|