diff --git a/test.py b/test.py index 136c6dd33f894d890a27ef320d21587a98954a70..3b778d327d98aa4743fce1bc35e667767e538331 100644 --- a/test.py +++ b/test.py @@ -45,8 +45,9 @@ for post in dataset(test_ids): detokenized = TreebankWordDetokenizer().detokenize(post["post_tokens"]) # ATTACK HERE + batch = attack(detokenized) - inputs = tokenizer(detokenized, return_tensors="pt", padding=True).to(device) + inputs = tokenizer(batch, return_tensors="pt", padding=True).to(device) prediction_logits, _ = model(input_ids=inputs['input_ids'],attention_mask=inputs['attention_mask']) softmax = torch.nn.Softmax(dim=1) probs = softmax(prediction_logits) diff --git a/utils/attack.py b/utils/attack.py new file mode 100644 index 0000000000000000000000000000000000000000..5b456b35f7923e28458c683e5499b3b33655a368 --- /dev/null +++ b/utils/attack.py @@ -0,0 +1,4 @@ +import transformers + +def attack(sentence, model): +