Skip to content
Snippets Groups Projects
Commit 3017363c authored by friebolin's avatar friebolin
Browse files

Add backtranslation code

parent bafe7145
No related branches found
No related tags found
No related merge requests found
""" Backtranslation: Translate original sentences using Fairseq (https://github.com/facebookresearch/fairseq/blob/main/examples/translation/README.md)
to create 5 paraphrases """
import numpy as np
import pandas as pd
import json
import torch
import tqdm as notebook_tqdm
import os
# Define backtranslation temperature parameter
#temperature = 0.8
temperature = 1.2
def load_data_set(file_name):
with open(file_name, "r") as file:
data = file.read()
return json.loads(data)
def load_data_sets(data_dir):
semeval_train_data = load_data_set(os.path.join(data_dir, "semeval_train.txt"))
companies_train_data = load_data_set(os.path.join(data_dir, "companies_train.txt"))
relocar_train_data = load_data_set(os.path.join(data_dir, "relocar_train.txt"))
return semeval_train_data, companies_train_data, relocar_train_data
data_dir = "./data"
semeval_train_data, companies_train_data, relocar_train_data = load_data_sets(data_dir)
#Load Fairseq transformers trained on WMT'19
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model', tokenizer='moses', bpe='fastbpe')
de2en = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.de-en.single_model', tokenizer='moses', bpe='fastbpe')
en2de.eval() # disable dropout
en2de.cuda() # move model to GPU for faster translation
de2en.cuda()
# Define helper functions
def extract_target(dp):
start = dp["pos"][0]
end = dp["pos"][1]
target = dp["sentence"][start:end]
target = " ".join(target).lower()
return target
def join_tokens(dp):
sent = " ".join(dp)
sent = sent.replace(' ,', ',')
sent = sent.replace(' .', '.')
sent = sent.replace(' !', '!')
sent = sent.replace(' ?', '?')
sent = sent.replace(' :', ':')
sent = sent.replace(' ;', ';')
sent = sent.replace(" '", "'")
sent = sent.replace(' "', '"')
sent = sent.replace('` ', '')
return sent
# Apply backtranslation
datasets = [semeval_train_data, companies_train_data, relocar_train_data]
path_names = [f"semeval_loc_with_paraphrases_temp{temperature}", f"semeval_org_with_paraphrases_temp{temperature}", f"relocar_with_paraphrases_temp{temperature}"]
for dataset, name in zip(datasets, path_names):
# Preprocess original sentences
data = pd.DataFrame.from_dict(dataset, orient='columns')
# Extract targets & add as new column
target_words = [extract_target(row) for index, row in data.iterrows()]
data["targets"] = target_words
# Join split original sentence to single sentence
joined_sents = [join_tokens(sent) for sent in data["sentence"]]
data["joined_sents"] = joined_sents
# Paraphrase joined original sentences
def paraphrase(sent, temperature):
en_encode = en2de.encode(sent)
outputs = en2de.generate(en_encode, sampling=True, topp=0.5, temperature=temperature)
nucleus = [en2de.decode(x['tokens']) for x in outputs]
multi_paraphrases = [] #backtranslate 4 times
for sentence in nucleus:
de_1 = de2en.translate(sentence)
de_2 = en2de.translate(de_1)
paraphrase = de2en.translate(de_2)
multi_paraphrases.append(paraphrase)
return multi_paraphrases
paraphrases = [paraphrase(sent, temperature) for sent in data["joined_sents"]]
data["paraphrases"] = paraphrases
data_path = os.path.join(data_dir, name)
data.to_csv(data_path)
print(f"Done: paraphrases generated and saved for {name}.")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment