jordiclive commited on
Commit
2e57dd6
1 Parent(s): 2069da8

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +8 -7
README.md CHANGED
@@ -84,7 +84,7 @@ repo_id = "jordiclive/lora-llama-33B-alpaca_gpt4-dolly_15k-vicuna-r64"
84
  base_model = "decapoda-research/llama-30b-hf"
85
 
86
  # Model Loading
87
- def transfer_embeddings(model, embed_path, tokenizer):
88
  old_embeddings = model.get_input_embeddings()
89
  old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
90
  new_embeddings = torch.nn.Embedding(old_num_tokens, old_embedding_dim)
@@ -93,16 +93,17 @@ def transfer_embeddings(model, embed_path, tokenizer):
93
  embed_weights = torch.load(embed_path, map_location=old_embeddings.weight.device)
94
  vocab_size = tokenizer.vocab_size
95
  new_embeddings.weight.data[:vocab_size, :] = old_embeddings.weight.data[:vocab_size, :]
96
- new_embeddings.weight.data[vocab_size : vocab_size + embed_weights.shape[0], :] = embed_weights.weight.data.to(
97
  new_embeddings.weight.dtype
98
  ).to(new_embeddings.weight.device)
99
  model.set_input_embeddings(new_embeddings)
100
  model.tie_weights()
101
 
102
 
 
103
  def load_peft_model(model, peft_model_path, tokenizer):
104
  embed_weights = hf_hub_download(peft_model_path, "extra_embeddings.pt")
105
- model.resize_token_embeddings(tokenizer.vocab_size + embed_weights.shape[0])
106
  model.config.eos_token_id = tokenizer.eos_token_id
107
  model.config.bos_token_id = tokenizer.bos_token_id
108
  model.config.pad_token_id = tokenizer.pad_token_id
@@ -112,20 +113,22 @@ def load_peft_model(model, peft_model_path, tokenizer):
112
  torch_dtype=model.dtype,
113
  )
114
  model.eos_token_id = tokenizer.eos_token_id
115
- transfer_embeddings(model, Path(peft_model_path).joinpath("extra_embeddings.pt"), tokenizer)
116
  return model
117
 
118
 
119
  tokenizer = transformers.AutoTokenizer.from_pretrained(repo_id)
120
 
121
  model = transformers.AutoModelForCausalLM.from_pretrained(
122
- base_model, torch_dtype=dtype, trust_remote_code=True, cache_dir="/mnt/data/jordiclive/data_cache"
123
  )
124
  model = load_peft_model(model, repo_id, tokenizer)
125
 
126
 
127
  # device configuration
128
  model = model.to(device)
 
 
129
 
130
 
131
  # Choose Generation parameters
@@ -164,6 +167,4 @@ def generate(prompt, generation_config=generation_config, max_new_tokens=2048, d
164
  generate("What is a meme, and what's the history behind this word?")
165
  generate("What's the Earth total population")
166
  generate("Write a story about future of AI development")
167
-
168
-
169
  ```
 
84
  base_model = "decapoda-research/llama-30b-hf"
85
 
86
  # Model Loading
87
+ def add_embeddings(model, embed_path, tokenizer):
88
  old_embeddings = model.get_input_embeddings()
89
  old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
90
  new_embeddings = torch.nn.Embedding(old_num_tokens, old_embedding_dim)
 
93
  embed_weights = torch.load(embed_path, map_location=old_embeddings.weight.device)
94
  vocab_size = tokenizer.vocab_size
95
  new_embeddings.weight.data[:vocab_size, :] = old_embeddings.weight.data[:vocab_size, :]
96
+ new_embeddings.weight.data[vocab_size : vocab_size + embed_weights.shape[0], :] = embed_weights.to(
97
  new_embeddings.weight.dtype
98
  ).to(new_embeddings.weight.device)
99
  model.set_input_embeddings(new_embeddings)
100
  model.tie_weights()
101
 
102
 
103
+
104
  def load_peft_model(model, peft_model_path, tokenizer):
105
  embed_weights = hf_hub_download(peft_model_path, "extra_embeddings.pt")
106
+ model.resize_token_embeddings(tokenizer.vocab_size + torch.load(embed_weights).shape[0])
107
  model.config.eos_token_id = tokenizer.eos_token_id
108
  model.config.bos_token_id = tokenizer.bos_token_id
109
  model.config.pad_token_id = tokenizer.pad_token_id
 
113
  torch_dtype=model.dtype,
114
  )
115
  model.eos_token_id = tokenizer.eos_token_id
116
+ add_embeddings(model, embed_weights, tokenizer)
117
  return model
118
 
119
 
120
  tokenizer = transformers.AutoTokenizer.from_pretrained(repo_id)
121
 
122
  model = transformers.AutoModelForCausalLM.from_pretrained(
123
+ base_model, torch_dtype=dtype, trust_remote_code=True,
124
  )
125
  model = load_peft_model(model, repo_id, tokenizer)
126
 
127
 
128
  # device configuration
129
  model = model.to(device)
130
+ if dtype == torch.float16:
131
+ model = model.half()
132
 
133
 
134
  # Choose Generation parameters
 
167
  generate("What is a meme, and what's the history behind this word?")
168
  generate("What's the Earth total population")
169
  generate("Write a story about future of AI development")
 
 
170
  ```