Skip to content
Snippets Groups Projects
Commit d6d128ea authored by hubert's avatar hubert
Browse files

marked knnmt parts

parent 4f718126
No related branches found
No related tags found
No related merge requests found
......@@ -91,9 +91,10 @@ class SequenceGenerator(nn.Module):
self.max_len = max_len or self.model.max_decoder_positions()
self.args = args
# print("sequence generator arguments: ", args)
### taken from https://github.com/urvashik/knnmt/blob/master/fairseq/sequence_generator.py
if self.args and self.args.knnmt:
self.knn_dstore = KNN_Dstore(args)
###
self.normalize_scores = normalize_scores
......@@ -323,13 +324,13 @@ class SequenceGenerator(nn.Module):
original_batch_idxs = sample["id"]
else:
original_batch_idxs = torch.arange(0, bsz).type_as(tokens)
### taken from https://github.com/urvashik/knnmt/blob/master/fairseq/sequence_generator.py
if self.args and self.args.knnmt and self.args.save_knns:
assert beam_size == 1, "Saving knns for beam size > 1 is too complicated!"
knns = torch.zeros([bsz, max_len+1, self.args.k], dtype=torch.int)
vals = torch.zeros([bsz, max_len+1, self.args.k], dtype=torch.int)
probs = torch.zeros([bsz, max_len+1, self.args.k], dtype=torch.float32)
###
for step in range(max_len + 1): # one extra step for EOS marker
......@@ -372,6 +373,7 @@ class SequenceGenerator(nn.Module):
)
probs = probs[:, -1, :] * self.lm_weight
lprobs += probs
### taken from https://github.com/urvashik/knnmt/blob/master/fairseq/sequence_generator.py
if self.args and self.args.knnmt:
queries = avg_attn_scores[self.args.knn_keytype]
if len(avg_attn_scores.keys()) > 2:
......@@ -415,7 +417,7 @@ class SequenceGenerator(nn.Module):
#print(knn_scores)
#print(knn_scores[inds])
lprobs = knn_scores
###
lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)
lprobs[:, self.pad] = -math.inf # never select pad
......@@ -505,11 +507,12 @@ class SequenceGenerator(nn.Module):
attn,
src_lengths,
max_len,
knnmt=self.args and self.args.knnmt and self.args.save_knns,
knnmt=self.args and self.args.knnmt and self.args.save_knns, #### taken from https://github.com/urvashik/knnmt/blob/master/fairseq/sequence_generator.py
knns=knns if self.args and self.args.save_knns else None,
knn_vals=vals if self.args and self.args.save_knns else None,
knn_probs=probs if self.args and self.args.save_knns else None,
)
###
num_remaining_sent -= len(finalized_sents)
assert num_remaining_sent >= 0
......@@ -675,6 +678,7 @@ class SequenceGenerator(nn.Module):
tensor[mask] = tensor[mask][:, :1, :]
return tensor.view(-1, tensor.size(-1))
# adapted to fit to knnmt following https://github.com/urvashik/knnmt/blob/master/fairseq/sequence_generator.py
def finalize_hypos(
self,
step: int,
......@@ -722,11 +726,12 @@ class SequenceGenerator(nn.Module):
pos_scores[:, step] = eos_scores
# convert from cumulative to per-position scores
pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]
### taken from https://github.com/urvashik/knnmt/blob/master/fairseq/sequence_generator.py
if knnmt:
knns = knns[bbsz_idx, :step+1]
knn_vals = knn_vals[bbsz_idx, :step+1]
knn_probs = knn_probs[bbsz_idx, :step+1]
###
# normalize sentence-level scores
if self.normalize_scores:
eos_scores /= (step + 1) ** self.len_penalty
......@@ -775,7 +780,7 @@ class SequenceGenerator(nn.Module):
hypo_attn = attn_clone[i]
else:
hypo_attn = torch.empty(0)
# adapted according to https://github.com/urvashik/knnmt/blob/master/fairseq/sequence_generator.py
finalized[sent].append(
{
"tokens": tokens_clone[i],
......@@ -868,8 +873,10 @@ class EnsembleModel(nn.Module):
args=None,
):
if args:
### taken from https://github.com/urvashik/knnmt/blob/master/fairseq/sequence_generator.py
if args.knnmt and len(self.models) > 1:
raise ValueError("Cannot use knnmt with actual ensembles!")
###
log_probs = []
avg_attn: Optional[Tensor] = None
encoder_out: Optional[Dict[str, List[Tensor]]] = None
......@@ -901,8 +908,10 @@ class EnsembleModel(nn.Module):
elif attn_holder is not None:
attn = attn_holder[0]
if args:
### taken from https://github.com/urvashik/knnmt/blob/master/fairseq/sequence_generator.py
if args.knnmt:
knn_queries = decoder_out[1][args.knn_keytype]
###
if attn is not None:
attn = attn[:, -1, :]
......@@ -916,13 +925,16 @@ class EnsembleModel(nn.Module):
probs = probs[:, -1, :]
if self.models_size == 1:
if args:
### taken from https://github.com/urvashik/knnmt/blob/master/fairseq/sequence_generator.py
if args.knnmt:
return probs, {"attn": attn, args.knn_keytype: knn_queries}
###
return probs, attn
elif args:
### taken from https://github.com/urvashik/knnmt/blob/master/fairseq/sequence_generator.py
if args.knnmt:
raise ValueError("Cannot use with a real ensemble yet!")
###
log_probs.append(probs)
if attn is not None:
if avg_attn is None:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment