-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata.py
More file actions
157 lines (127 loc) · 5 KB
/
data.py
File metadata and controls
157 lines (127 loc) · 5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import torch
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
import pandas as pd
import spacy
from collections import Counter
from tqdm import tqdm
import json
import os
from PIL import Image
import matplotlib.pyplot as plt
class CapsCollate:
"""
Collate to apply the padding to the captions with dataloader
"""
def __init__(self, pad_idx, batch_first=False):
self.pad_idx = pad_idx
self.batch_first = batch_first
def __call__(self, batch):
imgs = [item[0].unsqueeze(0) for item in batch]
imgs = torch.cat(imgs, dim=0)
targets = [item[1] for item in batch]
targets = pad_sequence(targets, batch_first=self.batch_first, padding_value=self.pad_idx)
return imgs, targets
class Vocabulary:
def __init__(self, freq_threshold):
# setting the pre-reserved tokens int to string tokens
self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
# string to int tokens
# its reverse dict self.itos
self.stoi = {v: k for k, v in self.itos.items()}
self.freq_threshold = freq_threshold
self.spacy = spacy_eng = spacy.load("en_core_web_sm")
def __len__(self):
return len(self.itos)
@staticmethod
def tokenize(text, spacy):
return [token.text.lower() for token in spacy.tokenizer(text)]
def build_vocab(self, sentence_list):
frequencies = Counter()
idx = 4
for sentence in tqdm(sentence_list):
for word in self.tokenize(sentence, self.spacy):
frequencies[word] += 1
# add the word to the vocab if it reaches minum frequecy threshold
if frequencies[word] == self.freq_threshold:
self.stoi[word] = idx
self.itos[idx] = word
idx += 1
def numericalize(self, text):
""" For each word in the text corresponding index token for that word form the vocab built as list """
tokenized_text = self.tokenize(text, self.spacy)
return [self.stoi[token] if token in self.stoi else self.stoi["<UNK>"] for token in tokenized_text]
class FlickrDataset(Dataset):
"""
FlickrDataset
"""
def __init__(self, root_dir, captions_df, vocab, transform=None):
self.root_dir = root_dir
self.df = captions_df
self.transform = transform
# Get image and caption colum from the dataframe
self.imgs = self.df["image"]
self.captions = self.df["caption"]
# Initialize vocabulary and build vocab
self.vocab = vocab
self.last_idx = -1
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
self.last_idx = idx
caption = self.captions[idx]
img_name = self.imgs[idx]
img_location = os.path.join(self.root_dir, img_name)
img = Image.open(img_location).convert("RGB")
# apply the transfromation to the image
if self.transform is not None:
img = self.transform(img)
# numericalize the caption text
caption_vec = []
caption_vec += [self.vocab.stoi["<SOS>"]]
caption_vec += self.vocab.numericalize(caption)
caption_vec += [self.vocab.stoi["<EOS>"]]
return img, torch.tensor(caption_vec)
def get_last_captions(self):
idx = self.last_idx
idx = (idx // 5) * 5
#captions = self.captions[idx // 5 * 5: idx // 5 * 5 + 5]
captions = []
for i in range(1,6):
t = list(self.captions[idx:idx + i])[-1]
t = t.split()
captions.append(t)
return captions
def show_image(inp, title=None):
"""Imshow for Tensor."""
inp = inp.numpy().transpose((1, 2, 0))
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001) # pause a bit so that plots are updated
def build_vocab(captions_file_path):
## load all captions
captions_df = pd.read_csv(captions_file_path)
captions_list = captions_df['caption'].tolist()
## build vocabulary
vocab = Vocabulary(freq_threshold=5)
vocab.build_vocab(captions_list)
return vocab
def karpathy_split(captions_path, karpathy_json_path = 'dataset.json'):
captions_df = pd.read_csv(captions_path)
with open(karpathy_json_path, 'r') as j:
data = json.load(j)
train_images = []
val_images = []
test_images = []
for img in data['images']:
if img['split'] in {'train', 'restval'}:
train_images.append(img['filename'])
elif img['split'] in {'val'}:
val_images.append(img['filename'])
elif img['split'] in {'test'}:
test_images.append(img['filename'])
train_df = captions_df.loc[captions_df['image'].isin(train_images)].reset_index(drop=True)
val_df = captions_df.loc[captions_df['image'].isin(val_images)].reset_index(drop=True)
test_df = captions_df.loc[captions_df['image'].isin(test_images)].reset_index(drop=True)
return train_df, val_df, test_df