From f1db10fc383dbfe3941a7154a545764796930d54 Mon Sep 17 00:00:00 2001 From: kupper <kupper@login.cl.uni-heidelberg.de> Date: Sun, 23 Mar 2025 19:41:26 +0100 Subject: [PATCH] Context importance analysis --- scripts/NEC_context_cl.sh | 16 +++++++++ .../NEC_evaluation/context_sensitivity.py | 33 +++++++++++++++++++ 2 files changed, 49 insertions(+) create mode 100644 scripts/NEC_context_cl.sh create mode 100644 src/experiments/NEC_evaluation/context_sensitivity.py diff --git a/scripts/NEC_context_cl.sh b/scripts/NEC_context_cl.sh new file mode 100644 index 0000000..424c68e --- /dev/null +++ b/scripts/NEC_context_cl.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +#SBATCH --job-name=NEC_context_sensitivity +#SBATCH --output=logs/NEC_context_sensitivity_%j.txt +#SBATCH --ntasks=1 +#SBATCH --time=24:00:00 +#SBATCH --mem=8000 +#SBATCH --mail-type=ALL +#SBATCH --mail-user=kupper@cl.uni-heidelberg.de +#SBATCH --partition=students +#SBATCH --cpus-per-task=4 +#SBATCH --qos=batch +#SBATCH --gres=gpu + +export PYTHONUNBUFFERED=1 +srun python3 -m src.experiments.NEC_evaluation.context_sensitivity diff --git a/src/experiments/NEC_evaluation/context_sensitivity.py b/src/experiments/NEC_evaluation/context_sensitivity.py new file mode 100644 index 0000000..5f76320 --- /dev/null +++ b/src/experiments/NEC_evaluation/context_sensitivity.py @@ -0,0 +1,33 @@ + + +import data.data_manager as data_manager +from src.common_interface import classify_entity + +def run_context_analysis(model_name, dataset, num_sentences): + labels = data_manager.get_labels(dataset) + data = data_manager.get_annotated_sentences(dataset, num_sentences) + + for i in range(min(len(data), num_sentences)): + annotated_sentence = data[i] + sentence = annotated_sentence[0] + entity = annotated_sentence[1][0] + + words = sentence.split() + + print(f"Original sentence: {sentence}") + print(f"Entity: {entity[0]}") + + for word_index in range(len(words)): + modified_words = words.copy() + modified_words[word_index] = "[blank]" + modified_sentence = ' '.join(modified_words) + + predicted = classify_entity(model_name, modified_sentence, entity[0], labels) + + print(f"Modified sentence: {modified_sentence}") + if predicted != entity[1]: + print("!!MISPREDICTION!!") + print(f"Predicted: {predicted}, True: {entity[1]}") + + +run_context_analysis("T5-NLI", "FIGER-coarse", 50) -- GitLab