In my case, I'm using a Normal Distribution as my prior in a neural net model. I have a class called class1, for example, and in the init function I have to initiate my prior. However, calling .to('cuda')
of an instance of class1 doesn't change the distribution device and causes error in later usages. Therefore, I could have used register buffers to manage it as follows.
class class1(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("mean", torch.tensor(0.))
self.register_buffer("var", torch.tensor(1.))
def get_dist(self):
return torch.distributions.Normal(self.mean, self.var)
However, I have several priors, and it's not possible to register_buffer a list. So, an option could be initiating distributions in get_dist property unless you don't care about the time complexity of initiating distributions. I decided to define a function for initiating distributions and a try-except in get_dist to handle different states. If the distributions variable is not assigned or on CPU while we expect it to be on GPU, it jumps to except where I initiate the distributions using torch.zeros(..).to(device)
.
Overall, to handle this error of CPU/GPU device, you need to initiate a distribution using Tensor input parameters with appropriate device. And the main reason is torch.Distribution
module hasn't a device attribute unfortunately.
normal_cuda_kernel not implemented for long
in normal distribution. the correct code (note 0 and 1 are 0.0 and 1.0 now):normal = Normal(torch.tensor(0.0).to(device=device, torch.tensor(1.0).to(device=device))
– Fool