Update README.md
Browse files
README.md
CHANGED
|
@@ -15,11 +15,10 @@ cd MemoryLLM
|
|
| 15 |
Then simply use the following code to load the model:
|
| 16 |
```python
|
| 17 |
import torch
|
| 18 |
-
from modeling_memoryllm import MemoryLLM
|
| 19 |
from transformers import AutoTokenizer
|
|
|
|
| 20 |
|
| 21 |
# load the model mplus-8b (currently we only have the pretrained version)
|
| 22 |
-
from modeling_mplus import MPlus
|
| 23 |
model = MPlus.from_pretrained("YuWangX/mplus-8b", attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16)
|
| 24 |
tokenizer = AutoTokenizer.from_pretrained("YuWangX/mplus-8b")
|
| 25 |
model = model.to(torch.bfloat16) # need to call it again to cast the `inv_freq` in rotary_emb to bfloat16 as well
|
|
|
|
| 15 |
Then simply use the following code to load the model:
|
| 16 |
```python
|
| 17 |
import torch
|
|
|
|
| 18 |
from transformers import AutoTokenizer
|
| 19 |
+
from modeling_mplus import MPlus
|
| 20 |
|
| 21 |
# load the model mplus-8b (currently we only have the pretrained version)
|
|
|
|
| 22 |
model = MPlus.from_pretrained("YuWangX/mplus-8b", attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16)
|
| 23 |
tokenizer = AutoTokenizer.from_pretrained("YuWangX/mplus-8b")
|
| 24 |
model = model.to(torch.bfloat16) # need to call it again to cast the `inv_freq` in rotary_emb to bfloat16 as well
|