diff --git a/README.md b/README.md
index ccd468940ba36c67f6b8679a1b5721f802941cf5..7033cafe7e5844fc9abd5a81524e20d05365e6bc 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,11 @@
-Most of the code of this repo is a copy and reimplementation from [ BryanPlummer /
-two_branch_networks](https://github.com/BryanPlummer/two_branch_networks). The aim of the work of the repo is to apply Layer-Wise Relevance Propagation for investigating the image-text matching in a Two-Branch Network.
+The code of implementing the Embedding Work in retrieval_model is adapted from [ BryanPlummer /
+two_branch_networks](https://github.com/BryanPlummer/two_branch_networks). 
+
+Trained models given different datasets are saved in `models`. 
+
+The aim of the work of the repo is to apply Layer-Wise Relevance Propagation for investigating the image-text matching in a Two-Branch Network. 
+
+All details are inllustrated in the jupyter notebook `lrp.ipynb`
 
 prerequisite:
 
diff --git a/retrieval_model_lrp.py b/retrieval_model_lrp.py
deleted file mode 100644
index c24a10ff290f6a948c3a554fc14337ede41a8e8b..0000000000000000000000000000000000000000
--- a/retrieval_model_lrp.py
+++ /dev/null
@@ -1,124 +0,0 @@
-import pickle
-import numbers
-import torch
-import torch.nn as nn
-
-def make_fc_1d(f_in, f_out):
-    return nn.Sequential(nn.Linear(f_in, f_out),
-                         nn.BatchNorm1d(f_out),
-                         nn.ReLU(inplace=True),
-                         nn.Dropout(p=0.5))
-
-class EmbedBranch(nn.Module):
-    def __init__(self, feat_dim, embedding_dim, metric_dim):
-        super(EmbedBranch, self).__init__()
-        self.fc1 = make_fc_1d(feat_dim, embedding_dim)
-        print('embedding_dim, metric_dim: ',embedding_dim, metric_dim)
-        self.fc2 = nn.Linear(embedding_dim, metric_dim)
-
-    def forward(self, x):
-        x = self.fc1(x)
-        x = self.fc2(x)
-
-        # L2 normalize each feature vector
-        x = nn.functional.normalize(x)
-        return x
-
-def pdist(x1, x2):
-    """
-        x1: Tensor of shape (h1, w)
-        x2: Tensor of shape (h2, w)
-        Return pairwise distance for each row vector in x1, x2 as
-        a Tensor of shape (h1, h2)
-    """
-    x1_square = torch.sum(x1*x1, 1).view(-1, 1)
-    x2_square = torch.sum(x2*x2, 1).view(1, -1)
-    return torch.sqrt(x1_square - 2 * torch.mm(x1, x2.transpose(0, 1)) + x2_square + 1e-4)
-
-def embedding_loss(im_embeds, sent_embeds, im_labels, args):
-    """
-        im_embeds: (b, 512) image embedding tensors
-        sent_embeds: (sample_size * b, 512) sentence embedding tensors
-            where the order of sentence corresponds to the order of images and
-            setnteces for the same image are next to each other
-        im_labels: (sample_size * b, b) boolean tensor, where (i, j) entry is
-            True if and only if sentence[i], image[j] is a positive pair
-    """
-    # compute embedding loss
-    sent_im_ratio = args.sample_size
-    num_img = im_embeds.size(0)
-    num_sent = num_img * sent_im_ratio
-
-    sent_im_dist = pdist(sent_embeds, im_embeds)
-    im_labels = im_labels > 0
-
-    # image loss: sentence, positive image, and negative image
-    pos_pair_dist = torch.masked_select(sent_im_dist, im_labels).view(num_sent, 1)
-    neg_pair_dist = torch.masked_select(sent_im_dist, ~im_labels).view(num_sent, -1)
-    im_loss = torch.clamp(args.margin + pos_pair_dist - neg_pair_dist, 0, 1e6)
-    im_loss = im_loss.topk(args.num_neg_sample)[0].mean()
- 
-    # sentence loss: image, positive sentence, and negative sentence
-    neg_pair_dist = torch.masked_select(sent_im_dist.t(), ~im_labels.t()).view(num_img, -1)
-    neg_pair_dist = neg_pair_dist.repeat(1, sent_im_ratio).view(num_sent, -1)
-    sent_loss = torch.clamp(args.margin + pos_pair_dist - neg_pair_dist, 0, 1e6)
-    sent_loss = sent_loss.topk(args.num_neg_sample)[0].mean()
-
-    # sentence only loss (neighborhood-preserving constraints)
-    sent_sent_dist = pdist(sent_embeds, sent_embeds)
-    sent_sent_mask = im_labels.t().repeat(1, sent_im_ratio).view(num_sent, num_sent)
-    pos_pair_dist = torch.masked_select(sent_sent_dist, sent_sent_mask).view(-1, sent_im_ratio)
-    pos_pair_dist = pos_pair_dist.max(dim=1, keepdim=True)[0]
-    neg_pair_dist = torch.masked_select(sent_sent_dist, ~sent_sent_mask).view(num_sent, -1)
-    sent_only_loss = torch.clamp(args.margin + pos_pair_dist - neg_pair_dist, 0, 1e6)
-    sent_only_loss = sent_only_loss.topk(args.num_neg_sample)[0].mean()
-
-    loss = im_loss * args.im_loss_factor + sent_loss + sent_only_loss * args.sent_only_loss_factor
-    return loss
-
-class ImageSentenceEmbeddingNetwork(nn.Module):
-    def __init__(self, args, vecs, image_feature_dim):
-        super(ImageSentenceEmbeddingNetwork, self).__init__()
-        embedding_dim = args.dim_embed
-        metric_dim = int(args.dim_embed / 4)
-        n_tokens, token_dim = vecs.shape
-        self.words = nn.Embedding(n_tokens, token_dim)
-        self.words.weight = nn.Parameter(torch.from_numpy(vecs))
-        self.vecs = torch.from_numpy(vecs)
-        self.word_reg = nn.MSELoss()
-        if args.language_model == 'attend':
-            self.word_attention = nn.Sequential(nn.Linear(vecs.shape[1] * 2, 1),
-                                                nn.ReLU(inplace=True),
-                                                nn.Softmax(dim=1))
-
-        self.text_branch = EmbedBranch(token_dim, embedding_dim, metric_dim)
-        self.image_branch = EmbedBranch(image_feature_dim, embedding_dim, metric_dim)
-        self.args = args
-        if args.cuda:
-            self.cuda()
-            self.vecs = self.vecs.cuda()
-
-
-    def forward(self, images, tokens):
-        words = self.words(tokens)
-        n_words = torch.sum(tokens > 0, 1).float() + 1e-10
-        sum_words = words.sum(1).squeeze()
-        sentences = sum_words / n_words.unsqueeze(1)
-        
-        if self.args.language_model == 'attend':
-            max_length = tokens.size(-1)
-            context_vector = sentences.unsqueeze(1).repeat(1, max_length, 1)
-            attention_inputs = torch.cat((context_vector, words), 2)
-            attention_weights = self.word_attention(attention_inputs)
-            sentences = nn.functional.normalize(torch.sum(words * attention_weights, 1))
-
-        sentences = self.text_branch(sentences)
-        images = self.image_branch(images)
-        return images, sentences
-
-    def train_forward(self, images, sentences, im_labels):
-        im_embeds, sent_embeds = self(images, sentences)
-        embed_loss = embedding_loss(im_embeds, sent_embeds, im_labels, self.args)
-        word_loss = self.word_reg(self.words.weight, self.vecs)
-        loss = embed_loss + word_loss * self.args.word_embedding_reg
-        return loss, im_embeds, sent_embeds