Update modeling_prismatic.py to account for the case where `input_ids` is `None `
#5
by
eliotj
- opened
- modeling_prismatic.py +5 -2
modeling_prismatic.py
CHANGED
|
@@ -322,7 +322,7 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
|
|
| 322 |
# => Multimodal Forward :: (pixel_values is not None) and (input_ids/embeds.shape[0] == pixel_values.shape[0])
|
| 323 |
|
| 324 |
# === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
|
| 325 |
-
if input_ids.shape[1] == 1:
|
| 326 |
assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
|
| 327 |
assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
|
| 328 |
assert labels is None, "Unexpected key `labels` provided during cached generation!"
|
|
@@ -359,7 +359,10 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
|
|
| 359 |
)
|
| 360 |
|
| 361 |
# === Handle Multimodal Forward ===
|
| 362 |
-
elif (
|
|
|
|
|
|
|
|
|
|
| 363 |
assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
|
| 364 |
|
| 365 |
# Visual Feature Extraction
|
|
|
|
| 322 |
# => Multimodal Forward :: (pixel_values is not None) and (input_ids/embeds.shape[0] == pixel_values.shape[0])
|
| 323 |
|
| 324 |
# === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
|
| 325 |
+
if input_ids is not None and input_ids.shape[1] == 1:
|
| 326 |
assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
|
| 327 |
assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
|
| 328 |
assert labels is None, "Unexpected key `labels` provided during cached generation!"
|
|
|
|
| 359 |
)
|
| 360 |
|
| 361 |
# === Handle Multimodal Forward ===
|
| 362 |
+
elif (
|
| 363 |
+
(input_ids is not None and input_ids.shape[0] == pixel_values.shape[0]) or
|
| 364 |
+
(inputs_embeds is not None and inputs_embeds.shape[0] == pixel_values.shape[0])
|
| 365 |
+
):
|
| 366 |
assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
|
| 367 |
|
| 368 |
# Visual Feature Extraction
|