Spaces:
Build error
Build error
fix training bug
Browse files- .gitignore +2 -0
- configs/dataset_config.py +2 -6
- mmgpt/models/builder.py +49 -1
- mmgpt/train/instruction_finetune.py +2 -1
.gitignore
CHANGED
|
@@ -2,6 +2,8 @@
|
|
| 2 |
|
| 3 |
wandb/
|
| 4 |
|
|
|
|
|
|
|
| 5 |
# Byte-compiled / optimized / DLL files
|
| 6 |
__pycache__/
|
| 7 |
*.py[cod]
|
|
|
|
| 2 |
|
| 3 |
wandb/
|
| 4 |
|
| 5 |
+
checkpoints/
|
| 6 |
+
|
| 7 |
# Byte-compiled / optimized / DLL files
|
| 8 |
__pycache__/
|
| 9 |
*.py[cod]
|
configs/dataset_config.py
CHANGED
|
@@ -51,14 +51,10 @@ visual_datasets = [
|
|
| 51 |
language_datasets = [
|
| 52 |
dict(
|
| 53 |
type="dolly",
|
| 54 |
-
|
| 55 |
-
"data/dolly/databricks-dolly-15k.jsonl",
|
| 56 |
-
],
|
| 57 |
),
|
| 58 |
dict(
|
| 59 |
type="alpaca_gpt4",
|
| 60 |
-
|
| 61 |
-
"data/alpaca_gpt4/alpaca_gpt4_data.json",
|
| 62 |
-
],
|
| 63 |
),
|
| 64 |
]
|
|
|
|
| 51 |
language_datasets = [
|
| 52 |
dict(
|
| 53 |
type="dolly",
|
| 54 |
+
ann_path="data/dolly/databricks-dolly-15k.jsonl",
|
|
|
|
|
|
|
| 55 |
),
|
| 56 |
dict(
|
| 57 |
type="alpaca_gpt4",
|
| 58 |
+
ann_path="data/alpaca_gpt4/alpaca_gpt4_data.json",
|
|
|
|
|
|
|
| 59 |
),
|
| 60 |
]
|
mmgpt/models/builder.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
from .open_flamingo import create_model_and_transforms as create_open_flamingo_model_and_transforms
|
| 2 |
-
|
|
|
|
| 3 |
|
| 4 |
def create_model_and_transforms(
|
| 5 |
model_name: str,
|
|
@@ -24,3 +25,50 @@ def create_model_and_transforms(
|
|
| 24 |
# TODO: support BLIP2
|
| 25 |
else:
|
| 26 |
raise ValueError(f"Unknown model name: {model_name}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from .open_flamingo import create_model_and_transforms as create_open_flamingo_model_and_transforms
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from transformers import LlamaTokenizer, LlamaForCausalLM
|
| 4 |
|
| 5 |
def create_model_and_transforms(
|
| 6 |
model_name: str,
|
|
|
|
| 25 |
# TODO: support BLIP2
|
| 26 |
else:
|
| 27 |
raise ValueError(f"Unknown model name: {model_name}")
|
| 28 |
+
|
| 29 |
+
# only for debugging
|
| 30 |
+
def create_toy_model_and_transforms(
|
| 31 |
+
model_name: str,
|
| 32 |
+
clip_vision_encoder_path: str,
|
| 33 |
+
clip_vision_encoder_pretrained: str,
|
| 34 |
+
lang_encoder_path: str,
|
| 35 |
+
tokenizer_path: str,
|
| 36 |
+
tuning_config,
|
| 37 |
+
pretrained_model_path,
|
| 38 |
+
**kwargs,
|
| 39 |
+
):
|
| 40 |
+
print("init toy vision encoder")
|
| 41 |
+
import torchvision
|
| 42 |
+
|
| 43 |
+
image_processor = torchvision.transforms.Compose(
|
| 44 |
+
[
|
| 45 |
+
torchvision.transforms.Resize((224, 224)),
|
| 46 |
+
torchvision.transforms.ToTensor(),
|
| 47 |
+
]
|
| 48 |
+
)
|
| 49 |
+
print("init tokenizer")
|
| 50 |
+
text_tokenizer = LlamaTokenizer.from_pretrained(tokenizer_path)
|
| 51 |
+
# add Flamingo special tokens to the tokenizer
|
| 52 |
+
text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "<image>"]})
|
| 53 |
+
if text_tokenizer.pad_token is None:
|
| 54 |
+
# Issue: GPT models don't have a pad token, which we use to
|
| 55 |
+
# modify labels for the loss.
|
| 56 |
+
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
|
| 57 |
+
|
| 58 |
+
class ToyModel(nn.Module):
|
| 59 |
+
def __init__(self, *args, **kwargs):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.input_embeddings = nn.Embedding(38000, 512)
|
| 62 |
+
self.layer = nn.Linear(512, 512)
|
| 63 |
+
self.config = {"hidden_size": 512}
|
| 64 |
+
|
| 65 |
+
def forward(self, lang_x, **kwargs):
|
| 66 |
+
x = self.input_embeddings(lang_x)
|
| 67 |
+
x = self.layer(x)
|
| 68 |
+
loss = x.sum()
|
| 69 |
+
|
| 70 |
+
return (loss,)
|
| 71 |
+
|
| 72 |
+
model = ToyModel()
|
| 73 |
+
|
| 74 |
+
return model, image_processor, text_tokenizer
|
mmgpt/train/instruction_finetune.py
CHANGED
|
@@ -21,6 +21,7 @@ from transformers import (
|
|
| 21 |
)
|
| 22 |
|
| 23 |
from mmgpt import create_model_and_transforms
|
|
|
|
| 24 |
from mmgpt.datasets import InfiniteSampler, build_dataset
|
| 25 |
from mmgpt.train.distributed import init_distributed_device, world_info_from_env
|
| 26 |
from mmgpt.train.train_utils import AverageMeter, get_autocast, get_cast_dtype, get_checkpoint
|
|
@@ -185,7 +186,7 @@ def main():
|
|
| 185 |
)
|
| 186 |
|
| 187 |
# build language dataset and dataloader for multi-modality training
|
| 188 |
-
if dataset_config.get('language_datasets') is not None and len(
|
| 189 |
lang_dataset = build_dataset(
|
| 190 |
dataset_config=dataset_config.language_datasets,
|
| 191 |
tokenizer=tokenizer,
|
|
|
|
| 21 |
)
|
| 22 |
|
| 23 |
from mmgpt import create_model_and_transforms
|
| 24 |
+
from mmgpt.models.builder import create_toy_model_and_transforms
|
| 25 |
from mmgpt.datasets import InfiniteSampler, build_dataset
|
| 26 |
from mmgpt.train.distributed import init_distributed_device, world_info_from_env
|
| 27 |
from mmgpt.train.train_utils import AverageMeter, get_autocast, get_cast_dtype, get_checkpoint
|
|
|
|
| 186 |
)
|
| 187 |
|
| 188 |
# build language dataset and dataloader for multi-modality training
|
| 189 |
+
if dataset_config.get('language_datasets') is not None and len(dataset_config.language_datasets) > 0:
|
| 190 |
lang_dataset = build_dataset(
|
| 191 |
dataset_config=dataset_config.language_datasets,
|
| 192 |
tokenizer=tokenizer,
|