PyTorch : How to apply the same random transformation to multiple image?
Asked Answered
H

5

14

I am writing a simple transformation for a dataset which contains many pairs of images. As a data augmentation, I want to apply some random transformation for each pair but the images in that pair should be transformed in the same way. For example, given a pair of two images A and B, if A is flipped horizontally, B must be flipped horizontally as A. Then the next pair C and D should be differently transformed from A and B but C and D are transformed in the same way. I am trying that in the way below

import random
import numpy as np
import torchvision.transforms as transforms
from PIL import Image

img_a = Image.open("sample_ajpg") # note that two images have the same size
img_b = Image.open("sample_b.png")
img_c, img_d = Image.open("sample_c.jpg"), Image.open("sample_d.png")

transform = transforms.RandomChoice(
    [transforms.RandomHorizontalFlip(), 
     transforms.RandomVerticalFlip()]
)
random.seed(0)
display(transform(img_a))
display(transform(img_b))

random.seed(1)
display(transform(img_c))
display(transform(img_d))

Yet、 the above code does not choose the same transformation and as I tested, it is dependent on the number of times transform is called.

Is there any way to force transforms.RandomChoice to use the same transform when specified?

Hyoscyamus answered 25/12, 2020 at 12:18 Comment(0)
I
11

Usually a workaround is to apply the transform on the first image, retrieve the parameters of that transform, then apply with a deterministic transform with those parameters on the remaining images. However, here RandomChoice does not provide an API to get the parameters of the applied transform since it involves a variable number of transforms. In those cases, I usually implement an overwrite to the original function.

Looking at the torchvision implementation, it's as simple as:

class RandomChoice(RandomTransforms):
    def __call__(self, img):
        t = random.choice(self.transforms)
        return t(img)

Here are two possible solutions.

  1. You can either sample from the transform list on __init__ instead of on __call__:

    import random
    import torchvision.transforms as T
    
    class RandomChoice(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.t = random.choice(self.transforms)
    
        def __call__(self, img):
            return self.t(img)
    

    So you can do:

    transform = RandomChoice([
         T.RandomHorizontalFlip(), 
         T.RandomVerticalFlip()
    ])
    display(transform(img_a)) # both img_a and img_b will
    display(transform(img_b)) # have the same transform
    
    transform = RandomChoice([
        T.RandomHorizontalFlip(), 
        T.RandomVerticalFlip()
    ])
    display(transform(img_c)) # both img_c and img_d will
    display(transform(img_d)) # have the same transform
    

  1. Or better yet, transform the images in batch:

    import random
    import torchvision.transforms as T
    
    class RandomChoice(torch.nn.Module):
        def __init__(self, transforms):
           super().__init__()
           self.transforms = transforms
    
        def __call__(self, imgs):
            t = random.choice(self.transforms)
            return [t(img) for img in imgs]
    

    Which allows to do:

    transform = RandomChoice([
         T.RandomHorizontalFlip(), 
         T.RandomVerticalFlip()
    ])
    
    img_at, img_bt = transform([img_a, img_b])
    display(img_at) # both img_a and img_b will
    display(img_bt) # have the same transform
    
    img_ct, img_dt = transform([img_c, img_d])
    display(img_ct) # both img_c and img_d will
    display(img_dt) # have the same transform
    
Iinden answered 25/12, 2020 at 13:52 Comment(4)
Just to clarify, transform = transforms.RandomChoice([ in your answer is actually transform = RandomChoice([,right?Hyoscyamus
Oups, sorry it should be T.RandomChoice() since I imported torchvision.transforms as T.Iinden
I agree transforming in a batch (where possible) seems like the best solutionFinicky
I think it should be RandomChoice() not T.RandomChoice() otherwise it calls the RandomChoice class of torchvision.transforms. Also, when I tried this method with RandomRotate it doesn't work. because it only randomly chooses a transformation from the list of transformations you list, not within those transformations. For example if you have a pait of images that needs to be augmented the same way, this method unfortunately doesn't work because they still might be transformed by random degrees.Hypesthesia
U
5

Referencing Random transforms for both input and target? I think this is probably the cleanest way to do it. Save the random state before applying any transformation and the just restore it for each consequent call

t = transforms.RandomRotation(degrees=360)
state = torch.get_rng_state()
x = t(x)
torch.set_rng_state(state)
y = t(y)
Unthinkable answered 23/5, 2022 at 18:2 Comment(0)
D
4

Simply, take the randomization part out of PyTorch into an if statement. Below code uses vflip. Similarly for horizontal or other transforms.

import random
import torchvision.transforms.functional as TF

if random.random() > 0.5:
    image = TF.vflip(image)
    mask  = TF.vflip(mask)

This issue has been discussed in PyTorch forum. Several solutions' pros and cons were discussed on the official GitHub repository page. PyTorch maintainers have suggested this simple approach.

Do not use torchvision.transforms.RandomVerticalFlip(p=1). Use torchvision.transforms.functional.vflip

Functional transforms give you fine-grained control of the transformation pipeline. As opposed to the transformations above, functional transforms don’t contain a random number generator for their parameters. That means you have to specify/generate all parameters, but you can reuse the functional transform.

Diabolism answered 4/2, 2021 at 19:59 Comment(3)
While this code may solve the question, including an explanation of how and why this solves the problem would really help to improve the quality of your post, and probably result in more up-votes. Remember that you are answering the question for readers in the future, not just the person asking now. Please edit your answer to add explanations and give an indication of what limitations and assumptions apply.Organic
@AdrianMole Thanks for the suggestion. I've added explanation :-)Diabolism
This answer is underrated. Why change the whole class, if you can just use a random number? This also works for things such as random cropping: Simply use torchvision.transforms.crop() with random ints for the top and left params (make sure for them to be within [0,orig_size-target_size[).Battledore
F
4

I realize the OP requested a solution using torchvision and I think @Ivan's answer does a good job addressing this.

However, for those not tied to a specific augmentation library, I wanted to point out that Albumentations appears to handle these kind of situations nicely in a native fashion by allowing the user to pass multiple source images, boxes, etc into the same transform. The return is structured as a dict

import albumentations as A

transform = A.Compose(
    transforms=[
        A.VerticalFlip(p=0.5),
        A.HorizontalFlip(p=0.5)],
    additional_targets={'image0': 'image', 'image1': 'image'}
)
transformed = transform(image=image, image0=image0, image1=image1)

Now you can access transformed['image0'], transformed['image1'], etc and all of them will have random parameters applied

Finicky answered 4/2, 2022 at 21:2 Comment(2)
What is the torchvision.transforms equivalent of this?Bonspiel
I don't know of a torchvision equivalent which was why I suggest AlbumentationsFinicky
D
0

I dont know of a function to fix the random output. maybe try a different logic, like creating the randomization yourself to be able to reuse the same transformation. logic:

  • generate a random number
  • based on the number apply a transformation on both images
  • generate another random number
  • do the same for the other two images try this:
import random
import numpy as np
import torchvision.transforms as transforms
from PIL import Image

img_a = Image.open("sample_ajpg") # note that two images have the same size
img_b = Image.open("sample_b.png")
img_c, img_d = Image.open("sample_c.jpg"), Image.open("sample_d.png")

if random.random() > 0.5:
        image_a_flipped = transforms.functional_pil.vflip(img_a)
        image_b_flipped = transforms.functional_pil.vflip(img_b)
else:
    image_a_flipped = transforms.functional_pil.hflip(img_a)
    image_b_flipped = transforms.functional_pil.hflip(img_b)

if random.random() > 0.5:
        image_c_flipped = transforms.functional_pil.vflip(img_c)
        image_d_flipped = transforms.functional_pil.vflip(img_d)
else:
    image_c_flipped = transforms.functional_pil.hflip(img_c)
    image_d_flipped = transforms.functional_pil.hflip(img_d)
    
display(image_a_flipped)
display(image_b_flipped)

display(image_c_flipped)
display(image_d_flipped)
Delly answered 25/12, 2020 at 13:35 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.