From c770b84a2c1cb51fdcb489a4ece880a4e324f799 Mon Sep 17 00:00:00 2001
From: mai <mai@cl.uni-heidelberg.de>
Date: Sat, 11 Mar 2023 21:52:49 +0100
Subject: [PATCH] Add attack function

---
 .gitignore                              |   1 +
 test.py                                 |  34 +++++----
 utils/__pycache__/attack.cpython-38.pyc | Bin 322 -> 0 bytes
 utils/__pycache__/eval.cpython-38.pyc   | Bin 1417 -> 0 bytes
 utils/attack.py                         |  93 +++++++++++++++++++++++-
 5 files changed, 112 insertions(+), 16 deletions(-)
 delete mode 100644 utils/__pycache__/attack.cpython-38.pyc
 delete mode 100644 utils/__pycache__/eval.cpython-38.pyc

diff --git a/.gitignore b/.gitignore
index c4a428c..e67bbc9 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,3 @@
+__pycache__
 venv
 bert-base-uncased-hatexplain-rationale-two
diff --git a/test.py b/test.py
index 41742ed..6ddfc17 100644
--- a/test.py
+++ b/test.py
@@ -6,13 +6,14 @@ from nltk.tokenize.treebank import TreebankWordDetokenizer
 from utils.eval import eval
 from utils.attack import attack
 
-device = 'cuda' if torch.cuda.is_available() else 'cpu'
+device = "cuda" if torch.cuda.is_available() else "cpu"
 
-tokenizer = AutoTokenizer.from_pretrained("Hate-speech-CNERG/bert-base-uncased-hatexplain-rationale-two")
-model = \
-    Model_Rational_Label.from_pretrained(
-        "Hate-speech-CNERG/bert-base-uncased-hatexplain-rationale-two"
-    )
+tokenizer = AutoTokenizer.from_pretrained(
+    "Hate-speech-CNERG/bert-base-uncased-hatexplain-rationale-two"
+)
+model = Model_Rational_Label.from_pretrained(
+    "Hate-speech-CNERG/bert-base-uncased-hatexplain-rationale-two"
+)
 model = model.to(device)
 
 
@@ -26,16 +27,18 @@ model = model.to(device)
 # print(f"Normal: {probs[1][0]}\nHatespeech: {probs[1][1]}")
 
 # Load test dataset
-with open('data/post_id_divisions.json') as splits:
+with open("data/post_id_divisions.json") as splits:
     data = json.load(splits)
-    test_ids = data['test']
+    test_ids = data["test"]
+
 
 def dataset(ids):
-    with open('data/dataset.json') as data_file:
+    with open("data/dataset.json") as data_file:
         data = json.load(data_file)
     for i in ids:
         yield data[i]
 
+
 counter = 0
 batchsize = 8
 for post in dataset(test_ids):
@@ -43,15 +46,17 @@ for post in dataset(test_ids):
     #     break
     # counter += 1
 
-    detokenized = TreebankWordDetokenizer().detokenize(post["post_tokens"])
-    # batch = attack(detokenized)
+    text = TreebankWordDetokenizer().detokenize(post["post_tokens"])
 
-    # probabilities = eval(detokenized, model, tokenizer)
-    probabilities = eval(["this is a test", "this is a tast"], model, tokenizer)
+    attacks = attack(text, model, tokenizer)
+    print(attacks)
+
+    probabilities = eval(attacks, model, tokenizer)
+    # probabilities = eval(["this is a test", "this is a tast"], model, tokenizer)
     print(probabilities)
     # print(f"Normal: {probabilities[0][0]}\nHatespeech: {probabilities[0][1]}\n\n")
     # print(f"Normal: {probabilities[1][0]}\nHatespeech: {probabilities[1][1]}\n\n")
-    
+
     # ATTACK HERE
     # batch = attack(detokenized)
 
@@ -68,4 +73,3 @@ for post in dataset(test_ids):
     # print(post["post_id"])
     # print(post["annotators"][0]["label"])
     # print(TreebankWordDetokenizer().detokenize(post["post_tokens"]))
-
diff --git a/utils/__pycache__/attack.cpython-38.pyc b/utils/__pycache__/attack.cpython-38.pyc
deleted file mode 100644
index 775aa2dc19cfcc635d2bb60490f0cc6bc040346b..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001

literal 322
zcmYjMy-ou$47Tr}Tv5x!OWXi?0aPKecR?(LC5nxDx?jb)A`omm2#mZ^RwiD73FpdC
zOW$XIas0)(TP{yQ10NQ2ulYNP!^l)zE96{>fUtoLP)7MXY>@C4%A+6cz}n?0{YR=3
zZWMB+^ok!4@DtyVKxc4WW-$wWoDuIT$7LQ1;Vy9JI3Bk>aTDjEueqyU*nfdPQIFr1
zdFp*FgLP~qjHCl%T3Z=QJN9CIWt)qw4J-q*+nO;{jTjq(b@k=&9B2D7Ehgz-fPSv5
it1N=IojmrwRiEi_dhvh?y%>utU2SZ5g{mftBKre~qDs#I

diff --git a/utils/__pycache__/eval.cpython-38.pyc b/utils/__pycache__/eval.cpython-38.pyc
deleted file mode 100644
index af701cdd725a97b7278e7893c946fb1cf3c62780..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001

literal 1417
zcmbVMOK%f96drpXdC*3{1{GL%ha#2avPB4i5U&LY(ITM&N}RFNntH~A?MW$`tYF*!
z089Rox2*UJ2oT3kleSmrF0M0kGUqYpeCIp%yjWlFMzF-|HMY@;qHkU~y)Y<u;5GyV
zix`qofp!q`&sgFe%wm>&igsd#x1xUf9ipN>a_NgwsPbtZd;NHI;6H`?9^7UE!bWe=
zXZ$77J%&om@H}qN=eUW#pgC%A6ZzCEJ&KQQNrRe%Wg8L5)<(qINBA|Kr}GT*ByUsJ
zVcjD%Z#AtZ{fOr6rVW73(HJ0f39N_Ql0q>quNhKl&O~kjB7ut%%ir5UH-tT-nikx0
zZGw7zwSwvc$xRZvdu6QNBvKgrCir}jZSS{=OgPCUHAYMYX9IFy*gdYHkP+ps<OJS?
z2co>mCm?EiY+F<fIH8W!v@}zt3xG4YTU+(SH}F8|E&i_NWzO#%7s;t_<^+bOqU4M`
zqL#m`Bq)p47d@zbUWyNZ=Pw1em(G)1OADE5Rd`5$PaR%(&TXyBGt>oprPYK^gcMeA
z<LfG&?+vzjX_Q`7dOQt-Ai}^@<ROS9Q=0_%E6Fa5@NA|7Aa?cs_>Aw>KX2d~6QGZb
zOW=^oF`*?R{}F#<8NUa_nTqp#Z#-DVpTbwv(u!&|Nkf3)#VG$5RIz~S(fz9l@BpYy
zqCFhgu`@0YH^~9)CceTo7GuvDP(%zZJSsrrwF$Mra)@(DUNh?Ays90p>~<f!o({tr
z*)Rp;S`}pswx#Qcva0P!FyngEg4^${Q9;dqAL}+8T<wqYcLV~Cddx)8&s-Wts$(#^
zHDN|{Munsk$sM-J;j(nCEj6_Tec$i6#GfOVhHdUTcDch`#`sL+++_iQaTiVw>PV_x
zVU0s0*D{MlZvD)_zef5pRR3{nc%byYsc6oJ#?}n*n_)r4kg^$^BdP_J*G~pKtgVn{
z$Y)dzszaCfhx3ppeDR`77f}~u6n?+51piD@j53t{?j}9dj`dYfLgF%eSOIm~e=hZ!
c7v2%@oDx^S=GwQQ?OO8No<Ea-SVQgJAEo)Ep#T5?

diff --git a/utils/attack.py b/utils/attack.py
index d6d1c12..506214e 100644
--- a/utils/attack.py
+++ b/utils/attack.py
@@ -1,5 +1,96 @@
+from typing import Union
 import transformers
+import string
 
-def attack(sentence, model, tokenizer):
+def attack(text, model, tokenizer, subs=1, top_k=5):
+    """
+    Return adversarial examples
+
+    Parameters
+    ----------
+    text : str
+        Text to be attacked/modified.
+    model : transformers.AutoModelForSequenceClassification
+        Victim model, trained HateXplain model
+    tokenizer : transformers.AutoTokenizer
+        Tokenizer from trained HateXplain model
+    subs : int
+        Number of character substitutions. 
+        Default: 1
+    top_k : int
+        Return this many of the best candidates. Best is determined by how much
+        they influence the probability scores
+        Default: 5
+
+    Returns
+    -------
+    attacks : List[str]
+        List of the `top_k` attacks on the input text
+    """
+    device = 'cuda' if torch.cuda.is_available() else 'cpu'
     model = model.to(device)
 
+    # Compute probabilities prior to the attacks
+    # inputs = tokenizer(
+    #     text, 
+    #     return_tensors="pt", 
+    #     padding=True
+    # ).to(device)
+    # prediction_logits, _ = model(
+    #     input_ids=inputs['input_ids'],
+    #     attention_mask=inputs['attention_mask']
+    # )
+    # softmax = torch.nn.Softmax(dim=1)
+    # prior_probabilities = softmax(prediction_logits)
+    # prior_hatespeech_probability = prior_probabilities[0][1]
+
+    prior_hatespeech_probability = eval(text, model, tokenizer)[0][1]
+
+    # Generate attacks
+    candidate_scores = {}
+    for i, char in enumerate(text):
+        for candidate in generate_candidates(text, i, model, tokenizer):
+            candidate_probability = eval(candidate, model, tokenizer)[0][1]
+            
+            candidate_score = prior_hatespeech_probability - candidate_probability
+            # higher score is better
+            candidate_scores[candidate] = candidate_score
+
+    sorted_candidate_scores = dict(sorted(candidate_scores.items(), 
+                                   key=lambda item: item[1], 
+                                   reverse=True))
+    attacks = list(sorted_candidate_scores)[:top_k]
+    return attacks
+
+
+def generate_candidates(text, i, model, tokenizer)
+    """
+    Substitute a character in the text with every possible substitution 
+
+    Parameters
+    ----------
+    text : str
+        Text to be attacked/modified.
+    i : int
+        Index of character to be substituted
+    model : transformers.AutoModelForSequenceClassification
+        Victim model, trained HateXplain model
+    tokenizer : transformers.AutoTokenizer
+        Tokenizer from trained HateXplain model
+
+    Yields
+    ------
+    candidate : 
+        List of the `top_k` attacks on the input text
+    """
+
+    permissible_substitutions = string.printable
+    # 0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~
+
+    for substitution_char in permissible_substitutions:
+        if substitution_char == text[i]:
+            continue
+        candidate = list(text)
+        candidate[i] = substitution_char 
+        candidate = "".join(candidate)
+        yield candidate
-- 
GitLab