How does pytorch backprop through argmax?
Asked Answered
T

2

26

I'm building Kmeans in pytorch using gradient descent on centroid locations, instead of expectation-maximisation. Loss is the sum of square distances of each point to its nearest centroid. To identify which centroid is nearest to each point, I use argmin, which is not differentiable everywhere. However, pytorch is still able to backprop and update weights (centroid locations), giving similar performance to sklearn kmeans on the data.

Any ideas how this is working, or how I can figure this out within pytorch? Discussion on pytorch github suggests argmax is not differentiable: https://github.com/pytorch/pytorch/issues/1339.

Example code below (on random pts):

import numpy as np
import torch

num_pts, batch_size, n_dims, num_clusters, lr = 1000, 100, 200, 20, 1e-5

# generate random points
vector = torch.from_numpy(np.random.rand(num_pts, n_dims)).float()

# randomly pick starting centroids
idx = np.random.choice(num_pts, size=num_clusters)
kmean_centroids = vector[idx][:,None,:] # [num_clusters,1,n_dims]
kmean_centroids = torch.tensor(kmean_centroids, requires_grad=True)

for t in range(4001):
    # get batch
    idx = np.random.choice(num_pts, size=batch_size)
    vector_batch = vector[idx]

    distances = vector_batch - kmean_centroids # [num_clusters, #pts, #dims]
    distances = torch.sum(distances**2, dim=2) # [num_clusters, #pts]

    # argmin
    membership = torch.min(distances, 0)[1] # [#pts]

    # cluster distances
    cluster_loss = 0
    for i in range(num_clusters):
        subset = torch.transpose(distances,0,1)[membership==i]
        if len(subset)!=0: # to prevent NaN
            cluster_loss += torch.sum(subset[:,i])

    cluster_loss.backward()
    print(cluster_loss.item())

    with torch.no_grad():
        kmean_centroids -= lr * kmean_centroids.grad
        kmean_centroids.grad.zero_()
Teryl answered 3/3, 2019 at 14:6 Comment(1)
Argmax is non-differentiable. But you can try some tricks like homes.cs.washington.edu/~hapeng/paper/peng2018backprop.pdf , the paper also referenced some other work in the similar train of thought in trying to backprop pass some sort of argmax/sparsemax. Disclaimer: I've personally not work on such problems.Fiona
E
25

As alvas noted in the comments, argmax is not differentiable. However, once you compute it and assign each datapoint to a cluster, the derivative of loss with respect to the location of these clusters is well-defined. This is what your algorithm does.

Why does it work? If you had only one cluster (so that the argmax operation didn't matter), your loss function would be quadratic, with minimum at the mean of the data points. Now with multiple clusters, you can see that your loss function is piecewise (in higher dimensions think volumewise) quadratic - for any set of centroids [C1, C2, C3, ...] each data point is assigned to some centroid CN and the loss is locally quadratic. The extent of this locality is given by all alternative centroids [C1', C2', C3', ...] for which the assignment coming from argmax remains the same; within this region the argmax can be treated as a constant, rather than a function and thus the derivative of loss is well-defined.

Now, in reality, it's unlikely you can treat argmax as constant, but you can still treat the naive "argmax-is-a-constant" derivative as pointing approximately towards a minimum, because the majority of data points are likely to indeed belong to the same cluster between iterations. And once you get close enough to a local minimum such that the points no longer change their assignments, the process can converge to a minimum.

Another, more theoretical way to look at it is that you're doing an approximation of expectation maximization. Normally, you would have the "compute assignments" step, which is mirrored by argmax, and the "minimize" step which boils down to finding the minimizing cluster centers given the current assignments. The minimum is given by d(loss)/d([C1, C2, ...]) == 0, which for a quadratic loss is given analytically by the means of data points within each cluster. In your implementation, you're solving the same equation but with a gradient descent step. In fact, if you used a 2nd order (Newton) update scheme instead of 1st order gradient descent, you would be implicitly reproducing exactly the baseline EM scheme.

Erythro answered 4/3, 2019 at 10:18 Comment(2)
Thanks for the fantastic answer, the comparison to EM is helpful. So because there is no backprop through argmin, cluster assignments are treated as constant when pytorch backprops loss to centroid locations?Teryl
Yes. In general, non-floating point tensors cannot carry the .grad parameter. This means that membership cannot be backpropagated through and is treated as constant.Erythro
N
7

Imagine this:

t = torch.tensor([-0.0627,  0.1373,  0.0616, -1.7994,  0.8853, 
                  -0.0656,  1.0034,  0.6974,  -0.2919, -0.0456])
torch.argmax(t).item() # outputs 6

We increase t[0] for some, δ close to 0, will this update the argmax? It will not, so we are dealing with 0 gradients, all the time. Just ignore this layer, or assume it is frozen.

The same is for argmin, or any other function where the dependent variable is in discrete steps.

Nita answered 9/7, 2019 at 16:21 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.