x54-729
		
	commited on
		
		
					Commit 
							
							·
						
						e96c608
	
1
								Parent(s):
							
							9195687
								
update modeling file to newest
Browse files- configuration_internlm2.py +1 -1
- modeling_internlm2.py +9 -1
    	
        configuration_internlm2.py
    CHANGED
    
    | @@ -177,4 +177,4 @@ class InternLM2Config(PretrainedConfig): | |
| 177 | 
             
                        raise ValueError(
         | 
| 178 | 
             
                            f"`rope_scaling`'s factor field must be a number >= 1, got {rope_scaling_factor} "
         | 
| 179 | 
             
                            f"of type {type(rope_scaling_factor)}"
         | 
| 180 | 
            -
                        )
         | 
|  | |
| 177 | 
             
                        raise ValueError(
         | 
| 178 | 
             
                            f"`rope_scaling`'s factor field must be a number >= 1, got {rope_scaling_factor} "
         | 
| 179 | 
             
                            f"of type {type(rope_scaling_factor)}"
         | 
| 180 | 
            +
                        )
         | 
    	
        modeling_internlm2.py
    CHANGED
    
    | @@ -59,6 +59,10 @@ try: | |
| 59 | 
             
            except:
         | 
| 60 | 
             
                pass
         | 
| 61 |  | 
|  | |
|  | |
|  | |
|  | |
| 62 |  | 
| 63 | 
             
            logger = logging.get_logger(__name__)
         | 
| 64 |  | 
| @@ -1093,7 +1097,11 @@ class InternLM2Model(InternLM2PreTrainedModel): | |
| 1093 | 
             
                    else:
         | 
| 1094 | 
             
                        causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
         | 
| 1095 | 
             
                        if sequence_length != 1:
         | 
| 1096 | 
            -
                             | 
|  | |
|  | |
|  | |
|  | |
| 1097 | 
             
                        causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
         | 
| 1098 | 
             
                        causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
         | 
| 1099 | 
             
                        if attention_mask is not None:
         | 
|  | |
| 59 | 
             
            except:
         | 
| 60 | 
             
                pass
         | 
| 61 |  | 
| 62 | 
            +
            try:
         | 
| 63 | 
            +
                support_bf16_triu = torch.__version__ >= "2.1.0"
         | 
| 64 | 
            +
            except Exception:
         | 
| 65 | 
            +
                support_bf16_triu = False
         | 
| 66 |  | 
| 67 | 
             
            logger = logging.get_logger(__name__)
         | 
| 68 |  | 
|  | |
| 1097 | 
             
                    else:
         | 
| 1098 | 
             
                        causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
         | 
| 1099 | 
             
                        if sequence_length != 1:
         | 
| 1100 | 
            +
                            if support_bf16_triu or dtype == torch.float32:
         | 
| 1101 | 
            +
                                causal_mask = torch.triu(causal_mask, diagonal=1)
         | 
| 1102 | 
            +
                            else:
         | 
| 1103 | 
            +
                                triu_mask = torch.triu(torch.ones(causal_mask.size(), device=device), diagonal=1).bool()
         | 
| 1104 | 
            +
                                causal_mask.masked_fill_(~triu_mask, 0)
         | 
| 1105 | 
             
                        causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
         | 
| 1106 | 
             
                        causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
         | 
| 1107 | 
             
                        if attention_mask is not None:
         | 
