diff --git a/Code/train.py b/Code/train.py index 55bdd7fc099cc75e096f1fa24cfd404567e02b42..919fb063a8d58f0db930d2af9898da11567e1e53 100644 --- a/Code/train.py +++ b/Code/train.py @@ -17,6 +17,7 @@ import os import pandas as pd import sklearn +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') metric=evaluate.load("accuracy") torch.cuda.empty_cache() diff --git a/inference.py b/inference.py index a330f5ef654932959b8448a153d4a80209106ec4..6caeb49806571ffcd29210b5ba26ac277e106ebe 100644 --- a/inference.py +++ b/inference.py @@ -1,12 +1,13 @@ """Demo for inference: User enters a sentence and our trained BERT model predicts if the target word is literal or non-literal""" import sys sys.path.insert(0, 'Code/') -from Code.preprocess import * -from Code.models import * -from Code.train import * -#import preprocess -#import models -#import train +# from Code.preprocess import * +# from Code.models import * +# from Code.train import * + +import Code.preprocess +import Code.models +import Code.train import json import torch from transformers import BertTokenizer, BertModel, BertConfig, BertPreTrainedModel, PreTrainedModel, AutoConfig, AutoModel, AutoTokenizer