Skip to content
Snippets Groups Projects
Commit 2135dd2c authored by mai's avatar mai
Browse files

Add test script

Test script returns prediction probabilities.
Takes a hard-coded sentence
parent dd696cb3
No related branches found
No related tags found
No related merge requests found
venv
[submodule "bert-base-uncased-hatexplain-rationale-two"]
path = bert-base-uncased-hatexplain-rationale-two
url = https://huggingface.co/Hate-speech-CNERG/bert-base-uncased-hatexplain-rationale-two
# adversarial-hatespeech
## Installation
```bash
$ pip install transformers
$ pip install --no-cache-dir torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
$ git clone https://huggingface.co/Hate-speech-CNERG/bert-base-uncased-hatexplain-rationale-two
```
## Getting started
......
Subproject commit 7b1a724a178c639a4b3446c0ff8f13d19be4f471
bert-base-uncased-hatexplain-rationale-two
\ No newline at end of file
test.py 0 → 100644
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
### from models.py
from hatexplain.models import *
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = AutoTokenizer.from_pretrained("Hate-speech-CNERG/bert-base-uncased-hatexplain-rationale-two")
model = \
Model_Rational_Label.from_pretrained(
"Hate-speech-CNERG/bert-base-uncased-hatexplain-rationale-two"
)
model = model.to(device)
inputs = tokenizer('He is a great guy', return_tensors="pt").to(device)
prediction_logits, _ = model(input_ids=inputs['input_ids'],attention_mask=inputs['attention_mask'])
softmax = torch.nn.Softmax(dim=1)
probs = softmax(prediction_logits)
print(f"Normal: {probs[0][0]}\nHatespeech: {probs[0][1]}")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment