Skip to content
Snippets Groups Projects
Commit 0badf970 authored by hubert's avatar hubert
Browse files

same style with black

parent 3b620821
No related branches found
No related tags found
No related merge requests found
...@@ -23,8 +23,13 @@ def calculate_chrf_score(hypothesis: str, reference: str) -> float: ...@@ -23,8 +23,13 @@ def calculate_chrf_score(hypothesis: str, reference: str) -> float:
return chrf_score return chrf_score
def calculate_chrf_score_for_each_sentence_in_data(hypotheses: List[str], references: List[str], ast_dataset_filename: str, asr_refs: List[str], asr_hypos: List[str]) -> List[ def calculate_chrf_score_for_each_sentence_in_data(
Tuple[float, str, str, str, str]]: hypotheses: List[str],
references: List[str],
ast_dataset_filename: str,
asr_refs: List[str],
asr_hypos: List[str],
) -> List[Tuple[float, str, str, str, str]]:
""" """
Calculate chrF score for each hypothesis, reference pair available and rank each pair with decreasing chrF score. Calculate chrF score for each hypothesis, reference pair available and rank each pair with decreasing chrF score.
...@@ -33,7 +38,9 @@ def calculate_chrf_score_for_each_sentence_in_data(hypotheses: List[str], refere ...@@ -33,7 +38,9 @@ def calculate_chrf_score_for_each_sentence_in_data(hypotheses: List[str], refere
:return: return elements sorted by decreasing chrF score :return: return elements sorted by decreasing chrF score
""" """
chrf_scores = defaultdict(tuple) chrf_scores = defaultdict(tuple)
ast_dataset = pd.read_csv(ast_dataset_filename, sep="\t").filter(['tgt_text', 'src_text'], axis=1) ast_dataset = pd.read_csv(ast_dataset_filename, sep="\t").filter(
["tgt_text", "src_text"], axis=1
)
reference_to_transcript = dict(ast_dataset.values) reference_to_transcript = dict(ast_dataset.values)
asr_refs_to_hypos = {} asr_refs_to_hypos = {}
for r, h in zip(asr_refs, asr_hypos): for r, h in zip(asr_refs, asr_hypos):
...@@ -52,13 +59,24 @@ def calculate_chrf_score_for_each_sentence_in_data(hypotheses: List[str], refere ...@@ -52,13 +59,24 @@ def calculate_chrf_score_for_each_sentence_in_data(hypotheses: List[str], refere
print(asr_transcript) print(asr_transcript)
# chrf_score = calculate_chrf_score(hypothesis, reference) # chrf_score = calculate_chrf_score(hypothesis, reference)
ranked_chrf_score_list = sorted( ranked_chrf_score_list = sorted(
[(chrf_score, hypothesis, reference, asr_transcript, asr_hypo) for chrf_score, (hypothesis, reference, asr_transcript, asr_hypo) in chrf_scores.items()], [
(chrf_score, hypothesis, reference, asr_transcript, asr_hypo)
for chrf_score, (
hypothesis,
reference,
asr_transcript,
asr_hypo,
) in chrf_scores.items()
],
key=lambda x: x[0], key=lambda x: x[0],
reverse=True) reverse=True,
)
return ranked_chrf_score_list return ranked_chrf_score_list
def get_reference_and_hypothesis_strings_from_datafile(filename: str) -> Tuple[List[str], List[str]]: def get_reference_and_hypothesis_strings_from_datafile(
filename: str,
) -> Tuple[List[str], List[str]]:
with open(filename) as f: with open(filename) as f:
data = f.read().split("\n")[:-1] data = f.read().split("\n")[:-1]
references, hypotheses = [], [] references, hypotheses = [], []
...@@ -72,24 +90,53 @@ def get_reference_and_hypothesis_strings_from_datafile(filename: str) -> Tuple[L ...@@ -72,24 +90,53 @@ def get_reference_and_hypothesis_strings_from_datafile(filename: str) -> Tuple[L
def compare_chrf_scores( def compare_chrf_scores(
chrf_score_list_one: List[Tuple[float, str, str]], chrf_score_list_one: List[Tuple[float, str, str]],
chrf_score_list_two: List[Tuple[float, str, str]], chrf_score_list_two: List[Tuple[float, str, str]],
model1: str, model2: str model1: str,
model2: str,
) -> str: ) -> str:
comparisons = [] comparisons = []
for (chrf_score, hypo, ref, transcript, _) in chrf_score_list_one: for (chrf_score, hypo, ref, transcript, _) in chrf_score_list_one:
for (chrf_score_other, other_hypo, other_ref, other_transcript, asr_hypo) in chrf_score_list_two: for (
chrf_score_other,
other_hypo,
other_ref,
other_transcript,
asr_hypo,
) in chrf_score_list_two:
# we need to remove those to get ANYTHING OTHER than TO REMOVE and PLEASE REMOVE as highest chrF samples # we need to remove those to get ANYTHING OTHER than TO REMOVE and PLEASE REMOVE as highest chrF samples
# if "TO REMOVE" in ref or "PLEASE REMOVE" in ref: # if "TO REMOVE" in ref or "PLEASE REMOVE" in ref:
# continue # continue
if ref == other_ref: if ref == other_ref:
comparisons.append((chrf_score - chrf_score_other, chrf_score, chrf_score_other, hypo, other_hypo, ref, transcript, asr_hypo)) comparisons.append(
(
chrf_score - chrf_score_other,
chrf_score,
chrf_score_other,
hypo,
other_hypo,
ref,
transcript,
asr_hypo,
)
)
comparisons.sort(key=lambda x: x[0], reverse=True) comparisons.sort(key=lambda x: x[0], reverse=True)
markdown_string = "" markdown_string = ""
for ter_difference, first_score, second_score, first_hypo, second_hypo, ref_string, transcript, asr_hypo in comparisons[:15]: for (
string = f"_chrF Difference_: {ter_difference}\n_chrF {model1}_: {first_score}\n_chrF {model2}_: {second_score}\n" + \ ter_difference,
f"*Transcript*:\n{transcript}\n*ASR Output*:\n{asr_hypo}\n*Reference*:\n{ref_string}\n" + \ first_score,
f"*{model1} hypo*: {first_hypo}\n*{model2} hypo*: {second_hypo}" second_score,
first_hypo,
second_hypo,
ref_string,
transcript,
asr_hypo,
) in comparisons[:15]:
string = (
f"_chrF Difference_: {ter_difference}\n_chrF {model1}_: {first_score}\n_chrF {model2}_: {second_score}\n"
+ f"*Transcript*:\n{transcript}\n*ASR Output*:\n{asr_hypo}\n*Reference*:\n{ref_string}\n"
+ f"*{model1} hypo*: {first_hypo}\n*{model2} hypo*: {second_hypo}"
)
markdown_string += string markdown_string += string
markdown_string += "\n" + "---" * 15 + "\n" markdown_string += "\n" + "---" * 15 + "\n"
return markdown_string return markdown_string
...@@ -110,53 +157,70 @@ def test_ter_all_wrong(): ...@@ -110,53 +157,70 @@ def test_ter_all_wrong():
def test_calculate_chrf_score_for_each_sentence_in_data(): def test_calculate_chrf_score_for_each_sentence_in_data():
assert [(100, "Hello , goodbye", "Goodbye ! or"), (0, "Hello , goodbye", "Hello , goodbye")] == \ assert [
calculate_chrf_score_for_each_sentence_in_data( (100, "Hello , goodbye", "Goodbye ! or"),
["Hello , goodbye", "Hello , goodbye"], ["Hello , goodbye", "Goodbye ! or"] (0, "Hello , goodbye", "Hello , goodbye"),
) ] == calculate_chrf_score_for_each_sentence_in_data(
["Hello , goodbye", "Hello , goodbye"], ["Hello , goodbye", "Goodbye ! or"]
)
def test_get_reference_and_hypothesis_strings_from_datafile(): def test_get_reference_and_hypothesis_strings_from_datafile():
assert ["Hello , goodbye", "Goodbye ! or"], ["Hello , goodbye", "Hello , goodbye"] == \ assert ["Hello , goodbye", "Goodbye ! or"], [
get_reference_and_hypothesis_strings_from_datafile("test_file.txt") "Hello , goodbye",
"Hello , goodbye",
] == get_reference_and_hypothesis_strings_from_datafile("test_file.txt")
def test_combination(): def test_combination():
r, h = get_reference_and_hypothesis_strings_from_datafile("test_file.txt") r, h = get_reference_and_hypothesis_strings_from_datafile("test_file.txt")
assert [(100, "Hello , goodbye", "Goodbye ! or"), (0, "Hello , goodbye", "Hello , goodbye")] == \ assert [
calculate_chrf_score_for_each_sentence_in_data(h, r) (100, "Hello , goodbye", "Goodbye ! or"),
(0, "Hello , goodbye", "Hello , goodbye"),
] == calculate_chrf_score_for_each_sentence_in_data(h, r)
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("log_file_one", help="name of fairseq-generate log file") parser.add_argument("log_file_one", help="name of fairseq-generate log file")
parser.add_argument("log_file_two", help="name of fairseq-generate log file") parser.add_argument("log_file_two", help="name of fairseq-generate log file")
parser.add_argument("ast_data_set", help="tsv file of dataset") parser.add_argument("ast_data_set", help="tsv file of dataset")
parser.add_argument("asr_output", help="name of fairseq-generate log file for asr model") parser.add_argument(
"asr_output", help="name of fairseq-generate log file for asr model"
)
parser.add_argument("model_one_name", help="specify name of first model") parser.add_argument("model_one_name", help="specify name of first model")
parser.add_argument("model_two_name", help="specify name of second model") parser.add_argument("model_two_name", help="specify name of second model")
args = parser.parse_args() args = parser.parse_args()
fairseq_file = args.log_file_one fairseq_file = args.log_file_one
fairseq_file_for_comparison = args.log_file_two fairseq_file_for_comparison = args.log_file_two
result_filename = f"unfiltered_chrF_comparison_{fairseq_file.split('.txt')[0]}_{fairseq_file_for_comparison.split('.txt')[0]}" result_filename = f"unfiltered_chrF_comparison_{fairseq_file.split('.txt')[0]}_{fairseq_file_for_comparison.split('.txt')[0]}"
references_list, hypotheses_list = get_reference_and_hypothesis_strings_from_datafile(fairseq_file) (
asr_ref, asr_hypos = get_reference_and_hypothesis_strings_from_datafile(args.asr_output) references_list,
chrf_score_examples = calculate_chrf_score_for_each_sentence_in_data(hypotheses_list, references_list, args.ast_data_set, asr_ref, asr_hypos) hypotheses_list,
comparison_references_list, comparison_hypotheses_list = get_reference_and_hypothesis_strings_from_datafile( ) = get_reference_and_hypothesis_strings_from_datafile(fairseq_file)
fairseq_file_for_comparison asr_ref, asr_hypos = get_reference_and_hypothesis_strings_from_datafile(
args.asr_output
)
chrf_score_examples = calculate_chrf_score_for_each_sentence_in_data(
hypotheses_list, references_list, args.ast_data_set, asr_ref, asr_hypos
) )
(
comparison_references_list,
comparison_hypotheses_list,
) = get_reference_and_hypothesis_strings_from_datafile(fairseq_file_for_comparison)
comparison_chrf_score_examples = calculate_chrf_score_for_each_sentence_in_data( comparison_chrf_score_examples = calculate_chrf_score_for_each_sentence_in_data(
comparison_hypotheses_list, comparison_hypotheses_list,
comparison_references_list, comparison_references_list,
args.ast_data_set, args.ast_data_set,
asr_ref, asr_ref,
asr_hypos asr_hypos,
) )
result_string = compare_chrf_scores( result_string = compare_chrf_scores(
chrf_score_examples, chrf_score_examples,
comparison_chrf_score_examples, comparison_chrf_score_examples,
args.model_one_name, args.model_one_name,
args.model_two_name args.model_two_name,
) )
save_to_json(result_string, result_filename) save_to_json(result_string, result_filename)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment