Update modeling_internvl_chat.py
Browse files
modeling_internvl_chat.py
CHANGED
|
@@ -588,7 +588,7 @@ class InternVLChatModel(PreTrainedModel):
|
|
| 588 |
input_ids = input_ids.reshape(B * N)
|
| 589 |
|
| 590 |
relative_threshold_value = torch.quantile(gate_result[:, 0].to(torch.float32), self.flash_relative_threshold)
|
| 591 |
-
gate_result = (gate_result[:, 0]
|
| 592 |
|
| 593 |
selected_embeds = []
|
| 594 |
for i in range(gate_result.size(0)):
|
|
|
|
| 588 |
input_ids = input_ids.reshape(B * N)
|
| 589 |
|
| 590 |
relative_threshold_value = torch.quantile(gate_result[:, 0].to(torch.float32), self.flash_relative_threshold)
|
| 591 |
+
gate_result = (gate_result[:, 0] > relative_threshold_value) & (gate_result[:, 0] >= self.flash_absolute_threshold)
|
| 592 |
|
| 593 |
selected_embeds = []
|
| 594 |
for i in range(gate_result.size(0)):
|