diff --git a/README.md b/README.md index ccd468940ba36c67f6b8679a1b5721f802941cf5..7033cafe7e5844fc9abd5a81524e20d05365e6bc 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,11 @@ -Most of the code of this repo is a copy and reimplementation from [ BryanPlummer / -two_branch_networks](https://github.com/BryanPlummer/two_branch_networks). The aim of the work of the repo is to apply Layer-Wise Relevance Propagation for investigating the image-text matching in a Two-Branch Network. +The code of implementing the Embedding Work in retrieval_model is adapted from [ BryanPlummer / +two_branch_networks](https://github.com/BryanPlummer/two_branch_networks). + +Trained models given different datasets are saved in `models`. + +The aim of the work of the repo is to apply Layer-Wise Relevance Propagation for investigating the image-text matching in a Two-Branch Network. + +All details are inllustrated in the jupyter notebook `lrp.ipynb` prerequisite: diff --git a/retrieval_model_lrp.py b/retrieval_model_lrp.py deleted file mode 100644 index c24a10ff290f6a948c3a554fc14337ede41a8e8b..0000000000000000000000000000000000000000 --- a/retrieval_model_lrp.py +++ /dev/null @@ -1,124 +0,0 @@ -import pickle -import numbers -import torch -import torch.nn as nn - -def make_fc_1d(f_in, f_out): - return nn.Sequential(nn.Linear(f_in, f_out), - nn.BatchNorm1d(f_out), - nn.ReLU(inplace=True), - nn.Dropout(p=0.5)) - -class EmbedBranch(nn.Module): - def __init__(self, feat_dim, embedding_dim, metric_dim): - super(EmbedBranch, self).__init__() - self.fc1 = make_fc_1d(feat_dim, embedding_dim) - print('embedding_dim, metric_dim: ',embedding_dim, metric_dim) - self.fc2 = nn.Linear(embedding_dim, metric_dim) - - def forward(self, x): - x = self.fc1(x) - x = self.fc2(x) - - # L2 normalize each feature vector - x = nn.functional.normalize(x) - return x - -def pdist(x1, x2): - """ - x1: Tensor of shape (h1, w) - x2: Tensor of shape (h2, w) - Return pairwise distance for each row vector in x1, x2 as - a Tensor of shape (h1, h2) - """ - x1_square = torch.sum(x1*x1, 1).view(-1, 1) - x2_square = torch.sum(x2*x2, 1).view(1, -1) - return torch.sqrt(x1_square - 2 * torch.mm(x1, x2.transpose(0, 1)) + x2_square + 1e-4) - -def embedding_loss(im_embeds, sent_embeds, im_labels, args): - """ - im_embeds: (b, 512) image embedding tensors - sent_embeds: (sample_size * b, 512) sentence embedding tensors - where the order of sentence corresponds to the order of images and - setnteces for the same image are next to each other - im_labels: (sample_size * b, b) boolean tensor, where (i, j) entry is - True if and only if sentence[i], image[j] is a positive pair - """ - # compute embedding loss - sent_im_ratio = args.sample_size - num_img = im_embeds.size(0) - num_sent = num_img * sent_im_ratio - - sent_im_dist = pdist(sent_embeds, im_embeds) - im_labels = im_labels > 0 - - # image loss: sentence, positive image, and negative image - pos_pair_dist = torch.masked_select(sent_im_dist, im_labels).view(num_sent, 1) - neg_pair_dist = torch.masked_select(sent_im_dist, ~im_labels).view(num_sent, -1) - im_loss = torch.clamp(args.margin + pos_pair_dist - neg_pair_dist, 0, 1e6) - im_loss = im_loss.topk(args.num_neg_sample)[0].mean() - - # sentence loss: image, positive sentence, and negative sentence - neg_pair_dist = torch.masked_select(sent_im_dist.t(), ~im_labels.t()).view(num_img, -1) - neg_pair_dist = neg_pair_dist.repeat(1, sent_im_ratio).view(num_sent, -1) - sent_loss = torch.clamp(args.margin + pos_pair_dist - neg_pair_dist, 0, 1e6) - sent_loss = sent_loss.topk(args.num_neg_sample)[0].mean() - - # sentence only loss (neighborhood-preserving constraints) - sent_sent_dist = pdist(sent_embeds, sent_embeds) - sent_sent_mask = im_labels.t().repeat(1, sent_im_ratio).view(num_sent, num_sent) - pos_pair_dist = torch.masked_select(sent_sent_dist, sent_sent_mask).view(-1, sent_im_ratio) - pos_pair_dist = pos_pair_dist.max(dim=1, keepdim=True)[0] - neg_pair_dist = torch.masked_select(sent_sent_dist, ~sent_sent_mask).view(num_sent, -1) - sent_only_loss = torch.clamp(args.margin + pos_pair_dist - neg_pair_dist, 0, 1e6) - sent_only_loss = sent_only_loss.topk(args.num_neg_sample)[0].mean() - - loss = im_loss * args.im_loss_factor + sent_loss + sent_only_loss * args.sent_only_loss_factor - return loss - -class ImageSentenceEmbeddingNetwork(nn.Module): - def __init__(self, args, vecs, image_feature_dim): - super(ImageSentenceEmbeddingNetwork, self).__init__() - embedding_dim = args.dim_embed - metric_dim = int(args.dim_embed / 4) - n_tokens, token_dim = vecs.shape - self.words = nn.Embedding(n_tokens, token_dim) - self.words.weight = nn.Parameter(torch.from_numpy(vecs)) - self.vecs = torch.from_numpy(vecs) - self.word_reg = nn.MSELoss() - if args.language_model == 'attend': - self.word_attention = nn.Sequential(nn.Linear(vecs.shape[1] * 2, 1), - nn.ReLU(inplace=True), - nn.Softmax(dim=1)) - - self.text_branch = EmbedBranch(token_dim, embedding_dim, metric_dim) - self.image_branch = EmbedBranch(image_feature_dim, embedding_dim, metric_dim) - self.args = args - if args.cuda: - self.cuda() - self.vecs = self.vecs.cuda() - - - def forward(self, images, tokens): - words = self.words(tokens) - n_words = torch.sum(tokens > 0, 1).float() + 1e-10 - sum_words = words.sum(1).squeeze() - sentences = sum_words / n_words.unsqueeze(1) - - if self.args.language_model == 'attend': - max_length = tokens.size(-1) - context_vector = sentences.unsqueeze(1).repeat(1, max_length, 1) - attention_inputs = torch.cat((context_vector, words), 2) - attention_weights = self.word_attention(attention_inputs) - sentences = nn.functional.normalize(torch.sum(words * attention_weights, 1)) - - sentences = self.text_branch(sentences) - images = self.image_branch(images) - return images, sentences - - def train_forward(self, images, sentences, im_labels): - im_embeds, sent_embeds = self(images, sentences) - embed_loss = embedding_loss(im_embeds, sent_embeds, im_labels, self.args) - word_loss = self.word_reg(self.words.weight, self.vecs) - loss = embed_loss + word_loss * self.args.word_embedding_reg - return loss, im_embeds, sent_embeds