|
|
|
|
|
import torch
|
|
import torchvision.transforms as transforms
|
|
import torch.utils.data as data
|
|
import os
|
|
import pickle
|
|
import numpy as np
|
|
import nltk
|
|
from PIL import Image
|
|
from build_vocab import Vocabulary
|
|
import random
|
|
import json
|
|
import lmdb
|
|
|
|
|
|
class Recipe1MDataset(data.Dataset):
|
|
|
|
def __init__(self, data_dir, aux_data_dir, split, maxseqlen, maxnuminstrs, maxnumlabels, maxnumims,
|
|
transform=None, max_num_samples=-1, use_lmdb=False, suff=''):
|
|
|
|
self.ingrs_vocab = pickle.load(open(os.path.join(aux_data_dir, suff + 'recipe1m_vocab_ingrs.pkl'), 'rb'))
|
|
self.instrs_vocab = pickle.load(open(os.path.join(aux_data_dir, suff + 'recipe1m_vocab_toks.pkl'), 'rb'))
|
|
self.dataset = pickle.load(open(os.path.join(aux_data_dir, suff + 'recipe1m_'+split+'.pkl'), 'rb'))
|
|
|
|
self.label2word = self.get_ingrs_vocab()
|
|
|
|
self.use_lmdb = use_lmdb
|
|
if use_lmdb:
|
|
self.image_file = lmdb.open(os.path.join(aux_data_dir, 'lmdb_' + split), max_readers=1, readonly=True,
|
|
lock=False, readahead=False, meminit=False)
|
|
|
|
self.ids = []
|
|
self.split = split
|
|
for i, entry in enumerate(self.dataset):
|
|
if len(entry['images']) == 0:
|
|
continue
|
|
self.ids.append(i)
|
|
|
|
self.root = os.path.join(data_dir, 'images', split)
|
|
self.transform = transform
|
|
self.max_num_labels = maxnumlabels
|
|
self.maxseqlen = maxseqlen
|
|
self.max_num_instrs = maxnuminstrs
|
|
self.maxseqlen = maxseqlen*maxnuminstrs
|
|
self.maxnumims = maxnumims
|
|
if max_num_samples != -1:
|
|
random.shuffle(self.ids)
|
|
self.ids = self.ids[:max_num_samples]
|
|
|
|
def get_instrs_vocab(self):
|
|
return self.instrs_vocab
|
|
|
|
def get_instrs_vocab_size(self):
|
|
return len(self.instrs_vocab)
|
|
|
|
def get_ingrs_vocab(self):
|
|
return [min(w, key=len) if not isinstance(w, str) else w for w in
|
|
self.ingrs_vocab.idx2word.values()]
|
|
|
|
def get_ingrs_vocab_size(self):
|
|
return len(self.ingrs_vocab)
|
|
|
|
def __getitem__(self, index):
|
|
"""Returns one data pair (image and caption)."""
|
|
|
|
sample = self.dataset[self.ids[index]]
|
|
img_id = sample['id']
|
|
captions = sample['tokenized']
|
|
paths = sample['images'][0:self.maxnumims]
|
|
|
|
idx = index
|
|
|
|
labels = self.dataset[self.ids[idx]]['ingredients']
|
|
title = sample['title']
|
|
|
|
tokens = []
|
|
tokens.extend(title)
|
|
|
|
tokens.append('<eoi>')
|
|
for c in captions:
|
|
tokens.extend(c)
|
|
tokens.append('<eoi>')
|
|
|
|
ilabels_gt = np.ones(self.max_num_labels) * self.ingrs_vocab('<pad>')
|
|
pos = 0
|
|
|
|
true_ingr_idxs = []
|
|
for i in range(len(labels)):
|
|
true_ingr_idxs.append(self.ingrs_vocab(labels[i]))
|
|
|
|
for i in range(self.max_num_labels):
|
|
if i >= len(labels):
|
|
label = '<pad>'
|
|
else:
|
|
label = labels[i]
|
|
label_idx = self.ingrs_vocab(label)
|
|
if label_idx not in ilabels_gt:
|
|
ilabels_gt[pos] = label_idx
|
|
pos += 1
|
|
|
|
ilabels_gt[pos] = self.ingrs_vocab('<end>')
|
|
ingrs_gt = torch.from_numpy(ilabels_gt).long()
|
|
|
|
if len(paths) == 0:
|
|
path = None
|
|
image_input = torch.zeros((3, 224, 224))
|
|
else:
|
|
if self.split == 'train':
|
|
img_idx = np.random.randint(0, len(paths))
|
|
else:
|
|
img_idx = 0
|
|
path = paths[img_idx]
|
|
if self.use_lmdb:
|
|
try:
|
|
with self.image_file.begin(write=False) as txn:
|
|
image = txn.get(path.encode())
|
|
image = np.fromstring(image, dtype=np.uint8)
|
|
image = np.reshape(image, (256, 256, 3))
|
|
image = Image.fromarray(image.astype('uint8'), 'RGB')
|
|
except:
|
|
print ("Image id not found in lmdb. Loading jpeg file...")
|
|
image = Image.open(os.path.join(self.root, path[0], path[1],
|
|
path[2], path[3], path)).convert('RGB')
|
|
else:
|
|
image = Image.open(os.path.join(self.root, path[0], path[1], path[2], path[3], path)).convert('RGB')
|
|
if self.transform is not None:
|
|
image = self.transform(image)
|
|
image_input = image
|
|
|
|
|
|
caption = []
|
|
|
|
caption = self.caption_to_idxs(tokens, caption)
|
|
caption.append(self.instrs_vocab('<end>'))
|
|
|
|
caption = caption[0:self.maxseqlen]
|
|
target = torch.Tensor(caption)
|
|
|
|
return image_input, target, ingrs_gt, img_id, path, self.instrs_vocab('<pad>')
|
|
|
|
def __len__(self):
|
|
return len(self.ids)
|
|
|
|
def caption_to_idxs(self, tokens, caption):
|
|
|
|
caption.append(self.instrs_vocab('<start>'))
|
|
for token in tokens:
|
|
caption.append(self.instrs_vocab(token))
|
|
return caption
|
|
|
|
|
|
def collate_fn(data):
|
|
|
|
|
|
|
|
image_input, captions, ingrs_gt, img_id, path, pad_value = zip(*data)
|
|
|
|
|
|
|
|
image_input = torch.stack(image_input, 0)
|
|
ingrs_gt = torch.stack(ingrs_gt, 0)
|
|
|
|
|
|
lengths = [len(cap) for cap in captions]
|
|
targets = torch.ones(len(captions), max(lengths)).long()*pad_value[0]
|
|
|
|
for i, cap in enumerate(captions):
|
|
end = lengths[i]
|
|
targets[i, :end] = cap[:end]
|
|
|
|
return image_input, targets, ingrs_gt, img_id, path
|
|
|
|
|
|
def get_loader(data_dir, aux_data_dir, split, maxseqlen,
|
|
maxnuminstrs, maxnumlabels, maxnumims, transform, batch_size,
|
|
shuffle, num_workers, drop_last=False,
|
|
max_num_samples=-1,
|
|
use_lmdb=False,
|
|
suff=''):
|
|
|
|
dataset = Recipe1MDataset(data_dir=data_dir, aux_data_dir=aux_data_dir, split=split,
|
|
maxseqlen=maxseqlen, maxnumlabels=maxnumlabels, maxnuminstrs=maxnuminstrs,
|
|
maxnumims=maxnumims,
|
|
transform=transform,
|
|
max_num_samples=max_num_samples,
|
|
use_lmdb=use_lmdb,
|
|
suff=suff)
|
|
|
|
data_loader = torch.utils.data.DataLoader(dataset=dataset,
|
|
batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
|
|
drop_last=drop_last, collate_fn=collate_fn, pin_memory=True)
|
|
return data_loader, dataset
|
|
|