diff --git a/inference.py b/inference.py index 313bfad932dc225bacb1d7e85d7ed6dd425bc283..9785f35c5a56b19eb3afcfcb0d7c7364494af537 100644 --- a/inference.py +++ b/inference.py @@ -81,11 +81,11 @@ model = torch.load(model_path, map_location=device) model.eval() -train_sampler = RandomSampler(data_sample) -train_dataloader = DataLoader(data_sample, sampler=train_sampler, batch_size=1) +train_sampler = RandomSampler(input_as_dataset) +train_dataloader = DataLoader(input_as_dataset, sampler=train_sampler, batch_size=1) for batch in train_dataloader: - print(batch) + #print(batch) inputs = {'input_ids': batch[0], 'attention_mask': batch[1], 'token_type_ids': batch[2],