Update modeling_rwkv5.py
Browse files- modeling_rwkv5.py +10 -2
modeling_rwkv5.py
CHANGED
@@ -747,8 +747,16 @@ class Rwkv5Model(Rwkv5PreTrainedModel):
|
|
747 |
block.attention.output.weight.mul_(2 ** int(block_id // self.config.rescale_every))
|
748 |
block.feed_forward.value.weight.mul_(2 ** int(block_id // self.config.rescale_every))
|
749 |
else:
|
750 |
-
|
751 |
-
block.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
752 |
|
753 |
self.layers_are_rescaled = not self.training
|
754 |
|
|
|
747 |
block.attention.output.weight.mul_(2 ** int(block_id // self.config.rescale_every))
|
748 |
block.feed_forward.value.weight.mul_(2 ** int(block_id // self.config.rescale_every))
|
749 |
else:
|
750 |
+
# Deal with quantization statistics
|
751 |
+
if hasattr(block.attention.output.weight, "SCB"):
|
752 |
+
block.attention.output.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
|
753 |
+
block.feed_forward.value.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
|
754 |
+
elif hasattr(block.attention.output.weight, "quant_state"):
|
755 |
+
self._bnb_4bit_dequantize_and_rescale(block.attention.output, block_id)
|
756 |
+
self._bnb_4bit_dequantize_and_rescale(block.feed_forward.value, block_id)
|
757 |
+
else:
|
758 |
+
block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every))
|
759 |
+
block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every))
|
760 |
|
761 |
self.layers_are_rescaled = not self.training
|
762 |
|