File size: 14,228 Bytes
ed4ea6c 761114f ed4ea6c 761114f ed4ea6c 761114f ed4ea6c 761114f ed4ea6c 761114f ed4ea6c 761114f ed4ea6c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 |
import copy
import json
import os.path
import re
import shutil
import inspect
from typing import Optional, Union
import torch
import torch.nn.functional as F
from transformers import LlamaConfig
from transformers.loss.loss_utils import LOSS_MAPPING
from transformers.modeling_outputs import CausalLMOutput
from transformers.utils.hub import cached_file, get_checkpoint_shard_files
from transformers.utils import (
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
)
from transformers.modeling_utils import unwrap_model, logger
from functools import partial
from safetensors.torch import load_file as safe_load_file
try:
from flash_attn.models.gpt import GPTLMHeadModel
except ImportError:
GPTLMHeadModel = None
try:
from flash_attn.models.llama import llama_config_to_gpt2_config, inv_remap_state_dict_hf_llama
except ImportError:
llama_config_to_gpt2_config = None
inv_remap_state_dict_hf_llama = None
def state_dict_from_pretrained(model_name, checkpoint_path: str = "", device=None, dtype=None, **kwargs):
"""
code modified from: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/pretrained.py
"""
# If not fp32, then we don't want to load directly to the GPU
mapped_device = "cpu" if dtype not in [torch.float32, None] else device
is_sharded = False
load_safe = False
# Try loading from HF hub instead of from local files
resolved_archive_file = cached_file(model_name, os.path.join(checkpoint_path, WEIGHTS_NAME),
_raise_exceptions_for_missing_entries=False, **kwargs)
if resolved_archive_file is None:
resolved_archive_file = cached_file(model_name, os.path.join(checkpoint_path, WEIGHTS_INDEX_NAME),
_raise_exceptions_for_missing_entries=False, **kwargs)
if resolved_archive_file is not None:
is_sharded = True
if resolved_archive_file is None:
raise EnvironmentError(f"Model name {model_name} was not found.")
if load_safe:
loader = partial(safe_load_file, device=mapped_device)
else:
loader = partial(torch.load, map_location=mapped_device)
if is_sharded:
# resolved_archive_file becomes a list of files that point to the different
# checkpoint shards in this case.
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
model_name, resolved_archive_file
)
state_dict = {}
for sharded_file in resolved_archive_file:
state_dict.update(loader(sharded_file))
else:
state_dict = loader(resolved_archive_file)
# Convert dtype before moving to GPU to save memory
if dtype is not None:
state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
return state_dict
class NovoMolGenConfig(LlamaConfig):
# model_type = "NovoMolGen"
def __init__(self,
use_flash_attn: bool = True,
fused_bias_fc: bool = True,
fused_mlp: bool = False,
fused_dropout_add_ln: bool = True,
residual_in_fp32: bool = True,
loss_type: str = 'ForCausalLM',
**kwargs
):
super().__init__(**kwargs)
self.use_flash_attn = use_flash_attn
self.fused_bias_fc = fused_bias_fc
self.fused_mlp = fused_mlp
self.fused_dropout_add_ln = fused_dropout_add_ln
self.residual_in_fp32 = residual_in_fp32
self.loss_type = loss_type
self.auto_map = {"AutoModelForCausalLM": "modeling_novomolgen.NovoMolGen"}
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
checkpoint_path: str = "",
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
**kwargs,
):
resolved_archive_config_file = cached_file(pretrained_model_name_or_path,
os.path.join(checkpoint_path, "config.json"),
_raise_exceptions_for_missing_entries=False, force_download=force_download)
if resolved_archive_config_file is not None:
with open(resolved_archive_config_file, "r", encoding="utf-8") as reader:
text = reader.read()
config_dict = json.loads(text)
else:
raise EnvironmentError(f"config for {pretrained_model_name_or_path} was not found.")
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
print(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class NovoMolGen(GPTLMHeadModel):
def __init__(
self,
config: NovoMolGenConfig,
mol_type: str = "SMILES",
):
self.base_config = config
self.mol_type = mol_type
config = llama_config_to_gpt2_config(config)
config.use_flash_attn = self.base_config.use_flash_attn
config.fused_bias_fc = self.base_config.fused_bias_fc
config.fused_mlp = self.base_config.fused_mlp
config.fused_dropout_add_ln = self.base_config.fused_dropout_add_ln
config.residual_in_fp32 = self.base_config.residual_in_fp32
GPTLMHeadModel.__init__(self, config)
# TODO: here we ignore attention_mask to make it compatible with HF trainer. The MHA in flash-attention should
# be reimplement and integrate attention_mask like here:
# https://github.com/huggingface/transformers/blob/0864dd3beb238b7bec3528a3d1d6c17a28f51a51/src/transformers/models/llama/modeling_llama.py#L536
def forward(self, input_ids, attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None,
position_ids=None, inference_params=None, num_last_tokens=0, **loss_kwargs):
"""
input_ids: (batch, seqlen) int tensor
inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
num_last_tokens: if > 0, only return the logits for the last n tokens
"""
assert (
input_ids.ndim == 2
), f"Expected `input_ids` to have shape [b, slen], but got shape {input_ids.shape}"
b, slen = input_ids.shape
hidden_states = self.transformer(
input_ids, position_ids=position_ids, inference_params=inference_params
)
if inference_params is not None:
assert hidden_states.ndim == 3, "sequence_parallel is not supported in generation mode"
if num_last_tokens > 0:
hidden_states = hidden_states[:, -num_last_tokens:]
if self.project_out is not None:
hidden_states = self.project_out(hidden_states)
if self.output_scale != 1.0:
hidden_states = hidden_states * self.output_scale
if not self.norm_head:
lm_logits = self.lm_head(hidden_states)
else:
lm_head_weight = F.normalize(self.lm_head.weight)
# if isinstance(self.lm_head, ColumnParallelLinear) and self.lm_head.sequence_parallel:
# hidden_states = all_gather(hidden_states, self.lm_head.process_group)
lm_logits = F.linear(hidden_states, lm_head_weight, bias=self.lm_head.bias)
# During inference, we want the full logit for sampling
# if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
# lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)
# lm_logits = rearrange(lm_logits, "(n b) ... d -> b ... (n d)", b=b)
loss = None
if labels is not None:
loss = self.loss_function(logits=lm_logits, labels=labels, vocab_size=self.base_config.vocab_size,
**loss_kwargs)
return CausalLMOutput(
loss=loss,
logits=lm_logits,
hidden_states=hidden_states
)
@property
def loss_function(self):
if getattr(self.base_config, "loss_type", None) is not None:
loss_type = self.base_config.loss_type
else:
loss_type = self.__class__.__name__
if loss_type not in LOSS_MAPPING:
loss_groups = f"({'|'.join(LOSS_MAPPING)})"
loss_type = re.findall(loss_groups, self.__class__.__name__)
if len(loss_type) > 0:
loss_type = loss_type[0]
else:
loss_type = None
if loss_type is None or loss_type not in LOSS_MAPPING and getattr(self.base_config, "loss_type",
None) is not None:
print(
f"`loss_type={loss_type}` was set in the base_config but it is unrecognised."
f"Using the default loss: `ForCausalLMLoss`."
)
loss_type = "ForCausalLM"
return LOSS_MAPPING[loss_type]
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
state_dict: Optional[dict] = None,
safe_serialization: bool = False,
**kwargs,
):
if safe_serialization:
raise ImportError("`safe_serialization` is not implemented yet`.")
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
os.makedirs(save_directory, exist_ok=True)
# Save the config
if is_main_process:
self.base_config.save_pretrained(save_directory)
# Save the model
if state_dict is None:
# Only save the model itself if we are using distributed training
model_to_save = unwrap_model(self)
state_dict = model_to_save.state_dict()
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
torch.save(state_dict, os.path.join(save_directory, weights_name))
# find the file where NovoMolGen is defined
src = inspect.getsourcefile(type(self))
if src:
dst = os.path.join(save_directory, os.path.basename(src))
shutil.copy(src, dst)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path,
checkpoint_path: str = "",
config: Optional[Union[NovoMolGenConfig, str, os.PathLike]] = None,
**kwargs,
):
if config is None:
config = NovoMolGenConfig.from_pretrained(pretrained_model_name_or_path, checkpoint_path=checkpoint_path, **kwargs)
model = cls(config)
if os.path.exists(pretrained_model_name_or_path):
state_dict = torch.load(os.path.join(pretrained_model_name_or_path, checkpoint_path, WEIGHTS_NAME))
else:
state_dict = state_dict_from_pretrained(pretrained_model_name_or_path, checkpoint_path=checkpoint_path, **kwargs)
model.load_state_dict(state_dict)
return model
def sample(
self,
tokenizer,
batch_size: int = 4,
max_length: int = 64,
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 0.95,
device: torch.device = torch.device("cuda"),
):
"""
Generate a batch of sequences from the model.
Returns a dictionary with up to three keys:
{
"<mol_type>": <list of raw sequences in that moltype>,
"sequences": <torch.LongTensor of valid token IDs>
}
"""
input_ids = tokenizer.encode("", return_tensors="pt").to(device)
# Repeat the prompt for the desired batch size
input_ids = input_ids.repeat_interleave(batch_size, dim=0)
# If the tokenizer includes an EOS token for an empty prompt, we remove it.
if input_ids.shape[1] > 1:
input_ids = input_ids[:, :-1]
generation_output = self.generate(
input_ids,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
eos_token_id=tokenizer.eos_token_id,
return_dict_in_generate=True,
)
sequences = self._filter_tokens_after_eos(
generation_output.sequences, eos_id=tokenizer.eos_token_id
)
decoded_strings = tokenizer.batch_decode(sequences, skip_special_tokens=True)
decoded_strings = [s.replace(" ", "") for s in decoded_strings]
result = {
self.mol_type: decoded_strings,
"sequences": sequences,
}
return result
@staticmethod
def _filter_tokens_after_eos(sequences, eos_id):
output = copy.deepcopy(sequences)
for i in range(sequences.size(0)):
row = sequences[i]
eos_position = (row == eos_id).nonzero()
if eos_position.numel() > 0:
eos_position = eos_position[0, 0].item() # Get the index of the first occurrence
output[i, eos_position + 1:] = eos_id
return output
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs):
# HF’s GenerationMixin would normally do more, but for a basic LM this usually suffices:
return {"input_ids": input_ids, "attention_mask": attention_mask}
|