How to make a PyTorch Distribution on GPU
Asked Answered
P

4

6

Is it possible to make the PyTorch distributions create their samples directly on GPU.

If I do

from torch.distributions import Uniform, Normal
normal = Normal(3, 1)
sample = normal.sample()

Then sample will be on CPU. Of course it is possible to do sample = sample.to(torch.device("cuda")) to make it on GPU. But is there a way to have the sample go directly to GPU without first creating it on CPU?

PyTorch distributions inherit from Object, not nn.Module so it does not have a to method the put the distribution instance on GPU.

Any ideas?

Proustite answered 4/12, 2019 at 15:38 Comment(0)
P
9

Distributions use the reparametrization trick. Thus giving size 0 tensors which are on GPU to the distribution constructor works. As follows:

normal = Normal(torch.tensor(0).to(device=torch.device("cuda")), torch.tensor(1).to(device=torch.device("cuda")))
Proustite answered 4/12, 2019 at 15:54 Comment(2)
this blows with error normal_cuda_kernel not implemented for long in normal distribution. the correct code (note 0 and 1 are 0.0 and 1.0 now): normal = Normal(torch.tensor(0.0).to(device=device, torch.tensor(1.0).to(device=device))Fool
Really wish there were more succint notation for thisCathleencathlene
L
1

In my case, I'm using a Normal Distribution as my prior in a neural net model. I have a class called class1, for example, and in the init function I have to initiate my prior. However, calling .to('cuda') of an instance of class1 doesn't change the distribution device and causes error in later usages. Therefore, I could have used register buffers to manage it as follows.

class class1(nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("mean", torch.tensor(0.))
        self.register_buffer("var", torch.tensor(1.))
    def get_dist(self):
        return torch.distributions.Normal(self.mean, self.var)

However, I have several priors, and it's not possible to register_buffer a list. So, an option could be initiating distributions in get_dist property unless you don't care about the time complexity of initiating distributions. I decided to define a function for initiating distributions and a try-except in get_dist to handle different states. If the distributions variable is not assigned or on CPU while we expect it to be on GPU, it jumps to except where I initiate the distributions using torch.zeros(..).to(device).

Overall, to handle this error of CPU/GPU device, you need to initiate a distribution using Tensor input parameters with appropriate device. And the main reason is torch.Distribution module hasn't a device attribute unfortunately.

Literacy answered 20/5, 2021 at 20:8 Comment(0)
R
1

I just came across the same problem, and thanks to the other answers here for the pointers. I want to offer another option if you want a distribution inside a module, which is to override the to method in the module and manually call the to methods on the distribution parameter tensors. I've only tested with Uniform but works well here.

class MyModule(nn.Module):
    def __init__(self, ...):
        self.rng = Uniform(
            low=torch.zeros(3),
            high=torch.ones(3)
        )
    
    def to(self, *args, **kwargs):
        super().to(*args, **kwargs)
        self.rng.low = self.rng.low.to(*args, **kwargs)
        self.rng.high = self.rng.high.to(*args, **kwargs)

Now you can put your model on the gpu as usual and self.rng.sample() will return a sample on the correct device.

Ronnaronnholm answered 14/7, 2021 at 11:47 Comment(1)
Note that you have to include more parameters depending on the distribution. E.g. for torch.distributions.MultivariateNormal you need to set: _unbroadcasted_scale_tril, loc, scale_tril, covariance_matrix, and precision_matrix and explicitely have to check if they are None for the last three.Divalent
R
0

You can solve the problem of "transferring non-parameter/buffer attributes to GPU" by overriding self._apply(self, fn) method of your network. Like this:

def _apply(self, fn):
    # apply fn() to your modules
    for module in self.children():  # like 'ResNet_backbone'
        module._apply(fn)
    # apply fn() to your prior
    self.prior.attr1 = fn(self.prior.attr1)  # like 'MultivariateNormal.loc', need to be Tensor
    self.prior.attr2 = fn(self.prior.attr2) 
    ···
    self.prior.attrN = fn(self.prior.attrN) 
    # if we do not use register_buffer(Tensor)
    #    apply fn() to your non-parameter/buffer attributes
    #    need to be Tensor too
    self.attr1 = fn(self.attr1)
    self.attr2 = fn(self.attr2) 
    ···
    self.attrN = fn(self.attrN) 
Recollection answered 28/3, 2022 at 12:52 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.