jordiclive
commited on
Commit
•
2e57dd6
1
Parent(s):
2069da8
Update README.md
Browse files
README.md
CHANGED
@@ -84,7 +84,7 @@ repo_id = "jordiclive/lora-llama-33B-alpaca_gpt4-dolly_15k-vicuna-r64"
|
|
84 |
base_model = "decapoda-research/llama-30b-hf"
|
85 |
|
86 |
# Model Loading
|
87 |
-
def
|
88 |
old_embeddings = model.get_input_embeddings()
|
89 |
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
|
90 |
new_embeddings = torch.nn.Embedding(old_num_tokens, old_embedding_dim)
|
@@ -93,16 +93,17 @@ def transfer_embeddings(model, embed_path, tokenizer):
|
|
93 |
embed_weights = torch.load(embed_path, map_location=old_embeddings.weight.device)
|
94 |
vocab_size = tokenizer.vocab_size
|
95 |
new_embeddings.weight.data[:vocab_size, :] = old_embeddings.weight.data[:vocab_size, :]
|
96 |
-
new_embeddings.weight.data[vocab_size : vocab_size + embed_weights.shape[0], :] = embed_weights.
|
97 |
new_embeddings.weight.dtype
|
98 |
).to(new_embeddings.weight.device)
|
99 |
model.set_input_embeddings(new_embeddings)
|
100 |
model.tie_weights()
|
101 |
|
102 |
|
|
|
103 |
def load_peft_model(model, peft_model_path, tokenizer):
|
104 |
embed_weights = hf_hub_download(peft_model_path, "extra_embeddings.pt")
|
105 |
-
model.resize_token_embeddings(tokenizer.vocab_size + embed_weights.shape[0])
|
106 |
model.config.eos_token_id = tokenizer.eos_token_id
|
107 |
model.config.bos_token_id = tokenizer.bos_token_id
|
108 |
model.config.pad_token_id = tokenizer.pad_token_id
|
@@ -112,20 +113,22 @@ def load_peft_model(model, peft_model_path, tokenizer):
|
|
112 |
torch_dtype=model.dtype,
|
113 |
)
|
114 |
model.eos_token_id = tokenizer.eos_token_id
|
115 |
-
|
116 |
return model
|
117 |
|
118 |
|
119 |
tokenizer = transformers.AutoTokenizer.from_pretrained(repo_id)
|
120 |
|
121 |
model = transformers.AutoModelForCausalLM.from_pretrained(
|
122 |
-
base_model, torch_dtype=dtype, trust_remote_code=True,
|
123 |
)
|
124 |
model = load_peft_model(model, repo_id, tokenizer)
|
125 |
|
126 |
|
127 |
# device configuration
|
128 |
model = model.to(device)
|
|
|
|
|
129 |
|
130 |
|
131 |
# Choose Generation parameters
|
@@ -164,6 +167,4 @@ def generate(prompt, generation_config=generation_config, max_new_tokens=2048, d
|
|
164 |
generate("What is a meme, and what's the history behind this word?")
|
165 |
generate("What's the Earth total population")
|
166 |
generate("Write a story about future of AI development")
|
167 |
-
|
168 |
-
|
169 |
```
|
|
|
84 |
base_model = "decapoda-research/llama-30b-hf"
|
85 |
|
86 |
# Model Loading
|
87 |
+
def add_embeddings(model, embed_path, tokenizer):
|
88 |
old_embeddings = model.get_input_embeddings()
|
89 |
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
|
90 |
new_embeddings = torch.nn.Embedding(old_num_tokens, old_embedding_dim)
|
|
|
93 |
embed_weights = torch.load(embed_path, map_location=old_embeddings.weight.device)
|
94 |
vocab_size = tokenizer.vocab_size
|
95 |
new_embeddings.weight.data[:vocab_size, :] = old_embeddings.weight.data[:vocab_size, :]
|
96 |
+
new_embeddings.weight.data[vocab_size : vocab_size + embed_weights.shape[0], :] = embed_weights.to(
|
97 |
new_embeddings.weight.dtype
|
98 |
).to(new_embeddings.weight.device)
|
99 |
model.set_input_embeddings(new_embeddings)
|
100 |
model.tie_weights()
|
101 |
|
102 |
|
103 |
+
|
104 |
def load_peft_model(model, peft_model_path, tokenizer):
|
105 |
embed_weights = hf_hub_download(peft_model_path, "extra_embeddings.pt")
|
106 |
+
model.resize_token_embeddings(tokenizer.vocab_size + torch.load(embed_weights).shape[0])
|
107 |
model.config.eos_token_id = tokenizer.eos_token_id
|
108 |
model.config.bos_token_id = tokenizer.bos_token_id
|
109 |
model.config.pad_token_id = tokenizer.pad_token_id
|
|
|
113 |
torch_dtype=model.dtype,
|
114 |
)
|
115 |
model.eos_token_id = tokenizer.eos_token_id
|
116 |
+
add_embeddings(model, embed_weights, tokenizer)
|
117 |
return model
|
118 |
|
119 |
|
120 |
tokenizer = transformers.AutoTokenizer.from_pretrained(repo_id)
|
121 |
|
122 |
model = transformers.AutoModelForCausalLM.from_pretrained(
|
123 |
+
base_model, torch_dtype=dtype, trust_remote_code=True,
|
124 |
)
|
125 |
model = load_peft_model(model, repo_id, tokenizer)
|
126 |
|
127 |
|
128 |
# device configuration
|
129 |
model = model.to(device)
|
130 |
+
if dtype == torch.float16:
|
131 |
+
model = model.half()
|
132 |
|
133 |
|
134 |
# Choose Generation parameters
|
|
|
167 |
generate("What is a meme, and what's the history behind this word?")
|
168 |
generate("What's the Earth total population")
|
169 |
generate("Write a story about future of AI development")
|
|
|
|
|
170 |
```
|