tuandunghcmut commited on
Commit
e7887f2
·
verified ·
1 Parent(s): 0b87f0f

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. InternVL/.github/ISSUE_TEMPLATE/1-bug-report.yml +54 -0
  2. InternVL/.github/ISSUE_TEMPLATE/2-feature-request.yml +31 -0
  3. InternVL/.github/ISSUE_TEMPLATE/3-documentation.yml +23 -0
  4. InternVL/internvl_g/eval/evaluate_caption.py +237 -0
  5. InternVL/internvl_g/internvl/dist_utils.py +101 -0
  6. InternVL/internvl_g/internvl/model/__init__.py +0 -0
  7. InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/__init__.py +87 -0
  8. InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/modeling_intern_vit.py +342 -0
  9. InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/modeling_internvl.py +669 -0
  10. InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/modeling_qllama.py +1073 -0
  11. InternVL/internvl_g/internvl/train/__init__.py +0 -0
  12. InternVL/internvl_g/internvl/train/dataset.py +283 -0
  13. InternVL/internvl_g/internvl/train/internvl_stage2_finetune.py +286 -0
  14. InternVL/internvl_g/internvl/train/trainer_monkey_patch.py +150 -0
  15. InternVL/internvl_g/shell/finetune/internvl_stage2_finetune_coco_364_bs1024_ep5.sh +58 -0
  16. InternVL/internvl_g/shell/finetune/internvl_stage2_finetune_flickr_364_bs1024_ep10.sh +58 -0
  17. InternVL/internvl_g/shell/finetune/internvl_stage2_finetune_flickrcn_364_bs1024_ep10.sh +58 -0
  18. InternVL/internvl_g/shell/head_finetune/internvl_stage2_finetune_coco_224_bs1024_ep5_head_4gpu.sh +59 -0
  19. InternVL/internvl_g/shell/head_finetune/internvl_stage2_finetune_flickr_224_bs1024_ep10_head_4gpu.sh +59 -0
  20. InternVL/internvl_g/shell/head_finetune/internvl_stage2_finetune_flickrcn_224_bs1024_ep10_head_4gpu.sh +59 -0
  21. InternVL/internvl_g/shell/lora_finetune/internvl_stage2_finetune_coco_224_bs1024_ep5_lora16_4gpu.sh +61 -0
  22. InternVL/internvl_g/shell/lora_finetune/internvl_stage2_finetune_flickr_224_bs1024_ep10_lora16_4gpu.sh +61 -0
  23. InternVL/internvl_g/shell/lora_finetune/internvl_stage2_finetune_flickrcn_224_bs1024_ep10_lora16_4gpu.sh +61 -0
  24. InternVL/segmentation/configs/_base_/datasets/ade20k_504x504.py +56 -0
  25. InternVL/segmentation/configs/_base_/datasets/ade20k_504x504_1of16.py +56 -0
  26. InternVL/segmentation/configs/_base_/datasets/cityscapes_1024x1024.py +35 -0
  27. InternVL/segmentation/configs/_base_/models/apcnet_r50-d8.py +44 -0
  28. InternVL/segmentation/configs/_base_/models/bisenetv1_r18-d32.py +68 -0
  29. InternVL/segmentation/configs/_base_/models/danet_r50-d8.py +44 -0
  30. InternVL/segmentation/configs/_base_/models/deeplabv3plus_r50-d8.py +46 -0
  31. InternVL/segmentation/configs/_base_/models/dmnet_r50-d8.py +44 -0
  32. InternVL/segmentation/configs/_base_/models/encnet_r50-d8.py +48 -0
  33. InternVL/segmentation/configs/_base_/models/erfnet_fcn.py +32 -0
  34. InternVL/segmentation/configs/_base_/models/fastfcn_r50-d32_jpu_psp.py +53 -0
  35. InternVL/segmentation/configs/_base_/models/fcn_hr18.py +52 -0
  36. InternVL/segmentation/configs/_base_/models/fpn_r50.py +36 -0
  37. InternVL/segmentation/configs/_base_/models/isanet_r50-d8.py +45 -0
  38. InternVL/segmentation/configs/_base_/models/lraspp_m-v3-d8.py +25 -0
  39. InternVL/segmentation/configs/_base_/models/pointrend_r50.py +56 -0
  40. InternVL/segmentation/configs/_base_/models/pspnet_unet_s5-d16.py +50 -0
  41. InternVL/segmentation/configs/_base_/models/upernet_r50.py +44 -0
  42. InternVL/segmentation/configs/_base_/schedules/schedule_10k.py +9 -0
  43. InternVL/segmentation/configs/_base_/schedules/schedule_160k.py +9 -0
  44. InternVL/segmentation/configs/_base_/schedules/schedule_20k.py +9 -0
  45. InternVL/segmentation/configs/_base_/schedules/schedule_320k.py +9 -0
  46. InternVL/segmentation/configs/_base_/schedules/schedule_40k.py +9 -0
  47. InternVL/segmentation/configs/_base_/schedules/schedule_5k.py +9 -0
  48. InternVL/segmentation/configs/_base_/schedules/schedule_80k.py +9 -0
  49. InternVL/segmentation/configs/intern_vit_6b/few_shot/linear_intern_vit_6b_504_10k_ade20k_bs16_lr4e-5_1of8.py +72 -0
  50. InternVL/segmentation/configs/intern_vit_6b/few_shot/linear_intern_vit_6b_504_20k_ade20k_bs16_lr4e-5_1of4.py +72 -0
InternVL/.github/ISSUE_TEMPLATE/1-bug-report.yml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: 🐞 Bug report
2
+ description: Create a report to help us reproduce and fix the bug
3
+ title: "[Bug] "
4
+ labels: ['Bug']
5
+
6
+ body:
7
+ - type: checkboxes
8
+ attributes:
9
+ label: Checklist
10
+ options:
11
+ - label: 1. I have searched related issues but cannot get the expected help.
12
+ - label: 2. The bug has not been fixed in the latest version.
13
+ - label: 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
14
+ - type: textarea
15
+ attributes:
16
+ label: Describe the bug
17
+ description: A clear and concise description of what the bug is.
18
+ validations:
19
+ required: true
20
+ - type: textarea
21
+ attributes:
22
+ label: Reproduction
23
+ description: |
24
+ 1. What command or script did you run?
25
+ placeholder: |
26
+ A placeholder for the command.
27
+ validations:
28
+ required: true
29
+ - type: textarea
30
+ attributes:
31
+ label: Environment
32
+ description: |
33
+ 1. Please run `lmdeploy check_env` to collect necessary environment information and paste it here.
34
+ 2. You may add addition that may be helpful for locating the problem, such as
35
+ - Which **model** are you using?
36
+ - How you installed PyTorch \[e.g., pip, conda, source\]
37
+ - Other environment variables that may be related (such as `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH`, etc.)
38
+ placeholder: Environment here.
39
+ render: Shell
40
+ validations:
41
+ required: true
42
+ - type: textarea
43
+ attributes:
44
+ label: Error traceback
45
+ description: |
46
+ If applicable, paste the error trackback here.
47
+ placeholder: Logs and traceback here.
48
+ render: Shell
49
+ - type: markdown
50
+ attributes:
51
+ value: >
52
+ If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!
53
+
54
+ Thanks for your bug report. We appreciate it a lot.
InternVL/.github/ISSUE_TEMPLATE/2-feature-request.yml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: 🚀 Feature request
2
+ description: Suggest an idea for this project
3
+ title: "[Feature] "
4
+
5
+ body:
6
+ - type: markdown
7
+ attributes:
8
+ value: |
9
+ We strongly appreciate you creating a PR to implement this feature [here](https://github.com/OpenGVLab/InternVL/pulls)!
10
+ If you need our help, please fill in as much of the following form as you're able to.
11
+
12
+ **The less clear the description, the longer it will take to solve it.**
13
+ - type: textarea
14
+ attributes:
15
+ label: Motivation
16
+ description: |
17
+ A clear and concise description of the motivation of the feature.
18
+ Ex1. It is inconvenient when \[....\].
19
+ validations:
20
+ required: true
21
+ - type: textarea
22
+ attributes:
23
+ label: Related resources
24
+ description: |
25
+ If there is an official code release or third-party implementations, please also provide the information here, which would be very helpful.
26
+ - type: textarea
27
+ attributes:
28
+ label: Additional context
29
+ description: |
30
+ Add any other context or screenshots about the feature request here.
31
+ If you would like to implement the feature and create a PR, please leave a comment here and that would be much appreciated.
InternVL/.github/ISSUE_TEMPLATE/3-documentation.yml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: 📚 Documentation
2
+ description: Report an issue related to the documentation.
3
+ labels: "kind/doc,status/unconfirmed"
4
+ title: "[Docs] "
5
+
6
+ body:
7
+ - type: textarea
8
+ attributes:
9
+ label: 📚 The doc issue
10
+ description: >
11
+ A clear and concise description the issue.
12
+ validations:
13
+ required: true
14
+
15
+ - type: textarea
16
+ attributes:
17
+ label: Suggest a potential alternative/fix
18
+ description: >
19
+ Tell us how we could improve the documentation in this regard.
20
+ - type: markdown
21
+ attributes:
22
+ value: >
23
+ Thanks for contributing 🎉!
InternVL/internvl_g/eval/evaluate_caption.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import itertools
3
+ import json
4
+ import os
5
+ import random
6
+ import time
7
+ from functools import partial
8
+
9
+ import torch
10
+ import torchvision.transforms as T
11
+ from internvl.model.internvl_stage2 import InternVLConfig, InternVLModel
12
+ from PIL import Image
13
+ from pycocoevalcap.eval import COCOEvalCap
14
+ from pycocotools.coco import COCO
15
+ from torchvision.transforms.functional import InterpolationMode
16
+ from tqdm import tqdm
17
+ from transformers import LlamaTokenizer
18
+
19
+ ds_collections = {
20
+ 'flickr30k': {
21
+ 'root': 'data/flickr30k/',
22
+ 'annotation': 'data/flickr30k/flickr30k_test_karpathy.json',
23
+ },
24
+ 'coco': {
25
+ 'root': 'data/coco/',
26
+ 'annotation': ['data/coco/annotations/coco_karpathy_test.json',
27
+ 'data/coco/annotations/coco_karpathy_test_gt.json'],
28
+ },
29
+ 'nocaps': {
30
+ 'root': 'data/nocaps/images',
31
+ 'annotation': 'data/nocaps/nocaps_val_4500_captions.json',
32
+ },
33
+ }
34
+
35
+
36
+ class CaptionDataset(torch.utils.data.Dataset):
37
+
38
+ def __init__(self, name, root, annotation, prompt, input_size=224):
39
+ if name == 'coco':
40
+ self.images = json.load(open(annotation))
41
+ else:
42
+ self.images = json.load(open(annotation))['images']
43
+ self.name = name
44
+ self.prompt = prompt
45
+ self.root = root
46
+ self.transform = T.Compose([
47
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
48
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
49
+ T.ToTensor(),
50
+ T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
51
+ ])
52
+
53
+ def __len__(self):
54
+ return len(self.images)
55
+
56
+ def __getitem__(self, idx):
57
+ if self.name == 'coco':
58
+ filename = self.images[idx]['image']
59
+ image_id = int(filename.split('_')[-1].replace('.jpg', ''))
60
+ image_path = os.path.join(self.root, filename)
61
+ else:
62
+ image_id = self.images[idx]['id']
63
+ if 'file_name' in self.images[idx]:
64
+ image_path = os.path.join(self.root, self.images[idx]['file_name'])
65
+ else:
66
+ image_path = os.path.join(self.root, self.images[idx]['image'])
67
+ image = Image.open(image_path)
68
+ pixel_values = self.transform(image).unsqueeze(0)
69
+ return {
70
+ 'image_id': image_id,
71
+ 'input_text': self.prompt,
72
+ 'pixel_values': pixel_values
73
+ }
74
+
75
+
76
+ def collate_fn(inputs, tokenizer):
77
+ pixel_values = torch.cat([_['pixel_values'] for _ in inputs], dim=0)
78
+ image_ids = [_['image_id'] for _ in inputs]
79
+ input_texts = [_['input_text'] for _ in inputs]
80
+ input_tokens = tokenizer(input_texts, return_tensors='pt')
81
+
82
+ return pixel_values, image_ids, input_tokens.input_ids, input_tokens.attention_mask
83
+
84
+
85
+ class InferenceSampler(torch.utils.data.sampler.Sampler):
86
+
87
+ def __init__(self, size):
88
+ self._size = int(size)
89
+ assert size > 0
90
+ self._rank = torch.distributed.get_rank()
91
+ self._world_size = torch.distributed.get_world_size()
92
+ self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
93
+
94
+ @staticmethod
95
+ def _get_local_indices(total_size, world_size, rank):
96
+ shard_size = total_size // world_size
97
+ left = total_size % world_size
98
+ shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
99
+
100
+ begin = sum(shard_sizes[:rank])
101
+ end = min(sum(shard_sizes[:rank + 1]), total_size)
102
+ return range(begin, end)
103
+
104
+ def __iter__(self):
105
+ yield from self._local_indices
106
+
107
+ def __len__(self):
108
+ return len(self._local_indices)
109
+
110
+
111
+ def evaluate_qllama_model():
112
+ prompts = ['English caption:']
113
+ print('prompts:', prompts)
114
+
115
+ config = InternVLConfig.from_pretrained(args.checkpoint)
116
+ model = InternVLModel.from_pretrained(args.checkpoint, config=config).eval()
117
+ model = model.to(torch.float16).cuda()
118
+ tokenizer = LlamaTokenizer.from_pretrained(args.checkpoint)
119
+ tokenizer.add_eos_token = False
120
+
121
+ random.seed(args.seed)
122
+ summaries = []
123
+ for prompt in prompts:
124
+ for ds_name in args.datasets:
125
+ annotation = ds_collections[ds_name]['annotation']
126
+ if type(annotation) == list:
127
+ annotation = annotation[0]
128
+ if model.config.force_image_size is not None:
129
+ image_size = model.config.force_image_size
130
+ else:
131
+ image_size = model.config.vision_config.image_size
132
+ dataset = CaptionDataset(
133
+ name=ds_name,
134
+ root=ds_collections[ds_name]['root'],
135
+ annotation=annotation,
136
+ prompt=prompt,
137
+ input_size=image_size,
138
+ )
139
+ dataloader = torch.utils.data.DataLoader(
140
+ dataset=dataset,
141
+ sampler=InferenceSampler(len(dataset)),
142
+ batch_size=args.batch_size,
143
+ num_workers=args.num_workers,
144
+ pin_memory=True,
145
+ drop_last=False,
146
+ collate_fn=partial(collate_fn, tokenizer=tokenizer),
147
+ )
148
+
149
+ image_ids, captions = [], []
150
+ for _, (pixel_values, ids, input_ids, attention_mask) in tqdm(enumerate(dataloader)):
151
+ pred = model.generate(
152
+ pixel_values=pixel_values.cuda().to(torch.float16),
153
+ input_ids=input_ids.cuda(),
154
+ attention_mask=attention_mask.cuda(),
155
+ do_sample=False,
156
+ num_beams=args.num_beams,
157
+ max_new_tokens=30,
158
+ min_new_tokens=8,
159
+ use_cache=True
160
+ )
161
+ image_ids.extend(ids)
162
+ caption = [tokenizer.decode(_.cpu(), skip_special_tokens=True).strip() for _ in pred]
163
+ captions.extend(caption)
164
+ print(caption)
165
+
166
+ torch.distributed.barrier()
167
+
168
+ world_size = torch.distributed.get_world_size()
169
+ merged_ids = [None for _ in range(world_size)]
170
+ merged_captions = [None for _ in range(world_size)]
171
+ torch.distributed.all_gather_object(merged_ids, image_ids)
172
+ torch.distributed.all_gather_object(merged_captions, captions)
173
+
174
+ merged_ids = [_ for _ in itertools.chain.from_iterable(merged_ids)]
175
+ merged_captions = [_ for _ in itertools.chain.from_iterable(merged_captions)]
176
+ average_length = sum(len(x.split()) for x in merged_captions) / len(merged_captions)
177
+ print(f'Average length: {average_length}')
178
+
179
+ if torch.distributed.get_rank() == 0:
180
+ print(f'Evaluating {ds_name} ...')
181
+
182
+ results = []
183
+ for image_id, caption in zip(merged_ids, merged_captions):
184
+ results.append({
185
+ 'image_id': int(image_id),
186
+ 'caption': caption,
187
+ })
188
+ time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
189
+ results_file = f'{ds_name}_{time_prefix}.json'
190
+ results_file = os.path.join(args.out_dir, results_file)
191
+ json.dump(results, open(results_file, 'w'))
192
+
193
+ annotation = ds_collections[ds_name]['annotation']
194
+ if type(annotation) == list:
195
+ annotation = annotation[-1]
196
+ coco = COCO(annotation)
197
+ coco_result = coco.loadRes(results_file)
198
+ coco_eval = COCOEvalCap(coco, coco_result)
199
+ coco_eval.evaluate()
200
+
201
+ summary = coco_eval.eval.items()
202
+ print([ds_name, prompt, average_length, summary])
203
+ summaries.append([ds_name, prompt, average_length, summary])
204
+
205
+ torch.distributed.barrier()
206
+
207
+ for summary in summaries:
208
+ print(summary)
209
+
210
+
211
+ if __name__ == '__main__':
212
+
213
+ parser = argparse.ArgumentParser()
214
+ parser.add_argument('--checkpoint', type=str, default='')
215
+ parser.add_argument('--datasets', type=str, default='coco,flickr30k,nocaps')
216
+ parser.add_argument('--batch-size', type=int, default=1)
217
+ parser.add_argument('--num-workers', type=int, default=1)
218
+ parser.add_argument('--num-beams', type=int, default=5)
219
+ parser.add_argument('--out-dir', type=str, default='results')
220
+ parser.add_argument('--seed', type=int, default=0)
221
+ args = parser.parse_args()
222
+
223
+ os.makedirs(args.out_dir, exist_ok=True)
224
+
225
+ args.datasets = args.datasets.split(',')
226
+ print('datasets:', args.datasets)
227
+ assert args.batch_size == 1, 'Only batch size 1 is supported'
228
+
229
+ torch.distributed.init_process_group(
230
+ backend='nccl',
231
+ world_size=int(os.getenv('WORLD_SIZE', '1')),
232
+ rank=int(os.getenv('RANK', '0')),
233
+ )
234
+
235
+ torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
236
+
237
+ evaluate_qllama_model()
InternVL/internvl_g/internvl/dist_utils.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import socket
3
+ import subprocess
4
+ from datetime import timedelta
5
+
6
+ import torch
7
+ import torch.multiprocessing as mp
8
+ from torch import distributed as dist
9
+
10
+ timeout = timedelta(minutes=60)
11
+
12
+
13
+ def _find_free_port():
14
+ # Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501
15
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
16
+ # Binding to port 0 will cause the OS to find an available port for us
17
+ sock.bind(('', 0))
18
+ port = sock.getsockname()[1]
19
+ sock.close()
20
+ # NOTE: there is still a chance the port could be taken by other processes.
21
+ return port
22
+
23
+
24
+ def _is_free_port(port):
25
+ ips = socket.gethostbyname_ex(socket.gethostname())[-1]
26
+ ips.append('localhost')
27
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
28
+ return all(s.connect_ex((ip, port)) != 0 for ip in ips)
29
+
30
+
31
+ def init_dist(launcher, backend='nccl', **kwargs):
32
+ if mp.get_start_method(allow_none=True) is None:
33
+ mp.set_start_method('spawn')
34
+ if launcher == 'pytorch':
35
+ _init_dist_pytorch(backend, **kwargs)
36
+ elif launcher == 'mpi':
37
+ _init_dist_mpi(backend, **kwargs)
38
+ elif launcher == 'slurm':
39
+ _init_dist_slurm(backend, **kwargs)
40
+ else:
41
+ raise ValueError(f'Invalid launcher type: {launcher}')
42
+
43
+
44
+ def _init_dist_pytorch(backend, **kwargs):
45
+ # TODO: use local_rank instead of rank % num_gpus
46
+ rank = int(os.environ['RANK'])
47
+ num_gpus = torch.cuda.device_count()
48
+ torch.cuda.set_device(rank % num_gpus)
49
+ dist.init_process_group(backend=backend, **kwargs)
50
+
51
+
52
+ def _init_dist_mpi(backend, **kwargs):
53
+ local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
54
+ torch.cuda.set_device(local_rank)
55
+ if 'MASTER_PORT' not in os.environ:
56
+ # 29500 is torch.distributed default port
57
+ os.environ['MASTER_PORT'] = '29500'
58
+ if 'MASTER_ADDR' not in os.environ:
59
+ raise KeyError('The environment variable MASTER_ADDR is not set')
60
+ os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
61
+ os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
62
+ dist.init_process_group(backend=backend, **kwargs)
63
+
64
+
65
+ def _init_dist_slurm(backend, port=None):
66
+ """Initialize slurm distributed training environment.
67
+
68
+ If argument ``port`` is not specified, then the master port will be system
69
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
70
+ environment variable, then a default port ``29500`` will be used.
71
+
72
+ Args:
73
+ backend (str): Backend of torch.distributed.
74
+ port (int, optional): Master port. Defaults to None.
75
+ """
76
+ proc_id = int(os.environ['SLURM_PROCID'])
77
+ ntasks = int(os.environ['SLURM_NTASKS'])
78
+ node_list = os.environ['SLURM_NODELIST']
79
+ num_gpus = torch.cuda.device_count()
80
+ torch.cuda.set_device(proc_id % num_gpus)
81
+ addr = subprocess.getoutput(
82
+ f'scontrol show hostname {node_list} | head -n1')
83
+ # specify master port
84
+ if port is not None:
85
+ os.environ['MASTER_PORT'] = str(port)
86
+ elif 'MASTER_PORT' in os.environ:
87
+ pass # use MASTER_PORT in the environment variable
88
+ else:
89
+ # if torch.distributed default port(29500) is available
90
+ # then use it, else find a free port
91
+ if _is_free_port(29500):
92
+ os.environ['MASTER_PORT'] = '29500'
93
+ else:
94
+ os.environ['MASTER_PORT'] = str(_find_free_port())
95
+ # use MASTER_ADDR in the environment variable if it already exists
96
+ if 'MASTER_ADDR' not in os.environ:
97
+ os.environ['MASTER_ADDR'] = addr
98
+ os.environ['WORLD_SIZE'] = str(ntasks)
99
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
100
+ os.environ['RANK'] = str(proc_id)
101
+ dist.init_process_group(backend=backend, timeout=timeout)
InternVL/internvl_g/internvl/model/__init__.py ADDED
File without changes
InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/__init__.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2023 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torchvision.transforms as T
10
+ from torchvision.transforms import InterpolationMode
11
+ from transformers import LlamaTokenizer
12
+
13
+ from .configuration_intern_vit import InternVisionConfig
14
+ from .configuration_internvl import InternVLConfig
15
+ from .modeling_intern_vit import InternVisionModel
16
+ from .modeling_internvl import InternVL_C, InternVL_G, InternVLModel
17
+
18
+ __all__ = ['InternVisionConfig', 'InternVisionModel', 'InternVLConfig',
19
+ 'InternVLModel', 'InternVL_C', 'InternVL_G']
20
+
21
+
22
+ # Prefix the text "summarize:"
23
+ class InternVLTokenizer(nn.Module):
24
+ def __init__(self, model_path):
25
+ super(InternVLTokenizer, self).__init__()
26
+ self.tokenizer = LlamaTokenizer.from_pretrained(model_path)
27
+ self.tokenizer.pad_token = ' ' # allow padding
28
+ self.tokenizer.add_eos_token = True
29
+
30
+ def forward(self, text, prefix='summarize:'):
31
+ if type(text) == str:
32
+ text = prefix + text
33
+ elif type(text) == list:
34
+ text = [prefix + item for item in text]
35
+ text = self.tokenizer(text, return_tensors='pt', max_length=80, truncation=True, padding='max_length').input_ids
36
+ return text
37
+
38
+
39
+ def build_transform(task, image_size=224, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
40
+ if task == 'retrieval':
41
+ transform = T.Compose([
42
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
43
+ T.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
44
+ T.ToTensor(),
45
+ T.Normalize(mean=mean, std=std)])
46
+ else:
47
+ transform = T.Compose([
48
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
49
+ T.Resize(image_size, interpolation=InterpolationMode.BICUBIC),
50
+ T.CenterCrop(image_size),
51
+ T.ToTensor(),
52
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
53
+ return transform
54
+
55
+
56
+ def load_internvl_c_huggingface(ckpt_path, device, task):
57
+ model = InternVL_C.from_pretrained(ckpt_path, torch_dtype=torch.float16).to(device)
58
+ if model.config.use_backbone_lora:
59
+ model.vision_model.merge_and_unload()
60
+ model.vision_model = model.vision_model.model
61
+ if model.config.use_qllama_lora:
62
+ model.qllama.merge_and_unload()
63
+ model.qllama = model.qllama.model
64
+ if model.config.force_image_size is not None:
65
+ image_size = model.config.force_image_size
66
+ else:
67
+ image_size = model.config.vision_config.image_size
68
+ transform = build_transform(task, image_size)
69
+ tokenizer = InternVLTokenizer(ckpt_path)
70
+ return model, transform, tokenizer
71
+
72
+
73
+ def load_internvl_g_huggingface(ckpt_path, device, task):
74
+ model = InternVL_G.from_pretrained(ckpt_path, torch_dtype=torch.float16).to(device)
75
+ if model.config.use_backbone_lora:
76
+ model.vision_model.merge_and_unload()
77
+ model.vision_model = model.vision_model.model
78
+ if model.config.use_qllama_lora:
79
+ model.qllama.merge_and_unload()
80
+ model.qllama = model.qllama.model
81
+ if model.config.force_image_size is not None:
82
+ image_size = model.config.force_image_size
83
+ else:
84
+ image_size = model.config.vision_config.image_size
85
+ transform = build_transform(task, image_size)
86
+ tokenizer = InternVLTokenizer(ckpt_path)
87
+ return model, transform, tokenizer
InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/modeling_intern_vit.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2023 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+ from typing import Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint
11
+ from einops import rearrange
12
+ from timm.models.layers import DropPath
13
+ from torch import nn
14
+ from transformers.activations import ACT2FN
15
+ from transformers.modeling_outputs import (BaseModelOutput,
16
+ BaseModelOutputWithPooling)
17
+ from transformers.modeling_utils import PreTrainedModel
18
+ from transformers.utils import logging
19
+
20
+ from .configuration_intern_vit import InternVisionConfig
21
+
22
+ try:
23
+ from .flash_attention import FlashAttention
24
+ has_flash_attn = True
25
+ except:
26
+ print('FlashAttention is not installed.')
27
+ has_flash_attn = False
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ class InternRMSNorm(nn.Module):
34
+ def __init__(self, hidden_size, eps=1e-6):
35
+ super().__init__()
36
+ self.weight = nn.Parameter(torch.ones(hidden_size))
37
+ self.variance_epsilon = eps
38
+
39
+ def forward(self, hidden_states):
40
+ input_dtype = hidden_states.dtype
41
+ hidden_states = hidden_states.to(torch.float32)
42
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
43
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
44
+ return self.weight * hidden_states.to(input_dtype)
45
+
46
+
47
+ try:
48
+ from apex.normalization import FusedRMSNorm
49
+
50
+ InternRMSNorm = FusedRMSNorm # noqa
51
+
52
+ logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm')
53
+ except ImportError:
54
+ # using the normal InternRMSNorm
55
+ pass
56
+ except Exception:
57
+ logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm')
58
+ pass
59
+
60
+
61
+ class InternVisionEmbeddings(nn.Module):
62
+ def __init__(self, config: InternVisionConfig):
63
+ super().__init__()
64
+ self.config = config
65
+ self.embed_dim = config.hidden_size
66
+ self.image_size = config.image_size
67
+ self.patch_size = config.patch_size
68
+
69
+ self.class_embedding = nn.Parameter(
70
+ torch.randn(1, 1, self.embed_dim),
71
+ )
72
+
73
+ self.patch_embedding = nn.Conv2d(
74
+ in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
75
+ )
76
+
77
+ self.num_patches = (self.image_size // self.patch_size) ** 2
78
+ self.num_positions = self.num_patches + 1
79
+
80
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
81
+
82
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
83
+ batch_size = pixel_values.shape[0]
84
+ target_dtype = self.patch_embedding.weight.dtype
85
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
86
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
87
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
88
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
89
+ embeddings = embeddings + self.position_embedding.to(target_dtype)
90
+ return embeddings
91
+
92
+
93
+ class InternAttention(nn.Module):
94
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
95
+
96
+ def __init__(self, config: InternVisionConfig):
97
+ super().__init__()
98
+ self.config = config
99
+ self.embed_dim = config.hidden_size
100
+ self.num_heads = config.num_attention_heads
101
+ self.use_flash_attn = config.use_flash_attn and has_flash_attn
102
+ if config.use_flash_attn and not has_flash_attn:
103
+ print('Warning: Flash Attention is not available, use_flash_attn is set to False.')
104
+ self.head_dim = self.embed_dim // self.num_heads
105
+ if self.head_dim * self.num_heads != self.embed_dim:
106
+ raise ValueError(
107
+ f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'
108
+ f' {self.num_heads}).'
109
+ )
110
+
111
+ self.scale = self.head_dim ** -0.5
112
+ self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
113
+ self.attn_drop = nn.Dropout(config.attention_dropout)
114
+ self.proj_drop = nn.Dropout(config.dropout)
115
+
116
+ self.qk_normalization = config.qk_normalization
117
+
118
+ if self.qk_normalization:
119
+ self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
120
+ self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
121
+
122
+ if self.use_flash_attn:
123
+ self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
124
+ self.proj = nn.Linear(self.embed_dim, self.embed_dim)
125
+
126
+ def _naive_attn(self, x):
127
+ B, N, C = x.shape
128
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
129
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
130
+
131
+ if self.qk_normalization:
132
+ B_, H_, N_, D_ = q.shape
133
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
134
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
135
+
136
+ attn = ((q * self.scale) @ k.transpose(-2, -1))
137
+ attn = attn.softmax(dim=-1)
138
+ attn = self.attn_drop(attn)
139
+
140
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
141
+ x = self.proj(x)
142
+ x = self.proj_drop(x)
143
+ return x
144
+
145
+ def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
146
+ qkv = self.qkv(x)
147
+ qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
148
+
149
+ if self.qk_normalization:
150
+ q, k, v = qkv.unbind(2)
151
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
152
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
153
+ qkv = torch.stack([q, k, v], dim=2)
154
+
155
+ context, _ = self.inner_attn(
156
+ qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
157
+ )
158
+ outs = self.proj(rearrange(context, 'b s h d -> b s (h d)'))
159
+ outs = self.proj_drop(outs)
160
+ return outs
161
+
162
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
163
+ x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
164
+ return x
165
+
166
+
167
+ class InternMLP(nn.Module):
168
+ def __init__(self, config: InternVisionConfig):
169
+ super().__init__()
170
+ self.config = config
171
+ self.act = ACT2FN[config.hidden_act]
172
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
173
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
174
+
175
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
176
+ hidden_states = self.fc1(hidden_states)
177
+ hidden_states = self.act(hidden_states)
178
+ hidden_states = self.fc2(hidden_states)
179
+ return hidden_states
180
+
181
+
182
+ class InternVisionEncoderLayer(nn.Module):
183
+ def __init__(self, config: InternVisionConfig, drop_path_rate: float):
184
+ super().__init__()
185
+ self.embed_dim = config.hidden_size
186
+ self.intermediate_size = config.intermediate_size
187
+
188
+ self.attn = InternAttention(config)
189
+ self.mlp = InternMLP(config)
190
+ self.norm1 = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
191
+ self.norm2 = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
192
+
193
+ self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
194
+ self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
195
+ self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
196
+ self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
197
+
198
+ def forward(
199
+ self,
200
+ hidden_states: torch.Tensor,
201
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
202
+ """
203
+ Args:
204
+ hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
205
+ """
206
+ hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
207
+
208
+ hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2)
209
+
210
+ return hidden_states
211
+
212
+
213
+ class InternVisionEncoder(nn.Module):
214
+ """
215
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
216
+ [`InternEncoderLayer`].
217
+
218
+ Args:
219
+ config (`InternConfig`):
220
+ The corresponding vision configuration for the `InternEncoder`.
221
+ """
222
+
223
+ def __init__(self, config: InternVisionConfig):
224
+ super().__init__()
225
+ self.config = config
226
+ # stochastic depth decay rule
227
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
228
+ self.layers = nn.ModuleList([
229
+ InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
230
+ self.gradient_checkpointing = True
231
+
232
+ def forward(
233
+ self,
234
+ inputs_embeds,
235
+ output_hidden_states: Optional[bool] = None,
236
+ return_dict: Optional[bool] = None,
237
+ ) -> Union[Tuple, BaseModelOutput]:
238
+ r"""
239
+ Args:
240
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
241
+ Embedded representation of the inputs. Should be float, not int tokens.
242
+ output_hidden_states (`bool`, *optional*):
243
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
244
+ for more detail.
245
+ return_dict (`bool`, *optional*):
246
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
247
+ """
248
+ output_hidden_states = (
249
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
250
+ )
251
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
252
+
253
+ encoder_states = () if output_hidden_states else None
254
+ hidden_states = inputs_embeds
255
+
256
+ for idx, encoder_layer in enumerate(self.layers):
257
+ if output_hidden_states:
258
+ encoder_states = encoder_states + (hidden_states,)
259
+ if self.gradient_checkpointing and self.training:
260
+ layer_outputs = torch.utils.checkpoint.checkpoint(
261
+ encoder_layer,
262
+ hidden_states)
263
+ else:
264
+ layer_outputs = encoder_layer(
265
+ hidden_states,
266
+ )
267
+ hidden_states = layer_outputs
268
+
269
+ if output_hidden_states:
270
+ encoder_states = encoder_states + (hidden_states,)
271
+
272
+ if not return_dict:
273
+ return tuple(v for v in [hidden_states, encoder_states] if v is not None)
274
+ return BaseModelOutput(
275
+ last_hidden_state=hidden_states, hidden_states=encoder_states
276
+ )
277
+
278
+
279
+ class InternVisionModel(PreTrainedModel):
280
+ main_input_name = 'pixel_values'
281
+ config_class = InternVisionConfig
282
+
283
+ def __init__(self, config: InternVisionConfig):
284
+ super().__init__(config)
285
+ self.config = config
286
+
287
+ self.embeddings = InternVisionEmbeddings(config)
288
+ self.encoder = InternVisionEncoder(config)
289
+
290
+ def resize_pos_embeddings(self, old_size, new_size, patch_size):
291
+ pos_emb = self.embeddings.position_embedding
292
+ _, num_positions, embed_dim = pos_emb.shape
293
+ cls_emb = pos_emb[:, :1, :]
294
+ pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
295
+ pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False)
296
+ pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
297
+ pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
298
+ self.embeddings.position_embedding = nn.Parameter(pos_emb)
299
+ logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size))
300
+
301
+ def get_input_embeddings(self):
302
+ return self.embeddings
303
+
304
+ def forward(
305
+ self,
306
+ pixel_values: Optional[torch.FloatTensor] = None,
307
+ output_hidden_states: Optional[bool] = None,
308
+ return_dict: Optional[bool] = None,
309
+ pixel_embeds: Optional[torch.FloatTensor] = None,
310
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
311
+ output_hidden_states = (
312
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
313
+ )
314
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
315
+
316
+ if pixel_values is None and pixel_embeds is None:
317
+ raise ValueError('You have to specify pixel_values or pixel_embeds')
318
+
319
+ if pixel_embeds is not None:
320
+ hidden_states = pixel_embeds
321
+ else:
322
+ if len(pixel_values.shape) == 4:
323
+ hidden_states = self.embeddings(pixel_values)
324
+ else:
325
+ raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
326
+ encoder_outputs = self.encoder(
327
+ inputs_embeds=hidden_states,
328
+ output_hidden_states=output_hidden_states,
329
+ return_dict=return_dict,
330
+ )
331
+ last_hidden_state = encoder_outputs.last_hidden_state
332
+ pooled_output = last_hidden_state[:, 0, :]
333
+
334
+ if not return_dict:
335
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
336
+
337
+ return BaseModelOutputWithPooling(
338
+ last_hidden_state=last_hidden_state,
339
+ pooler_output=pooled_output,
340
+ hidden_states=encoder_outputs.hidden_states,
341
+ attentions=encoder_outputs.attentions,
342
+ )
InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/modeling_internvl.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2023 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+ from dataclasses import dataclass
7
+ from functools import partial
8
+ from typing import Any, Optional, Tuple, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.distributed as dist
13
+ import torch.nn.functional as F
14
+ import torch.utils.checkpoint
15
+ from peft import LoraConfig, get_peft_model
16
+ from timm.models.layers import DropPath
17
+ from torch import nn
18
+ from transformers import GenerationConfig
19
+ from transformers.modeling_utils import PreTrainedModel
20
+ from transformers.utils import ModelOutput, logging
21
+
22
+ from .configuration_internvl import InternVLConfig
23
+ from .modeling_intern_vit import (InternVisionEmbeddings, InternVisionEncoder,
24
+ InternVisionModel)
25
+ from .modeling_qllama import LlamaForCausalLM, _expand_mask, _make_causal_mask
26
+
27
+ try:
28
+ from .flash_attention import FlashAttention # v1/v2
29
+ except:
30
+ print('FlashAttention is not installed.')
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ class InternVLPreTrainedModel(PreTrainedModel):
36
+ """
37
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
38
+ models.
39
+ """
40
+
41
+ config_class = InternVLConfig
42
+ base_model_prefix = 'internvl'
43
+ supports_gradient_checkpointing = True
44
+ _keys_to_ignore_on_load_missing = [
45
+ r'position_ids',
46
+ ]
47
+ _no_split_modules = ['InternAttention', 'LlamaDecoderLayer', 'LlamaForCausalLM']
48
+ _skip_keys_device_placement = 'past_key_values'
49
+ _keep_in_fp32_modules = ['wo']
50
+
51
+ # def _init_weights(self, module):
52
+ # """Initialize the weights"""
53
+ # factor = self.config.initializer_range
54
+ # if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear):
55
+ # module.weight.data.normal_(mean=0.0, std=factor)
56
+ # if hasattr(module, 'bias') and module.bias is not None:
57
+ # module.bias.data.zero_()
58
+ # if isinstance(module, InternVisionEmbeddings):
59
+ # if hasattr(self.config, 'vision_config'):
60
+ # factor = self.config.vision_config.initializer_range
61
+ # nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
62
+ # nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
63
+ # elif isinstance(module, nn.LayerNorm):
64
+ # module.bias.data.zero_()
65
+ # module.weight.data.fill_(1.0)
66
+ # elif isinstance(module, nn.Linear) and module.bias is not None:
67
+ # module.bias.data.zero_()
68
+
69
+ def _set_gradient_checkpointing(self, module, value=False):
70
+ if isinstance(module, InternVisionModel):
71
+ module.gradient_checkpointing = value
72
+ if isinstance(module, InternVisionEncoder):
73
+ module.gradient_checkpointing = value
74
+
75
+
76
+ class CrossAttention(nn.Module):
77
+ def __init__(
78
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
79
+ proj_drop=0., attn_head_dim=None, out_dim=None):
80
+ super().__init__()
81
+ if out_dim is None:
82
+ out_dim = dim
83
+ self.num_heads = num_heads
84
+ head_dim = dim // num_heads
85
+ if attn_head_dim is not None:
86
+ head_dim = attn_head_dim
87
+ all_head_dim = head_dim * self.num_heads
88
+ self.scale = qk_scale or head_dim ** -0.5
89
+ assert all_head_dim == dim
90
+
91
+ self.q = nn.Linear(dim, all_head_dim, bias=False)
92
+ self.k = nn.Linear(dim, all_head_dim, bias=False)
93
+ self.v = nn.Linear(dim, all_head_dim, bias=False)
94
+
95
+ if qkv_bias:
96
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
97
+ self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
98
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
99
+ else:
100
+ self.q_bias = None
101
+ self.k_bias = None
102
+ self.v_bias = None
103
+
104
+ self.attn_drop = nn.Dropout(attn_drop)
105
+ self.proj = nn.Linear(all_head_dim, out_dim)
106
+ self.proj_drop = nn.Dropout(proj_drop)
107
+
108
+ def forward(self, x, k=None, v=None):
109
+ B, N, C = x.shape
110
+ N_k = k.shape[1]
111
+ N_v = v.shape[1]
112
+
113
+ q_bias, k_bias, v_bias = None, None, None
114
+ if self.q_bias is not None:
115
+ q_bias = self.q_bias
116
+ k_bias = self.k_bias
117
+ v_bias = self.v_bias
118
+
119
+ q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
120
+ q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
121
+
122
+ k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
123
+ k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
124
+
125
+ v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
126
+ v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
127
+
128
+ q = q * self.scale
129
+ attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
130
+
131
+ attn = attn.softmax(dim=-1)
132
+ attn = self.attn_drop(attn)
133
+
134
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
135
+ x = self.proj(x)
136
+ x = self.proj_drop(x)
137
+
138
+ return x
139
+
140
+
141
+ class AttentiveBlock(nn.Module):
142
+
143
+ def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
144
+ drop_path=0., norm_layer=nn.LayerNorm, attn_head_dim=None, out_dim=None):
145
+ super().__init__()
146
+
147
+ self.norm1_q = norm_layer(dim)
148
+ self.norm1_k = norm_layer(dim)
149
+ self.norm1_v = norm_layer(dim)
150
+ self.cross_attn = CrossAttention(
151
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
152
+ proj_drop=drop, attn_head_dim=attn_head_dim, out_dim=out_dim)
153
+
154
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
155
+
156
+ def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None):
157
+ x_q = self.norm1_q(x_q + pos_q)
158
+ x_k = self.norm1_k(x_kv + pos_k)
159
+ x_v = self.norm1_v(x_kv)
160
+ x = self.cross_attn(x_q, k=x_k, v=x_v)
161
+
162
+ return x
163
+
164
+
165
+ class AttentionPoolingBlock(AttentiveBlock):
166
+
167
+ def forward(self, x):
168
+ x_q = x.mean(1, keepdim=True)
169
+ x_kv, pos_q, pos_k = x, 0, 0
170
+ x = super().forward(x_q, x_kv, pos_q, pos_k, bool_masked_pos=None, rel_pos_bias=None)
171
+ x = x.squeeze(1)
172
+ return x
173
+
174
+
175
+ @dataclass
176
+ class InternVLModelOutput(ModelOutput):
177
+ """
178
+ Class defining the outputs of [`InternVLModelOutput`].
179
+ """
180
+
181
+ loss: Optional[torch.FloatTensor] = None
182
+ loss_itm: Optional[torch.FloatTensor] = None
183
+ loss_itc: Optional[torch.FloatTensor] = None
184
+ loss_itg: Optional[torch.FloatTensor] = None
185
+
186
+ def to_tuple(self) -> Tuple[Any]:
187
+ return tuple(
188
+ self[k]
189
+ if k not in ['loss', 'loss_itm', 'loss_itc', 'loss_itg']
190
+ else getattr(self, k).to_tuple()
191
+ for k in self.keys()
192
+ )
193
+
194
+
195
+ class GatherLayer(torch.autograd.Function):
196
+ """Gather tensors from all process, supporting backward propagation.
197
+ """
198
+
199
+ @staticmethod
200
+ def forward(ctx, input):
201
+ ctx.save_for_backward(input)
202
+ output = [torch.zeros_like(input) for _ in range(dist.get_world_size())]
203
+ dist.all_gather(output, input)
204
+ return torch.stack(output, 0)
205
+
206
+ @staticmethod
207
+ def backward(ctx, grads):
208
+ input, = ctx.saved_tensors
209
+ dist.all_reduce(grads)
210
+ grad_out = torch.zeros_like(input)
211
+ grad_out[:] = grads[dist.get_rank()]
212
+ return grad_out
213
+
214
+
215
+ class InternVLModel(InternVLPreTrainedModel):
216
+ config_class = InternVLConfig
217
+ main_input_name = 'pixel_values'
218
+
219
+ def __init__(self, config: InternVLConfig):
220
+ super().__init__(config)
221
+
222
+ text_hidden_size = config.qllama_config.hidden_size
223
+ vision_hidden_size = config.vision_config.hidden_size
224
+ clip_embed_dim = config.clip_embed_dim
225
+ attn_pool_num_heads = config.attn_pool_num_heads
226
+ config.qllama_config.num_query_token = config.num_query_token
227
+ self.num_query_token = config.num_query_token
228
+ self.label_smoothing = config.label_smoothing
229
+
230
+ self.vision_model = InternVisionModel(config.vision_config) # frozen
231
+ self.qllama = LlamaForCausalLM(config.qllama_config) # frozen
232
+ self.query_tokens = nn.Parameter( # trainable
233
+ torch.zeros(1, config.num_query_token, text_hidden_size)
234
+ )
235
+
236
+ self.text_projection = nn.Parameter(torch.empty(text_hidden_size, clip_embed_dim)) # frozen
237
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) # trainable
238
+ self.clip_projector = AttentionPoolingBlock( # frozen
239
+ dim=vision_hidden_size, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None,
240
+ drop=0., attn_drop=0., norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim)
241
+ self.clip_projector2 = AttentionPoolingBlock( # trainable
242
+ dim=text_hidden_size, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None,
243
+ drop=0., attn_drop=0., norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim)
244
+ self.itm_head = nn.Linear(text_hidden_size, 2) # trainable
245
+ self.gradient_checkpointing = True
246
+
247
+ # Initialize weights and apply final processing
248
+ # self.post_init()
249
+
250
+ if config.use_backbone_lora:
251
+ self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=config.use_backbone_lora * 2)
252
+ if config.use_qllama_lora:
253
+ self.wrap_qllama_lora(r=config.use_qllama_lora, lora_alpha=config.use_qllama_lora * 2)
254
+ if config.force_image_size:
255
+ self.vision_model.resize_pos_embeddings(
256
+ old_size=config.vision_config.image_size,
257
+ new_size=config.force_image_size,
258
+ patch_size=config.vision_config.patch_size
259
+ )
260
+
261
+ def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
262
+ lora_config = LoraConfig(
263
+ r=r,
264
+ target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'],
265
+ lora_alpha=lora_alpha,
266
+ lora_dropout=lora_dropout,
267
+ )
268
+ self.vision_model = get_peft_model(self.vision_model, lora_config)
269
+ self.vision_model.print_trainable_parameters()
270
+
271
+ def wrap_qllama_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
272
+ lora_config = LoraConfig(
273
+ r=r,
274
+ target_modules=['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj',
275
+ 'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj'],
276
+ lora_alpha=lora_alpha,
277
+ lora_dropout=lora_dropout,
278
+ )
279
+ self.qllama = get_peft_model(self.qllama, lora_config)
280
+ self.qllama.print_trainable_parameters()
281
+
282
+ def get_input_embeddings(self):
283
+ return self.qllama.get_input_embeddings()
284
+
285
+ def set_input_embeddings(self, value):
286
+ self.qllama.set_input_embeddings(value)
287
+
288
+ def set_output_embeddings(self, new_embeddings):
289
+ self.qllama.set_output_embeddings(new_embeddings)
290
+
291
+ def get_output_embeddings(self) -> nn.Module:
292
+ return self.qllama.get_output_embeddings()
293
+
294
+ @torch.no_grad()
295
+ def _prepare_attention_mask(
296
+ self,
297
+ image_attention_mask: torch.LongTensor,
298
+ attention_mask: torch.LongTensor,
299
+ input_embeds: torch.FloatTensor,
300
+ repeat_time: int,
301
+ ):
302
+ # itm, itc
303
+ attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1)
304
+ expand_mask = _expand_mask(attention_mask, input_embeds.dtype).to(
305
+ input_embeds.device) # [bsz, 1, tgt_seq_len, src_seq_len]
306
+ itm_mask_neg, itm_mask_pos, itc_mask = torch.chunk(expand_mask, repeat_time, dim=0)
307
+
308
+ itc_mask[:, :, :self.num_query_token, self.num_query_token:] = torch.finfo(input_embeds.dtype).min
309
+ itc_mask[:, :, self.num_query_token:, :self.num_query_token] = torch.finfo(input_embeds.dtype).min
310
+ itc_mask_causal = _make_causal_mask(
311
+ (itc_mask.shape[0], itc_mask.shape[2] - self.num_query_token),
312
+ input_embeds.dtype,
313
+ device=input_embeds.device
314
+ )
315
+ # use causal mask for text in itc
316
+ itc_mask[:, :, self.num_query_token:, self.num_query_token:] += itc_mask_causal
317
+
318
+ attention_mask = torch.cat([itm_mask_neg, itm_mask_pos, itc_mask], dim=0)
319
+
320
+ return attention_mask
321
+
322
+ def forward(
323
+ self,
324
+ pixel_values: torch.FloatTensor,
325
+ positive_input_ids: torch.FloatTensor,
326
+ positive_attention_mask: torch.LongTensor,
327
+ negative_input_ids: torch.FloatTensor,
328
+ negative_attention_mask: torch.LongTensor,
329
+ summarize_input_ids: torch.FloatTensor,
330
+ summarize_attention_mask: torch.LongTensor,
331
+ input_ids: torch.FloatTensor,
332
+ attention_mask: torch.LongTensor,
333
+ image_ids: torch.LongTensor,
334
+ labels: torch.LongTensor,
335
+ output_attentions: Optional[bool] = None,
336
+ output_hidden_states: Optional[bool] = None,
337
+ return_dict: Optional[bool] = None,
338
+ ) -> Union[Tuple, InternVLModelOutput]:
339
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
340
+
341
+ # step 1: forward the images through the vision encoder,
342
+ # to get image embeddings of shape (batch_size, seq_len, hidden_size)
343
+ vision_outputs = self.vision_model(
344
+ pixel_values=pixel_values,
345
+ output_hidden_states=output_hidden_states,
346
+ return_dict=return_dict)
347
+ image_embeds = vision_outputs[0]
348
+ backbone_embeds = self.clip_projector(image_embeds)
349
+
350
+ # step 2: prepare input_ids and attention_mask for two sub-tasks:
351
+ # 1) image-text matching; 2) image-text contrastive learning.
352
+ batch_size = input_ids.shape[0]
353
+ input_ids = torch.cat([negative_input_ids, positive_input_ids,
354
+ summarize_input_ids], dim=0) # [3 * batch_size, seq_len]
355
+ itm_attention_mask = torch.cat(
356
+ [negative_attention_mask, positive_attention_mask], dim=0)
357
+ attention_mask = torch.cat(
358
+ [itm_attention_mask, summarize_attention_mask], dim=0) # [3 * batch_size, seq_len]
359
+
360
+ repeat_time = input_ids.size(0) // batch_size
361
+ # step 3: forward the input_ids and attention_mask through the text encoder.
362
+ input_embeds = self.get_input_embeddings()(input_ids)
363
+ query_tokens = self.query_tokens.repeat(repeat_time * batch_size, 1, 1)
364
+ input_embeds = torch.cat([query_tokens, input_embeds], dim=1)
365
+ image_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
366
+ attention_mask = self._prepare_attention_mask(
367
+ image_attention_mask, attention_mask, input_embeds, repeat_time
368
+ )
369
+ if type(self.qllama.model) == LlamaForCausalLM:
370
+ outputs = self.qllama.model.model.forward_train(
371
+ inputs_embeds=input_embeds,
372
+ vision_hidden_states=image_embeds,
373
+ attention_mask=attention_mask,
374
+ output_attentions=output_attentions,
375
+ output_hidden_states=output_hidden_states,
376
+ return_dict=return_dict,
377
+ repeat_time=repeat_time,
378
+ ).last_hidden_state
379
+ else:
380
+ outputs = self.qllama.model.forward_train(
381
+ inputs_embeds=input_embeds,
382
+ vision_hidden_states=image_embeds,
383
+ attention_mask=attention_mask,
384
+ output_attentions=output_attentions,
385
+ output_hidden_states=output_hidden_states,
386
+ return_dict=return_dict,
387
+ repeat_time=repeat_time,
388
+ ).last_hidden_state
389
+ image_embeds = outputs[:, :self.num_query_token]
390
+ text_embeds = outputs[:, self.num_query_token:]
391
+ image_itm_neg, image_itm_pos, image_itc = image_embeds.chunk(repeat_time, dim=0)
392
+ text_itm_neg, text_itm_pos, text_itc = text_embeds.chunk(repeat_time, dim=0)
393
+ image_itm = torch.cat([image_itm_neg, image_itm_pos], dim=0)
394
+
395
+ ###============== Image-Text Matching ===================###
396
+ image_itm = self.itm_head(image_itm)
397
+ logits = image_itm.mean(dim=1)
398
+ itm_labels = torch.cat([
399
+ torch.zeros(batch_size, dtype=torch.long, device=logits.device),
400
+ torch.ones(batch_size, dtype=torch.long, device=logits.device)
401
+ ], dim=0)
402
+ loss_itm = F.cross_entropy(logits, itm_labels)
403
+ neg_match_acc = ((logits[:batch_size].argmax(dim=-1) == 0) / batch_size).sum()
404
+ pos_match_acc = ((logits[batch_size:].argmax(dim=-1) == 1) / batch_size).sum()
405
+
406
+ ###============== Image-Text Contrastive ===================###
407
+ image_itc = self.clip_projector2(image_itc)
408
+
409
+ selected = summarize_attention_mask.sum(1) - 1
410
+ text_itc = text_itc[torch.arange(text_itc.shape[0]), selected]
411
+ text_itc = text_itc @ self.text_projection
412
+
413
+ # normalized features
414
+ backbone_embeds = backbone_embeds / backbone_embeds.norm(dim=1, keepdim=True)
415
+ image_itc = image_itc / image_itc.norm(dim=1, keepdim=True)
416
+ text_itc = text_itc / text_itc.norm(dim=1, keepdim=True)
417
+ backbone_embeds_all = GatherLayer.apply(backbone_embeds).flatten(0, 1)
418
+ image_itc_all = GatherLayer.apply(image_itc).flatten(0, 1)
419
+ text_itc_all = GatherLayer.apply(text_itc).flatten(0, 1)
420
+
421
+ # cosine similarity as logits
422
+ logit_scale = self.logit_scale.exp()
423
+ sim_i2t = logit_scale * (image_itc @ text_itc_all.t())
424
+ sim_t2i = logit_scale * (text_itc @ image_itc_all.t())
425
+ backbone_i2t = logit_scale * (backbone_embeds @ text_itc_all.t())
426
+ backbone_t2i = logit_scale * (text_itc @ backbone_embeds_all.t())
427
+
428
+ image_ids = image_ids.view(-1, 1)
429
+ image_ids_all = GatherLayer.apply(image_ids).flatten(0, 1)
430
+ pos_idx = torch.eq(image_ids, image_ids_all.t()).float()
431
+ sim_targets = pos_idx / pos_idx.sum(1, keepdim=True)
432
+
433
+ loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1) * sim_targets, dim=1).mean()
434
+ loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1) * sim_targets, dim=1).mean()
435
+ loss_backbone_t2i = -torch.sum(F.log_softmax(backbone_t2i, dim=1) * sim_targets, dim=1).mean()
436
+ loss_backbone_i2t = -torch.sum(F.log_softmax(backbone_i2t, dim=1) * sim_targets, dim=1).mean()
437
+ loss_itc = (loss_t2i + loss_i2t) / 2 + (loss_backbone_t2i + loss_backbone_i2t) / 2
438
+
439
+ vision_sim = F.cosine_similarity(backbone_embeds.detach(), image_itc).mean()
440
+
441
+ loss = loss_itm + loss_itc
442
+ if dist.get_rank() == 0:
443
+ print(f'loss: {loss.item()}, loss_itm: {loss_itm.item()}, loss_itc: {loss_itc.item()}, '
444
+ f'vision_similarity: {round(vision_sim.item(), 5)}, '
445
+ f'logit scale: {round(1.0 / logit_scale.item(), 5)}, '
446
+ f'pos_match_acc: {round(pos_match_acc.item(), 4)}, '
447
+ f'neg_match_acc: {round(neg_match_acc.item(), 4)}')
448
+
449
+ return InternVLModelOutput(
450
+ loss=loss,
451
+ loss_itc=loss_itc.detach(),
452
+ loss_itm=loss_itm.detach(),
453
+ )
454
+
455
+ @torch.no_grad()
456
+ def generate(
457
+ self,
458
+ pixel_values: torch.FloatTensor,
459
+ input_ids: torch.FloatTensor,
460
+ attention_mask: torch.LongTensor,
461
+ generation_config: Optional[GenerationConfig] = None,
462
+ output_hidden_states: Optional[bool] = None,
463
+ return_dict: Optional[bool] = None,
464
+ **generate_kwargs,
465
+ ) -> torch.LongTensor:
466
+
467
+ vision_outputs = self.vision_model(
468
+ pixel_values=pixel_values,
469
+ output_hidden_states=output_hidden_states,
470
+ return_dict=return_dict)
471
+ image_embeds = vision_outputs[0]
472
+
473
+ batch_size = image_embeds.shape[0]
474
+ input_embeds = self.get_input_embeddings()(input_ids)
475
+ query_tokens = self.query_tokens.repeat(batch_size, 1, 1)
476
+ input_embeds = torch.cat([query_tokens, input_embeds], dim=1)
477
+ image_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
478
+ attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1)
479
+
480
+ outputs = self.qllama.generate(
481
+ inputs_embeds=input_embeds,
482
+ attention_mask=attention_mask,
483
+ vision_hidden_states=image_embeds,
484
+ generation_config=generation_config,
485
+ use_zero_attention_mask=True,
486
+ **generate_kwargs,
487
+ )
488
+
489
+ return outputs
490
+
491
+ def get_text_features(
492
+ self,
493
+ input_ids: torch.Tensor,
494
+ attention_mask: torch.Tensor,
495
+ output_attentions: Optional[bool] = None,
496
+ output_hidden_states: Optional[bool] = None,
497
+ return_dict: Optional[bool] = None,
498
+ ):
499
+ r"""
500
+ Returns:
501
+ text_outputs (`CausalLMOutputWithPast`, or `tuple(torch.FloatTensor)` if `return_dict=False`):
502
+ The language model outputs. If `return_dict=True`, the output is a [`CausalLMOutputWithPast`] that
503
+ contains the language model logits, the past key values and the hidden states if
504
+ `output_hidden_states=True`.
505
+ ```"""
506
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
507
+ output_hidden_states = (
508
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
509
+ )
510
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
511
+
512
+ input_embeds = self.get_input_embeddings()(input_ids)
513
+ attention_mask = _expand_mask(attention_mask, input_embeds.dtype).to(
514
+ input_embeds.device) # [bsz, 1, tgt_seq_len, src_seq_len]
515
+ attention_mask += _make_causal_mask(
516
+ (attention_mask.shape[0], attention_mask.shape[2]),
517
+ input_embeds.dtype,
518
+ device=input_embeds.device
519
+ )
520
+ if type(self.qllama.model) == LlamaForCausalLM:
521
+ outputs = self.qllama.model.model.forward_train(
522
+ inputs_embeds=input_embeds,
523
+ vision_hidden_states=None,
524
+ attention_mask=attention_mask,
525
+ output_attentions=output_attentions,
526
+ output_hidden_states=output_hidden_states,
527
+ return_dict=return_dict,
528
+ ).last_hidden_state
529
+ else:
530
+ outputs = self.qllama.model.forward_train(
531
+ inputs_embeds=input_embeds,
532
+ vision_hidden_states=None,
533
+ attention_mask=attention_mask,
534
+ output_attentions=output_attentions,
535
+ output_hidden_states=output_hidden_states,
536
+ return_dict=return_dict,
537
+ ).last_hidden_state
538
+ return outputs
539
+
540
+ def get_image_features(
541
+ self,
542
+ pixel_values: torch.FloatTensor,
543
+ output_attentions: Optional[bool] = None,
544
+ output_hidden_states: Optional[bool] = None,
545
+ return_dict: Optional[bool] = None,
546
+ ):
547
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
548
+ output_hidden_states = (
549
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
550
+ )
551
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
552
+
553
+ vision_outputs = self.vision_model(
554
+ pixel_values=pixel_values,
555
+ output_hidden_states=output_hidden_states,
556
+ return_dict=return_dict)
557
+ image_embeds = vision_outputs[0]
558
+ backbone_embeds = image_embeds
559
+
560
+ batch_size = image_embeds.shape[0]
561
+ input_embeds = self.query_tokens.repeat(batch_size, 1, 1)
562
+
563
+ attention_mask = torch.ones(input_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
564
+ attention_mask = _expand_mask(attention_mask, input_embeds.dtype).to(
565
+ input_embeds.device) # [bsz, 1, tgt_seq_len, src_seq_len]
566
+ if type(self.qllama.model) == LlamaForCausalLM:
567
+ outputs = self.qllama.model.model.forward_train(
568
+ inputs_embeds=input_embeds,
569
+ vision_hidden_states=image_embeds,
570
+ attention_mask=attention_mask,
571
+ output_attentions=output_attentions,
572
+ output_hidden_states=output_hidden_states,
573
+ return_dict=return_dict,
574
+ ).last_hidden_state
575
+ else:
576
+ outputs = self.qllama.model.forward_train(
577
+ inputs_embeds=input_embeds,
578
+ vision_hidden_states=image_embeds,
579
+ attention_mask=attention_mask,
580
+ output_attentions=output_attentions,
581
+ output_hidden_states=output_hidden_states,
582
+ return_dict=return_dict,
583
+ ).last_hidden_state
584
+ return backbone_embeds, outputs
585
+
586
+
587
+ class InternVL_C(InternVLModel):
588
+
589
+ def encode_image(self, image):
590
+ vision_outputs = self.vision_model(
591
+ pixel_values=image,
592
+ output_hidden_states=False,
593
+ return_dict=True)
594
+ image_embeds = vision_outputs[0]
595
+ image_embeds = self.clip_projector(image_embeds)
596
+ return image_embeds
597
+
598
+ def encode_text(self, text):
599
+ attention_mask = text > 0
600
+ text_embeds = self.get_text_features(
601
+ input_ids=text,
602
+ attention_mask=attention_mask,
603
+ output_attentions=False,
604
+ output_hidden_states=False,
605
+ return_dict=True,
606
+ )
607
+ text_embeds = text_embeds[torch.arange(text_embeds.shape[0]), attention_mask.sum(1) - 1]
608
+ text_embeds = text_embeds @ self.text_projection
609
+ return text_embeds
610
+
611
+ def forward(self, image, text):
612
+ image_features = self.encode_image(image)
613
+ text_features = self.encode_text(text)
614
+
615
+ # normalized features
616
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
617
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
618
+
619
+ # cosine similarity as logits
620
+ logit_scale = self.logit_scale.exp()
621
+ logits_per_image = logit_scale * image_features @ text_features.t()
622
+ logits_per_text = logits_per_image.t()
623
+
624
+ return logits_per_image, logits_per_text
625
+
626
+
627
+ class InternVL_G(InternVLModel):
628
+
629
+ def encode_image(self, image):
630
+ backbone_embeds, image_embeds = self.get_image_features(
631
+ pixel_values=image,
632
+ output_hidden_states=False,
633
+ return_dict=True,
634
+ )
635
+ backbone_embeds = self.clip_projector(backbone_embeds)
636
+ image_embeds = self.clip_projector2(image_embeds)
637
+ # ensemble
638
+ backbone_embeds = backbone_embeds / backbone_embeds.norm(dim=1, keepdim=True)
639
+ image_embeds = image_embeds / image_embeds.norm(dim=1, keepdim=True)
640
+ image_embeds = image_embeds + backbone_embeds
641
+ return image_embeds
642
+
643
+ def encode_text(self, text):
644
+ attention_mask = text > 0
645
+ text_embeds = self.get_text_features(
646
+ input_ids=text,
647
+ attention_mask=attention_mask,
648
+ output_attentions=False,
649
+ output_hidden_states=False,
650
+ return_dict=True,
651
+ )
652
+ text_embeds = text_embeds[torch.arange(text_embeds.shape[0]), attention_mask.sum(1) - 1]
653
+ text_embeds = text_embeds @ self.text_projection
654
+ return text_embeds
655
+
656
+ def forward(self, image, text):
657
+ image_features = self.encode_image(image)
658
+ text_features = self.encode_text(text)
659
+
660
+ # normalized features
661
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
662
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
663
+
664
+ # cosine similarity as logits
665
+ logit_scale = self.logit_scale.exp()
666
+ logits_per_image = logit_scale * image_features @ text_features.t()
667
+ logits_per_text = logits_per_image.t()
668
+
669
+ return logits_per_image, logits_per_text
InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/modeling_qllama.py ADDED
@@ -0,0 +1,1073 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
4
+ # and OPT implementations in this library. It has been modified from its
5
+ # original forms to accommodate minor architectural differences compared
6
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ """ PyTorch QLLaMA model."""
20
+ import math
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import CrossEntropyLoss
27
+ from transformers import LlamaConfig
28
+ from transformers.activations import ACT2FN
29
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
30
+ CausalLMOutputWithPast)
31
+ from transformers.modeling_utils import PreTrainedModel
32
+ from transformers.utils import (add_start_docstrings,
33
+ add_start_docstrings_to_model_forward, logging,
34
+ replace_return_docstrings)
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+ _CONFIG_FOR_DOC = 'LlamaConfig'
39
+
40
+
41
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
42
+ def _make_causal_mask(
43
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
44
+ ):
45
+ """
46
+ Make causal mask used for bi-directional self-attention.
47
+ """
48
+ bsz, tgt_len = input_ids_shape
49
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
50
+ mask_cond = torch.arange(mask.size(-1), device=device)
51
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
52
+ mask = mask.to(dtype)
53
+
54
+ if past_key_values_length > 0:
55
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
56
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
57
+
58
+
59
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
60
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
61
+ """
62
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
63
+ """
64
+ bsz, src_len = mask.size()
65
+ tgt_len = tgt_len if tgt_len is not None else src_len
66
+
67
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
68
+
69
+ inverted_mask = 1.0 - expanded_mask
70
+
71
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
72
+
73
+
74
+ class LlamaRMSNorm(nn.Module):
75
+ def __init__(self, hidden_size, eps=1e-6):
76
+ """
77
+ LlamaRMSNorm is equivalent to T5LayerNorm
78
+ """
79
+ super().__init__()
80
+ self.weight = nn.Parameter(torch.ones(hidden_size))
81
+ self.variance_epsilon = eps
82
+
83
+ def forward(self, hidden_states):
84
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
85
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
86
+
87
+ # convert into half-precision if necessary
88
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
89
+ hidden_states = hidden_states.to(self.weight.dtype)
90
+
91
+ return self.weight * hidden_states
92
+
93
+
94
+ try:
95
+ from functools import partial
96
+
97
+ from apex.normalization import FusedRMSNorm
98
+
99
+ LlamaRMSNorm = partial(FusedRMSNorm, eps=1e-6) # noqa
100
+ print('Discovered apex.normalization.FusedRMSNorm - will use it instead of LlamaRMSNorm')
101
+ except ImportError:
102
+ # using the normal LlamaRMSNorm
103
+ pass
104
+ except Exception:
105
+ print('discovered apex but it failed to load, falling back to LlamaRMSNorm')
106
+ pass
107
+
108
+
109
+ class LlamaRotaryEmbedding(torch.nn.Module):
110
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
111
+ super().__init__()
112
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
113
+ self.register_buffer('inv_freq', inv_freq)
114
+
115
+ # Build here to make `torch.jit.trace` work.
116
+ self.max_seq_len_cached = max_position_embeddings
117
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
118
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
119
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
120
+ emb = torch.cat((freqs, freqs), dim=-1)
121
+ self.register_buffer('cos_cached', emb.cos()[None, None, :, :], persistent=False)
122
+ self.register_buffer('sin_cached', emb.sin()[None, None, :, :], persistent=False)
123
+
124
+ def forward(self, x, seq_len=None):
125
+ # x: [bs, num_attention_heads, seq_len, head_size]
126
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
127
+ if seq_len > self.max_seq_len_cached:
128
+ self.max_seq_len_cached = seq_len
129
+ t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
130
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
131
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
132
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
133
+ self.register_buffer('cos_cached', emb.cos()[None, None, :, :], persistent=False)
134
+ self.register_buffer('sin_cached', emb.sin()[None, None, :, :], persistent=False)
135
+ return (
136
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
137
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
138
+ )
139
+
140
+
141
+ class FixedLlamaRotaryEmbedding(torch.nn.Module):
142
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
143
+ super().__init__()
144
+
145
+ self.dim = dim
146
+ self.max_position_embeddings = max_position_embeddings
147
+ self.base = base
148
+ self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
149
+
150
+ # Build here to make `torch.jit.trace` work.
151
+ self._set_cos_sin_cache(
152
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
153
+ )
154
+
155
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
156
+ self.max_seq_len_cached = seq_len
157
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
158
+
159
+ freqs = torch.outer(t, self.inv_freq)
160
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
161
+ emb = torch.cat((freqs, freqs), dim=-1)
162
+ self.register_buffer('cos_cached', emb.cos()[None, None, :, :], persistent=False)
163
+ self.register_buffer('sin_cached', emb.sin()[None, None, :, :], persistent=False)
164
+
165
+ def forward(self, x, seq_len=None):
166
+ # x: [bs, num_attention_heads, seq_len, head_size]
167
+ if seq_len > self.max_seq_len_cached:
168
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
169
+
170
+ return (
171
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
172
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
173
+ )
174
+
175
+
176
+ LlamaRotaryEmbedding = FixedLlamaRotaryEmbedding
177
+
178
+
179
+ def rotate_half(x):
180
+ """Rotates half the hidden dims of the input."""
181
+ x1 = x[..., : x.shape[-1] // 2]
182
+ x2 = x[..., x.shape[-1] // 2:]
183
+ return torch.cat((-x2, x1), dim=-1)
184
+
185
+
186
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
187
+ gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
188
+ gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
189
+ cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
190
+ sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
191
+ q_embed = (q * cos) + (rotate_half(q) * sin)
192
+ k_embed = (k * cos) + (rotate_half(k) * sin)
193
+ return q_embed, k_embed
194
+
195
+
196
+ class LlamaMLP(nn.Module):
197
+ def __init__(
198
+ self,
199
+ hidden_size: int,
200
+ intermediate_size: int,
201
+ hidden_act: str,
202
+ ):
203
+ super().__init__()
204
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
205
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
206
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
207
+ self.act_fn = ACT2FN[hidden_act]
208
+
209
+ def forward(self, x):
210
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
211
+
212
+
213
+ class LlamaAttention(nn.Module):
214
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
215
+
216
+ def __init__(self, config: LlamaConfig):
217
+ super().__init__()
218
+ self.config = config
219
+ self.hidden_size = config.hidden_size
220
+ self.num_heads = config.num_attention_heads
221
+ self.head_dim = self.hidden_size // self.num_heads
222
+ self.max_position_embeddings = config.max_position_embeddings
223
+
224
+ if (self.head_dim * self.num_heads) != self.hidden_size:
225
+ raise ValueError(
226
+ f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
227
+ f' and `num_heads`: {self.num_heads}).'
228
+ )
229
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
230
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
231
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
232
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
233
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
234
+
235
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
236
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
237
+
238
+ def forward(
239
+ self,
240
+ hidden_states: torch.Tensor,
241
+ attention_mask: Optional[torch.Tensor] = None,
242
+ position_ids: Optional[torch.LongTensor] = None,
243
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
244
+ output_attentions: bool = False,
245
+ use_cache: bool = False,
246
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
247
+ bsz, q_len, _ = hidden_states.size()
248
+
249
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
250
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
251
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
252
+
253
+ kv_seq_len = key_states.shape[-2]
254
+ if past_key_value is not None:
255
+ kv_seq_len += past_key_value[0].shape[-2]
256
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
257
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
258
+ # [bsz, nh, t, hd]
259
+
260
+ if past_key_value is not None:
261
+ # reuse k, v, self_attention
262
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
263
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
264
+
265
+ past_key_value = (key_states, value_states) if use_cache else None
266
+
267
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
268
+
269
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
270
+ raise ValueError(
271
+ f'Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is'
272
+ f' {attn_weights.size()}'
273
+ )
274
+
275
+ if attention_mask is not None:
276
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
277
+ raise ValueError(
278
+ f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
279
+ )
280
+ attn_weights = attn_weights + attention_mask
281
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
282
+
283
+ # upcast attention to fp32
284
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
285
+ attn_output = torch.matmul(attn_weights, value_states)
286
+
287
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
288
+ raise ValueError(
289
+ f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
290
+ f' {attn_output.size()}'
291
+ )
292
+
293
+ attn_output = attn_output.transpose(1, 2)
294
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
295
+
296
+ attn_output = self.o_proj(attn_output)
297
+
298
+ if not output_attentions:
299
+ attn_weights = None
300
+
301
+ return attn_output, attn_weights, past_key_value
302
+
303
+
304
+ class LlamaCrossAttention(nn.Module):
305
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
306
+
307
+ def __init__(self, config: LlamaConfig):
308
+ super().__init__()
309
+ self.config = config
310
+ self.hidden_size = config.hidden_size
311
+ self.num_heads = config.num_attention_heads
312
+ self.head_dim = self.hidden_size // self.num_heads
313
+ self.max_position_embeddings = config.max_position_embeddings
314
+ self.vision_hidden_size = 3200
315
+
316
+ if (self.head_dim * self.num_heads) != self.hidden_size:
317
+ raise ValueError(
318
+ f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
319
+ f' and `num_heads`: {self.num_heads}).'
320
+ )
321
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
322
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
323
+ self.norm1 = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
324
+
325
+ self.k_proj = nn.Linear(self.vision_hidden_size, self.num_heads * self.head_dim, bias=False)
326
+ self.v_proj = nn.Linear(self.vision_hidden_size, self.num_heads * self.head_dim, bias=False)
327
+ self.norm2 = LlamaRMSNorm(self.vision_hidden_size, eps=config.rms_norm_eps)
328
+
329
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
330
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
331
+
332
+ def forward(
333
+ self,
334
+ hidden_states: torch.Tensor,
335
+ vision_hidden_states: torch.Tensor,
336
+ repeat_time: int = 1,
337
+ attention_mask: Optional[torch.Tensor] = None,
338
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
339
+ output_attentions: bool = False,
340
+ use_cache: bool = False,
341
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
342
+ hidden_states = self.norm1(hidden_states)
343
+
344
+ bsz, q_len, _ = hidden_states.size()
345
+
346
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
347
+
348
+ vision_hidden_states = self.norm2(vision_hidden_states)
349
+
350
+ bs_v, kv_len, _ = vision_hidden_states.size()
351
+
352
+ key_states = self.k_proj(vision_hidden_states).view(
353
+ bs_v, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
354
+ value_states = self.v_proj(vision_hidden_states).view(
355
+ bs_v, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
356
+
357
+ key_states = key_states.repeat(repeat_time, 1, 1, 1)
358
+ value_states = value_states.repeat(repeat_time, 1, 1, 1)
359
+
360
+ kv_seq_len = key_states.shape[-2]
361
+ if past_key_value is not None:
362
+ kv_seq_len += past_key_value[0].shape[-2]
363
+
364
+ if past_key_value is not None:
365
+ # reuse k, v, self_attention
366
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
367
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
368
+
369
+ past_key_value = (key_states, value_states) if use_cache else None
370
+
371
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
372
+
373
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
374
+ raise ValueError(
375
+ f'Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is'
376
+ f' {attn_weights.size()}'
377
+ )
378
+
379
+ if attention_mask is not None:
380
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
381
+ raise ValueError(
382
+ f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
383
+ )
384
+ attn_weights = attn_weights + attention_mask
385
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
386
+
387
+ # upcast attention to fp32
388
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
389
+ attn_output = torch.matmul(attn_weights, value_states)
390
+
391
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
392
+ raise ValueError(
393
+ f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
394
+ f' {attn_output.size()}'
395
+ )
396
+
397
+ attn_output = attn_output.transpose(1, 2)
398
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
399
+
400
+ attn_output = self.o_proj(attn_output)
401
+
402
+ if not output_attentions:
403
+ attn_weights = None
404
+
405
+ return attn_output, attn_weights, past_key_value
406
+
407
+
408
+ class LlamaDecoderLayer(nn.Module):
409
+ def __init__(self, config: LlamaConfig, use_cross_attn: bool):
410
+ super().__init__()
411
+ self.hidden_size = config.hidden_size
412
+ self.self_attn = LlamaAttention(config=config)
413
+ self.cross_attn = LlamaCrossAttention(config=config) if use_cross_attn else None
414
+ self.mlp = LlamaMLP(
415
+ hidden_size=self.hidden_size,
416
+ intermediate_size=config.intermediate_size,
417
+ hidden_act=config.hidden_act,
418
+ )
419
+ self.num_query_token = 96
420
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
421
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
422
+
423
+ def forward(
424
+ self,
425
+ hidden_states: torch.Tensor,
426
+ vision_hidden_states: torch.Tensor,
427
+ attention_mask: Optional[torch.Tensor] = None,
428
+ position_ids: Optional[torch.LongTensor] = None,
429
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
430
+ output_attentions: Optional[bool] = False,
431
+ use_cache: Optional[bool] = False,
432
+ repeat_time: int = 1,
433
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
434
+ """
435
+ Args:
436
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
437
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
438
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
439
+ output_attentions (`bool`, *optional*):
440
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
441
+ returned tensors for more detail.
442
+ use_cache (`bool`, *optional*):
443
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
444
+ (see `past_key_values`).
445
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
446
+ """
447
+
448
+ residual = hidden_states
449
+
450
+ hidden_states = self.input_layernorm(hidden_states)
451
+
452
+ # Self Attention
453
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
454
+ hidden_states=hidden_states,
455
+ attention_mask=attention_mask,
456
+ position_ids=position_ids,
457
+ past_key_value=past_key_value,
458
+ output_attentions=output_attentions,
459
+ use_cache=use_cache,
460
+ )
461
+ hidden_states = residual + hidden_states
462
+
463
+ # when using generate function and cache mode, the size of hidden_states is 1,
464
+ # so we should not use cross attention
465
+ if self.cross_attn is not None and hidden_states.size(1) >= self.num_query_token \
466
+ and vision_hidden_states is not None:
467
+ query_feats = hidden_states[:, :self.num_query_token, :]
468
+ text_feats = hidden_states[:, self.num_query_token:, :]
469
+ residual = query_feats
470
+ query_feats, _, _ = self.cross_attn(
471
+ hidden_states=query_feats,
472
+ vision_hidden_states=vision_hidden_states,
473
+ attention_mask=None, # not use attention mask in cross attention
474
+ past_key_value=past_key_value,
475
+ output_attentions=output_attentions,
476
+ use_cache=use_cache,
477
+ repeat_time=repeat_time,
478
+ )
479
+ query_feats = residual + query_feats
480
+ hidden_states = torch.cat([query_feats, text_feats], dim=1)
481
+
482
+ # Fully Connected
483
+ residual = hidden_states
484
+ hidden_states = self.post_attention_layernorm(hidden_states)
485
+ hidden_states = self.mlp(hidden_states)
486
+ hidden_states = residual + hidden_states
487
+
488
+ outputs = (hidden_states,)
489
+
490
+ if output_attentions:
491
+ outputs += (self_attn_weights,)
492
+
493
+ if use_cache:
494
+ outputs += (present_key_value,)
495
+
496
+ return outputs
497
+
498
+
499
+ LLAMA_START_DOCSTRING = r"""
500
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
501
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
502
+ etc.)
503
+
504
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
505
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
506
+ and behavior.
507
+
508
+ Parameters:
509
+ config ([`LlamaConfig`]):
510
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
511
+ load the weights associated with the model, only the configuration. Check out the
512
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
513
+ """
514
+
515
+
516
+ @add_start_docstrings(
517
+ 'The bare LLaMA Model outputting raw hidden-states without any specific head on top.',
518
+ LLAMA_START_DOCSTRING,
519
+ )
520
+ class LlamaPreTrainedModel(PreTrainedModel):
521
+ config_class = LlamaConfig
522
+ base_model_prefix = 'model'
523
+ supports_gradient_checkpointing = True
524
+ _no_split_modules = ['LlamaDecoderLayer']
525
+ _keys_to_ignore_on_load_unexpected = [r'decoder\.version']
526
+
527
+ def _init_weights(self, module):
528
+ std = self.config.initializer_range
529
+ if isinstance(module, nn.Linear):
530
+ module.weight.data.normal_(mean=0.0, std=std)
531
+ if module.bias is not None:
532
+ module.bias.data.zero_()
533
+ elif isinstance(module, nn.Embedding):
534
+ module.weight.data.normal_(mean=0.0, std=std)
535
+ if module.padding_idx is not None:
536
+ module.weight.data[module.padding_idx].zero_()
537
+
538
+ def _set_gradient_checkpointing(self, module, value=False):
539
+ if isinstance(module, LlamaModel):
540
+ module.gradient_checkpointing = value
541
+ if isinstance(module, LlamaDecoderLayer):
542
+ module.gradient_checkpointing = value
543
+
544
+
545
+ LLAMA_INPUTS_DOCSTRING = r"""
546
+ Args:
547
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
548
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
549
+ it.
550
+
551
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
552
+ [`PreTrainedTokenizer.__call__`] for details.
553
+
554
+ [What are input IDs?](../glossary#input-ids)
555
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
556
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
557
+
558
+ - 1 for tokens that are **not masked**,
559
+ - 0 for tokens that are **masked**.
560
+
561
+ [What are attention masks?](../glossary#attention-mask)
562
+
563
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
564
+ [`PreTrainedTokenizer.__call__`] for details.
565
+
566
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
567
+ `past_key_values`).
568
+
569
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
570
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
571
+ information on the default strategy.
572
+
573
+ - 1 indicates the head is **not masked**,
574
+ - 0 indicates the head is **masked**.
575
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
576
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
577
+ config.n_positions - 1]`.
578
+
579
+ [What are position IDs?](../glossary#position-ids)
580
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
581
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
582
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
583
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
584
+
585
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
586
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
587
+
588
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
589
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
590
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
591
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
592
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
593
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
594
+ model's internal embedding lookup matrix.
595
+ use_cache (`bool`, *optional*):
596
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
597
+ `past_key_values`).
598
+ output_attentions (`bool`, *optional*):
599
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
600
+ tensors for more detail.
601
+ output_hidden_states (`bool`, *optional*):
602
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
603
+ more detail.
604
+ return_dict (`bool`, *optional*):
605
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
606
+ """
607
+
608
+
609
+ @add_start_docstrings(
610
+ 'The bare LLaMA Model outputting raw hidden-states without any specific head on top.',
611
+ LLAMA_START_DOCSTRING,
612
+ )
613
+ class LlamaModel(LlamaPreTrainedModel):
614
+ """
615
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
616
+
617
+ Args:
618
+ config: LlamaConfig
619
+ """
620
+
621
+ def __init__(self, config: LlamaConfig):
622
+ super().__init__(config)
623
+ self.padding_idx = config.pad_token_id
624
+ self.vocab_size = config.vocab_size
625
+ self.cross_attention_frequency = config.cross_attention_frequency
626
+ self.num_query_token = config.num_query_token
627
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
628
+ use_cross_attn = [idx % self.cross_attention_frequency == 0 for idx in range(config.num_hidden_layers)]
629
+ self.layers = nn.ModuleList(
630
+ [LlamaDecoderLayer(config, use_cross_attn[idx]) for idx in range(config.num_hidden_layers)])
631
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
632
+ self.gradient_checkpointing = False
633
+ # Initialize weights and apply final processing
634
+ # self.post_init()
635
+
636
+ def get_input_embeddings(self):
637
+ return self.embed_tokens
638
+
639
+ def set_input_embeddings(self, value):
640
+ self.embed_tokens = value
641
+
642
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
643
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
644
+ # create causal mask
645
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
646
+ combined_attention_mask = None
647
+ if input_shape[-1] > 1:
648
+ combined_attention_mask = _make_causal_mask(
649
+ input_shape,
650
+ inputs_embeds.dtype,
651
+ device=inputs_embeds.device,
652
+ past_key_values_length=past_key_values_length,
653
+ )
654
+
655
+ if attention_mask is not None:
656
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
657
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
658
+ inputs_embeds.device
659
+ )
660
+ combined_attention_mask = (
661
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
662
+ )
663
+
664
+ return combined_attention_mask
665
+
666
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
667
+ def forward(
668
+ self,
669
+ input_ids: torch.LongTensor = None,
670
+ attention_mask: Optional[torch.Tensor] = None,
671
+ position_ids: Optional[torch.LongTensor] = None,
672
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
673
+ inputs_embeds: Optional[torch.FloatTensor] = None,
674
+ vision_hidden_states: Optional[torch.FloatTensor] = None,
675
+ repeat_time: Optional[int] = 1,
676
+ use_cache: Optional[bool] = None,
677
+ output_attentions: Optional[bool] = None,
678
+ output_hidden_states: Optional[bool] = None,
679
+ use_zero_attention_mask: Optional[bool] = None,
680
+ return_dict: Optional[bool] = None,
681
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
682
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
683
+ output_hidden_states = (
684
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
685
+ )
686
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
687
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
688
+
689
+ # retrieve input_ids and inputs_embeds
690
+ if input_ids is not None and inputs_embeds is not None:
691
+ raise ValueError('You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time')
692
+ elif input_ids is not None:
693
+ batch_size, seq_length = input_ids.shape
694
+ elif inputs_embeds is not None:
695
+ batch_size, seq_length, _ = inputs_embeds.shape
696
+ else:
697
+ raise ValueError('You have to specify either decoder_input_ids or decoder_inputs_embeds')
698
+ seq_length_with_past = seq_length
699
+ past_key_values_length = 0
700
+
701
+ if past_key_values is not None:
702
+ past_key_values_length = past_key_values[0][0].shape[2]
703
+ seq_length_with_past = seq_length_with_past + past_key_values_length
704
+
705
+ if position_ids is None:
706
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
707
+ position_ids = torch.arange(
708
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
709
+ )
710
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
711
+ else:
712
+ position_ids = position_ids.view(-1, seq_length).long()
713
+
714
+ if inputs_embeds is None:
715
+ inputs_embeds = self.embed_tokens(input_ids)
716
+ # embed positions
717
+ if attention_mask is None:
718
+ attention_mask = torch.ones(
719
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
720
+ )
721
+ attention_mask = self._prepare_decoder_attention_mask(
722
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
723
+ )
724
+ if use_zero_attention_mask:
725
+ attention_mask[:, :, :self.num_query_token, :self.num_query_token] = 0
726
+
727
+ hidden_states = inputs_embeds
728
+
729
+ if self.gradient_checkpointing and self.training:
730
+ if use_cache:
731
+ logger.warning_once(
732
+ '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
733
+ )
734
+ use_cache = False
735
+
736
+ # decoder layers
737
+ all_hidden_states = () if output_hidden_states else None
738
+ all_self_attns = () if output_attentions else None
739
+ next_decoder_cache = () if use_cache else None
740
+
741
+ for idx, decoder_layer in enumerate(self.layers):
742
+ if output_hidden_states:
743
+ all_hidden_states += (hidden_states,)
744
+
745
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
746
+
747
+ layer_outputs = decoder_layer(
748
+ hidden_states,
749
+ vision_hidden_states,
750
+ attention_mask=attention_mask,
751
+ position_ids=position_ids,
752
+ past_key_value=past_key_value,
753
+ output_attentions=output_attentions,
754
+ use_cache=use_cache,
755
+ repeat_time=repeat_time,
756
+ )
757
+
758
+ hidden_states = layer_outputs[0]
759
+
760
+ if use_cache:
761
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
762
+
763
+ if output_attentions:
764
+ all_self_attns += (layer_outputs[1],)
765
+
766
+ hidden_states = self.norm(hidden_states)
767
+
768
+ # add hidden states from the last decoder layer
769
+ if output_hidden_states:
770
+ all_hidden_states += (hidden_states,)
771
+
772
+ next_cache = next_decoder_cache if use_cache else None
773
+ if not return_dict:
774
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
775
+ return BaseModelOutputWithPast(
776
+ last_hidden_state=hidden_states,
777
+ past_key_values=next_cache,
778
+ hidden_states=all_hidden_states,
779
+ attentions=all_self_attns,
780
+ )
781
+
782
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
783
+ def forward_train(
784
+ self,
785
+ input_ids: torch.LongTensor = None,
786
+ attention_mask: Optional[torch.Tensor] = None,
787
+ position_ids: Optional[torch.LongTensor] = None,
788
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
789
+ inputs_embeds: Optional[torch.FloatTensor] = None,
790
+ vision_hidden_states: Optional[torch.FloatTensor] = None,
791
+ repeat_time: Optional[int] = 1,
792
+ use_cache: Optional[bool] = None,
793
+ output_attentions: Optional[bool] = None,
794
+ output_hidden_states: Optional[bool] = None,
795
+ return_dict: Optional[bool] = None,
796
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
797
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
798
+ output_hidden_states = (
799
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
800
+ )
801
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
802
+
803
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
804
+
805
+ # retrieve input_ids and inputs_embeds
806
+ if input_ids is not None and inputs_embeds is not None:
807
+ raise ValueError('You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time')
808
+ elif input_ids is not None:
809
+ batch_size, seq_length = input_ids.shape
810
+ elif inputs_embeds is not None:
811
+ batch_size, seq_length, _ = inputs_embeds.shape
812
+ else:
813
+ raise ValueError('You have to specify either decoder_input_ids or decoder_inputs_embeds')
814
+
815
+ seq_length_with_past = seq_length
816
+ past_key_values_length = 0
817
+
818
+ if past_key_values is not None:
819
+ past_key_values_length = past_key_values[0][0].shape[2]
820
+ seq_length_with_past = seq_length_with_past + past_key_values_length
821
+
822
+ if position_ids is None:
823
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
824
+ position_ids = torch.arange(
825
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
826
+ )
827
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
828
+ else:
829
+ position_ids = position_ids.view(-1, seq_length).long()
830
+
831
+ if inputs_embeds is None:
832
+ inputs_embeds = self.embed_tokens(input_ids)
833
+ # embed positions
834
+ # if attention_mask is None:
835
+ # attention_mask = torch.ones(
836
+ # (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
837
+ # )
838
+ # attention_mask = self._prepare_decoder_attention_mask(
839
+ # attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
840
+ # )
841
+ hidden_states = inputs_embeds
842
+
843
+ if self.gradient_checkpointing and self.training:
844
+ if use_cache:
845
+ logger.warning_once(
846
+ '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
847
+ )
848
+ use_cache = False
849
+
850
+ # decoder layers
851
+ all_hidden_states = () if output_hidden_states else None
852
+ all_self_attns = () if output_attentions else None
853
+ next_decoder_cache = () if use_cache else None
854
+
855
+ for idx, decoder_layer in enumerate(self.layers):
856
+ if output_hidden_states:
857
+ all_hidden_states += (hidden_states,)
858
+
859
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
860
+
861
+ if self.gradient_checkpointing and self.training:
862
+
863
+ def create_custom_forward(module):
864
+ def custom_forward(*inputs):
865
+ # None for past_key_value
866
+ return module(*inputs, output_attentions, None, repeat_time)
867
+
868
+ return custom_forward
869
+
870
+ layer_outputs = torch.utils.checkpoint.checkpoint(
871
+ create_custom_forward(decoder_layer),
872
+ hidden_states,
873
+ vision_hidden_states,
874
+ attention_mask,
875
+ position_ids,
876
+ None,
877
+ )
878
+ else:
879
+ layer_outputs = decoder_layer(
880
+ hidden_states,
881
+ vision_hidden_states,
882
+ attention_mask=attention_mask,
883
+ position_ids=position_ids,
884
+ past_key_value=past_key_value,
885
+ output_attentions=output_attentions,
886
+ use_cache=use_cache,
887
+ repeat_time=repeat_time,
888
+ )
889
+
890
+ hidden_states = layer_outputs[0]
891
+
892
+ if use_cache:
893
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
894
+
895
+ if output_attentions:
896
+ all_self_attns += (layer_outputs[1],)
897
+
898
+ hidden_states = self.norm(hidden_states)
899
+
900
+ # add hidden states from the last decoder layer
901
+ if output_hidden_states:
902
+ all_hidden_states += (hidden_states,)
903
+
904
+ next_cache = next_decoder_cache if use_cache else None
905
+ if not return_dict:
906
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
907
+ return BaseModelOutputWithPast(
908
+ last_hidden_state=hidden_states,
909
+ past_key_values=next_cache,
910
+ hidden_states=all_hidden_states,
911
+ attentions=all_self_attns,
912
+ )
913
+
914
+
915
+ class LlamaForCausalLM(LlamaPreTrainedModel):
916
+ def __init__(self, config):
917
+ super().__init__(config)
918
+ self.model = LlamaModel(config)
919
+
920
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
921
+
922
+ # Initialize weights and apply final processing
923
+ # self.post_init()
924
+
925
+ def get_input_embeddings(self):
926
+ return self.model.embed_tokens
927
+
928
+ def set_input_embeddings(self, value):
929
+ self.model.embed_tokens = value
930
+
931
+ def get_output_embeddings(self):
932
+ return self.lm_head
933
+
934
+ def set_output_embeddings(self, new_embeddings):
935
+ self.lm_head = new_embeddings
936
+
937
+ def set_decoder(self, decoder):
938
+ self.model = decoder
939
+
940
+ def get_decoder(self):
941
+ return self.model
942
+
943
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
944
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
945
+ def forward(
946
+ self,
947
+ input_ids: torch.LongTensor = None,
948
+ attention_mask: Optional[torch.Tensor] = None,
949
+ position_ids: Optional[torch.LongTensor] = None,
950
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
951
+ inputs_embeds: Optional[torch.FloatTensor] = None,
952
+ vision_hidden_states: Optional[torch.FloatTensor] = None,
953
+ labels: Optional[torch.LongTensor] = None,
954
+ use_cache: Optional[bool] = None,
955
+ output_attentions: Optional[bool] = None,
956
+ output_hidden_states: Optional[bool] = None,
957
+ use_zero_attention_mask: Optional[bool] = None,
958
+ return_dict: Optional[bool] = None,
959
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
960
+ r"""
961
+ Args:
962
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
963
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
964
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
965
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
966
+
967
+ Returns:
968
+
969
+ Example:
970
+
971
+ ```python
972
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
973
+
974
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
975
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
976
+
977
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
978
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
979
+
980
+ >>> # Generate
981
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
982
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
983
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
984
+ ```"""
985
+
986
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
987
+ output_hidden_states = (
988
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
989
+ )
990
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
991
+
992
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
993
+ outputs = self.model(
994
+ input_ids=input_ids,
995
+ attention_mask=attention_mask,
996
+ position_ids=position_ids,
997
+ past_key_values=past_key_values,
998
+ inputs_embeds=inputs_embeds,
999
+ vision_hidden_states=vision_hidden_states,
1000
+ use_cache=use_cache,
1001
+ output_attentions=output_attentions,
1002
+ output_hidden_states=output_hidden_states,
1003
+ return_dict=return_dict,
1004
+ use_zero_attention_mask=use_zero_attention_mask,
1005
+ )
1006
+
1007
+ hidden_states = outputs[0]
1008
+ logits = self.lm_head(hidden_states)
1009
+
1010
+ loss = None
1011
+ if labels is not None:
1012
+ # Shift so that tokens < n predict n
1013
+ shift_logits = logits[..., :-1, :].contiguous()
1014
+ shift_labels = labels[..., 1:].contiguous()
1015
+ # Flatten the tokens
1016
+ loss_fct = CrossEntropyLoss()
1017
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1018
+ shift_labels = shift_labels.view(-1)
1019
+ # Enable model parallelism
1020
+ shift_labels = shift_labels.to(shift_logits.device)
1021
+ loss = loss_fct(shift_logits, shift_labels)
1022
+
1023
+ if not return_dict:
1024
+ output = (logits,) + outputs[1:]
1025
+ return (loss,) + output if loss is not None else output
1026
+
1027
+ return CausalLMOutputWithPast(
1028
+ loss=loss,
1029
+ logits=logits,
1030
+ past_key_values=outputs.past_key_values,
1031
+ hidden_states=outputs.hidden_states,
1032
+ attentions=outputs.attentions,
1033
+ )
1034
+
1035
+ def prepare_inputs_for_generation(
1036
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None,
1037
+ vision_hidden_states=None, use_zero_attention_mask=None, **kwargs
1038
+ ):
1039
+ if past_key_values:
1040
+ input_ids = input_ids[:, -1:]
1041
+
1042
+ position_ids = kwargs.get('position_ids', None)
1043
+ if attention_mask is not None and position_ids is None:
1044
+ # create position_ids on the fly for batch generation
1045
+ position_ids = attention_mask.long().cumsum(-1) - 1
1046
+ position_ids.masked_fill_(attention_mask == 0, 1)
1047
+ if past_key_values:
1048
+ position_ids = position_ids[:, -1].unsqueeze(-1)
1049
+
1050
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1051
+ if inputs_embeds is not None and past_key_values is None:
1052
+ model_inputs = {'inputs_embeds': inputs_embeds}
1053
+ else:
1054
+ model_inputs = {'input_ids': input_ids}
1055
+
1056
+ model_inputs.update(
1057
+ {
1058
+ 'position_ids': position_ids,
1059
+ 'past_key_values': past_key_values,
1060
+ 'use_cache': kwargs.get('use_cache'),
1061
+ 'attention_mask': attention_mask,
1062
+ 'vision_hidden_states': vision_hidden_states,
1063
+ 'use_zero_attention_mask': use_zero_attention_mask,
1064
+ }
1065
+ )
1066
+ return model_inputs
1067
+
1068
+ @staticmethod
1069
+ def _reorder_cache(past_key_values, beam_idx):
1070
+ reordered_past = ()
1071
+ for layer_past in past_key_values:
1072
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1073
+ return reordered_past
InternVL/internvl_g/internvl/train/__init__.py ADDED
File without changes
InternVL/internvl_g/internvl/train/dataset.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ import re
4
+ from typing import Dict
5
+
6
+ import torch
7
+ import torchvision.transforms as T
8
+ from PIL import Image
9
+ from torch.utils.data import Dataset
10
+ from torchvision.transforms.functional import InterpolationMode
11
+
12
+
13
+ def build_transform(input_size):
14
+ # match fine-tune setting with blip2
15
+ # https://github.com/salesforce/LAVIS/blob/main/lavis/processors/blip_processors.py
16
+ transform = T.Compose([
17
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
18
+ T.RandomResizedCrop(input_size, scale=(0.5, 1.0),
19
+ interpolation=InterpolationMode.BICUBIC),
20
+ T.RandomHorizontalFlip(),
21
+ T.ToTensor(),
22
+ T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
23
+ ])
24
+ return transform
25
+
26
+
27
+ class FlickrDataset(Dataset):
28
+ """Dataset for supervised fine-tuning."""
29
+
30
+ def __init__(self, metas, tokenizer, data_args):
31
+ super(FlickrDataset, self).__init__()
32
+
33
+ f = open(metas['annotation'])
34
+ lines = f.readlines()[1:]
35
+
36
+ self.data_args = data_args
37
+ self.tokenizer = tokenizer
38
+ self.images = []
39
+ self.image_ids = []
40
+ self.captions = []
41
+
42
+ for line in lines:
43
+ image, caption = line.strip().split('.jpg,')
44
+ image_id = int(image)
45
+ caption = self.process_single_caption(caption)
46
+ image = image + '.jpg'
47
+ image_path = metas['root'] + '/' + image
48
+ self.images.append(image_path)
49
+ self.image_ids.append(image_id)
50
+ self.captions.append(caption)
51
+ print(f'There are {len(self.images)} images.')
52
+ print(f'There are {len(self.captions)} captions.')
53
+
54
+ def __len__(self):
55
+ return len(self.images)
56
+
57
+ def process_single_caption(self, caption, max_words=50):
58
+ caption = re.sub(r"([.!\"()*#:;~])", ' ', caption.lower())
59
+ caption = re.sub(r'\s{2,}', ' ', caption)
60
+ caption = caption.rstrip('\n')
61
+ caption = caption.strip(' ')
62
+
63
+ # truncate caption
64
+ caption_words = caption.split(' ')
65
+ if len(caption_words) > max_words:
66
+ caption = ' '.join(caption_words[: max_words])
67
+ return caption
68
+
69
+ def preprocess(self, image, caption, neg_caption):
70
+ model_inputs = dict()
71
+
72
+ # input image
73
+ image_transform = build_transform(input_size=self.data_args.force_image_size)
74
+ image = Image.open(image)
75
+ image = image.convert('RGB')
76
+ pixel_values = image_transform(image)
77
+ model_inputs['pixel_values'] = pixel_values
78
+
79
+ # for image-text matching
80
+ pos_model_inputs = self.tokenizer(
81
+ caption,
82
+ max_length=self.data_args.max_seq_length,
83
+ padding='max_length' if self.data_args.pad_to_max_length else False,
84
+ truncation=True,
85
+ return_tensors='pt',
86
+ )
87
+ model_inputs['positive_input_ids'] = pos_model_inputs['input_ids']
88
+ model_inputs['positive_attention_mask'] = pos_model_inputs['attention_mask']
89
+ neg_model_inputs = self.tokenizer(
90
+ neg_caption,
91
+ max_length=self.data_args.max_seq_length,
92
+ padding='max_length' if self.data_args.pad_to_max_length else False,
93
+ truncation=True,
94
+ return_tensors='pt',
95
+ )
96
+ model_inputs['negative_input_ids'] = neg_model_inputs['input_ids']
97
+ model_inputs['negative_attention_mask'] = neg_model_inputs['attention_mask']
98
+
99
+ # for image-text contrastive learning
100
+ summarize_model_inputs = self.tokenizer(
101
+ 'summarize:' + caption,
102
+ max_length=self.data_args.max_seq_length,
103
+ padding='max_length' if self.data_args.pad_to_max_length else False,
104
+ truncation=True,
105
+ return_tensors='pt',
106
+ )
107
+ model_inputs['summarize_input_ids'] = summarize_model_inputs['input_ids']
108
+ model_inputs['summarize_attention_mask'] = summarize_model_inputs['attention_mask']
109
+
110
+ # for image-grounded text generation
111
+ prefix = f'English caption:'
112
+ content = caption
113
+ tokenized_prefix = self.tokenizer(
114
+ prefix, padding=False, truncation=True, return_tensors='pt',
115
+ )
116
+ prefix_input_ids = tokenized_prefix['input_ids'][:, :-1] # remove eos
117
+ prefix_attention_mask = tokenized_prefix['attention_mask'][:, :-1] # remove eos
118
+ tokenized_content = self.tokenizer(
119
+ content,
120
+ max_length=self.data_args.max_seq_length - prefix_input_ids.size(1) + 1,
121
+ padding='max_length' if self.data_args.pad_to_max_length else False,
122
+ truncation=True,
123
+ return_tensors='pt',
124
+ )
125
+ content_input_ids = tokenized_content['input_ids'][:, 1:] # remove bos
126
+ content_attention_mask = tokenized_content['attention_mask'][:, 1:] # remove bos
127
+ model_inputs['input_ids'] = torch.cat([prefix_input_ids, content_input_ids], dim=1)
128
+ model_inputs['attention_mask'] = torch.cat([prefix_attention_mask, content_attention_mask], dim=1)
129
+ labels = model_inputs['input_ids'].clone()
130
+ labels[labels == self.tokenizer.pad_token_id] = -100
131
+ labels[:, :prefix_input_ids.size(1) - 1] = -100
132
+ model_inputs['labels'] = labels
133
+ return model_inputs
134
+
135
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
136
+ i = i % len(self.images)
137
+ j = random.randint(0, len(self.images) - 1)
138
+ while self.image_ids[j] == self.image_ids[i]:
139
+ j = random.randint(0, len(self.images) - 1)
140
+ ret = self.preprocess(self.images[i], self.captions[i], self.captions[j])
141
+ # for image-text matching
142
+ ret['positive_input_ids'] = ret['positive_input_ids'][0]
143
+ ret['positive_attention_mask'] = ret['positive_attention_mask'][0]
144
+ ret['negative_input_ids'] = ret['negative_input_ids'][0]
145
+ ret['negative_attention_mask'] = ret['negative_attention_mask'][0]
146
+ # for image-text contrastive learning
147
+ ret['summarize_input_ids'] = ret['summarize_input_ids'][0]
148
+ ret['summarize_attention_mask'] = ret['summarize_attention_mask'][0]
149
+ # for image-grounded text generation
150
+ ret['input_ids'] = ret['input_ids'][0]
151
+ ret['attention_mask'] = ret['attention_mask'][0]
152
+ ret['labels'] = ret['labels'][0]
153
+ ret['image_ids'] = torch.Tensor([self.image_ids[i]]).long()
154
+ return ret
155
+
156
+
157
+ class COCODataset(Dataset):
158
+ """Dataset for supervised fine-tuning."""
159
+
160
+ def __init__(self, metas, tokenizer, data_args):
161
+ super(COCODataset, self).__init__()
162
+
163
+ annotations = json.load(open(metas['annotation']))
164
+
165
+ self.data_args = data_args
166
+ self.tokenizer = tokenizer
167
+ self.images = []
168
+ self.image_ids = []
169
+ self.captions = []
170
+
171
+ for annotation in annotations:
172
+ image_id = int(annotation['image_id'].split('_')[-1])
173
+ caption = annotation['caption']
174
+ caption = self.process_single_caption(caption)
175
+ image = annotation['image']
176
+ image_path = metas['root'] + '/' + image
177
+ self.images.append(image_path)
178
+ self.image_ids.append(image_id)
179
+ self.captions.append(caption)
180
+ print(f'There are {len(self.images)} images.')
181
+ print(f'There are {len(self.captions)} captions.')
182
+
183
+ def __len__(self):
184
+ return len(self.images)
185
+
186
+ def process_single_caption(self, caption, max_words=50):
187
+ caption = re.sub(r"([.!\"()*#:;~])", ' ', caption.lower())
188
+ caption = re.sub(r'\s{2,}', ' ', caption)
189
+ caption = caption.rstrip('\n')
190
+ caption = caption.strip(' ')
191
+
192
+ # truncate caption
193
+ caption_words = caption.split(' ')
194
+ if len(caption_words) > max_words:
195
+ caption = ' '.join(caption_words[: max_words])
196
+ return caption
197
+
198
+ def preprocess(self, image, caption, neg_caption):
199
+ model_inputs = dict()
200
+
201
+ # input image
202
+ image_transform = build_transform(input_size=self.data_args.force_image_size)
203
+ image = Image.open(image)
204
+ image = image.convert('RGB')
205
+ pixel_values = image_transform(image)
206
+ model_inputs['pixel_values'] = pixel_values
207
+
208
+ # for image-text matching
209
+ pos_model_inputs = self.tokenizer(
210
+ caption,
211
+ max_length=self.data_args.max_seq_length,
212
+ padding='max_length' if self.data_args.pad_to_max_length else False,
213
+ truncation=True,
214
+ return_tensors='pt',
215
+ )
216
+ model_inputs['positive_input_ids'] = pos_model_inputs['input_ids']
217
+ model_inputs['positive_attention_mask'] = pos_model_inputs['attention_mask']
218
+ neg_model_inputs = self.tokenizer(
219
+ neg_caption,
220
+ max_length=self.data_args.max_seq_length,
221
+ padding='max_length' if self.data_args.pad_to_max_length else False,
222
+ truncation=True,
223
+ return_tensors='pt',
224
+ )
225
+ model_inputs['negative_input_ids'] = neg_model_inputs['input_ids']
226
+ model_inputs['negative_attention_mask'] = neg_model_inputs['attention_mask']
227
+
228
+ # for image-text contrastive learning
229
+ summarize_model_inputs = self.tokenizer(
230
+ 'summarize:' + caption,
231
+ max_length=self.data_args.max_seq_length,
232
+ padding='max_length' if self.data_args.pad_to_max_length else False,
233
+ truncation=True,
234
+ return_tensors='pt',
235
+ )
236
+ model_inputs['summarize_input_ids'] = summarize_model_inputs['input_ids']
237
+ model_inputs['summarize_attention_mask'] = summarize_model_inputs['attention_mask']
238
+
239
+ # for image-grounded text generation
240
+ prefix = f'English caption:'
241
+ content = caption
242
+ tokenized_prefix = self.tokenizer(
243
+ prefix, padding=False, truncation=True, return_tensors='pt',
244
+ )
245
+ prefix_input_ids = tokenized_prefix['input_ids'][:, :-1] # remove eos
246
+ prefix_attention_mask = tokenized_prefix['attention_mask'][:, :-1] # remove eos
247
+ tokenized_content = self.tokenizer(
248
+ content,
249
+ max_length=self.data_args.max_seq_length - prefix_input_ids.size(1) + 1,
250
+ padding='max_length' if self.data_args.pad_to_max_length else False,
251
+ truncation=True,
252
+ return_tensors='pt',
253
+ )
254
+ content_input_ids = tokenized_content['input_ids'][:, 1:] # remove bos
255
+ content_attention_mask = tokenized_content['attention_mask'][:, 1:] # remove bos
256
+ model_inputs['input_ids'] = torch.cat([prefix_input_ids, content_input_ids], dim=1)
257
+ model_inputs['attention_mask'] = torch.cat([prefix_attention_mask, content_attention_mask], dim=1)
258
+ labels = model_inputs['input_ids'].clone()
259
+ labels[labels == self.tokenizer.pad_token_id] = -100
260
+ labels[:, :prefix_input_ids.size(1) - 1] = -100
261
+ model_inputs['labels'] = labels
262
+ return model_inputs
263
+
264
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
265
+ i = i % len(self.images)
266
+ j = random.randint(0, len(self.images) - 1)
267
+ while self.image_ids[j] == self.image_ids[i]:
268
+ j = random.randint(0, len(self.images) - 1)
269
+ ret = self.preprocess(self.images[i], self.captions[i], self.captions[j])
270
+ # for image-text matching
271
+ ret['positive_input_ids'] = ret['positive_input_ids'][0]
272
+ ret['positive_attention_mask'] = ret['positive_attention_mask'][0]
273
+ ret['negative_input_ids'] = ret['negative_input_ids'][0]
274
+ ret['negative_attention_mask'] = ret['negative_attention_mask'][0]
275
+ # for image-text contrastive learning
276
+ ret['summarize_input_ids'] = ret['summarize_input_ids'][0]
277
+ ret['summarize_attention_mask'] = ret['summarize_attention_mask'][0]
278
+ # for image-grounded text generation
279
+ ret['input_ids'] = ret['input_ids'][0]
280
+ ret['attention_mask'] = ret['attention_mask'][0]
281
+ ret['labels'] = ret['labels'][0]
282
+ ret['image_ids'] = torch.Tensor([self.image_ids[i]]).long()
283
+ return ret
InternVL/internvl_g/internvl/train/internvl_stage2_finetune.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import sys
4
+ import warnings
5
+ from dataclasses import dataclass, field
6
+ from typing import Dict, Optional
7
+
8
+ import torch.distributed as dist
9
+ import transformers
10
+ from internvl.dist_utils import init_dist
11
+ from internvl.model.internvl_stage2_retrieval import (InternVLConfig,
12
+ InternVLModel)
13
+ from internvl.train.dataset import COCODataset, FlickrDataset
14
+ from internvl.train.trainer_monkey_patch import replace_create_optimizer
15
+ from PIL import Image, ImageFile, PngImagePlugin
16
+ from transformers import (HfArgumentParser, LlamaTokenizer, Trainer,
17
+ TrainingArguments, default_data_collator, set_seed)
18
+ from transformers.trainer_utils import get_last_checkpoint
19
+ from transformers.utils.logging import (enable_default_handler,
20
+ enable_explicit_format, set_verbosity)
21
+
22
+ IGNORE_INDEX = -100
23
+ Image.MAX_IMAGE_PIXELS = None
24
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
25
+ MaximumDecompressedSize = 1024
26
+ MegaByte = 2 ** 20
27
+ PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
28
+
29
+ warnings.filterwarnings('ignore')
30
+ logger = logging.getLogger(__name__)
31
+
32
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
33
+
34
+ ds_collections = {
35
+ 'flickr30k_en_train': {
36
+ 'root': './data/flickr30k/Images/',
37
+ 'annotation': './data/flickr30k/flickr30k_train_karpathy.txt',
38
+ },
39
+ 'flickr30k_cn_train': {
40
+ 'root': './data/flickr30k/Images/',
41
+ 'annotation': './data/flickr30k/flickr30k_cn_train.txt',
42
+ },
43
+ 'coco_karpathy_train': {
44
+ 'root': './data/coco/',
45
+ 'annotation': './data/coco/annotations/coco_karpathy_train.json',
46
+ },
47
+ }
48
+
49
+
50
+ @dataclass
51
+ class ModelArguments:
52
+ """
53
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
54
+ """
55
+ model_name_or_path: str = field(
56
+ metadata={'help': 'Path to pretrained model or model identifier from huggingface.co/models'}
57
+ )
58
+ freeze_model: bool = field(
59
+ default=False,
60
+ metadata={'help': 'Set to True to freeze the entire model.'},
61
+ )
62
+ freeze_vision_model: bool = field(
63
+ default=False,
64
+ metadata={'help': 'Set to True to freeze the vision backbone of the model.'},
65
+ )
66
+ freeze_qllama: bool = field(
67
+ default=False,
68
+ metadata={'help': 'Set to True to freeze the QLLaMA of the model.'},
69
+ )
70
+ unfreeze_qllama_head: bool = field(
71
+ default=False,
72
+ metadata={'help': 'Set to True to unfreeze the head of the QLLaMA.'},
73
+ )
74
+ unfreeze_crossattn: bool = field(
75
+ default=False,
76
+ metadata={'help': 'Set to True to unfreeze the cross attention layers in the QLLaMA.'},
77
+ )
78
+ use_backbone_lora: int = field(
79
+ default=0, metadata={'help': 'If non-zero, indicates the use of LoRA in the vision backbone of the model'}
80
+ )
81
+ use_qllama_lora: int = field(
82
+ default=0, metadata={'help': 'If non-zero, indicates the use of LoRA in the QLLaMA of the model'}
83
+ )
84
+ use_custom_trainer: bool = field(
85
+ default=False, metadata={'help': 'Set to True to enable the use of a custom trainer.'},
86
+ )
87
+ drop_path_rate: float = field(
88
+ default=0.0, metadata={'help': 'Specify the value of drop path rate in the vision backbone. Default is 0.'}
89
+ )
90
+
91
+
92
+ @dataclass
93
+ class DataTrainingArguments:
94
+ """
95
+ Arguments pertaining to what data we are going to input our model for training and eval.
96
+ """
97
+ dataset_name: Optional[str] = field(
98
+ default='flickr30k_en_train',
99
+ metadata={'help': 'Specify the name of dataset to be used.'},
100
+ )
101
+ max_seq_length: Optional[int] = field(
102
+ default=80,
103
+ metadata={
104
+ 'help': (
105
+ 'The maximum total input sequence length after tokenization. Sequences longer '
106
+ 'than this will be truncated, sequences shorter will be padded.'
107
+ )
108
+ },
109
+ )
110
+ force_image_size: Optional[int] = field(
111
+ default=224,
112
+ metadata={'help': 'Specify the image size for training models.'},
113
+ )
114
+ pad_to_max_length: bool = field(
115
+ default=False,
116
+ metadata={
117
+ 'help': (
118
+ 'Whether to pad all samples to model maximum sentence length. '
119
+ 'If False, will pad the samples dynamically when batching to the maximum length in the batch. More '
120
+ 'efficient on GPU but very bad for TPU.'
121
+ )
122
+ },
123
+ )
124
+
125
+
126
+ def main():
127
+ # Parse input arguments
128
+ # See all possible arguments in src/transformers/training_args.py
129
+ # If use DeepSpeed zero3, init_dist must before HfArgumentParser
130
+ launcher = os.environ.get('LAUNCHER', 'slurm')
131
+ init_dist(launcher=launcher, backend='nccl')
132
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
133
+ if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
134
+ # If we pass only one argument to the script, and it's the path to a json file,
135
+ # let's parse it to get our arguments.
136
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
137
+ else:
138
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
139
+
140
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
141
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
142
+ # send_example_telemetry('finetune Flickr30K', model_args, data_args)
143
+
144
+ # Setup logging
145
+ logging.basicConfig(
146
+ format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
147
+ datefmt='%m/%d/%Y %H:%M:%S',
148
+ handlers=[logging.StreamHandler(sys.stdout)],
149
+ )
150
+
151
+ if training_args.should_log:
152
+ # The default of training_args.log_level is passive, so we set log level at info here to have that default.
153
+ transformers.utils.logging.set_verbosity_info()
154
+
155
+ log_level = training_args.get_process_log_level()
156
+ logger.setLevel(log_level)
157
+ set_verbosity(log_level)
158
+ enable_default_handler()
159
+ enable_explicit_format()
160
+
161
+ # Log on each process the small summary:
162
+ logger.warning(
163
+ f'Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}'
164
+ + f'distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}'
165
+ )
166
+ logger.info(f'Training/evaluation parameters {training_args}')
167
+
168
+ # Detecting last checkpoint and eventually continue from last checkpoint.
169
+ last_checkpoint = None
170
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
171
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
172
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
173
+ raise ValueError(
174
+ f'Output directory ({training_args.output_dir}) already exists and is not empty. '
175
+ 'Use --overwrite_output_dir to overcome.'
176
+ )
177
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
178
+ logger.info(
179
+ f'Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change '
180
+ 'the `--output_dir` or add `--overwrite_output_dir` to train from scratch.'
181
+ )
182
+ # Set seed before initializing model.
183
+ set_seed(training_args.seed)
184
+
185
+ # Load pretrained model, tokenizer, and image processor
186
+ tokenizer = LlamaTokenizer.from_pretrained(
187
+ model_args.model_name_or_path,
188
+ add_eos_token=True
189
+ )
190
+
191
+ if 'flickr' in data_args.dataset_name:
192
+ train_dataset = FlickrDataset(metas=ds_collections[data_args.dataset_name],
193
+ tokenizer=tokenizer, data_args=data_args)
194
+ elif 'coco' in data_args.dataset_name:
195
+ train_dataset = COCODataset(metas=ds_collections[data_args.dataset_name],
196
+ tokenizer=tokenizer, data_args=data_args)
197
+ config = InternVLConfig.from_pretrained(model_args.model_name_or_path)
198
+ config.vision_config.drop_path_rate = model_args.drop_path_rate
199
+ model = InternVLModel.from_pretrained(
200
+ model_args.model_name_or_path,
201
+ # ignore_mismatched_sizes=True,
202
+ config=config
203
+ )
204
+ if data_args.force_image_size != 224:
205
+ model.config.force_image_size = data_args.force_image_size
206
+ model.vision_model.resize_pos_embeddings(old_size=224, new_size=data_args.force_image_size, patch_size=14)
207
+
208
+ model.config.use_cache = False
209
+ model.config.qllama_config.use_cache = False
210
+ model.qllama.gradient_checkpointing = True
211
+ model.qllama.model.gradient_checkpointing = True
212
+ model.vision_model.gradient_checkpointing = True
213
+ model.vision_model.encoder.gradient_checkpointing = True
214
+
215
+ def _freeze_params(module):
216
+ for param in module.parameters():
217
+ param.requires_grad = False
218
+
219
+ if model_args.freeze_model:
220
+ _freeze_params(model)
221
+
222
+ if model_args.freeze_vision_model:
223
+ model.vision_model = model.vision_model.eval()
224
+ _freeze_params(model.vision_model)
225
+
226
+ if model_args.freeze_qllama:
227
+ model.qllama = model.qllama.eval()
228
+ _freeze_params(model.qllama)
229
+
230
+ if model_args.use_backbone_lora:
231
+ model.wrap_backbone_lora(r=model_args.use_backbone_lora, lora_alpha=model_args.use_backbone_lora * 2)
232
+ model.config.use_backbone_lora = model_args.use_backbone_lora
233
+
234
+ if model_args.use_qllama_lora:
235
+ model.wrap_qllama_lora(r=model_args.use_qllama_lora, lora_alpha=model_args.use_backbone_lora * 2)
236
+ model.config.use_qllama_lora = model_args.use_qllama_lora
237
+
238
+ if model_args.unfreeze_crossattn:
239
+ for name, param in model.qllama.named_parameters():
240
+ if 'cross_attn' in name:
241
+ param.requires_grad = True
242
+
243
+ if model_args.unfreeze_qllama_head:
244
+ model.qllama.lm_head.weight.requires_grad = True
245
+ model.text_projection.requires_grad = True
246
+
247
+ # print trainable parameters
248
+ if dist.get_rank() == 0:
249
+ for name, param in model.named_parameters():
250
+ print(name, param.requires_grad)
251
+
252
+ # set seed for torch dataloaders
253
+ set_seed(training_args.seed)
254
+
255
+ # Initialize our Trainer
256
+ if model_args.use_custom_trainer:
257
+ replace_create_optimizer()
258
+
259
+ trainer = Trainer(
260
+ model=model,
261
+ args=training_args,
262
+ train_dataset=train_dataset if training_args.do_train else None,
263
+ eval_dataset=None,
264
+ tokenizer=tokenizer,
265
+ data_collator=default_data_collator,
266
+ )
267
+
268
+ # Training
269
+ if training_args.do_train:
270
+ checkpoint = None
271
+ if training_args.resume_from_checkpoint is not None:
272
+ checkpoint = training_args.resume_from_checkpoint
273
+ elif last_checkpoint is not None:
274
+ checkpoint = last_checkpoint
275
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
276
+ trainer.save_model() # Saves the tokenizer too for easy upload
277
+
278
+ metrics = train_result.metrics
279
+ metrics['train_samples'] = len(train_dataset)
280
+ trainer.log_metrics('train', metrics)
281
+ trainer.save_metrics('train', metrics)
282
+ trainer.save_state()
283
+
284
+
285
+ if __name__ == '__main__':
286
+ main()
InternVL/internvl_g/internvl/train/trainer_monkey_patch.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import transformers
7
+ from transformers import Trainer, logging
8
+ from transformers.trainer import is_sagemaker_mp_enabled
9
+
10
+ logger = logging.get_logger(__name__)
11
+
12
+
13
+ def get_num_layer_for_vit_and_qllama(var_name, vit_num_max_layer, llama_num_max_layer):
14
+ if var_name in ('query_tokens', 'logit_scale',):
15
+ return 0
16
+ if var_name.startswith('clip_projector.'):
17
+ return vit_num_max_layer
18
+ if var_name.startswith('clip_projector2.') or var_name.startswith('itm_head.') or \
19
+ var_name == 'text_projection':
20
+ return llama_num_max_layer
21
+ if var_name.startswith('vision_model.'):
22
+ if 'embeddings.' in var_name:
23
+ return 0
24
+ if 'layers.' in var_name:
25
+ var_name = var_name.split('layers.')[-1]
26
+ layer_id = int(var_name.split('.')[0])
27
+ return layer_id + 1
28
+ if var_name.startswith('qllama.'):
29
+ if 'embed_tokens' in var_name:
30
+ return 0
31
+ if 'layers.' in var_name:
32
+ var_name = var_name.split('layers.')[-1]
33
+ layer_id = int(var_name.split('.')[0])
34
+ return layer_id + 1
35
+ else:
36
+ return llama_num_max_layer
37
+ return 0
38
+
39
+
40
+ def param_classification(name):
41
+ if name in ['query_tokens', 'text_projection', 'logit_scale']:
42
+ return 'qllama'
43
+ elif name.startswith('vision_model.'):
44
+ return 'vit'
45
+ elif name.startswith('qllama.'):
46
+ return 'qllama'
47
+ elif name.startswith('clip_projector.'):
48
+ return 'vit'
49
+ elif name.startswith('clip_projector2.'):
50
+ return 'qllama'
51
+ elif name.startswith('itm_head.'):
52
+ return 'qllama'
53
+ else:
54
+ return 'other'
55
+
56
+
57
+ def create_optimizer(self):
58
+ """
59
+ Setup the optimizer.
60
+
61
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
62
+ Trainer's init through `optimizers`, or subclass and override this method in a subclass.
63
+ """
64
+ opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
65
+
66
+ parameter_groups = {}
67
+ try: # for stage2 model
68
+ vit_num_layers = opt_model.config.vision_config.num_hidden_layers + 2
69
+ qllama_num_layers = opt_model.config.qllama_config.num_hidden_layers + 2
70
+ except: # for stage3 model
71
+ vit_num_layers = opt_model.qllama.config.vision_config.num_hidden_layers + 2
72
+ qllama_num_layers = opt_model.qllama.config.qllama_config.num_hidden_layers + 2
73
+ print('vit_num_layers:', vit_num_layers)
74
+ print('qllama_num_layers:', qllama_num_layers)
75
+
76
+ vit_layer_decay_rate = float(os.getenv('VIT_LAYER_DECAY_RATE', 1.0))
77
+ qllama_layer_decay_rate = float(os.getenv('QLLAMA_LAYER_DECAY_RATE', 1.0))
78
+ print('vit_layer_decay_rate:', vit_layer_decay_rate)
79
+ print('qllama_layer_decay_rate:', qllama_layer_decay_rate)
80
+
81
+ for name, param in opt_model.named_parameters():
82
+ if not param.requires_grad:
83
+ continue # frozen weights
84
+ if len(param.shape) == 1 or name.endswith('.bias'):
85
+ group_name = 'no_decay'
86
+ this_weight_decay = 0.
87
+ else:
88
+ group_name = 'decay'
89
+ this_weight_decay = self.args.weight_decay
90
+
91
+ cls = param_classification(name)
92
+ layer_id = get_num_layer_for_vit_and_qllama(name, vit_num_layers, qllama_num_layers)
93
+ group_name = '%s_layer_%d_%s' % (cls, layer_id, group_name)
94
+ if group_name not in parameter_groups:
95
+ if cls == 'vit':
96
+ scale = vit_layer_decay_rate ** (vit_num_layers - layer_id - 1)
97
+ else:
98
+ scale = qllama_layer_decay_rate ** (qllama_num_layers - layer_id - 1)
99
+ scale = min(1.0, scale)
100
+ parameter_groups[group_name] = {
101
+ 'weight_decay': this_weight_decay,
102
+ 'params': [],
103
+ 'param_names': [],
104
+ 'lr_scale': scale,
105
+ 'group_name': group_name,
106
+ 'lr': scale * self.args.learning_rate,
107
+ }
108
+ parameter_groups[group_name]['params'].append(param)
109
+ parameter_groups[group_name]['param_names'].append(name)
110
+
111
+ rank = torch.distributed.get_rank()
112
+ if rank == 0:
113
+ to_display = {}
114
+ for key in parameter_groups:
115
+ to_display[key] = {
116
+ 'param_names': parameter_groups[key]['param_names'],
117
+ 'lr_scale': parameter_groups[key]['lr_scale'],
118
+ 'lr': parameter_groups[key]['lr'],
119
+ 'weight_decay': parameter_groups[key]['weight_decay'],
120
+ }
121
+ print('Param groups = %s' % json.dumps(to_display, indent=2))
122
+
123
+ optimizer_grouped_parameters = list(parameter_groups.values())
124
+ optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
125
+
126
+ self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
127
+ if optimizer_cls.__name__ == 'Adam8bit':
128
+ import bitsandbytes
129
+
130
+ manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
131
+
132
+ skipped = 0
133
+ for module in opt_model.modules():
134
+ if isinstance(module, nn.Embedding):
135
+ skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
136
+ logger.info(f'skipped {module}: {skipped / 2 ** 20}M params')
137
+ manager.register_module_override(module, 'weight', {'optim_bits': 32})
138
+ logger.debug(f'bitsandbytes: will optimize {module} in fp32')
139
+ logger.info(f'skipped: {skipped / 2 ** 20}M params')
140
+
141
+ if is_sagemaker_mp_enabled():
142
+ import smdistributed.modelparallel.torch as smp
143
+ self.optimizer = smp.DistributedOptimizer(self.optimizer)
144
+
145
+ return self.optimizer
146
+
147
+
148
+ def replace_create_optimizer():
149
+ print('Replace original create_optimizer with custom create_optimizer')
150
+ transformers.Trainer.create_optimizer = create_optimizer
InternVL/internvl_g/shell/finetune/internvl_stage2_finetune_coco_364_bs1024_ep5.sh ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -x
2
+
3
+ export VIT_LAYER_DECAY_RATE=0.9
4
+ export QLLAMA_LAYER_DECAY_RATE=0.9
5
+
6
+ PARTITION=${PARTITION:-"VC2"}
7
+ GPUS=${GPUS:-32}
8
+ GPUS_PER_NODE=${GPUS_PER_NODE:-8}
9
+ QUOTA_TYPE=${QUOTA_TYPE:-"reserved"}
10
+ NODES=$((GPUS / GPUS_PER_NODE))
11
+ CPUS_PER_TASK=${CPUS_PER_TASK:-10}
12
+ SRUN_ARGS=${SRUN_ARGS:-""}
13
+
14
+
15
+ export PYTHONPATH="${PYTHONPATH}:$(pwd)"
16
+
17
+ # number of gpus: 32
18
+ # batch size per gpu: 32
19
+ # gradient accumulation steps: 1
20
+ # total batch size: 1024
21
+ # epoch: 5
22
+ srun -p ${PARTITION} \
23
+ --gres=gpu:${GPUS_PER_NODE} \
24
+ --nodes=${NODES} \
25
+ --ntasks=${GPUS} \
26
+ --ntasks-per-node=${GPUS_PER_NODE} \
27
+ --cpus-per-task=${CPUS_PER_TASK} \
28
+ --kill-on-bad-exit=1 \
29
+ --quotatype=${QUOTA_TYPE} \
30
+ ${SRUN_ARGS} \
31
+ python -u internvl/train/internvl_stage2_finetune.py \
32
+ --dataset_name 'coco_karpathy_train' \
33
+ --model_name_or_path "./pretrained/InternVL-14B-224px" \
34
+ --output_dir "./work_dirs/internvl_stage2_finetune_coco_364_bs1024_ep5" \
35
+ --overwrite_output_dir True \
36
+ --force_image_size 364 \
37
+ --drop_path_rate 0.3 \
38
+ --use_custom_trainer \
39
+ --dataloader_num_workers 2 \
40
+ --pad_to_max_length True \
41
+ --bf16 True \
42
+ --num_train_epochs 5 \
43
+ --per_device_train_batch_size 32 \
44
+ --gradient_accumulation_steps 1 \
45
+ --evaluation_strategy "no" \
46
+ --save_strategy "steps" \
47
+ --save_steps 100 \
48
+ --save_total_limit 5 \
49
+ --learning_rate 1e-6 \
50
+ --weight_decay 0.05 \
51
+ --warmup_steps 100 \
52
+ --lr_scheduler_type "cosine" \
53
+ --logging_steps 1 \
54
+ --max_seq_length 80 \
55
+ --do_train True \
56
+ --optim adamw_torch \
57
+ --deepspeed "zero_stage1_config.json" \
58
+ --report_to "tensorboard"
InternVL/internvl_g/shell/finetune/internvl_stage2_finetune_flickr_364_bs1024_ep10.sh ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -x
2
+
3
+ export VIT_LAYER_DECAY_RATE=0.9
4
+ export QLLAMA_LAYER_DECAY_RATE=0.9
5
+
6
+ PARTITION=${PARTITION:-"VC2"}
7
+ GPUS=${GPUS:-32}
8
+ GPUS_PER_NODE=${GPUS_PER_NODE:-8}
9
+ QUOTA_TYPE=${QUOTA_TYPE:-"reserved"}
10
+ NODES=$((GPUS / GPUS_PER_NODE))
11
+ CPUS_PER_TASK=${CPUS_PER_TASK:-10}
12
+ SRUN_ARGS=${SRUN_ARGS:-""}
13
+
14
+
15
+ export PYTHONPATH="${PYTHONPATH}:$(pwd)"
16
+
17
+ # number of gpus: 32
18
+ # batch size per gpu: 32
19
+ # gradient accumulation steps: 1
20
+ # total batch size: 1024
21
+ # epoch: 10
22
+ srun -p ${PARTITION} \
23
+ --gres=gpu:${GPUS_PER_NODE} \
24
+ --nodes=${NODES} \
25
+ --ntasks=${GPUS} \
26
+ --ntasks-per-node=${GPUS_PER_NODE} \
27
+ --cpus-per-task=${CPUS_PER_TASK} \
28
+ --kill-on-bad-exit=1 \
29
+ --quotatype=${QUOTA_TYPE} \
30
+ ${SRUN_ARGS} \
31
+ python -u internvl/train/internvl_stage2_finetune.py \
32
+ --dataset_name 'flickr30k_en_train' \
33
+ --model_name_or_path "./pretrained/InternVL-14B-224px" \
34
+ --output_dir "./work_dirs/internvl_stage2_finetune_flickr_364_bs1024_ep10" \
35
+ --overwrite_output_dir True \
36
+ --force_image_size 364 \
37
+ --drop_path_rate 0.3 \
38
+ --use_custom_trainer \
39
+ --dataloader_num_workers 2 \
40
+ --pad_to_max_length True \
41
+ --bf16 True \
42
+ --num_train_epochs 10 \
43
+ --per_device_train_batch_size 32 \
44
+ --gradient_accumulation_steps 1 \
45
+ --evaluation_strategy "no" \
46
+ --save_strategy "steps" \
47
+ --save_steps 100 \
48
+ --save_total_limit 5 \
49
+ --learning_rate 1e-6 \
50
+ --weight_decay 0.05 \
51
+ --warmup_steps 100 \
52
+ --lr_scheduler_type "cosine" \
53
+ --logging_steps 1 \
54
+ --max_seq_length 80 \
55
+ --do_train True \
56
+ --optim adamw_torch \
57
+ --deepspeed "zero_stage1_config.json" \
58
+ --report_to "tensorboard"
InternVL/internvl_g/shell/finetune/internvl_stage2_finetune_flickrcn_364_bs1024_ep10.sh ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -x
2
+
3
+ export VIT_LAYER_DECAY_RATE=0.9
4
+ export QLLAMA_LAYER_DECAY_RATE=0.9
5
+
6
+ PARTITION=${PARTITION:-"VC2"}
7
+ GPUS=${GPUS:-32}
8
+ GPUS_PER_NODE=${GPUS_PER_NODE:-8}
9
+ QUOTA_TYPE=${QUOTA_TYPE:-"reserved"}
10
+ NODES=$((GPUS / GPUS_PER_NODE))
11
+ CPUS_PER_TASK=${CPUS_PER_TASK:-10}
12
+ SRUN_ARGS=${SRUN_ARGS:-""}
13
+
14
+
15
+ export PYTHONPATH="${PYTHONPATH}:$(pwd)"
16
+
17
+ # number of gpus: 32
18
+ # batch size per gpu: 32
19
+ # gradient accumulation steps: 1
20
+ # total batch size: 1024
21
+ # epoch: 10
22
+ srun -p ${PARTITION} \
23
+ --gres=gpu:${GPUS_PER_NODE} \
24
+ --nodes=${NODES} \
25
+ --ntasks=${GPUS} \
26
+ --ntasks-per-node=${GPUS_PER_NODE} \
27
+ --cpus-per-task=${CPUS_PER_TASK} \
28
+ --kill-on-bad-exit=1 \
29
+ --quotatype=${QUOTA_TYPE} \
30
+ ${SRUN_ARGS} \
31
+ python -u internvl/train/internvl_stage2_finetune.py \
32
+ --dataset_name 'flickr30k_cn_train' \
33
+ --model_name_or_path "./pretrained/InternVL-14B-224px" \
34
+ --output_dir "./work_dirs/internvl_stage2_finetune_flickrcn_364_bs1024_ep10" \
35
+ --overwrite_output_dir True \
36
+ --force_image_size 364 \
37
+ --drop_path_rate 0.3 \
38
+ --use_custom_trainer \
39
+ --dataloader_num_workers 2 \
40
+ --pad_to_max_length True \
41
+ --bf16 True \
42
+ --num_train_epochs 10 \
43
+ --per_device_train_batch_size 32 \
44
+ --gradient_accumulation_steps 1 \
45
+ --evaluation_strategy "no" \
46
+ --save_strategy "steps" \
47
+ --save_steps 100 \
48
+ --save_total_limit 5 \
49
+ --learning_rate 1e-6 \
50
+ --weight_decay 0.05 \
51
+ --warmup_steps 100 \
52
+ --lr_scheduler_type "cosine" \
53
+ --logging_steps 1 \
54
+ --max_seq_length 80 \
55
+ --do_train True \
56
+ --optim adamw_torch \
57
+ --deepspeed "zero_stage1_config.json" \
58
+ --report_to "tensorboard"
InternVL/internvl_g/shell/head_finetune/internvl_stage2_finetune_coco_224_bs1024_ep5_head_4gpu.sh ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -x
2
+
3
+ GPUS=${GPUS:-4}
4
+ BATCH_SIZE=${BATCH_SIZE:-32}
5
+
6
+
7
+ export PYTHONPATH="${PYTHONPATH}:$(pwd)"
8
+ export MASTER_PORT=34229
9
+ export TF_CPP_MIN_LOG_LEVEL=3
10
+ export LAUNCHER=pytorch
11
+
12
+ OUTPUT_DIR='work_dirs/internvl_stage2_finetune_coco_364_bs1024_ep5_head_4gpu'
13
+
14
+ if [ ! -d "$OUTPUT_DIR" ]; then
15
+ mkdir -p "$OUTPUT_DIR"
16
+ fi
17
+
18
+ # number of gpus: 32
19
+ # batch size per gpu: 32
20
+ # gradient accumulation steps: 1
21
+ # total batch size: 1024
22
+ # epoch: 5
23
+ torchrun \
24
+ --nnodes=1 \
25
+ --node_rank=0 \
26
+ --master_addr=127.0.0.1 \
27
+ --nproc_per_node=${GPUS} \
28
+ --master_port=${MASTER_PORT} \
29
+ internvl/train/internvl_stage2_finetune.py \
30
+ --dataset_name 'coco_karpathy_train' \
31
+ --model_name_or_path "./pretrained/InternVL-14B-224px" \
32
+ --output_dir ${OUTPUT_DIR} \
33
+ --overwrite_output_dir True \
34
+ --freeze_model \
35
+ --freeze_vision_model \
36
+ --freeze_qllama \
37
+ --unfreeze_qllama_head \
38
+ --force_image_size 224 \
39
+ --drop_path_rate 0.0 \
40
+ --dataloader_num_workers 2 \
41
+ --pad_to_max_length True \
42
+ --bf16 True \
43
+ --num_train_epochs 5 \
44
+ --per_device_train_batch_size ${BATCH_SIZE} \
45
+ --gradient_accumulation_steps 1 \
46
+ --evaluation_strategy "no" \
47
+ --save_strategy "steps" \
48
+ --save_steps 100 \
49
+ --save_total_limit 5 \
50
+ --learning_rate 1e-6 \
51
+ --weight_decay 0.05 \
52
+ --warmup_steps 100 \
53
+ --lr_scheduler_type "cosine" \
54
+ --logging_steps 1 \
55
+ --max_seq_length 80 \
56
+ --do_train True \
57
+ --optim adamw_torch \
58
+ --deepspeed "zero_stage3_config.json" \
59
+ --report_to "tensorboard"
InternVL/internvl_g/shell/head_finetune/internvl_stage2_finetune_flickr_224_bs1024_ep10_head_4gpu.sh ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -x
2
+
3
+ GPUS=${GPUS:-4}
4
+ BATCH_SIZE=${BATCH_SIZE:-32}
5
+
6
+
7
+ export PYTHONPATH="${PYTHONPATH}:$(pwd)"
8
+ export MASTER_PORT=34229
9
+ export TF_CPP_MIN_LOG_LEVEL=3
10
+ export LAUNCHER=pytorch
11
+
12
+ OUTPUT_DIR='work_dirs/internvl_stage2_finetune_flickr_364_bs1024_ep10_head_4gpu'
13
+
14
+ if [ ! -d "$OUTPUT_DIR" ]; then
15
+ mkdir -p "$OUTPUT_DIR"
16
+ fi
17
+
18
+ # number of gpus: 32
19
+ # batch size per gpu: 32
20
+ # gradient accumulation steps: 1
21
+ # total batch size: 1024
22
+ # epoch: 10
23
+ torchrun \
24
+ --nnodes=1 \
25
+ --node_rank=0 \
26
+ --master_addr=127.0.0.1 \
27
+ --nproc_per_node=${GPUS} \
28
+ --master_port=${MASTER_PORT} \
29
+ internvl/train/internvl_stage2_finetune.py \
30
+ --dataset_name 'flickr30k_en_train' \
31
+ --model_name_or_path "./pretrained/InternVL-14B-224px" \
32
+ --output_dir ${OUTPUT_DIR} \
33
+ --overwrite_output_dir True \
34
+ --freeze_model \
35
+ --freeze_vision_model \
36
+ --freeze_qllama \
37
+ --unfreeze_qllama_head \
38
+ --force_image_size 224 \
39
+ --drop_path_rate 0.0 \
40
+ --dataloader_num_workers 2 \
41
+ --pad_to_max_length True \
42
+ --bf16 True \
43
+ --num_train_epochs 10 \
44
+ --per_device_train_batch_size ${BATCH_SIZE} \
45
+ --gradient_accumulation_steps 1 \
46
+ --evaluation_strategy "no" \
47
+ --save_strategy "steps" \
48
+ --save_steps 100 \
49
+ --save_total_limit 5 \
50
+ --learning_rate 1e-6 \
51
+ --weight_decay 0.05 \
52
+ --warmup_steps 100 \
53
+ --lr_scheduler_type "cosine" \
54
+ --logging_steps 1 \
55
+ --max_seq_length 80 \
56
+ --do_train True \
57
+ --optim adamw_torch \
58
+ --deepspeed "zero_stage3_config.json" \
59
+ --report_to "tensorboard"
InternVL/internvl_g/shell/head_finetune/internvl_stage2_finetune_flickrcn_224_bs1024_ep10_head_4gpu.sh ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -x
2
+
3
+ GPUS=${GPUS:-4}
4
+ BATCH_SIZE=${BATCH_SIZE:-32}
5
+
6
+
7
+ export PYTHONPATH="${PYTHONPATH}:$(pwd)"
8
+ export MASTER_PORT=34229
9
+ export TF_CPP_MIN_LOG_LEVEL=3
10
+ export LAUNCHER=pytorch
11
+
12
+ OUTPUT_DIR='work_dirs/internvl_stage2_finetune_flickrcn_364_bs1024_ep10_head_4gpu'
13
+
14
+ if [ ! -d "$OUTPUT_DIR" ]; then
15
+ mkdir -p "$OUTPUT_DIR"
16
+ fi
17
+
18
+ # number of gpus: 32
19
+ # batch size per gpu: 32
20
+ # gradient accumulation steps: 1
21
+ # total batch size: 1024
22
+ # epoch: 10
23
+ torchrun \
24
+ --nnodes=1 \
25
+ --node_rank=0 \
26
+ --master_addr=127.0.0.1 \
27
+ --nproc_per_node=${GPUS} \
28
+ --master_port=${MASTER_PORT} \
29
+ internvl/train/internvl_stage2_finetune.py \
30
+ --dataset_name 'flickr30k_cn_train' \
31
+ --model_name_or_path "./pretrained/InternVL-14B-224px" \
32
+ --output_dir ${OUTPUT_DIR} \
33
+ --overwrite_output_dir True \
34
+ --freeze_model \
35
+ --freeze_vision_model \
36
+ --freeze_qllama \
37
+ --unfreeze_qllama_head \
38
+ --force_image_size 224 \
39
+ --drop_path_rate 0.0 \
40
+ --dataloader_num_workers 2 \
41
+ --pad_to_max_length True \
42
+ --bf16 True \
43
+ --num_train_epochs 10 \
44
+ --per_device_train_batch_size ${BATCH_SIZE} \
45
+ --gradient_accumulation_steps 1 \
46
+ --evaluation_strategy "no" \
47
+ --save_strategy "steps" \
48
+ --save_steps 100 \
49
+ --save_total_limit 5 \
50
+ --learning_rate 1e-6 \
51
+ --weight_decay 0.05 \
52
+ --warmup_steps 100 \
53
+ --lr_scheduler_type "cosine" \
54
+ --logging_steps 1 \
55
+ --max_seq_length 80 \
56
+ --do_train True \
57
+ --optim adamw_torch \
58
+ --deepspeed "zero_stage3_config.json" \
59
+ --report_to "tensorboard"
InternVL/internvl_g/shell/lora_finetune/internvl_stage2_finetune_coco_224_bs1024_ep5_lora16_4gpu.sh ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -x
2
+
3
+ GPUS=${GPUS:-4}
4
+ BATCH_SIZE=${BATCH_SIZE:-32}
5
+
6
+
7
+ export PYTHONPATH="${PYTHONPATH}:$(pwd)"
8
+ export MASTER_PORT=34229
9
+ export TF_CPP_MIN_LOG_LEVEL=3
10
+ export LAUNCHER=pytorch
11
+
12
+ OUTPUT_DIR='work_dirs/internvl_stage2_finetune_coco_364_bs1024_ep5_lora_4gpu'
13
+
14
+ if [ ! -d "$OUTPUT_DIR" ]; then
15
+ mkdir -p "$OUTPUT_DIR"
16
+ fi
17
+
18
+ # number of gpus: 32
19
+ # batch size per gpu: 32
20
+ # gradient accumulation steps: 1
21
+ # total batch size: 1024
22
+ # epoch: 5
23
+ torchrun \
24
+ --nnodes=1 \
25
+ --node_rank=0 \
26
+ --master_addr=127.0.0.1 \
27
+ --nproc_per_node=${GPUS} \
28
+ --master_port=${MASTER_PORT} \
29
+ internvl/train/internvl_stage2_finetune.py \
30
+ --dataset_name 'coco_karpathy_train' \
31
+ --model_name_or_path "./pretrained/InternVL-14B-224px" \
32
+ --output_dir ${OUTPUT_DIR} \
33
+ --overwrite_output_dir True \
34
+ --freeze_model \
35
+ --freeze_vision_model \
36
+ --freeze_qllama \
37
+ --unfreeze_qllama_head \
38
+ --use_backbone_lora 16 \
39
+ --use_qllama_lora 16 \
40
+ --force_image_size 224 \
41
+ --drop_path_rate 0.0 \
42
+ --dataloader_num_workers 2 \
43
+ --pad_to_max_length True \
44
+ --bf16 True \
45
+ --num_train_epochs 5 \
46
+ --per_device_train_batch_size ${BATCH_SIZE} \
47
+ --gradient_accumulation_steps 1 \
48
+ --evaluation_strategy "no" \
49
+ --save_strategy "steps" \
50
+ --save_steps 100 \
51
+ --save_total_limit 5 \
52
+ --learning_rate 1e-6 \
53
+ --weight_decay 0.05 \
54
+ --warmup_steps 100 \
55
+ --lr_scheduler_type "cosine" \
56
+ --logging_steps 1 \
57
+ --max_seq_length 80 \
58
+ --do_train True \
59
+ --optim adamw_torch \
60
+ --deepspeed "zero_stage3_config.json" \
61
+ --report_to "tensorboard"
InternVL/internvl_g/shell/lora_finetune/internvl_stage2_finetune_flickr_224_bs1024_ep10_lora16_4gpu.sh ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -x
2
+
3
+ GPUS=${GPUS:-4}
4
+ BATCH_SIZE=${BATCH_SIZE:-32}
5
+
6
+
7
+ export PYTHONPATH="${PYTHONPATH}:$(pwd)"
8
+ export MASTER_PORT=34229
9
+ export TF_CPP_MIN_LOG_LEVEL=3
10
+ export LAUNCHER=pytorch
11
+
12
+ OUTPUT_DIR='work_dirs/internvl_stage2_finetune_flickr_364_bs1024_ep10_lora_4gpu'
13
+
14
+ if [ ! -d "$OUTPUT_DIR" ]; then
15
+ mkdir -p "$OUTPUT_DIR"
16
+ fi
17
+
18
+ # number of gpus: 32
19
+ # batch size per gpu: 32
20
+ # gradient accumulation steps: 1
21
+ # total batch size: 1024
22
+ # epoch: 10
23
+ torchrun \
24
+ --nnodes=1 \
25
+ --node_rank=0 \
26
+ --master_addr=127.0.0.1 \
27
+ --nproc_per_node=${GPUS} \
28
+ --master_port=${MASTER_PORT} \
29
+ internvl/train/internvl_stage2_finetune.py \
30
+ --dataset_name 'flickr30k_en_train' \
31
+ --model_name_or_path "./pretrained/InternVL-14B-224px" \
32
+ --output_dir ${OUTPUT_DIR} \
33
+ --overwrite_output_dir True \
34
+ --freeze_model \
35
+ --freeze_vision_model \
36
+ --freeze_qllama \
37
+ --unfreeze_qllama_head \
38
+ --use_backbone_lora 16 \
39
+ --use_qllama_lora 16 \
40
+ --force_image_size 224 \
41
+ --drop_path_rate 0.0 \
42
+ --dataloader_num_workers 2 \
43
+ --pad_to_max_length True \
44
+ --bf16 True \
45
+ --num_train_epochs 10 \
46
+ --per_device_train_batch_size ${BATCH_SIZE} \
47
+ --gradient_accumulation_steps 1 \
48
+ --evaluation_strategy "no" \
49
+ --save_strategy "steps" \
50
+ --save_steps 100 \
51
+ --save_total_limit 5 \
52
+ --learning_rate 1e-6 \
53
+ --weight_decay 0.05 \
54
+ --warmup_steps 100 \
55
+ --lr_scheduler_type "cosine" \
56
+ --logging_steps 1 \
57
+ --max_seq_length 80 \
58
+ --do_train True \
59
+ --optim adamw_torch \
60
+ --deepspeed "zero_stage3_config.json" \
61
+ --report_to "tensorboard"
InternVL/internvl_g/shell/lora_finetune/internvl_stage2_finetune_flickrcn_224_bs1024_ep10_lora16_4gpu.sh ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -x
2
+
3
+ GPUS=${GPUS:-4}
4
+ BATCH_SIZE=${BATCH_SIZE:-32}
5
+
6
+
7
+ export PYTHONPATH="${PYTHONPATH}:$(pwd)"
8
+ export MASTER_PORT=34229
9
+ export TF_CPP_MIN_LOG_LEVEL=3
10
+ export LAUNCHER=pytorch
11
+
12
+ OUTPUT_DIR='work_dirs/internvl_stage2_finetune_flickrcn_364_bs1024_ep10_lora_4gpu'
13
+
14
+ if [ ! -d "$OUTPUT_DIR" ]; then
15
+ mkdir -p "$OUTPUT_DIR"
16
+ fi
17
+
18
+ # number of gpus: 32
19
+ # batch size per gpu: 32
20
+ # gradient accumulation steps: 1
21
+ # total batch size: 1024
22
+ # epoch: 10
23
+ torchrun \
24
+ --nnodes=1 \
25
+ --node_rank=0 \
26
+ --master_addr=127.0.0.1 \
27
+ --nproc_per_node=${GPUS} \
28
+ --master_port=${MASTER_PORT} \
29
+ internvl/train/internvl_stage2_finetune.py \
30
+ --dataset_name 'flickr30k_cn_train' \
31
+ --model_name_or_path "./pretrained/InternVL-14B-224px" \
32
+ --output_dir ${OUTPUT_DIR} \
33
+ --overwrite_output_dir True \
34
+ --freeze_model \
35
+ --freeze_vision_model \
36
+ --freeze_qllama \
37
+ --unfreeze_qllama_head \
38
+ --use_backbone_lora 16 \
39
+ --use_qllama_lora 16 \
40
+ --force_image_size 224 \
41
+ --drop_path_rate 0.0 \
42
+ --dataloader_num_workers 2 \
43
+ --pad_to_max_length True \
44
+ --bf16 True \
45
+ --num_train_epochs 10 \
46
+ --per_device_train_batch_size ${BATCH_SIZE} \
47
+ --gradient_accumulation_steps 1 \
48
+ --evaluation_strategy "no" \
49
+ --save_strategy "steps" \
50
+ --save_steps 100 \
51
+ --save_total_limit 5 \
52
+ --learning_rate 1e-6 \
53
+ --weight_decay 0.05 \
54
+ --warmup_steps 100 \
55
+ --lr_scheduler_type "cosine" \
56
+ --logging_steps 1 \
57
+ --max_seq_length 80 \
58
+ --do_train True \
59
+ --optim adamw_torch \
60
+ --deepspeed "zero_stage3_config.json" \
61
+ --report_to "tensorboard"
InternVL/segmentation/configs/_base_/datasets/ade20k_504x504.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dataset settings
2
+ dataset_type = 'ADE20KDataset'
3
+ data_root = 'data/ade/ADEChallengeData2016'
4
+ img_norm_cfg = dict(
5
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
6
+ crop_size = (504, 504)
7
+ train_pipeline = [
8
+ dict(type='LoadImageFromFile'),
9
+ dict(type='LoadAnnotations', reduce_zero_label=True),
10
+ dict(type='Resize', img_scale=(2016, 504), ratio_range=(0.5, 2.0)),
11
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
12
+ dict(type='RandomFlip', prob=0.5),
13
+ dict(type='PhotoMetricDistortion'),
14
+ dict(type='Normalize', **img_norm_cfg),
15
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
16
+ dict(type='DefaultFormatBundle'),
17
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
18
+ ]
19
+ test_pipeline = [
20
+ dict(type='LoadImageFromFile'),
21
+ dict(
22
+ type='MultiScaleFlipAug',
23
+ img_scale=(2016, 504),
24
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
25
+ flip=False,
26
+ transforms=[
27
+ dict(type='SETR_Resize', keep_ratio=True,
28
+ crop_size=crop_size, setr_multi_scale=True),
29
+ dict(type='ResizeToMultiple', size_divisor=14),
30
+ dict(type='RandomFlip'),
31
+ dict(type='Normalize', **img_norm_cfg),
32
+ dict(type='ImageToTensor', keys=['img']),
33
+ dict(type='Collect', keys=['img']),
34
+ ])
35
+ ]
36
+ data = dict(
37
+ samples_per_gpu=4,
38
+ workers_per_gpu=4,
39
+ train=dict(
40
+ type=dataset_type,
41
+ data_root=data_root,
42
+ img_dir='images/training',
43
+ ann_dir='annotations/training',
44
+ pipeline=train_pipeline),
45
+ val=dict(
46
+ type=dataset_type,
47
+ data_root=data_root,
48
+ img_dir='images/validation',
49
+ ann_dir='annotations/validation',
50
+ pipeline=test_pipeline),
51
+ test=dict(
52
+ type=dataset_type,
53
+ data_root=data_root,
54
+ img_dir='images/validation',
55
+ ann_dir='annotations/validation',
56
+ pipeline=test_pipeline))
InternVL/segmentation/configs/_base_/datasets/ade20k_504x504_1of16.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dataset settings
2
+ dataset_type = 'ADE20KDataset'
3
+ data_root = 'data/ade/ADEChallengeData2016'
4
+ img_norm_cfg = dict(
5
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
6
+ crop_size = (504, 504)
7
+ train_pipeline = [
8
+ dict(type='LoadImageFromFile'),
9
+ dict(type='LoadAnnotations', reduce_zero_label=True),
10
+ dict(type='Resize', img_scale=(2016, 504), ratio_range=(0.5, 2.0)),
11
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
12
+ dict(type='RandomFlip', prob=0.5),
13
+ dict(type='PhotoMetricDistortion'),
14
+ dict(type='Normalize', **img_norm_cfg),
15
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
16
+ dict(type='DefaultFormatBundle'),
17
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
18
+ ]
19
+ test_pipeline = [
20
+ dict(type='LoadImageFromFile'),
21
+ dict(
22
+ type='MultiScaleFlipAug',
23
+ img_scale=(2016, 504),
24
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
25
+ flip=False,
26
+ transforms=[
27
+ dict(type='Resize', keep_ratio=True),
28
+ dict(type='ResizeToMultiple', size_divisor=14),
29
+ dict(type='RandomFlip'),
30
+ dict(type='Normalize', **img_norm_cfg),
31
+ dict(type='ImageToTensor', keys=['img']),
32
+ dict(type='Collect', keys=['img']),
33
+ ])
34
+ ]
35
+ data = dict(
36
+ samples_per_gpu=4,
37
+ workers_per_gpu=4,
38
+ train=dict(
39
+ type=dataset_type,
40
+ data_root=data_root,
41
+ img_dir='images/training',
42
+ ann_dir='annotations/training',
43
+ max_image_num=20210 // 16,
44
+ pipeline=train_pipeline),
45
+ val=dict(
46
+ type=dataset_type,
47
+ data_root=data_root,
48
+ img_dir='images/validation',
49
+ ann_dir='annotations/validation',
50
+ pipeline=test_pipeline),
51
+ test=dict(
52
+ type=dataset_type,
53
+ data_root=data_root,
54
+ img_dir='images/validation',
55
+ ann_dir='annotations/validation',
56
+ pipeline=test_pipeline))
InternVL/segmentation/configs/_base_/datasets/cityscapes_1024x1024.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = './cityscapes.py'
2
+ img_norm_cfg = dict(
3
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
4
+ crop_size = (1024, 1024)
5
+ train_pipeline = [
6
+ dict(type='LoadImageFromFile'),
7
+ dict(type='LoadAnnotations'),
8
+ dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
9
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
10
+ dict(type='RandomFlip', prob=0.5),
11
+ dict(type='PhotoMetricDistortion'),
12
+ dict(type='Normalize', **img_norm_cfg),
13
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
14
+ dict(type='DefaultFormatBundle'),
15
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
16
+ ]
17
+ test_pipeline = [
18
+ dict(type='LoadImageFromFile'),
19
+ dict(
20
+ type='MultiScaleFlipAug',
21
+ img_scale=(2048, 1024),
22
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
23
+ flip=False,
24
+ transforms=[
25
+ dict(type='Resize', keep_ratio=True),
26
+ dict(type='RandomFlip'),
27
+ dict(type='Normalize', **img_norm_cfg),
28
+ dict(type='ImageToTensor', keys=['img']),
29
+ dict(type='Collect', keys=['img']),
30
+ ])
31
+ ]
32
+ data = dict(
33
+ train=dict(pipeline=train_pipeline),
34
+ val=dict(pipeline=test_pipeline),
35
+ test=dict(pipeline=test_pipeline))
InternVL/segmentation/configs/_base_/models/apcnet_r50-d8.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
3
+ model = dict(
4
+ type='EncoderDecoder',
5
+ pretrained='open-mmlab://resnet50_v1c',
6
+ backbone=dict(
7
+ type='ResNetV1c',
8
+ depth=50,
9
+ num_stages=4,
10
+ out_indices=(0, 1, 2, 3),
11
+ dilations=(1, 1, 2, 4),
12
+ strides=(1, 2, 1, 1),
13
+ norm_cfg=norm_cfg,
14
+ norm_eval=False,
15
+ style='pytorch',
16
+ contract_dilation=True),
17
+ decode_head=dict(
18
+ type='APCHead',
19
+ in_channels=2048,
20
+ in_index=3,
21
+ channels=512,
22
+ pool_scales=(1, 2, 3, 6),
23
+ dropout_ratio=0.1,
24
+ num_classes=19,
25
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
26
+ align_corners=False,
27
+ loss_decode=dict(
28
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
29
+ auxiliary_head=dict(
30
+ type='FCNHead',
31
+ in_channels=1024,
32
+ in_index=2,
33
+ channels=256,
34
+ num_convs=1,
35
+ concat_input=False,
36
+ dropout_ratio=0.1,
37
+ num_classes=19,
38
+ norm_cfg=norm_cfg,
39
+ align_corners=False,
40
+ loss_decode=dict(
41
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
42
+ # model training and testing settings
43
+ train_cfg=dict(),
44
+ test_cfg=dict(mode='whole'))
InternVL/segmentation/configs/_base_/models/bisenetv1_r18-d32.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
3
+ model = dict(
4
+ type='EncoderDecoder',
5
+ backbone=dict(
6
+ type='BiSeNetV1',
7
+ in_channels=3,
8
+ context_channels=(128, 256, 512),
9
+ spatial_channels=(64, 64, 64, 128),
10
+ out_indices=(0, 1, 2),
11
+ out_channels=256,
12
+ backbone_cfg=dict(
13
+ type='ResNet',
14
+ in_channels=3,
15
+ depth=18,
16
+ num_stages=4,
17
+ out_indices=(0, 1, 2, 3),
18
+ dilations=(1, 1, 1, 1),
19
+ strides=(1, 2, 2, 2),
20
+ norm_cfg=norm_cfg,
21
+ norm_eval=False,
22
+ style='pytorch',
23
+ contract_dilation=True),
24
+ norm_cfg=norm_cfg,
25
+ align_corners=False,
26
+ init_cfg=None),
27
+ decode_head=dict(
28
+ type='FCNHead',
29
+ in_channels=256,
30
+ in_index=0,
31
+ channels=256,
32
+ num_convs=1,
33
+ concat_input=False,
34
+ dropout_ratio=0.1,
35
+ num_classes=19,
36
+ norm_cfg=norm_cfg,
37
+ align_corners=False,
38
+ loss_decode=dict(
39
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
40
+ auxiliary_head=[
41
+ dict(
42
+ type='FCNHead',
43
+ in_channels=128,
44
+ channels=64,
45
+ num_convs=1,
46
+ num_classes=19,
47
+ in_index=1,
48
+ norm_cfg=norm_cfg,
49
+ concat_input=False,
50
+ align_corners=False,
51
+ loss_decode=dict(
52
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
53
+ dict(
54
+ type='FCNHead',
55
+ in_channels=128,
56
+ channels=64,
57
+ num_convs=1,
58
+ num_classes=19,
59
+ in_index=2,
60
+ norm_cfg=norm_cfg,
61
+ concat_input=False,
62
+ align_corners=False,
63
+ loss_decode=dict(
64
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
65
+ ],
66
+ # model training and testing settings
67
+ train_cfg=dict(),
68
+ test_cfg=dict(mode='whole'))
InternVL/segmentation/configs/_base_/models/danet_r50-d8.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
3
+ model = dict(
4
+ type='EncoderDecoder',
5
+ pretrained='open-mmlab://resnet50_v1c',
6
+ backbone=dict(
7
+ type='ResNetV1c',
8
+ depth=50,
9
+ num_stages=4,
10
+ out_indices=(0, 1, 2, 3),
11
+ dilations=(1, 1, 2, 4),
12
+ strides=(1, 2, 1, 1),
13
+ norm_cfg=norm_cfg,
14
+ norm_eval=False,
15
+ style='pytorch',
16
+ contract_dilation=True),
17
+ decode_head=dict(
18
+ type='DAHead',
19
+ in_channels=2048,
20
+ in_index=3,
21
+ channels=512,
22
+ pam_channels=64,
23
+ dropout_ratio=0.1,
24
+ num_classes=19,
25
+ norm_cfg=norm_cfg,
26
+ align_corners=False,
27
+ loss_decode=dict(
28
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
29
+ auxiliary_head=dict(
30
+ type='FCNHead',
31
+ in_channels=1024,
32
+ in_index=2,
33
+ channels=256,
34
+ num_convs=1,
35
+ concat_input=False,
36
+ dropout_ratio=0.1,
37
+ num_classes=19,
38
+ norm_cfg=norm_cfg,
39
+ align_corners=False,
40
+ loss_decode=dict(
41
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
42
+ # model training and testing settings
43
+ train_cfg=dict(),
44
+ test_cfg=dict(mode='whole'))
InternVL/segmentation/configs/_base_/models/deeplabv3plus_r50-d8.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
3
+ model = dict(
4
+ type='EncoderDecoder',
5
+ pretrained='open-mmlab://resnet50_v1c',
6
+ backbone=dict(
7
+ type='ResNetV1c',
8
+ depth=50,
9
+ num_stages=4,
10
+ out_indices=(0, 1, 2, 3),
11
+ dilations=(1, 1, 2, 4),
12
+ strides=(1, 2, 1, 1),
13
+ norm_cfg=norm_cfg,
14
+ norm_eval=False,
15
+ style='pytorch',
16
+ contract_dilation=True),
17
+ decode_head=dict(
18
+ type='DepthwiseSeparableASPPHead',
19
+ in_channels=2048,
20
+ in_index=3,
21
+ channels=512,
22
+ dilations=(1, 12, 24, 36),
23
+ c1_in_channels=256,
24
+ c1_channels=48,
25
+ dropout_ratio=0.1,
26
+ num_classes=19,
27
+ norm_cfg=norm_cfg,
28
+ align_corners=False,
29
+ loss_decode=dict(
30
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
31
+ auxiliary_head=dict(
32
+ type='FCNHead',
33
+ in_channels=1024,
34
+ in_index=2,
35
+ channels=256,
36
+ num_convs=1,
37
+ concat_input=False,
38
+ dropout_ratio=0.1,
39
+ num_classes=19,
40
+ norm_cfg=norm_cfg,
41
+ align_corners=False,
42
+ loss_decode=dict(
43
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
44
+ # model training and testing settings
45
+ train_cfg=dict(),
46
+ test_cfg=dict(mode='whole'))
InternVL/segmentation/configs/_base_/models/dmnet_r50-d8.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
3
+ model = dict(
4
+ type='EncoderDecoder',
5
+ pretrained='open-mmlab://resnet50_v1c',
6
+ backbone=dict(
7
+ type='ResNetV1c',
8
+ depth=50,
9
+ num_stages=4,
10
+ out_indices=(0, 1, 2, 3),
11
+ dilations=(1, 1, 2, 4),
12
+ strides=(1, 2, 1, 1),
13
+ norm_cfg=norm_cfg,
14
+ norm_eval=False,
15
+ style='pytorch',
16
+ contract_dilation=True),
17
+ decode_head=dict(
18
+ type='DMHead',
19
+ in_channels=2048,
20
+ in_index=3,
21
+ channels=512,
22
+ filter_sizes=(1, 3, 5, 7),
23
+ dropout_ratio=0.1,
24
+ num_classes=19,
25
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
26
+ align_corners=False,
27
+ loss_decode=dict(
28
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
29
+ auxiliary_head=dict(
30
+ type='FCNHead',
31
+ in_channels=1024,
32
+ in_index=2,
33
+ channels=256,
34
+ num_convs=1,
35
+ concat_input=False,
36
+ dropout_ratio=0.1,
37
+ num_classes=19,
38
+ norm_cfg=norm_cfg,
39
+ align_corners=False,
40
+ loss_decode=dict(
41
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
42
+ # model training and testing settings
43
+ train_cfg=dict(),
44
+ test_cfg=dict(mode='whole'))
InternVL/segmentation/configs/_base_/models/encnet_r50-d8.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
3
+ model = dict(
4
+ type='EncoderDecoder',
5
+ pretrained='open-mmlab://resnet50_v1c',
6
+ backbone=dict(
7
+ type='ResNetV1c',
8
+ depth=50,
9
+ num_stages=4,
10
+ out_indices=(0, 1, 2, 3),
11
+ dilations=(1, 1, 2, 4),
12
+ strides=(1, 2, 1, 1),
13
+ norm_cfg=norm_cfg,
14
+ norm_eval=False,
15
+ style='pytorch',
16
+ contract_dilation=True),
17
+ decode_head=dict(
18
+ type='EncHead',
19
+ in_channels=[512, 1024, 2048],
20
+ in_index=(1, 2, 3),
21
+ channels=512,
22
+ num_codes=32,
23
+ use_se_loss=True,
24
+ add_lateral=False,
25
+ dropout_ratio=0.1,
26
+ num_classes=19,
27
+ norm_cfg=norm_cfg,
28
+ align_corners=False,
29
+ loss_decode=dict(
30
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
31
+ loss_se_decode=dict(
32
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.2)),
33
+ auxiliary_head=dict(
34
+ type='FCNHead',
35
+ in_channels=1024,
36
+ in_index=2,
37
+ channels=256,
38
+ num_convs=1,
39
+ concat_input=False,
40
+ dropout_ratio=0.1,
41
+ num_classes=19,
42
+ norm_cfg=norm_cfg,
43
+ align_corners=False,
44
+ loss_decode=dict(
45
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
46
+ # model training and testing settings
47
+ train_cfg=dict(),
48
+ test_cfg=dict(mode='whole'))
InternVL/segmentation/configs/_base_/models/erfnet_fcn.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
3
+ model = dict(
4
+ type='EncoderDecoder',
5
+ pretrained=None,
6
+ backbone=dict(
7
+ type='ERFNet',
8
+ in_channels=3,
9
+ enc_downsample_channels=(16, 64, 128),
10
+ enc_stage_non_bottlenecks=(5, 8),
11
+ enc_non_bottleneck_dilations=(2, 4, 8, 16),
12
+ enc_non_bottleneck_channels=(64, 128),
13
+ dec_upsample_channels=(64, 16),
14
+ dec_stages_non_bottleneck=(2, 2),
15
+ dec_non_bottleneck_channels=(64, 16),
16
+ dropout_ratio=0.1,
17
+ init_cfg=None),
18
+ decode_head=dict(
19
+ type='FCNHead',
20
+ in_channels=16,
21
+ channels=128,
22
+ num_convs=1,
23
+ concat_input=False,
24
+ dropout_ratio=0.1,
25
+ num_classes=19,
26
+ norm_cfg=norm_cfg,
27
+ align_corners=False,
28
+ loss_decode=dict(
29
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
30
+ # model training and testing settings
31
+ train_cfg=dict(),
32
+ test_cfg=dict(mode='whole'))
InternVL/segmentation/configs/_base_/models/fastfcn_r50-d32_jpu_psp.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
3
+ model = dict(
4
+ type='EncoderDecoder',
5
+ pretrained='open-mmlab://resnet50_v1c',
6
+ backbone=dict(
7
+ type='ResNetV1c',
8
+ depth=50,
9
+ num_stages=4,
10
+ dilations=(1, 1, 2, 4),
11
+ strides=(1, 2, 2, 2),
12
+ out_indices=(1, 2, 3),
13
+ norm_cfg=norm_cfg,
14
+ norm_eval=False,
15
+ style='pytorch',
16
+ contract_dilation=True),
17
+ neck=dict(
18
+ type='JPU',
19
+ in_channels=(512, 1024, 2048),
20
+ mid_channels=512,
21
+ start_level=0,
22
+ end_level=-1,
23
+ dilations=(1, 2, 4, 8),
24
+ align_corners=False,
25
+ norm_cfg=norm_cfg),
26
+ decode_head=dict(
27
+ type='PSPHead',
28
+ in_channels=2048,
29
+ in_index=2,
30
+ channels=512,
31
+ pool_scales=(1, 2, 3, 6),
32
+ dropout_ratio=0.1,
33
+ num_classes=19,
34
+ norm_cfg=norm_cfg,
35
+ align_corners=False,
36
+ loss_decode=dict(
37
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
38
+ auxiliary_head=dict(
39
+ type='FCNHead',
40
+ in_channels=1024,
41
+ in_index=1,
42
+ channels=256,
43
+ num_convs=1,
44
+ concat_input=False,
45
+ dropout_ratio=0.1,
46
+ num_classes=19,
47
+ norm_cfg=norm_cfg,
48
+ align_corners=False,
49
+ loss_decode=dict(
50
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
51
+ # model training and testing settings
52
+ train_cfg=dict(),
53
+ test_cfg=dict(mode='whole'))
InternVL/segmentation/configs/_base_/models/fcn_hr18.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
3
+ model = dict(
4
+ type='EncoderDecoder',
5
+ pretrained='open-mmlab://msra/hrnetv2_w18',
6
+ backbone=dict(
7
+ type='HRNet',
8
+ norm_cfg=norm_cfg,
9
+ norm_eval=False,
10
+ extra=dict(
11
+ stage1=dict(
12
+ num_modules=1,
13
+ num_branches=1,
14
+ block='BOTTLENECK',
15
+ num_blocks=(4, ),
16
+ num_channels=(64, )),
17
+ stage2=dict(
18
+ num_modules=1,
19
+ num_branches=2,
20
+ block='BASIC',
21
+ num_blocks=(4, 4),
22
+ num_channels=(18, 36)),
23
+ stage3=dict(
24
+ num_modules=4,
25
+ num_branches=3,
26
+ block='BASIC',
27
+ num_blocks=(4, 4, 4),
28
+ num_channels=(18, 36, 72)),
29
+ stage4=dict(
30
+ num_modules=3,
31
+ num_branches=4,
32
+ block='BASIC',
33
+ num_blocks=(4, 4, 4, 4),
34
+ num_channels=(18, 36, 72, 144)))),
35
+ decode_head=dict(
36
+ type='FCNHead',
37
+ in_channels=[18, 36, 72, 144],
38
+ in_index=(0, 1, 2, 3),
39
+ channels=sum([18, 36, 72, 144]),
40
+ input_transform='resize_concat',
41
+ kernel_size=1,
42
+ num_convs=1,
43
+ concat_input=False,
44
+ dropout_ratio=-1,
45
+ num_classes=19,
46
+ norm_cfg=norm_cfg,
47
+ align_corners=False,
48
+ loss_decode=dict(
49
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
50
+ # model training and testing settings
51
+ train_cfg=dict(),
52
+ test_cfg=dict(mode='whole'))
InternVL/segmentation/configs/_base_/models/fpn_r50.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
3
+ model = dict(
4
+ type='EncoderDecoder',
5
+ pretrained='open-mmlab://resnet50_v1c',
6
+ backbone=dict(
7
+ type='ResNetV1c',
8
+ depth=50,
9
+ num_stages=4,
10
+ out_indices=(0, 1, 2, 3),
11
+ dilations=(1, 1, 1, 1),
12
+ strides=(1, 2, 2, 2),
13
+ norm_cfg=norm_cfg,
14
+ norm_eval=False,
15
+ style='pytorch',
16
+ contract_dilation=True),
17
+ neck=dict(
18
+ type='FPN',
19
+ in_channels=[256, 512, 1024, 2048],
20
+ out_channels=256,
21
+ num_outs=4),
22
+ decode_head=dict(
23
+ type='FPNHead',
24
+ in_channels=[256, 256, 256, 256],
25
+ in_index=[0, 1, 2, 3],
26
+ feature_strides=[4, 8, 16, 32],
27
+ channels=128,
28
+ dropout_ratio=0.1,
29
+ num_classes=19,
30
+ norm_cfg=norm_cfg,
31
+ align_corners=False,
32
+ loss_decode=dict(
33
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
34
+ # model training and testing settings
35
+ train_cfg=dict(),
36
+ test_cfg=dict(mode='whole'))
InternVL/segmentation/configs/_base_/models/isanet_r50-d8.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
3
+ model = dict(
4
+ type='EncoderDecoder',
5
+ pretrained='open-mmlab://resnet50_v1c',
6
+ backbone=dict(
7
+ type='ResNetV1c',
8
+ depth=50,
9
+ num_stages=4,
10
+ out_indices=(0, 1, 2, 3),
11
+ dilations=(1, 1, 2, 4),
12
+ strides=(1, 2, 1, 1),
13
+ norm_cfg=norm_cfg,
14
+ norm_eval=False,
15
+ style='pytorch',
16
+ contract_dilation=True),
17
+ decode_head=dict(
18
+ type='ISAHead',
19
+ in_channels=2048,
20
+ in_index=3,
21
+ channels=512,
22
+ isa_channels=256,
23
+ down_factor=(8, 8),
24
+ dropout_ratio=0.1,
25
+ num_classes=19,
26
+ norm_cfg=norm_cfg,
27
+ align_corners=False,
28
+ loss_decode=dict(
29
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
30
+ auxiliary_head=dict(
31
+ type='FCNHead',
32
+ in_channels=1024,
33
+ in_index=2,
34
+ channels=256,
35
+ num_convs=1,
36
+ concat_input=False,
37
+ dropout_ratio=0.1,
38
+ num_classes=19,
39
+ norm_cfg=norm_cfg,
40
+ align_corners=False,
41
+ loss_decode=dict(
42
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
43
+ # model training and testing settings
44
+ train_cfg=dict(),
45
+ test_cfg=dict(mode='whole'))
InternVL/segmentation/configs/_base_/models/lraspp_m-v3-d8.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', eps=0.001, requires_grad=True)
3
+ model = dict(
4
+ type='EncoderDecoder',
5
+ backbone=dict(
6
+ type='MobileNetV3',
7
+ arch='large',
8
+ out_indices=(1, 3, 16),
9
+ norm_cfg=norm_cfg),
10
+ decode_head=dict(
11
+ type='LRASPPHead',
12
+ in_channels=(16, 24, 960),
13
+ in_index=(0, 1, 2),
14
+ channels=128,
15
+ input_transform='multiple_select',
16
+ dropout_ratio=0.1,
17
+ num_classes=19,
18
+ norm_cfg=norm_cfg,
19
+ act_cfg=dict(type='ReLU'),
20
+ align_corners=False,
21
+ loss_decode=dict(
22
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
23
+ # model training and testing settings
24
+ train_cfg=dict(),
25
+ test_cfg=dict(mode='whole'))
InternVL/segmentation/configs/_base_/models/pointrend_r50.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
3
+ model = dict(
4
+ type='CascadeEncoderDecoder',
5
+ num_stages=2,
6
+ pretrained='open-mmlab://resnet50_v1c',
7
+ backbone=dict(
8
+ type='ResNetV1c',
9
+ depth=50,
10
+ num_stages=4,
11
+ out_indices=(0, 1, 2, 3),
12
+ dilations=(1, 1, 1, 1),
13
+ strides=(1, 2, 2, 2),
14
+ norm_cfg=norm_cfg,
15
+ norm_eval=False,
16
+ style='pytorch',
17
+ contract_dilation=True),
18
+ neck=dict(
19
+ type='FPN',
20
+ in_channels=[256, 512, 1024, 2048],
21
+ out_channels=256,
22
+ num_outs=4),
23
+ decode_head=[
24
+ dict(
25
+ type='FPNHead',
26
+ in_channels=[256, 256, 256, 256],
27
+ in_index=[0, 1, 2, 3],
28
+ feature_strides=[4, 8, 16, 32],
29
+ channels=128,
30
+ dropout_ratio=-1,
31
+ num_classes=19,
32
+ norm_cfg=norm_cfg,
33
+ align_corners=False,
34
+ loss_decode=dict(
35
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
36
+ dict(
37
+ type='PointHead',
38
+ in_channels=[256],
39
+ in_index=[0],
40
+ channels=256,
41
+ num_fcs=3,
42
+ coarse_pred_each_layer=True,
43
+ dropout_ratio=-1,
44
+ num_classes=19,
45
+ align_corners=False,
46
+ loss_decode=dict(
47
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
48
+ ],
49
+ # model training and testing settings
50
+ train_cfg=dict(
51
+ num_points=2048, oversample_ratio=3, importance_sample_ratio=0.75),
52
+ test_cfg=dict(
53
+ mode='whole',
54
+ subdivision_steps=2,
55
+ subdivision_num_points=8196,
56
+ scale_factor=2))
InternVL/segmentation/configs/_base_/models/pspnet_unet_s5-d16.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
3
+ model = dict(
4
+ type='EncoderDecoder',
5
+ pretrained=None,
6
+ backbone=dict(
7
+ type='UNet',
8
+ in_channels=3,
9
+ base_channels=64,
10
+ num_stages=5,
11
+ strides=(1, 1, 1, 1, 1),
12
+ enc_num_convs=(2, 2, 2, 2, 2),
13
+ dec_num_convs=(2, 2, 2, 2),
14
+ downsamples=(True, True, True, True),
15
+ enc_dilations=(1, 1, 1, 1, 1),
16
+ dec_dilations=(1, 1, 1, 1),
17
+ with_cp=False,
18
+ conv_cfg=None,
19
+ norm_cfg=norm_cfg,
20
+ act_cfg=dict(type='ReLU'),
21
+ upsample_cfg=dict(type='InterpConv'),
22
+ norm_eval=False),
23
+ decode_head=dict(
24
+ type='PSPHead',
25
+ in_channels=64,
26
+ in_index=4,
27
+ channels=16,
28
+ pool_scales=(1, 2, 3, 6),
29
+ dropout_ratio=0.1,
30
+ num_classes=2,
31
+ norm_cfg=norm_cfg,
32
+ align_corners=False,
33
+ loss_decode=dict(
34
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
35
+ auxiliary_head=dict(
36
+ type='FCNHead',
37
+ in_channels=128,
38
+ in_index=3,
39
+ channels=64,
40
+ num_convs=1,
41
+ concat_input=False,
42
+ dropout_ratio=0.1,
43
+ num_classes=2,
44
+ norm_cfg=norm_cfg,
45
+ align_corners=False,
46
+ loss_decode=dict(
47
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
48
+ # model training and testing settings
49
+ train_cfg=dict(),
50
+ test_cfg=dict(mode='slide', crop_size=256, stride=170))
InternVL/segmentation/configs/_base_/models/upernet_r50.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
3
+ model = dict(
4
+ type='EncoderDecoder',
5
+ pretrained='open-mmlab://resnet50_v1c',
6
+ backbone=dict(
7
+ type='ResNetV1c',
8
+ depth=50,
9
+ num_stages=4,
10
+ out_indices=(0, 1, 2, 3),
11
+ dilations=(1, 1, 1, 1),
12
+ strides=(1, 2, 2, 2),
13
+ norm_cfg=norm_cfg,
14
+ norm_eval=False,
15
+ style='pytorch',
16
+ contract_dilation=True),
17
+ decode_head=dict(
18
+ type='UPerHead',
19
+ in_channels=[256, 512, 1024, 2048],
20
+ in_index=[0, 1, 2, 3],
21
+ pool_scales=(1, 2, 3, 6),
22
+ channels=512,
23
+ dropout_ratio=0.1,
24
+ num_classes=19,
25
+ norm_cfg=norm_cfg,
26
+ align_corners=False,
27
+ loss_decode=dict(
28
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
29
+ auxiliary_head=dict(
30
+ type='FCNHead',
31
+ in_channels=1024,
32
+ in_index=2,
33
+ channels=256,
34
+ num_convs=1,
35
+ concat_input=False,
36
+ dropout_ratio=0.1,
37
+ num_classes=19,
38
+ norm_cfg=norm_cfg,
39
+ align_corners=False,
40
+ loss_decode=dict(
41
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
42
+ # model training and testing settings
43
+ train_cfg=dict(),
44
+ test_cfg=dict(mode='whole'))
InternVL/segmentation/configs/_base_/schedules/schedule_10k.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # optimizer
2
+ optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
3
+ optimizer_config = dict()
4
+ # learning policy
5
+ lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
6
+ # runtime settings
7
+ runner = dict(type='IterBasedRunner', max_iters=10000)
8
+ checkpoint_config = dict(by_epoch=False, interval=1000)
9
+ evaluation = dict(interval=1000, metric='mIoU', pre_eval=True)
InternVL/segmentation/configs/_base_/schedules/schedule_160k.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # optimizer
2
+ optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
3
+ optimizer_config = dict()
4
+ # learning policy
5
+ lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
6
+ # runtime settings
7
+ runner = dict(type='IterBasedRunner', max_iters=160000)
8
+ checkpoint_config = dict(by_epoch=False, interval=16000)
9
+ evaluation = dict(interval=16000, metric='mIoU', pre_eval=True)
InternVL/segmentation/configs/_base_/schedules/schedule_20k.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # optimizer
2
+ optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
3
+ optimizer_config = dict()
4
+ # learning policy
5
+ lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
6
+ # runtime settings
7
+ runner = dict(type='IterBasedRunner', max_iters=20000)
8
+ checkpoint_config = dict(by_epoch=False, interval=2000)
9
+ evaluation = dict(interval=2000, metric='mIoU', pre_eval=True)
InternVL/segmentation/configs/_base_/schedules/schedule_320k.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # optimizer
2
+ optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
3
+ optimizer_config = dict()
4
+ # learning policy
5
+ lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
6
+ # runtime settings
7
+ runner = dict(type='IterBasedRunner', max_iters=320000)
8
+ checkpoint_config = dict(by_epoch=False, interval=32000)
9
+ evaluation = dict(interval=32000, metric='mIoU')
InternVL/segmentation/configs/_base_/schedules/schedule_40k.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # optimizer
2
+ optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
3
+ optimizer_config = dict()
4
+ # learning policy
5
+ lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
6
+ # runtime settings
7
+ runner = dict(type='IterBasedRunner', max_iters=40000)
8
+ checkpoint_config = dict(by_epoch=False, interval=4000)
9
+ evaluation = dict(interval=4000, metric='mIoU', pre_eval=True)
InternVL/segmentation/configs/_base_/schedules/schedule_5k.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # optimizer
2
+ optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
3
+ optimizer_config = dict()
4
+ # learning policy
5
+ lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
6
+ # runtime settings
7
+ runner = dict(type='IterBasedRunner', max_iters=5000)
8
+ checkpoint_config = dict(by_epoch=False, interval=1000)
9
+ evaluation = dict(interval=1000, metric='mIoU', pre_eval=True)
InternVL/segmentation/configs/_base_/schedules/schedule_80k.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # optimizer
2
+ optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
3
+ optimizer_config = dict()
4
+ # learning policy
5
+ lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
6
+ # runtime settings
7
+ runner = dict(type='IterBasedRunner', max_iters=80000)
8
+ checkpoint_config = dict(by_epoch=False, interval=8000)
9
+ evaluation = dict(interval=8000, metric='mIoU', pre_eval=True)
InternVL/segmentation/configs/intern_vit_6b/few_shot/linear_intern_vit_6b_504_10k_ade20k_bs16_lr4e-5_1of8.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2023 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ _base_ = [
8
+ '../../_base_/models/segmenter_vit-b16_mask.py',
9
+ '../../_base_/datasets/ade20k_504x504_1of8.py',
10
+ '../../_base_/default_runtime.py',
11
+ '../../_base_/schedules/schedule_10k.py'
12
+ ]
13
+ deepspeed = False
14
+ deepspeed_config = 'zero_configs/adam_zero1_bf16.json'
15
+ pretrained = './pretrained/intern_vit_6b_224px.pth'
16
+ model = dict(
17
+ pretrained=None,
18
+ backbone=dict(
19
+ _delete_=True,
20
+ type='InternViT6B',
21
+ pretrain_size=224,
22
+ img_size=504,
23
+ patch_size=14,
24
+ embed_dim=3200,
25
+ depth=48,
26
+ num_heads=25,
27
+ mlp_ratio=4.,
28
+ qkv_bias=False,
29
+ drop_path_rate=0.4,
30
+ init_values=0.1,
31
+ with_cp=True,
32
+ use_flash_attn=True,
33
+ qk_normalization=True,
34
+ layerscale_force_fp32=False,
35
+ freeze_vit=False,
36
+ out_indices=[47],
37
+ pretrained=pretrained),
38
+ decode_head=dict(
39
+ _delete_=True,
40
+ type='FCNHead',
41
+ in_channels=3200,
42
+ channels=3200,
43
+ num_convs=0,
44
+ dropout_ratio=0.0,
45
+ concat_input=False,
46
+ num_classes=150,
47
+ with_norm=True,
48
+ loss_decode=dict(
49
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
50
+ test_cfg=dict(mode='slide', crop_size=(504, 504), stride=(322, 322))
51
+ )
52
+ optimizer = dict(_delete_=True, type='AdamW', lr=4e-5, betas=(0.9, 0.999), weight_decay=0.05,
53
+ constructor='CustomLayerDecayOptimizerConstructor',
54
+ paramwise_cfg=dict(num_layers=48, layer_decay_rate=0.95))
55
+ lr_config = dict(_delete_=True, policy='poly',
56
+ warmup='linear',
57
+ warmup_iters=200,
58
+ warmup_ratio=1e-6,
59
+ power=1.0, min_lr=0.0, by_epoch=False)
60
+ # By default, models are trained on 8 GPUs with 2 images per GPU
61
+ data = dict(samples_per_gpu=2)
62
+ runner = dict(type='IterBasedRunner')
63
+ if deepspeed:
64
+ checkpoint_config = dict(deepspeed=deepspeed, by_epoch=False, interval=1000, max_keep_ckpts=2)
65
+ else:
66
+ checkpoint_config = dict(by_epoch=False, interval=1000, max_keep_ckpts=2)
67
+ evaluation = dict(interval=1000, metric='mIoU', save_best='auto')
68
+ custom_hooks = [
69
+ dict(
70
+ type='ToBFloat16Hook',
71
+ priority=49),
72
+ ]
InternVL/segmentation/configs/intern_vit_6b/few_shot/linear_intern_vit_6b_504_20k_ade20k_bs16_lr4e-5_1of4.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2023 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ _base_ = [
8
+ '../../_base_/models/segmenter_vit-b16_mask.py',
9
+ '../../_base_/datasets/ade20k_504x504_1of4.py',
10
+ '../../_base_/default_runtime.py',
11
+ '../../_base_/schedules/schedule_20k.py'
12
+ ]
13
+ deepspeed = False
14
+ deepspeed_config = 'zero_configs/adam_zero1_bf16.json'
15
+ pretrained = './pretrained/intern_vit_6b_224px.pth'
16
+ model = dict(
17
+ pretrained=None,
18
+ backbone=dict(
19
+ _delete_=True,
20
+ type='InternViT6B',
21
+ pretrain_size=224,
22
+ img_size=504,
23
+ patch_size=14,
24
+ embed_dim=3200,
25
+ depth=48,
26
+ num_heads=25,
27
+ mlp_ratio=4.,
28
+ qkv_bias=False,
29
+ drop_path_rate=0.4,
30
+ init_values=0.1,
31
+ with_cp=True,
32
+ use_flash_attn=True,
33
+ qk_normalization=True,
34
+ layerscale_force_fp32=False,
35
+ freeze_vit=False,
36
+ out_indices=[47],
37
+ pretrained=pretrained),
38
+ decode_head=dict(
39
+ _delete_=True,
40
+ type='FCNHead',
41
+ in_channels=3200,
42
+ channels=3200,
43
+ num_convs=0,
44
+ dropout_ratio=0.0,
45
+ concat_input=False,
46
+ num_classes=150,
47
+ with_norm=True,
48
+ loss_decode=dict(
49
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
50
+ test_cfg=dict(mode='slide', crop_size=(504, 504), stride=(322, 322))
51
+ )
52
+ optimizer = dict(_delete_=True, type='AdamW', lr=4e-5, betas=(0.9, 0.999), weight_decay=0.05,
53
+ constructor='CustomLayerDecayOptimizerConstructor',
54
+ paramwise_cfg=dict(num_layers=48, layer_decay_rate=0.95))
55
+ lr_config = dict(_delete_=True, policy='poly',
56
+ warmup='linear',
57
+ warmup_iters=400,
58
+ warmup_ratio=1e-6,
59
+ power=1.0, min_lr=0.0, by_epoch=False)
60
+ # By default, models are trained on 8 GPUs with 2 images per GPU
61
+ data = dict(samples_per_gpu=2)
62
+ runner = dict(type='IterBasedRunner')
63
+ if deepspeed:
64
+ checkpoint_config = dict(deepspeed=deepspeed, by_epoch=False, interval=1000, max_keep_ckpts=2)
65
+ else:
66
+ checkpoint_config = dict(by_epoch=False, interval=1000, max_keep_ckpts=2)
67
+ evaluation = dict(interval=1000, metric='mIoU', save_best='auto')
68
+ custom_hooks = [
69
+ dict(
70
+ type='ToBFloat16Hook',
71
+ priority=49),
72
+ ]