Tao11 commited on
Commit
7624425
·
unverified ·
2 Parent(s): 18bd00e 5734c31

Fix training bug

Browse files
.gitignore CHANGED
@@ -2,6 +2,8 @@
2
 
3
  wandb/
4
 
 
 
5
  # Byte-compiled / optimized / DLL files
6
  __pycache__/
7
  *.py[cod]
 
2
 
3
  wandb/
4
 
5
+ checkpoints/
6
+
7
  # Byte-compiled / optimized / DLL files
8
  __pycache__/
9
  *.py[cod]
configs/dataset_config.py CHANGED
@@ -51,14 +51,10 @@ visual_datasets = [
51
  language_datasets = [
52
  dict(
53
  type="dolly",
54
- ann_paths=[
55
- "data/dolly/databricks-dolly-15k.jsonl",
56
- ],
57
  ),
58
  dict(
59
  type="alpaca_gpt4",
60
- ann_paths=[
61
- "data/alpaca_gpt4/alpaca_gpt4_data.json",
62
- ],
63
  ),
64
  ]
 
51
  language_datasets = [
52
  dict(
53
  type="dolly",
54
+ ann_path="data/dolly/databricks-dolly-15k.jsonl",
 
 
55
  ),
56
  dict(
57
  type="alpaca_gpt4",
58
+ ann_path="data/alpaca_gpt4/alpaca_gpt4_data.json",
 
 
59
  ),
60
  ]
mmgpt/models/builder.py CHANGED
@@ -1,5 +1,6 @@
1
  from .open_flamingo import create_model_and_transforms as create_open_flamingo_model_and_transforms
2
-
 
3
 
4
  def create_model_and_transforms(
5
  model_name: str,
@@ -24,3 +25,50 @@ def create_model_and_transforms(
24
  # TODO: support BLIP2
25
  else:
26
  raise ValueError(f"Unknown model name: {model_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from .open_flamingo import create_model_and_transforms as create_open_flamingo_model_and_transforms
2
+ import torch.nn as nn
3
+ from transformers import LlamaTokenizer, LlamaForCausalLM
4
 
5
  def create_model_and_transforms(
6
  model_name: str,
 
25
  # TODO: support BLIP2
26
  else:
27
  raise ValueError(f"Unknown model name: {model_name}")
28
+
29
+ # only for debugging
30
+ def create_toy_model_and_transforms(
31
+ model_name: str,
32
+ clip_vision_encoder_path: str,
33
+ clip_vision_encoder_pretrained: str,
34
+ lang_encoder_path: str,
35
+ tokenizer_path: str,
36
+ tuning_config,
37
+ pretrained_model_path,
38
+ **kwargs,
39
+ ):
40
+ print("init toy vision encoder")
41
+ import torchvision
42
+
43
+ image_processor = torchvision.transforms.Compose(
44
+ [
45
+ torchvision.transforms.Resize((224, 224)),
46
+ torchvision.transforms.ToTensor(),
47
+ ]
48
+ )
49
+ print("init tokenizer")
50
+ text_tokenizer = LlamaTokenizer.from_pretrained(tokenizer_path)
51
+ # add Flamingo special tokens to the tokenizer
52
+ text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "<image>"]})
53
+ if text_tokenizer.pad_token is None:
54
+ # Issue: GPT models don't have a pad token, which we use to
55
+ # modify labels for the loss.
56
+ text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
57
+
58
+ class ToyModel(nn.Module):
59
+ def __init__(self, *args, **kwargs):
60
+ super().__init__()
61
+ self.input_embeddings = nn.Embedding(38000, 512)
62
+ self.layer = nn.Linear(512, 512)
63
+ self.config = {"hidden_size": 512}
64
+
65
+ def forward(self, lang_x, **kwargs):
66
+ x = self.input_embeddings(lang_x)
67
+ x = self.layer(x)
68
+ loss = x.sum()
69
+
70
+ return (loss,)
71
+
72
+ model = ToyModel()
73
+
74
+ return model, image_processor, text_tokenizer
mmgpt/train/instruction_finetune.py CHANGED
@@ -21,6 +21,7 @@ from transformers import (
21
  )
22
 
23
  from mmgpt import create_model_and_transforms
 
24
  from mmgpt.datasets import InfiniteSampler, build_dataset
25
  from mmgpt.train.distributed import init_distributed_device, world_info_from_env
26
  from mmgpt.train.train_utils import AverageMeter, get_autocast, get_cast_dtype, get_checkpoint
@@ -185,7 +186,7 @@ def main():
185
  )
186
 
187
  # build language dataset and dataloader for multi-modality training
188
- if dataset_config.get('language_datasets') is not None and len(args.language_datasets) > 0:
189
  lang_dataset = build_dataset(
190
  dataset_config=dataset_config.language_datasets,
191
  tokenizer=tokenizer,
 
21
  )
22
 
23
  from mmgpt import create_model_and_transforms
24
+ from mmgpt.models.builder import create_toy_model_and_transforms
25
  from mmgpt.datasets import InfiniteSampler, build_dataset
26
  from mmgpt.train.distributed import init_distributed_device, world_info_from_env
27
  from mmgpt.train.train_utils import AverageMeter, get_autocast, get_cast_dtype, get_checkpoint
 
186
  )
187
 
188
  # build language dataset and dataloader for multi-modality training
189
+ if dataset_config.get('language_datasets') is not None and len(dataset_config.language_datasets) > 0:
190
  lang_dataset = build_dataset(
191
  dataset_config=dataset_config.language_datasets,
192
  tokenizer=tokenizer,