refactor: clean up num_examples code (#14)
Browse files- refactor: clean up num_examples code (b67df152d4d5143d84ff7c023a245d301ecccb22)
- custom_st.py +1 -5
custom_st.py
CHANGED
|
@@ -104,11 +104,7 @@ class Transformer(nn.Module):
|
|
| 104 |
adapter_mask = None
|
| 105 |
if task_type:
|
| 106 |
task_id = self._adaptation_map[task_type]
|
| 107 |
-
num_examples =
|
| 108 |
-
if isinstance(features['input_ids'][0], list):
|
| 109 |
-
# If input_ids[0] is a list, it means multiple inputs (list of texts)
|
| 110 |
-
num_examples = len(features['input_ids'])
|
| 111 |
-
|
| 112 |
adapter_mask = torch.full(
|
| 113 |
(num_examples,), task_id, dtype=torch.int32, device=features['input_ids'].device
|
| 114 |
)
|
|
|
|
| 104 |
adapter_mask = None
|
| 105 |
if task_type:
|
| 106 |
task_id = self._adaptation_map[task_type]
|
| 107 |
+
num_examples = features['input_ids'].size(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
adapter_mask = torch.full(
|
| 109 |
(num_examples,), task_id, dtype=torch.int32, device=features['input_ids'].device
|
| 110 |
)
|