ariG23498 HF staff commited on
Commit
5af9f73
·
verified ·
1 Parent(s): 9805fb9

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +20 -1
README.md CHANGED
@@ -13,4 +13,23 @@ tags:
13
  - sft
14
  ---
15
 
16
- SFT with Layer Skip
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  - sft
14
  ---
15
 
16
+ SFT with Layer Skip.
17
+
18
+ ```
19
+ class LayerSkipSFTTrainer(SFTTrainer):
20
+ def __init__(self, *args, **kwargs):
21
+ super().__init__(*args, **kwargs)
22
+ self.early_exit_layer = 1
23
+
24
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
25
+ labels = inputs.pop("labels")
26
+ outputs = model(**inputs, output_hidden_states=True)
27
+
28
+ hidden_state = outputs["hidden_states"][self.early_exit_layer]
29
+ logits = model.lm_head(hidden_state)
30
+ loss = model.loss_function(logits=logits, labels=labels, vocab_size=model.vocab_size)
31
+
32
+ self.early_exit_layer = (self.early_exit_layer + 1) % model.config.num_hidden_layers
33
+
34
+ return loss
35
+ ```