Spaces:
Build error
Build error
fix training bug
Browse files- .gitignore +2 -0
- configs/dataset_config.py +2 -6
- mmgpt/models/builder.py +49 -1
- mmgpt/train/instruction_finetune.py +2 -1
.gitignore
CHANGED
@@ -2,6 +2,8 @@
|
|
2 |
|
3 |
wandb/
|
4 |
|
|
|
|
|
5 |
# Byte-compiled / optimized / DLL files
|
6 |
__pycache__/
|
7 |
*.py[cod]
|
|
|
2 |
|
3 |
wandb/
|
4 |
|
5 |
+
checkpoints/
|
6 |
+
|
7 |
# Byte-compiled / optimized / DLL files
|
8 |
__pycache__/
|
9 |
*.py[cod]
|
configs/dataset_config.py
CHANGED
@@ -51,14 +51,10 @@ visual_datasets = [
|
|
51 |
language_datasets = [
|
52 |
dict(
|
53 |
type="dolly",
|
54 |
-
|
55 |
-
"data/dolly/databricks-dolly-15k.jsonl",
|
56 |
-
],
|
57 |
),
|
58 |
dict(
|
59 |
type="alpaca_gpt4",
|
60 |
-
|
61 |
-
"data/alpaca_gpt4/alpaca_gpt4_data.json",
|
62 |
-
],
|
63 |
),
|
64 |
]
|
|
|
51 |
language_datasets = [
|
52 |
dict(
|
53 |
type="dolly",
|
54 |
+
ann_path="data/dolly/databricks-dolly-15k.jsonl",
|
|
|
|
|
55 |
),
|
56 |
dict(
|
57 |
type="alpaca_gpt4",
|
58 |
+
ann_path="data/alpaca_gpt4/alpaca_gpt4_data.json",
|
|
|
|
|
59 |
),
|
60 |
]
|
mmgpt/models/builder.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from .open_flamingo import create_model_and_transforms as create_open_flamingo_model_and_transforms
|
2 |
-
|
|
|
3 |
|
4 |
def create_model_and_transforms(
|
5 |
model_name: str,
|
@@ -24,3 +25,50 @@ def create_model_and_transforms(
|
|
24 |
# TODO: support BLIP2
|
25 |
else:
|
26 |
raise ValueError(f"Unknown model name: {model_name}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from .open_flamingo import create_model_and_transforms as create_open_flamingo_model_and_transforms
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import LlamaTokenizer, LlamaForCausalLM
|
4 |
|
5 |
def create_model_and_transforms(
|
6 |
model_name: str,
|
|
|
25 |
# TODO: support BLIP2
|
26 |
else:
|
27 |
raise ValueError(f"Unknown model name: {model_name}")
|
28 |
+
|
29 |
+
# only for debugging
|
30 |
+
def create_toy_model_and_transforms(
|
31 |
+
model_name: str,
|
32 |
+
clip_vision_encoder_path: str,
|
33 |
+
clip_vision_encoder_pretrained: str,
|
34 |
+
lang_encoder_path: str,
|
35 |
+
tokenizer_path: str,
|
36 |
+
tuning_config,
|
37 |
+
pretrained_model_path,
|
38 |
+
**kwargs,
|
39 |
+
):
|
40 |
+
print("init toy vision encoder")
|
41 |
+
import torchvision
|
42 |
+
|
43 |
+
image_processor = torchvision.transforms.Compose(
|
44 |
+
[
|
45 |
+
torchvision.transforms.Resize((224, 224)),
|
46 |
+
torchvision.transforms.ToTensor(),
|
47 |
+
]
|
48 |
+
)
|
49 |
+
print("init tokenizer")
|
50 |
+
text_tokenizer = LlamaTokenizer.from_pretrained(tokenizer_path)
|
51 |
+
# add Flamingo special tokens to the tokenizer
|
52 |
+
text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "<image>"]})
|
53 |
+
if text_tokenizer.pad_token is None:
|
54 |
+
# Issue: GPT models don't have a pad token, which we use to
|
55 |
+
# modify labels for the loss.
|
56 |
+
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
|
57 |
+
|
58 |
+
class ToyModel(nn.Module):
|
59 |
+
def __init__(self, *args, **kwargs):
|
60 |
+
super().__init__()
|
61 |
+
self.input_embeddings = nn.Embedding(38000, 512)
|
62 |
+
self.layer = nn.Linear(512, 512)
|
63 |
+
self.config = {"hidden_size": 512}
|
64 |
+
|
65 |
+
def forward(self, lang_x, **kwargs):
|
66 |
+
x = self.input_embeddings(lang_x)
|
67 |
+
x = self.layer(x)
|
68 |
+
loss = x.sum()
|
69 |
+
|
70 |
+
return (loss,)
|
71 |
+
|
72 |
+
model = ToyModel()
|
73 |
+
|
74 |
+
return model, image_processor, text_tokenizer
|
mmgpt/train/instruction_finetune.py
CHANGED
@@ -21,6 +21,7 @@ from transformers import (
|
|
21 |
)
|
22 |
|
23 |
from mmgpt import create_model_and_transforms
|
|
|
24 |
from mmgpt.datasets import InfiniteSampler, build_dataset
|
25 |
from mmgpt.train.distributed import init_distributed_device, world_info_from_env
|
26 |
from mmgpt.train.train_utils import AverageMeter, get_autocast, get_cast_dtype, get_checkpoint
|
@@ -185,7 +186,7 @@ def main():
|
|
185 |
)
|
186 |
|
187 |
# build language dataset and dataloader for multi-modality training
|
188 |
-
if dataset_config.get('language_datasets') is not None and len(
|
189 |
lang_dataset = build_dataset(
|
190 |
dataset_config=dataset_config.language_datasets,
|
191 |
tokenizer=tokenizer,
|
|
|
21 |
)
|
22 |
|
23 |
from mmgpt import create_model_and_transforms
|
24 |
+
from mmgpt.models.builder import create_toy_model_and_transforms
|
25 |
from mmgpt.datasets import InfiniteSampler, build_dataset
|
26 |
from mmgpt.train.distributed import init_distributed_device, world_info_from_env
|
27 |
from mmgpt.train.train_utils import AverageMeter, get_autocast, get_cast_dtype, get_checkpoint
|
|
|
186 |
)
|
187 |
|
188 |
# build language dataset and dataloader for multi-modality training
|
189 |
+
if dataset_config.get('language_datasets') is not None and len(dataset_config.language_datasets) > 0:
|
190 |
lang_dataset = build_dataset(
|
191 |
dataset_config=dataset_config.language_datasets,
|
192 |
tokenizer=tokenizer,
|