Skip to content
Snippets Groups Projects
Commit f152da51 authored by wu's avatar wu
Browse files

Update main_Critic.py

parent 17e13823
No related branches found
No related tags found
No related merge requests found
......@@ -19,9 +19,13 @@ def set_parameter_requires_grad(model):
for param in model.parameters():
param.requires_grad = False # in-place
# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# loads model ActorOnly
model = SummarisationModel()
model.load_state_dict(torch.load('model_actor_only_wts1_0.pth')) #blow up here?
model.to(device)
model.load_state_dict(torch.load('/home/students/kreuzer/ao+cel/aof/nn-projekt-ss22/model_actor_only_wts1_0.pth', map_location = device)) #blow up here?
model.eval()
set_parameter_requires_grad(model)
......@@ -30,14 +34,15 @@ set_parameter_requires_grad(model)
batch_loss_tracker = open("CriticBatchLossTracker.txt", 'a')
logger = open("CriticLogger.txt", 'a')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.write('Used device: {}\n'.format(str(device)))
# initialize Critic
critic = Critic()
critic.to(device)
optimizer = torch.optim.Adam(critic.parameters(), lr= 0.001)
loss_fn = nn.MSELoss()
m.to(device)
# loads dataset cnn_dailymail
train_dataset = PreprocessedDataSet(['/workspace/students/kreuzer/complete/train'])
......@@ -79,17 +84,23 @@ for epoch in range(epochs):
logger.write("Warning! This datapoint has a too short document and will be ignored\n")
continue
# move "datapoint" to gpu
p_searchspace = datapoint.p_searchspace.to(device)
sent_vecs = datapoint.sent_vecs.to(device)
gold_sent_vecs = datapoint.gold_sent_vecs.to(device)
top_rouge = datapoint.top_rouge.to(device)
# positive sampling
k = np.random.choice(len(datapoint.p_searchspace))
sample = datapoint.sent_vecs.masked_select(datapoint.p_searchspace[k].bool()) # not padded sent embeddngs
k = np.random.choice(len(p_searchspace))
sample = sent_vecs.masked_select(p_searchspace[k].bool()) # not padded sent embeddngs
# hidden states, cell states of document encoder from model ActorOnly
_, document_vec_1 = model.encode_document(sample)
_, document_vec_2 = model.encode_document(datapoint.gold_sent_vecs)
_, document_vec_2 = model.encode_document(gold_sent_vecs)
# critic
score = critic(document_vec_1, document_vec_2)
loss = loss_fn(score, datapoint.top_rouge[k])
loss = loss_fn(score, top_rouge[k])
batch_loss += loss
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment