{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.12","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"none","dataSources":[{"sourceId":1111676,"sourceType":"datasetVersion","datasetId":623289},{"sourceId":1462296,"sourceType":"datasetVersion","datasetId":857191},{"sourceId":4845244,"sourceType":"datasetVersion","datasetId":2808179},{"sourceId":10260246,"sourceType":"datasetVersion","datasetId":6346990},{"sourceId":10279488,"sourceType":"datasetVersion","datasetId":6299544},{"sourceId":10322702,"sourceType":"datasetVersion","datasetId":6367524},{"sourceId":138371258,"sourceType":"kernelVersion"},{"sourceId":208899055,"sourceType":"kernelVersion"}],"dockerImageVersionId":30823,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":false}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"markdown","source":"# **ViT + GPT2**\n---","metadata":{}},{"cell_type":"markdown","source":"This model combines ViT as the image encoder and GPT2, a generative language model, as the caption decoder. ViT extracts patch-level features, which are passed to GPT2 for autoregressive caption generation.","metadata":{}},{"cell_type":"markdown","source":"## Import Library\n---","metadata":{}},{"cell_type":"code","source":"!pip install pycocoevalcap\n\n!pip install -U nltk\n!pip install nltk==3.5\n\n!unzip /usr/share/nltk_data/corpora/wordnet.zip -d /usr/share/nltk_data/corpora/","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"import gc\nimport json\nimport os\nimport re\nfrom collections import Counter\nfrom math import log, sqrt\nfrom pathlib import Path\nfrom types import SimpleNamespace\n\nimport albumentations as A\nfrom albumentations.pytorch import ToTensorV2\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport pandas as pd\nfrom PIL import Image\nfrom sklearn.model_selection import train_test_split\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.cuda.amp import GradScaler, autocast\nfrom timm import create_model, list_models\nfrom transformers import GPT2LMHeadModel, GPT2TokenizerFast, get_linear_schedule_with_warmup\nfrom tqdm.auto import tqdm\nfrom tqdm import tqdm as tqdm_base\nfrom nltk.tokenize import word_tokenize\nfrom nltk.translate.bleu_score import sentence_bleu, SmoothingFunction\nfrom nltk.translate.meteor_score import meteor_score","metadata":{"_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","execution":{"iopub.status.busy":"2024-12-28T18:10:09.426610Z","iopub.execute_input":"2024-12-28T18:10:09.426884Z","iopub.status.idle":"2024-12-28T18:10:18.550257Z","shell.execute_reply.started":"2024-12-28T18:10:09.426864Z","shell.execute_reply":"2024-12-28T18:10:18.549552Z"},"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"%env TOKENIZERS_PARALLELISM = false","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-28T18:10:18.551479Z","iopub.execute_input":"2024-12-28T18:10:18.552035Z","iopub.status.idle":"2024-12-28T18:10:18.556736Z","shell.execute_reply.started":"2024-12-28T18:10:18.552002Z","shell.execute_reply":"2024-12-28T18:10:18.556024Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Dataset & Preprocessing\n---\nTo train this model, we use the MS-COCO dataset, which contains thousands of images paired with human-generated captions.\n\nThe preprocessing pipeline transforms raw images into a format suitable for patch embedding and subsequent transformer processing. The provided preprocessing steps use the Albumentations library for augmentations and normalization.\n\n* Augments the dataset with realistic transformations to improve generalization.\n* Standardizes image dimensions and pixel distributions for compatibility with the Vision Transformer.\n* Balances augmentation to prevent overfitting while retaining essential visual features.","metadata":{}},{"cell_type":"code","source":"sample_tfms = [\n A.HorizontalFlip(),\n A.RandomBrightnessContrast(),\n A.ColorJitter(),\n A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.3, rotate_limit=45, p=0.5),\n A.HueSaturationValue(p=0.3),\n]\ntrain_tfms = A.Compose([\n *sample_tfms,\n A.Resize(224,224),\n A.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5],always_apply=True),\n ToTensorV2()\n])\nvalid_tfms = A.Compose([\n A.Resize(224,224),\n A.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5],always_apply=True),\n ToTensorV2()\n])","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-28T18:10:18.558648Z","iopub.execute_input":"2024-12-28T18:10:18.559005Z","iopub.status.idle":"2024-12-28T18:10:18.581647Z","shell.execute_reply.started":"2024-12-28T18:10:18.558971Z","shell.execute_reply":"2024-12-28T18:10:18.580826Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# GPT2 tokenizer is used for tokenization and padding\n\ntokenizer = GPT2TokenizerFast.from_pretrained('gpt2')\ntokenizer.pad_token = tokenizer.eos_token\ntokenizer.pad_token","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-28T18:10:18.583035Z","iopub.execute_input":"2024-12-28T18:10:18.583333Z","iopub.status.idle":"2024-12-28T18:10:20.863348Z","shell.execute_reply.started":"2024-12-28T18:10:18.583311Z","shell.execute_reply":"2024-12-28T18:10:20.862444Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"class Dataset:\n def __init__(self, df, tfms):\n self.df = df\n self.tfms = tfms\n def __len__(self):\n return len(self.df)\n def __getitem__(self,idx):\n sample = self.df.iloc[idx,:]\n image = sample['image']\n caption = sample['caption']\n image = Image.open(image).convert('RGB')\n image = np.array(image)\n augs = self.tfms(image=image)\n image = augs['image']\n caption = f\"{caption}<|endoftext|>\"\n input_ids = tokenizer(\n caption,\n truncation=True)['input_ids']\n labels = input_ids.copy()\n labels[:-1] = input_ids[1:]\n return image,input_ids,labels","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-28T18:10:20.864257Z","iopub.execute_input":"2024-12-28T18:10:20.864548Z","iopub.status.idle":"2024-12-28T18:10:20.869766Z","shell.execute_reply.started":"2024-12-28T18:10:20.864521Z","shell.execute_reply":"2024-12-28T18:10:20.868826Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"import pandas as pd\n\ntrain_df = pd.read_csv('/kaggle/input/coco-df/df_coco_train_complete.csv', index_col=False)\nval_df = pd.read_csv('/kaggle/input/coco-df/df_coco_val_complete.csv', index_col=False)\n\ntrain_df.reset_index(drop=True,inplace=True)\nval_df.reset_index(drop=True,inplace=True)\n\nprint(len(train_df),len(val_df))","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-28T18:10:20.870443Z","iopub.execute_input":"2024-12-28T18:10:20.870689Z","iopub.status.idle":"2024-12-28T18:10:22.546250Z","shell.execute_reply.started":"2024-12-28T18:10:20.870670Z","shell.execute_reply":"2024-12-28T18:10:22.545431Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"train_df.head(5)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-28T18:10:22.547118Z","iopub.execute_input":"2024-12-28T18:10:22.547445Z","iopub.status.idle":"2024-12-28T18:10:22.559856Z","shell.execute_reply.started":"2024-12-28T18:10:22.547411Z","shell.execute_reply":"2024-12-28T18:10:22.558968Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"val_df.head(5)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-28T18:10:22.562352Z","iopub.execute_input":"2024-12-28T18:10:22.562589Z","iopub.status.idle":"2024-12-28T18:10:22.577297Z","shell.execute_reply.started":"2024-12-28T18:10:22.562569Z","shell.execute_reply":"2024-12-28T18:10:22.576650Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"train_ds = Dataset(train_df,train_tfms)\nval_ds = Dataset(val_df,valid_tfms)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-28T18:10:22.578561Z","iopub.execute_input":"2024-12-28T18:10:22.578841Z","iopub.status.idle":"2024-12-28T18:10:22.590144Z","shell.execute_reply.started":"2024-12-28T18:10:22.578816Z","shell.execute_reply":"2024-12-28T18:10:22.589505Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"def collate_fn(batch):\n image = [i[0] for i in batch]\n input_ids = [i[1] for i in batch]\n labels = [i[2] for i in batch]\n image = torch.stack(image,dim=0)\n input_ids = tokenizer.pad(\n {'input_ids':input_ids},\n padding='longest',\n return_attention_mask=False,\n return_tensors='pt'\n )['input_ids']\n labels = tokenizer.pad(\n {'input_ids':labels},\n padding='longest',\n return_attention_mask=False,\n return_tensors='pt'\n )['input_ids']\n mask = (input_ids!=tokenizer.pad_token_id).long()\n labels[mask==0]=-100\n return image, input_ids, labels","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-28T18:10:22.591014Z","iopub.execute_input":"2024-12-28T18:10:22.591252Z","iopub.status.idle":"2024-12-28T18:10:22.604353Z","shell.execute_reply.started":"2024-12-28T18:10:22.591223Z","shell.execute_reply":"2024-12-28T18:10:22.603746Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Architecture Building\n---","metadata":{}},{"cell_type":"code","source":"class GPT2Attention(nn.Module):\n def __init__(self,config):\n super().__init__()\n self.embed_dim = config.embed_dim\n self.n_heads = config.num_heads\n assert self.embed_dim % self.n_heads == 0, 'embedding dimension by be divisible by number of heads'\n self.head_size = self.embed_dim // self.n_heads\n self.seq_len = config.seq_len\n \n self.c_attn = nn.Linear(self.embed_dim, self.head_size * self.n_heads * 3,bias=True)\n self.scale = self.head_size ** -0.5\n \n self.register_buffer('mask',torch.tril(torch.ones(1,1,self.seq_len,self.seq_len)))\n \n self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)\n \n self.attn_dropout = nn.Dropout(config.attention_dropout)\n self.resid_dropout = nn.Dropout(config.residual_dropout)\n \n \n def forward(self, x):\n b,t,c = x.shape\n # q,k,v shape individually: batch_size x seq_len x embed_dim\n # we know that qk_t = q x k_t, where q=bxtxhead_dim, k_t=bxhead_timxt\n q,k,v = self.c_attn(x).chunk(3,dim=-1)\n q = q.view(b,t,self.n_heads,self.head_size).permute(0,2,1,3) # batch x n_heads x seq_len x head_dim\n k = k.view(b,t,self.n_heads,self.head_size).permute(0,2,1,3)\n v = v.view(b,t,self.n_heads,self.head_size).permute(0,2,1,3)\n \n qk_t = (q@k.transpose(-2,-1)) * self.scale\n qk_t = qk_t.masked_fill(self.mask[:,:,:t,:t]==0,float('-inf'))\n qk_t = F.softmax(qk_t,dim=-1)\n weights = self.attn_dropout(qk_t)\n \n attention = weights @ v # batch x n_heads x t x head_size\n attention = attention.permute(0,2,1,3).contiguous().view(b,t,c) # batch x t x embed_dim\n \n out = self.c_proj(attention)\n out = self.resid_dropout(out)\n \n return out","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-28T18:10:22.605075Z","iopub.execute_input":"2024-12-28T18:10:22.605297Z","iopub.status.idle":"2024-12-28T18:10:22.618563Z","shell.execute_reply.started":"2024-12-28T18:10:22.605279Z","shell.execute_reply":"2024-12-28T18:10:22.617878Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"class GPT2CrossAttention(nn.Module):\n def __init__(self,config):\n super().__init__()\n self.embed_dim = config.embed_dim\n self.n_heads = config.num_heads\n assert self.embed_dim % self.n_heads == 0, 'embedding dimension by be divisible by number of heads'\n self.head_size = self.embed_dim // self.n_heads\n self.seq_len = config.seq_len\n \n self.q = nn.Linear(self.embed_dim,self.embed_dim)\n self.k = nn.Linear(self.embed_dim,self.embed_dim)\n self.v = nn.Linear(self.embed_dim,self.embed_dim)\n self.scale = self.head_size ** -0.5\n \n self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)\n \n self.attn_dropout = nn.Dropout(config.attention_dropout)\n self.resid_dropout = nn.Dropout(config.residual_dropout)\n \n self.apply(self._init_weights)\n \n def _init_weights(self, module):\n if isinstance(module, nn.Linear):\n torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n if module.bias is not None:\n torch.nn.init.zeros_(module.bias)\n \n \n def forward(self, q,k,v):\n b,t,c = q.shape\n \n q = self.q(q)\n k = self.k(k)\n v = self.v(v)\n \n q = q.view(b,q.size(1),self.n_heads,self.head_size).permute(0,2,1,3) # batch x n_heads x seq_len x head_dim\n k = k.view(b,k.size(1),self.n_heads,self.head_size).permute(0,2,1,3)\n v = v.view(b,v.size(1),self.n_heads,self.head_size).permute(0,2,1,3)\n \n qk_t = (q@k.transpose(-2,-1)) * self.scale\n qk_t = F.softmax(qk_t,dim=-1)\n weights = self.attn_dropout(qk_t)\n \n attention = weights @ v # batch x n_heads x t x head_size\n attention = attention.permute(0,2,1,3).contiguous().view(b,t,c) # batch x t x embed_dim\n \n out = self.c_proj(attention)\n out = self.resid_dropout(out)\n \n return out","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-28T18:10:22.619273Z","iopub.execute_input":"2024-12-28T18:10:22.619535Z","iopub.status.idle":"2024-12-28T18:10:22.635423Z","shell.execute_reply.started":"2024-12-28T18:10:22.619501Z","shell.execute_reply":"2024-12-28T18:10:22.634758Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"class GPT2MLP(nn.Module):\n def __init__(self,config):\n super().__init__()\n self.embed_dim = config.embed_dim\n self.mlp_ratio = config.mlp_ratio\n self.mlp_dropout = config.mlp_dropout\n \n self.c_fc = nn.Linear(self.embed_dim,self.embed_dim*self.mlp_ratio)\n self.c_proj = nn.Linear(self.embed_dim*self.mlp_ratio,self.embed_dim)\n self.act = nn.GELU()\n self.dropout = nn.Dropout(self.mlp_dropout)\n \n def forward(self,x):\n x = self.c_fc(x)\n x = self.act(x)\n x = self.c_proj(x)\n x = self.dropout(x)\n return x","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-28T18:10:22.636344Z","iopub.execute_input":"2024-12-28T18:10:22.636649Z","iopub.status.idle":"2024-12-28T18:10:22.651124Z","shell.execute_reply.started":"2024-12-28T18:10:22.636618Z","shell.execute_reply":"2024-12-28T18:10:22.650531Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"class GPT2Block(nn.Module):\n def __init__(self,config):\n super().__init__()\n self.embed_dim = config.embed_dim\n self.ln_1 = nn.LayerNorm(self.embed_dim)\n self.attn = GPT2Attention(config)\n self.ln_2 = nn.LayerNorm(self.embed_dim)\n self.mlp = GPT2MLP(config)\n self.ln_3 = nn.LayerNorm(self.embed_dim)\n self.cross_attn = GPT2CrossAttention(config)\n \n def forward(self,x,enc_out):\n x = x+self.attn(self.ln_1(x))\n x = x+self.cross_attn(self.ln_2(x),enc_out,enc_out)\n x = x+self.mlp(self.ln_3(x))\n return x","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-28T18:10:22.651984Z","iopub.execute_input":"2024-12-28T18:10:22.652254Z","iopub.status.idle":"2024-12-28T18:10:22.667036Z","shell.execute_reply.started":"2024-12-28T18:10:22.652227Z","shell.execute_reply":"2024-12-28T18:10:22.666383Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Model Implementation\n---\nViT Encoder:\nExtracts image embeddings in a transformer-friendly format.\n\nGPT2 Decoder:\nGPT2 generates captions by predicting the next word in the sequence based on the encoded image.","metadata":{}},{"cell_type":"code","source":"class VisionGPT2Model(nn.Module):\n def __init__(self,config):\n super().__init__()\n \n self.config = config\n \n vit = create_model('vit_base_patch16_224',pretrained=True,num_classes=0)\n self.patch_embed = vit.patch_embed\n num_patches = self.patch_embed.num_patches\n \n self.cls_token = vit.cls_token\n embed_len = num_patches + vit.num_prefix_tokens\n self.pos_embed = vit.pos_embed\n self.pos_drop = nn.Dropout(p=0.)\n \n self.blocks = nn.ModuleList([vit.blocks[i] for i in range(config.depth)])\n \n self.transformer = nn.ModuleDict(dict(\n wte = nn.Embedding(config.vocab_size,config.embed_dim),\n wpe = nn.Embedding(config.seq_len,config.embed_dim),\n drop = nn.Dropout(config.emb_dropout),\n h = nn.ModuleList([GPT2Block(config) for _ in range(config.depth)]),\n ln_f = nn.LayerNorm(config.embed_dim)\n ))\n self.lm_head = nn.Linear(config.embed_dim,config.vocab_size,bias=False)\n self.transformer.wte.weight = self.lm_head.weight\n \n def _pos_embed(self,x):\n pos_embed = self.pos_embed\n x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)\n x = x + pos_embed\n return self.pos_drop(x)\n \n def pretrained_layers_trainable(self,trainable=False):\n layers = [\n self.cls_token, self.patch_embed, self.pos_embed, self.blocks,\n self.transformer.wte, self.transformer.wpe,\n self.transformer.ln_f, self.lm_head\n ]\n gpt_layers = [[\n self.transformer.h[i].ln_1,self.transformer.h[i].ln_2,\n self.transformer.h[i].attn,self.transformer.h[i].mlp\n ] for i in range(self.config.depth)]\n for l in gpt_layers:\n layers.extend(l)\n \n for layer in layers:\n if not isinstance(layer,nn.Parameter):\n for p in layer.parameters():\n p.requires_grad = trainable\n else:\n layer.requires_grad = trainable\n \n total_frozen_params = sum([p.numel() for p in self.parameters() if not p.requires_grad])\n print(f'{total_frozen_params=}')\n \n def unfreeze_gpt_layers(self,):\n gpt_layers = [[\n self.transformer.h[i].ln_1,self.transformer.h[i].ln_2,\n self.transformer.h[i].attn,self.transformer.h[i].mlp\n ] for i in range(self.config.depth)]\n flatten = []\n for l in gpt_layers:\n flatten.extend(l)\n \n for layer in flatten:\n if not isinstance(layer,nn.Parameter):\n for p in layer.parameters():\n p.requires_grad = True\n else:\n layer.requires_grad = True\n \n @classmethod \n def from_pretrained(self,config):\n model = VisionGPT2Model(config)\n sd = model.state_dict()\n keys = sd.keys()\n ignore_matches = ['blocks.','cross_attn.','ln_3','cls_token','pos_embed','patch_embed.','.attn.mask']\n vit_keys = [key for key in keys if any(match in key for match in ignore_matches)]\n gpt_keys = [key for key in keys if key not in vit_keys]\n \n gpt2_small = GPT2LMHeadModel.from_pretrained('gpt2')\n sd_hf = gpt2_small.state_dict()\n hf_keys = sd_hf.keys()\n hf_keys = [k for k in hf_keys if not k.endswith('.attn.masked_bias')]\n hf_keys = [k for k in hf_keys if not k.endswith('.attn.bias')]\n transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']\n \n for k in hf_keys:\n if any(match in k for match in ignore_matches):\n continue\n if any(k.endswith(w) for w in transposed):\n assert sd_hf[k].shape[::-1] == sd[k].shape\n with torch.no_grad():\n sd[k].copy_(sd_hf[k].t())\n else:\n assert sd_hf[k].shape == sd[k].shape\n with torch.no_grad():\n sd[k].copy_(sd_hf[k])\n \n model.load_state_dict(sd)\n \n return model\n \n def forward(self,image,input_ids,labels=None):\n \n image = self.patch_embed(image)\n image = self._pos_embed(image)\n \n token_embeddings = self.transformer.wte(input_ids)\n pos_embs = torch.arange(0,input_ids.size(1)).to(input_ids.device)\n positional_embeddings = self.transformer.wpe(pos_embs)\n input_ids = self.transformer.drop(token_embeddings+positional_embeddings)\n \n for i in range(self.config.depth):\n image = self.blocks[i](image)\n input_ids = self.transformer.h[i](input_ids, image)\n \n input_ids = self.transformer.ln_f(input_ids)\n \n if labels is not None:\n lm_logits = self.lm_head(input_ids)\n loss = F.cross_entropy(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1))\n return loss\n \n lm_logits = self.lm_head(input_ids[:,[-1],:])\n return lm_logits\n \n def generate(self,image,sequence,max_tokens=50,temperature=1.0,deterministic=False):\n for _ in range(max_tokens):\n out = self(image,sequence)\n out = out[:,-1,:] / temperature\n probs = F.softmax(out,dim=-1)\n if deterministic:\n next_token = torch.argmax(probs,dim=-1,keepdim=True)\n else:\n next_token = torch.multinomial(probs,num_samples=1)\n sequence = torch.cat([sequence,next_token],dim=1)\n if next_token.item() == tokenizer.eos_token_id:\n break\n \n return sequence.cpu().flatten()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-28T18:10:22.667794Z","iopub.execute_input":"2024-12-28T18:10:22.668086Z","iopub.status.idle":"2024-12-28T18:10:22.686873Z","shell.execute_reply.started":"2024-12-28T18:10:22.668059Z","shell.execute_reply":"2024-12-28T18:10:22.686053Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Training\n---\n* Batch Size: 32, selected to balance GPU memory limitations and ensure stable gradient updates.\n* Learning Rate: 0.00001, chosen for fine-tuning pre-trained models to allow gradual and stable weight updates.\n* Optimizer: Adam, used for its momentum and adaptive learning rate capabilities, aiding efficient convergence.\n* Loss Function: CrossEntropyLoss, utilized to measure the difference between predicted and ground-truth sequences in image captioning tasks.\n* Learning Rate Scheduler: OneCycleLR, dynamically adjusted the learning rate by starting low, peaking mid-training, and reducing it towards the end to accelerate convergence and minimize overfitting.","metadata":{}},{"cell_type":"code","source":"class Trainer:\n def __init__(self, model_config, train_config, dls):\n self.train_config = train_config\n self.model_config = model_config\n self.device = self.train_config.device\n\n self.model = VisionGPT2Model.from_pretrained(model_config).to(self.device)\n self.model.pretrained_layers_trainable(trainable=True)\n\n print(f'trainable parameters: {sum([p.numel() for p in self.model.parameters() if p.requires_grad])}')\n\n self.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')\n self.tokenizer.pad_token = self.tokenizer.eos_token\n\n self.scaler = GradScaler()\n\n self.train_dl, self.val_dl = dls\n\n total_steps = len(self.train_dl)\n\n self.optim = torch.optim.Adam(self.model.parameters(), lr=self.train_config.lr / 25.)\n self.sched = torch.optim.lr_scheduler.OneCycleLR(\n self.optim,\n max_lr=self.train_config.lr,\n epochs=self.train_config.epochs,\n steps_per_epoch=total_steps\n )\n\n self.metrics = pd.DataFrame()\n self.metrics[['train_loss', 'train_perplexity', 'val_loss', 'val_perplexity']] = None\n\n self.gen_tfms = A.Compose([\n A.Resize(224, 224),\n A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], always_apply=True),\n ToTensorV2()\n ])\n\n def save_model(self):\n self.train_config.model_path.mkdir(exist_ok=True)\n sd = self.model.state_dict()\n torch.save(sd, self.train_config.model_path / 'captioner.pt')\n\n def load_best_model(self):\n sd = torch.load(self.train_config.model_path / 'captioner.pt')\n self.model.load_state_dict(sd)\n\n def load_checkpoint(self):\n sd = torch.load(self.train_config.pretrained_model_path / 'captioner (5).pt')\n self.model.load_state_dict(sd)\n\n def train_one_epoch(self, epoch):\n print(f\"Epoch {epoch + 1}/{self.train_config.epochs}\")\n\n running_loss = 0.\n for batch_idx, (image, input_ids, labels) in enumerate(self.train_dl, 1):\n with autocast():\n image = image.to(self.device)\n input_ids = input_ids.to(self.device)\n labels = labels.to(self.device)\n\n loss = self.model(image, input_ids, labels)\n\n self.scaler.scale(loss).backward()\n self.scaler.step(self.optim)\n self.scaler.update()\n self.sched.step()\n self.optim.zero_grad(set_to_none=True)\n\n running_loss += loss.item()\n\n if batch_idx % 1000 == 0:\n print(f\"Batch {batch_idx}/{len(self.train_dl)} - Train Loss: {loss.item():.3f}\")\n\n del image, input_ids, labels, loss\n\n train_loss = running_loss / len(self.train_dl)\n train_pxp = np.exp(train_loss)\n\n self.metrics.loc[epoch, ['train_loss', 'train_perplexity']] = (train_loss, train_pxp)\n\n @torch.no_grad()\n def valid_one_epoch(self, epoch):\n print(f\"Validating Epoch {epoch + 1}/{self.train_config.epochs}\")\n\n running_loss = 0.\n for batch_idx, (image, input_ids, labels) in enumerate(self.val_dl, 1):\n with autocast():\n image = image.to(self.device)\n input_ids = input_ids.to(self.device)\n labels = labels.to(self.device)\n\n loss = self.model(image, input_ids, labels)\n running_loss += loss.item()\n\n if batch_idx % 500 == 0:\n print(f\"Batch {batch_idx}/{len(self.val_dl)} - Validation Loss: {loss.item():.3f}\")\n\n del image, input_ids, labels, loss\n\n val_loss = running_loss / len(self.val_dl)\n val_pxp = np.exp(val_loss)\n\n self.metrics.loc[epoch, ['val_loss', 'val_perplexity']] = (val_loss, val_pxp)\n\n return val_pxp\n\n def clean(self):\n gc.collect()\n torch.cuda.empty_cache()\n\n def fit(self):\n best_pxp = 1e9\n best_epoch = -1\n\n for epoch in range(self.train_config.epochs):\n if epoch == self.train_config.freeze_epochs_gpt:\n self.model.unfreeze_gpt_layers()\n print('Unfreezing GPT2 entirely...')\n\n if epoch == self.train_config.freeze_epochs_all:\n self.model.pretrained_layers_trainable(trainable=True)\n\n self.model.train()\n self.train_one_epoch(epoch)\n self.clean()\n\n self.model.eval()\n pxp = self.valid_one_epoch(epoch)\n self.clean()\n\n print(self.metrics.tail(1))\n\n if pxp < best_pxp:\n best_pxp = pxp\n best_epoch = epoch\n print('Saving best model...')\n self.save_model()\n\n return {\n 'best_perplexity': best_pxp,\n 'best_epoch': best_epoch\n }\n\n @torch.no_grad()\n def generate_caption(self, image, max_tokens=30, temperature=0, deterministic=True):\n self.model.eval()\n\n image = Image.open(image).convert('RGB')\n image = np.array(image)\n image = self.gen_tfms(image=image)['image']\n image = image.unsqueeze(0).to(self.device)\n sequence = torch.ones(1, 1).to(device=self.device).long() * self.tokenizer.bos_token_id\n\n caption = self.model.generate(\n image,\n sequence,\n max_tokens=max_tokens,\n temperature=temperature,\n deterministic=deterministic\n )\n caption = self.tokenizer.decode(caption.numpy(), skip_special_tokens=True)\n\n return caption\n \n @torch.no_grad()\n def generate_caption_from_tensor(self, image_tensor, max_tokens=30, temperature=0, deterministic=False):\n \"\"\"\n Generate caption from an image tensor.\n \"\"\"\n self.model.eval()\n \n # Ensure the tensor is in the correct format and on the appropriate device\n image_tensor = image_tensor.unsqueeze(0).to(self.device)\n \n # Prepare the initial sequence for caption generation\n sequence = torch.ones(1, 1).to(device=self.device).long() * self.tokenizer.bos_token_id\n \n # Generate the caption using the model\n caption = self.model.generate(\n image_tensor,\n sequence,\n max_tokens=max_tokens,\n temperature=temperature,\n deterministic=deterministic\n )\n caption = self.tokenizer.decode(caption.cpu().numpy(), skip_special_tokens=True)\n \n return caption","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-28T18:10:22.687717Z","iopub.execute_input":"2024-12-28T18:10:22.687948Z","iopub.status.idle":"2024-12-28T18:10:22.710018Z","shell.execute_reply.started":"2024-12-28T18:10:22.687930Z","shell.execute_reply":"2024-12-28T18:10:22.709192Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"model_config = SimpleNamespace(\n vocab_size = 50_257,\n embed_dim = 768,\n num_heads = 12,\n seq_len = 1024,\n depth = 12,\n attention_dropout = 0.1,\n residual_dropout = 0.1,\n mlp_ratio = 4,\n mlp_dropout = 0.1,\n emb_dropout = 0.1,\n)\ntrain_config = SimpleNamespace(\n epochs = 10,\n freeze_epochs_gpt = 1,\n freeze_epochs_all = 2,\n lr = 1e-5,\n device = 'cuda',\n model_path = Path('captioner'),\n pretrained_model_path = Path('/kaggle/input/vit-gpt-result'),\n batch_size = 32\n)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-28T18:10:22.710679Z","iopub.execute_input":"2024-12-28T18:10:22.710884Z","iopub.status.idle":"2024-12-28T18:10:22.728307Z","shell.execute_reply.started":"2024-12-28T18:10:22.710865Z","shell.execute_reply":"2024-12-28T18:10:22.727611Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"train_dl = torch.utils.data.DataLoader(train_ds,batch_size=train_config.batch_size,shuffle=True,pin_memory=True,num_workers=2,persistent_workers=True,collate_fn=collate_fn)\nval_dl = torch.utils.data.DataLoader(val_ds,batch_size=train_config.batch_size,shuffle=False,pin_memory=True,num_workers=2,persistent_workers=True,collate_fn=collate_fn)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-28T18:10:22.728965Z","iopub.execute_input":"2024-12-28T18:10:22.729151Z","iopub.status.idle":"2024-12-28T18:10:22.742688Z","shell.execute_reply.started":"2024-12-28T18:10:22.729134Z","shell.execute_reply":"2024-12-28T18:10:22.741960Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"trainer = Trainer(model_config,train_config,(train_dl,val_dl))","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-28T18:10:22.743293Z","iopub.execute_input":"2024-12-28T18:10:22.743470Z","iopub.status.idle":"2024-12-28T18:10:31.802454Z","shell.execute_reply.started":"2024-12-28T18:10:22.743454Z","shell.execute_reply":"2024-12-28T18:10:31.801471Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"trainer.load_checkpoint()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-28T18:10:52.194321Z","iopub.execute_input":"2024-12-28T18:10:52.194662Z","iopub.status.idle":"2024-12-28T18:10:52.979064Z","shell.execute_reply.started":"2024-12-28T18:10:52.194635Z","shell.execute_reply":"2024-12-28T18:10:52.978370Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"trainer.fit()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-28T18:11:06.689658Z","iopub.execute_input":"2024-12-28T18:11:06.689986Z","iopub.status.idle":"2024-12-29T01:29:08.003989Z","shell.execute_reply.started":"2024-12-28T18:11:06.689951Z","shell.execute_reply":"2024-12-29T01:29:08.002569Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Predictions\n---","metadata":{}},{"cell_type":"code","source":"trainer.load_best_model()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-26T12:50:44.345564Z","iopub.execute_input":"2024-12-26T12:50:44.345865Z","iopub.status.idle":"2024-12-26T12:50:45.154500Z","shell.execute_reply.started":"2024-12-26T12:50:44.345843Z","shell.execute_reply":"2024-12-26T12:50:45.153545Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"def display_images_with_generated_captions(trainer, image_paths, num_samples=50):\n for i in range(min(num_samples, len(image_paths))):\n test_img = os.path.join(\"/kaggle/input/coco-2017-dataset/coco2017/test2017/\", image_paths[i])\n\n t = np.random.uniform(0, 0.5)\n\n gen_caption = trainer.generate_caption(test_img, temperature=t, deterministic=True)\n\n plt.imshow(Image.open(test_img).convert('RGB'))\n plt.axis('off')\n plt.show()\n print(f\"Generated: {gen_caption}\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-26T18:12:49.016432Z","iopub.execute_input":"2024-12-26T18:12:49.016780Z","iopub.status.idle":"2024-12-26T18:12:49.022592Z","shell.execute_reply.started":"2024-12-26T18:12:49.016753Z","shell.execute_reply":"2024-12-26T18:12:49.021688Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# image_paths = [\n# '000000166391.jpg',\n# '000000425702.jpg',\n# '000000575815.jpg',\n# '000000525322.jpg',\n# '000000574520.jpg',\n# '000000184324.jpg'\n# ]\n\nimage_paths = [\n'000000407116.jpg',\n'000000100786.jpg',\n'000000499739.jpg',\n'000000337080.jpg',\n'000000523525.jpg',\n'000000534661.jpg'\n]\n\ndisplay_images_with_generated_captions(trainer, image_paths)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-26T18:12:50.450040Z","iopub.execute_input":"2024-12-26T18:12:50.450346Z","iopub.status.idle":"2024-12-26T18:12:52.833629Z","shell.execute_reply.started":"2024-12-26T18:12:50.450319Z","shell.execute_reply":"2024-12-26T18:12:52.832832Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"for i in range(10):\n det = True\n test = val_df.sample(n=1).values[0]\n test_img, test_caption = test[0],test[1]\n plt.imshow(Image.open(test_img).convert('RGB'))\n t = np.random.uniform(0,0.1)\n gen_caption = trainer.generate_caption(test_img,temperature=t,deterministic=det)\n plt.title(f\"actual: {test_caption}\\nmodel: {gen_caption}\\ntemp: {t} deterministic generation: {det}\")\n plt.axis('off')\n plt.show()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-28T05:20:07.206411Z","iopub.execute_input":"2024-12-28T05:20:07.206738Z","iopub.status.idle":"2024-12-28T05:20:11.746696Z","shell.execute_reply.started":"2024-12-28T05:20:07.206713Z","shell.execute_reply":"2024-12-28T05:20:11.745681Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Evaluation\n---\nTo assess the performance of the image captioning models, we use the following metrics:\n\nBLEU (Bilingual Evaluation Understudy): Measures how many n-grams in the generated caption overlap with the reference captions.\nBLEU-1: Measures unigram precision. BLEU-2, BLEU-3, BLEU-4: Extend to bigrams, trigrams, and 4-grams, respectively.\n\nMETEOR (Metric for Evaluation of Translation with Explicit ORdering): Considers precision, recall, and alignment by matching words and phrases. It accounts for synonyms and stemming, making it more semantically aware.\n\nROUGE-L (Recall-Oriented Understudy for Gisting Evaluation): Focuses on the longest common subsequence (LCS) between the generated and reference captions, capturing sentence-level structure.","metadata":{}},{"cell_type":"code","source":"def clean_and_tokenize(text):\n text = re.sub(r'[^a-zA-Z0-9\\s]', '', text) \n return word_tokenize(text.lower())\n\ndef generate_tokenized_captions(trainer, df):\n pred_result = []\n valid_result = []\n \n grouped_df = df.groupby('image')['caption'].apply(list).reset_index()\n \n print(\"Generating captions...\")\n for _, row in tqdm(grouped_df.iterrows(), total=len(grouped_df), desc=\"Processing\"):\n test_img = row['image']\n captions = row['caption']\n \n tokenized_captions = [clean_and_tokenize(caption) for caption in captions]\n valid_result.append(tokenized_captions)\n \n t = np.random.uniform(0, 0.1)\n gen_caption = trainer.generate_caption(test_img, temperature=t, deterministic=True)\n \n pred_result.append(clean_and_tokenize(gen_caption))\n \n return pred_result, valid_result\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-27T10:02:34.389657Z","iopub.execute_input":"2024-12-27T10:02:34.389953Z","iopub.status.idle":"2024-12-27T10:02:34.396573Z","shell.execute_reply.started":"2024-12-27T10:02:34.389930Z","shell.execute_reply":"2024-12-27T10:02:34.395609Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"pred_result, valid_result = generate_tokenized_captions(trainer, val_df)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-27T10:02:42.146715Z","iopub.execute_input":"2024-12-27T10:02:42.146993Z","iopub.status.idle":"2024-12-27T10:21:36.340795Z","shell.execute_reply.started":"2024-12-27T10:02:42.146970Z","shell.execute_reply":"2024-12-27T10:21:36.339842Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"def preview_results(pred_result, valid_result, num_samples=5):\n for i, (pred, refs) in enumerate(zip(pred_result, valid_result)):\n if i >= num_samples:\n break\n print(f\"Sample {i + 1}:\")\n print(f\" Predicted: {' '.join(pred)}\")\n print(\" References:\")\n for j, ref in enumerate(refs):\n print(f\" {j + 1}: {' '.join(ref)}\")\n print()\n \npreview_results(pred_result, valid_result, num_samples=5)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-27T10:21:36.341916Z","iopub.execute_input":"2024-12-27T10:21:36.342178Z","iopub.status.idle":"2024-12-27T10:21:36.352867Z","shell.execute_reply.started":"2024-12-27T10:21:36.342147Z","shell.execute_reply":"2024-12-27T10:21:36.352076Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"def calculate_cider(predicted, references):\n def generate_ngrams(tokens, n):\n return [' '.join(tokens[i:i+n]) for i in range(len(tokens) - n + 1)]\n\n cider_score = 0.0\n n = 4 # CIDEr typically uses up to 4-grams\n\n for k in range(1, n+1):\n pred_kgrams = generate_ngrams(predicted, k)\n ref_kgrams = [generate_ngrams(ref, k) for ref in references]\n \n pred_tf = Counter(pred_kgrams)\n ref_tf = Counter([gram for ref in ref_kgrams for gram in ref])\n\n def compute_tfidf(tf, doc_tf, n_docs):\n return {gram: tf[gram] * log(n_docs / (1 + doc_tf[gram])) for gram in tf}\n\n pred_tfidf = compute_tfidf(pred_tf, ref_tf, len(references))\n ref_tfidf = compute_tfidf(ref_tf, ref_tf, len(references))\n\n overlap = set(pred_tfidf.keys()) & set(ref_tfidf.keys())\n numerator = sum(pred_tfidf[gram] * ref_tfidf[gram] for gram in overlap)\n denominator = sqrt(sum(v**2 for v in pred_tfidf.values())) * \\\n sqrt(sum(v**2 for v in ref_tfidf.values()))\n\n cider_score += numerator / denominator if denominator > 0 else 0.0\n\n return cider_score / n\n\ndef calculate_rouge_l(predicted, references):\n def lcs_length(x, y):\n dp = [[0] * (len(y) + 1) for _ in range(len(x) + 1)]\n for i in range(1, len(x) + 1):\n for j in range(1, len(y) + 1):\n if x[i-1] == y[j-1]:\n dp[i][j] = dp[i-1][j-1] + 1\n else:\n dp[i][j] = max(dp[i-1][j], dp[i][j-1])\n return dp[-1][-1]\n\n rouge_l_scores = []\n for ref in references:\n lcs = lcs_length(predicted, ref)\n precision = lcs / len(predicted) if len(predicted) > 0 else 0\n recall = lcs / len(ref) if len(ref) > 0 else 0\n if precision + recall > 0:\n f1 = (2 * precision * recall) / (precision + recall)\n else:\n f1 = 0\n rouge_l_scores.append(f1)\n\n return max(rouge_l_scores) # Use the best matching reference\n\ndef evaluate(pred_result, valid_result):\n smoothing_function = SmoothingFunction().method1\n bleu_scores = {'BLEU-1': [], 'BLEU-2': [], 'BLEU-3': [], 'BLEU-4': []}\n cider_scores = []\n meteor_scores = []\n rouge_l_scores = []\n\n for pred, refs in zip(pred_result, valid_result):\n bleu_scores['BLEU-1'].append(sentence_bleu(refs, pred, weights=(1, 0, 0, 0), smoothing_function=smoothing_function))\n bleu_scores['BLEU-2'].append(sentence_bleu(refs, pred, weights=(0.5, 0.5, 0, 0), smoothing_function=smoothing_function))\n bleu_scores['BLEU-3'].append(sentence_bleu(refs, pred, weights=(0.33, 0.33, 0.33, 0), smoothing_function=smoothing_function))\n bleu_scores['BLEU-4'].append(sentence_bleu(refs, pred, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smoothing_function))\n\n cider_scores.append(calculate_cider(pred, refs))\n meteor_scores.append(meteor_score([' '.join(ref) for ref in refs], ' '.join(pred)))\n rouge_l_scores.append(calculate_rouge_l(pred, refs))\n\n avg_scores = {\n 'BLEU-1': np.mean(bleu_scores['BLEU-1']),\n 'BLEU-2': np.mean(bleu_scores['BLEU-2']),\n 'BLEU-3': np.mean(bleu_scores['BLEU-3']),\n 'BLEU-4': np.mean(bleu_scores['BLEU-4']),\n 'CIDEr': np.mean(cider_scores),\n 'METEOR': np.mean(meteor_scores),\n 'ROUGE-L': np.mean(rouge_l_scores)\n }\n\n return avg_scores\n","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"scores = evaluate(pred_result, valid_result)\nprint(\"Evaluation Scores:\", scores)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-27T10:21:36.371303Z","iopub.execute_input":"2024-12-27T10:21:36.371501Z","iopub.status.idle":"2024-12-27T10:21:50.491589Z","shell.execute_reply.started":"2024-12-27T10:21:36.371484Z","shell.execute_reply":"2024-12-27T10:21:50.490678Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Save Config\n---","metadata":{}},{"cell_type":"code","source":"def save_model_as_torchscript(model, save_path):\n model.eval()\n \n image_size = torch.rand(1, 3, 224, 224)\n input_ids = torch.randint(0, model.config.vocab_size, (1, model.config.seq_len))\n \n traced_script_module = torch.jit.trace(model, (image_size, input_ids))\n \n traced_script_module.save(save_path)\n print(f\"Model saved as TorchScript to {save_path}\")\n\n\nmodel = VisionGPT2Model.from_pretrained(model_config)\n\nmodel.load_state_dict(torch.load('/kaggle/input/vit-gpt-result/captioner (3).pt'))\nsave_model_as_torchscript(model, \"vision_gpt2_model.pt\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-28T05:05:33.853622Z","iopub.execute_input":"2024-12-28T05:05:33.853952Z","iopub.status.idle":"2024-12-28T05:05:57.314450Z","shell.execute_reply.started":"2024-12-28T05:05:33.853921Z","shell.execute_reply":"2024-12-28T05:05:57.313457Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"def load_torchscript_model(load_path):\n loaded_model = torch.jit.load(load_path)\n print(f\"Model loaded from {load_path}\")\n return loaded_model\n\nloaded_model = load_torchscript_model(\"vision_gpt2_model.pt\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-28T05:06:01.798208Z","iopub.execute_input":"2024-12-28T05:06:01.798502Z","iopub.status.idle":"2024-12-28T05:06:02.702897Z","shell.execute_reply.started":"2024-12-28T05:06:01.798481Z","shell.execute_reply":"2024-12-28T05:06:02.702023Z"}},"outputs":[],"execution_count":null}]}