Update modeling_prismatic.py to account for the case where `input_ids` is `None `
Browse filesInput Ids and Input Embeds are both marked `Optional[torch.LongTensor] = None,` however failing to pass in `input_ids` into the `forward()` method results in an error in the first block, since the code automatically checks if `input_ids.shape[1] == 1` without first checking to see if `input_ids is not None`. 
This pull request updates the logic to allow for this case in Generation with Cache and Multimodal Forward.
- modeling_prismatic.py +5 -2
    	
        modeling_prismatic.py
    CHANGED
    
    | @@ -322,7 +322,7 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel): | |
| 322 | 
             
                    #   => Multimodal Forward :: (pixel_values is not None) and (input_ids/embeds.shape[0] == pixel_values.shape[0])
         | 
| 323 |  | 
| 324 | 
             
                    # === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
         | 
| 325 | 
            -
                    if input_ids.shape[1] == 1:
         | 
| 326 | 
             
                        assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
         | 
| 327 | 
             
                        assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
         | 
| 328 | 
             
                        assert labels is None, "Unexpected key `labels` provided during cached generation!"
         | 
| @@ -359,7 +359,10 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel): | |
| 359 | 
             
                        )
         | 
| 360 |  | 
| 361 | 
             
                    # === Handle Multimodal Forward ===
         | 
| 362 | 
            -
                    elif ( | 
|  | |
|  | |
|  | |
| 363 | 
             
                        assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
         | 
| 364 |  | 
| 365 | 
             
                        # Visual Feature Extraction
         | 
|  | |
| 322 | 
             
                    #   => Multimodal Forward :: (pixel_values is not None) and (input_ids/embeds.shape[0] == pixel_values.shape[0])
         | 
| 323 |  | 
| 324 | 
             
                    # === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
         | 
| 325 | 
            +
                    if input_ids is not None and input_ids.shape[1] == 1:
         | 
| 326 | 
             
                        assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
         | 
| 327 | 
             
                        assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
         | 
| 328 | 
             
                        assert labels is None, "Unexpected key `labels` provided during cached generation!"
         | 
|  | |
| 359 | 
             
                        )
         | 
| 360 |  | 
| 361 | 
             
                    # === Handle Multimodal Forward ===
         | 
| 362 | 
            +
                    elif (
         | 
| 363 | 
            +
                            (input_ids is not None and input_ids.shape[0] == pixel_values.shape[0]) or
         | 
| 364 | 
            +
                            (inputs_embeds is not None and inputs_embeds.shape[0] == pixel_values.shape[0])
         | 
| 365 | 
            +
                        ):            
         | 
| 366 | 
             
                        assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
         | 
| 367 |  | 
| 368 | 
             
                        # Visual Feature Extraction
         | 
