Correct the output dtype of rmsnorm_func
Browse filesCurrently the output dtype of `rmsnorm_func` is not the same as the input dtype, I'm not sure if this is the intended behaviour but this looks like a bug.
How to reproduce:
```
import torch
hidden_size = 8
hidden_states = torch.rand((4, hidden_size), dtype=torch.float16)
weight = torch.ones(hidden_size, dtype=torch.float32)
variance_epsilon = torch.tensor(1e-6)
def rmsnorm_func(hidden_states, weight, variance_epsilon):
    input_dtype = hidden_states.dtype
    hidden_states = hidden_states.to(torch.float32)
    variance = hidden_states.pow(2).mean(-1, keepdim=True)
    hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
    return weight * hidden_states.to(input_dtype)
print('input', hidden_states.dtype)
print('output', rmsnorm_func(hidden_states, weight, variance_epsilon).dtype)
```
Result:
```
input torch.float16
output torch.float32
```
With this PR: 
```
input torch.float16
output torch.float16
```
- modeling_flash_llama.py +1 -1
| @@ -68,7 +68,7 @@ def rmsnorm_func(hidden_states, weight, variance_epsilon): | |
| 68 | 
             
                hidden_states = hidden_states.to(torch.float32)
         | 
| 69 | 
             
                variance = hidden_states.pow(2).mean(-1, keepdim=True)
         | 
| 70 | 
             
                hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
         | 
| 71 | 
            -
                return weight * hidden_states.to(input_dtype)
         | 
| 72 |  | 
| 73 |  | 
| 74 | 
             
            class LlamaRMSNorm(nn.Module):
         | 
|  | |
| 68 | 
             
                hidden_states = hidden_states.to(torch.float32)
         | 
| 69 | 
             
                variance = hidden_states.pow(2).mean(-1, keepdim=True)
         | 
| 70 | 
             
                hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
         | 
| 71 | 
            +
                return (weight * hidden_states).to(input_dtype)
         | 
| 72 |  | 
| 73 |  | 
| 74 | 
             
            class LlamaRMSNorm(nn.Module):
         | 
