MathBite commited on
Commit
d048846
·
verified ·
1 Parent(s): fef3dc9

updated inference logic to make more robust correction

Browse files
Files changed (1) hide show
  1. modeling.py +15 -3
modeling.py CHANGED
@@ -101,10 +101,22 @@ class SelfCorrectiveLlama(LlamaForCausalLM):
101
  deletion_tokens_boost,
102
  torch.zeros_like(deletion_tokens_boost)
103
  )
104
- logits[:, :, -self.num_new_tokens:].add_(to_add)
105
  else:
106
- # Inference case: always add the deletion logits to the token logits
107
- logits[:, :, -self.num_new_tokens:].add_(deletion_tokens_boost)
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  # 6. Return the custom output object
110
  return SelfCorrectiveLlamaOutput(
 
101
  deletion_tokens_boost,
102
  torch.zeros_like(deletion_tokens_boost)
103
  )
 
104
  else:
105
+ # Inference case: The hallucination detector's decision becomes a hard gate.
106
+ hallucination_decision = torch.argmax(all_hallucination_logits, dim=-1)
107
+
108
+ # Create a mask that is True only when a hallucination is detected (decision != 0)
109
+ hallucination_present_mask = (hallucination_decision != 0).unsqueeze(-1)
110
+
111
+ # Where the mask is True, use the softplus boost.
112
+ # Where the mask is False, use a large negative value to suppress deletion.
113
+ to_add = torch.where(
114
+ hallucination_present_mask,
115
+ deletion_tokens_boost,
116
+ torch.full_like(deletion_tokens_boost, -1e9) # Suppress if no hallucination
117
+ )
118
+
119
+ logits[:, :, -self.num_new_tokens:].add_(to_add)
120
 
121
  # 6. Return the custom output object
122
  return SelfCorrectiveLlamaOutput(