I'm trying to improve a CNN I made by implementing a weighted loss method described in this paper. To do this, I looked into this notebook which implements the pseudo-code of the method described in the paper.
When translating their code to my model, I ran into the error RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior
when using torch.autograd.grad()
.
My code and the error is in the second to last line:
for epoch in range(1): #tqdm(range(params['epochs'])):
model.train()
text_t, labels_t = next(iter(train_iterator))
text_t = to_var(text_t, requires_grad=False)
labels_t = to_var(labels_t, requires_grad=False)
dummy = L2RWCNN(INPUT_DIM, EMBEDDING_DIM, N_FILTERS, FILTER_SIZES, OUTPUT_DIM,
DROPOUT, PAD_IDX)
dummy.state_dict(model.state_dict())
dummy.cuda()
y_f_hat = dummy(text_t)
cost = F.binary_cross_entropy_with_logits(y_f_hat.squeeze(), labels_t, reduce = False)
eps = to_var(torch.zeros(cost.size()))
l_f_meta = torch.sum(cost * eps)
dummy.zero_grad()
num_params = 0
grads = torch.autograd.grad(l_f_meta, (dummy.params()), create_graph = True)
with torch.no_grad():
for p, grad in zip(dummy.parameters(), grads):
tmp = p - params['lr'] * grad
p.copy_(tmp)
text_v, labels_v = next(iter(valid_iterator))
y_g_hat = dummy(text_v)
l_g_meta = F.binary_cross_entropy_with_logits(y_g_hat.squeeze(), labels_v, reduce = False)
l_g_meta = torch.sum(l_g_meta)
grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True)[0]
print(grad_eps)
I think the error is because eps
was not in any previous torch.autograd.grad()
calls. I tried the suggested solution of setting allow_unused=True
but that resulted in a None
value. I looked at this post to find a solution, but the method that fixed the problem here (don't slice the tensors) doesn't work for me because I'm not passing in any partial variables. I also tried setting create_graph = False
in my first autograd.grad()
call, but that didn't fix the issue. Does anyone have a solution?
EDIT: Created a new post with a different phrasing to the question here