numerically stable way to multiply log probability matrices in numpy
Asked Answered
S

4

41

I need to take the matrix product of two NumPy matrices (or other 2d arrays) containing log probabilities. The naive way np.log(np.dot(np.exp(a), np.exp(b))) is not preferred for obvious reasons.

Using

from scipy.misc import logsumexp
res = np.zeros((a.shape[0], b.shape[1]))
for n in range(b.shape[1]):
    # broadcast b[:,n] over rows of a, sum columns
    res[:, n] = logsumexp(a + b[:, n].T, axis=1) 

works but runs about 100 times slower than np.log(np.dot(np.exp(a), np.exp(b)))

Using

logsumexp((tile(a, (b.shape[1],1)) + repeat(b.T, a.shape[0], axis=0)).reshape(b.shape[1],a.shape[0],a.shape[1]), 2).T

or other combinations of tile and reshape also work but run even slower than the loop above due to the prohibitively large amounts of memory required for realistically sized input matrices.

I am currently considering writing a NumPy extension in C to compute this, but of course I'd rather avoid that. Is there an established way to do this, or does anybody know of a less memory intensive way of performing this computation?

EDIT: Thanks to larsmans for this solution (see below for derivation):

def logdot(a, b):
    max_a, max_b = np.max(a), np.max(b)
    exp_a, exp_b = a - max_a, b - max_b
    np.exp(exp_a, out=exp_a)
    np.exp(exp_b, out=exp_b)
    c = np.dot(exp_a, exp_b)
    np.log(c, out=c)
    c += max_a + max_b
    return c

A quick comparison of this method to the method posted above (logdot_old) using iPython's magic %timeit function yields the following:

In  [1] a = np.log(np.random.rand(1000,2000))

In  [2] b = np.log(np.random.rand(2000,1500))

In  [3] x = logdot(a, b)

In  [4] y = logdot_old(a, b) # this takes a while

In  [5] np.any(np.abs(x-y) > 1e-14)
Out [5] False

In  [6] %timeit logdot_old(a, b)
1 loops, best of 3: 1min 18s per loop

In  [6] %timeit logdot(a, b)
1 loops, best of 3: 264 ms per loop

Obviously larsmans' method obliterates mine!

Splasher answered 13/5, 2014 at 11:40 Comment(23)
I've encountered this several times, and actually this is just a general question, as it seems like a problem others must have as well, yet I couldn't find anything online. Most recently it involved an N * K feature matrix, with N >> K, and a K * M weight matrix, with K and M of roughly the same size.Splasher
if you already know C, you could use scipy.weave.blitz to incorporate a few lines of C in the rest of your python codeMendacious
Thanks, I was not aware this existed!Splasher
Alas, scipy.weave is not available for python3Splasher
did you try with the most recent scipy? (two days ago 0.14 got released?)Mendacious
I tried it, but it does not seem to include weave support for python3 eitherSplasher
In your example I don't think that scipy.misc.logsumexp is doing what you think it is - according to the docs the b= parameter is actually a scaling factor for exp(a), i.e. np.log(np.sum(b*np.exp(a))).Elevator
The b in my code is a numpy ndarray, not a parameter to logsumexp. The 1 corresponds to the axis parameter, I've clarified this in the example above.Splasher
@mart: why are you interpreting your weights as probabilities?Mimosaceous
Weave is in a deprecation cycle. Any new code should be using Cython instead.Mosul
@Splasher Would you be ok with that? I can work out the details in an answer.Mosul
@Davidmh, thanks for your suggestion. I've checked out Cython and it looks like re-implementing the first method would be the easiest solution. I'll keep you posted!Splasher
@Splasher your first snippet cythonised may be good enough, but will still have some Python overhead. It will be CPU-time-faster to expand the dot product and rewrite it in pure Cython; but much developer-time-slower.Mosul
@Splasher I really wonder where you encountered this problem; the result of that dot product is no longer a matrix of probabilities.Counterattraction
@larsmans, it occurred in code I was writing for training an HMM using Baum-Welch. Standard implementations use a method of rescaling probabilities to prevent fp underflow, I was hoping to gain a performance boost by using log-probabilities and figuring out whether the code could still be vectorized. Calling them probability matrices is indeed not entirely true.Splasher
@mart: I still don't understand what you're trying to do. Normally, weights are applied in natural parameter space, which is log-odds. Therefore, you should be scaling the log-odds by the weight. logsumexp is usually used when, e.g., calculating the log-normalizer given a log-odds parameter vector. Are you sure that what you're doing makes sense?Mimosaceous
@NeilG One setting under which this makes sense is as follows. Suppose we have three random variables X, Y, Z. X takes values {1, ..., M}, Y takes values {1, ..., R}, Z takes values {1, ..., N}. And suppose that if the value of Y is known, then X and Z are independent. Then the conditional distribution p(Y|X) can be represented as a matrix of size R×M: p(Y=y|X=x) = A_{y,x}, p(Z|Y) can be represented as a matrix B of size N×R, and the table of the conditional distribution p(Z|X) equals matrix A multiplied by matrix B.Luthern
@Luthern Your P(Y|X) is fine, but it's in the natural parameter space (log probabilities). Normally, you would convert those to the expectation parameter space (probabilities) by applying the gradient log-normalizer. See, for example, Deep Exponential Families. Another way to see it is to interpret your model is a neural network with categorical units. Similarly, you should be applying the gradient log-normalizer.Mimosaceous
@NeilG by gradient log-normalizer you mean softmax, right? Well, if I have probabilities (in contract to probabilities multiplied by some constant), then I don't need softmax.Luthern
@Luthern Right, but you don't have probabilities until you apply the GLN. The GLN is similar to the softmax, but without the final component. You can find a reference in Nielsen and Nock, Statistical Exponential Families.Mimosaceous
@NeilG To be clear, the purpose of my comment was to show another situation in which it's nice to have a logdot function. I am NOT trying to clarify @mart's idea. My usage is based on the fact that if you have a matrix A_{z,y} = p(Z=z|Y=y) and a matrix B_{y,x} = p(Y=y|X=x), and (X⊥Z|Y), then p(Z=z|X=x) = (AB)_{z,x}. Now, if I store logarithms of probabilities as A'_{z,y} = log p(Z=z|Y=y) and B'_{y,x} = log p(Y=y|X=x), then we have log p(Z=z|X=x) = logdot(A', B')_{z,x}. Personally I am planning to use this to build a categorical kinda directed graphical model. If I am wrong, please say where.Luthern
@NeilG Also, did you by any chance mean Nielsen and Garcia - "Statistical exponential families: A digest with flash cards". I didn't find Nielsen and Nock - "Statistical Exponential Families"Luthern
@crabman Yes, and Garcia, sorry. Was going from memory.Mimosaceous
C
27

logsumexp works by evaluating the right-hand side of the equation

log(∑ exp[a]) = max(a) + log(∑ exp[a - max(a)])

I.e., it pulls out the max before starting to sum, to prevent overflow in exp. The same can be applied before doing vector dot products:

log(exp[a] ⋅ exp[b])
 = log(∑ exp[a] × exp[b])
 = log(∑ exp[a + b])
 = max(a + b) + log(∑ exp[a + b - max(a + b)])     { this is logsumexp(a + b) }

but by taking a different turn in the derivation, we obtain

log(∑ exp[a] × exp[b])
 = max(a) + max(b) + log(∑ exp[a - max(a)] × exp[b - max(b)])
 = max(a) + max(b) + log(exp[a - max(a)] ⋅ exp[b - max(b)])

The final form has a vector dot product in its innards. It also extends readily to matrix multiplication, so we get the algorithm

def logdotexp(A, B):
    max_A = np.max(A)
    max_B = np.max(B)
    C = np.dot(np.exp(A - max_A), np.exp(B - max_B))
    np.log(C, out=C)
    C += max_A + max_B
    return C

This creates two A-sized temporaries and two B-sized ones, but one of each can be eliminated by

exp_A = A - max_A
np.exp(exp_A, out=exp_A)

and similarly for B. (If the input matrices may be modified by the function, all the temporaries can be eliminated.)

Counterattraction answered 5/6, 2014 at 15:3 Comment(4)
Thanks! I'll try if this gives the performance I was hoping for.Splasher
This is less stable than the original slower solution. Consider logdotexp([[0,0],[0,0]], [[-1000,0], [-1000,0]]).Misbecome
@Misbecome you're right. This method does not control stability for all elements. See my answer which also handles the counterexample you provided.Copeland
Fails for logdotexp(np.array([[0., -1000.]]), np.array([[-1000.], [0.]]))). Produces the result [[-inf]] when it should be around [[-999.3]] .Caxton
C
5

Suppose A.shape==(n,r) and B.shape==(r,m). In computing the matrix product C=A*B, there are actually n*m summations. To have stable results when you're working in log-space, You need the logsumexp trick in each of these summations. Fortunately, using numpy broadcasting that's quite easy to control stability of rows and columns of A and B separately.

Here is the code:

def logdotexp(A, B):
    max_A = np.max(A,1,keepdims=True)
    max_B = np.max(B,0,keepdims=True)
    C = np.dot(np.exp(A - max_A), np.exp(B - max_B))
    np.log(C, out=C)
    C += max_A + max_B
    return C

Note:

The reasoning behind this is similar to the FredFoo's answer, but he used a single maximum value for each matrix. Since he did not consider every n*m summations, some elements of the final matrix might still be unstable as mentioned in one of the comments.

Comparing with the currently accepted answer using @identity-m counter example:

def logdotexp_less_stable(A, B):
    max_A = np.max(A)
    max_B = np.max(B)
    C = np.dot(np.exp(A - max_A), np.exp(B - max_B))
    np.log(C, out=C)
    C += max_A + max_B
    return C

print('old method:')
print(logdotexp_less_stable([[0,0],[0,0]], [[-1000,0], [-1000,0]]))
print('new method:')
print(logdotexp([[0,0],[0,0]], [[-1000,0], [-1000,0]]))

which prints

old method:
[[      -inf 0.69314718]
 [      -inf 0.69314718]]
new method:
[[-9.99306853e+02  6.93147181e-01]
 [-9.99306853e+02  6.93147181e-01]]
Copeland answered 21/10, 2018 at 14:5 Comment(1)
Your method is also not ideal. For example take a = np.array([[-500., 900.]], dtype=np.float64), b = np.array([[900., -500.]], dtype=np.float64). Your logdotexp returns -inf, while scipy.special.logsumexp(a_np[0] + b_np[:, 0]) correctly returns 400.69314718055995.Luthern
L
3

The currently accepted answer by Fred Foo, as well as Hassan's answer, are numerically unstable (Hassan's answer is better). An example of an input on which Hassan's answer fails will be provided later. My implementation is as follows:

import numpy as np
from scipy.special import logsumexp

def logmatmulexp(log_A: np.ndarray, log_B: np.ndarray) -> np.ndarray:
    """Given matrix log_A of shape ϴ×R and matrix log_B of shape R×I, calculates                                                                                                                                                             
    (log_A.exp() @ log_B.exp()).log() in a numerically stable way.                                                                                                                                                                           
    Has O(ϴRI) time complexity and space complexity."""
    ϴ, R = log_A.shape
    I = log_B.shape[1]
    assert log_B.shape == (R, I)
    log_A_expanded = np.broadcast_to(np.expand_dims(log_A, 2), (ϴ, R, I))
    log_B_expanded = np.broadcast_to(np.expand_dims(log_B, 0), (ϴ, R, I))
    log_pairwise_products = log_A_expanded + log_B_expanded  # shape: (ϴ, R, I)                                                                                                                                                              
    return logsumexp(log_pairwise_products, axis=1)

Just like Hassan's answer and Fred Foo's answer, my answer has time complexity O(ϴRI). Their answers have space complexity O(ϴR+RI) (I am not actually sure about this), while mine unfortunately has space complexity O(ϴRI) - this is because numpy can multiply a ϴ×R matrix by a R×I matrix without allocating an additional array of size ϴ×R×I. Having O(ϴRI) space complexity is not an immanent property of my method - I think if you write it out using cycles, you can avoid this space complexity, but unfortunately I don't think you can do that using stock numpy functions.

I have checked how much actual time my code runs, it's 20 times slower than regular matrix multiplication.

Here's how you can know that my answer is numerically stable:

  1. Clearly, all lines other than the return line are numerically stable.
  2. The logsumexp function is known to be numerically stable.
  3. Therefor, my logmatmulexp function is numerically stable.

My implementation has another nice property. If instead of using numpy you write the same code in pytorch or using another library with automatic differentiation, you will get a numerically stable backward pass automatically. Here's how we can know the backward pass will be numerically stable:

  1. All functions in my code are differentiable everywhere (unlike np.max)
  2. Clearly, back propagating through all lines except the return line is numerically stable, because absolutely nothing weird is happening there.
  3. Usually the developers of pytorch know what they're doing. So it's enough to trust them that they implemented backward pass of logsumexp in a numerically stable way.
  4. Actually the gradient of logsumexp is the softmax function (for reference google "softmax is gradient of logsumexp" or see https://arxiv.org/abs/1704.00805 proposition 1). It's known that softmax can be calculated in a numerically stable way. So the pytorch devs probably just use softmax there (I haven't actually checked).

Below is the same code in pytorch (in case you need backpropagation). Due to how pytorch backpropagation works, during forward pass it will save the log_pairwise_products tensor for the backward pass. This tensor is large, and you probably don't want it to be saved - you can just recalculate it once again during backward pass. In such case I suggest you use checkpointing - it's really easy - see the second function below.

import torch
from torch.utils.checkpoint import checkpoint

def logmatmulexp(log_A: torch.Tensor, log_B: torch.Tensor) -> torch.Tensor:
    """Given matrix log_A of shape ϴ×R and matrix log_B of shape R×I, calculates                                                                                                                                                             
    (log_A.exp() @ log_B.exp()).log() and its backward in a numerically stable way."""
    ϴ, R = log_A.shape
    I = log_B.shape[1]
    assert log_B.shape == (R, I)
    log_A_expanded = log_A.unsqueeze(2).expand((ϴ, R, I))
    log_B_expanded = log_B.unsqueeze(0).expand((ϴ, R, I))
    log_pairwise_products = log_A_expanded + log_B_expanded  # shape: (ϴ, R, I)                                                                                                                                                              
    return torch.logsumexp(log_pairwise_products, dim=1)


def logmatmulexp_lowmem(log_A: torch.Tensor, log_B: torch.Tensor) -> torch.Tensor:
    """Same as logmatmulexp, but doesn't save a (ϴ, R, I)-shaped tensor for backward pass.                                                                                                                                                   

    Given matrix log_A of shape ϴ×R and matrix log_B of shape R×I, calculates                                                                                                                                                                
    (log_A.exp() @ log_B.exp()).log() and its backward in a numerically stable way."""
    return checkpoint(logmatmulexp, log_A, log_B)

Here's an input on which Hassan's implementation fails but my implementation gives the correct output:

def logmatmulexp_hassan(A, B):
    max_A = np.max(A,1,keepdims=True)
    max_B = np.max(B,0,keepdims=True)
    C = np.dot(np.exp(A - max_A), np.exp(B - max_B))
    np.log(C, out=C)
    C += max_A + max_B
    return C

log_A = np.array([[-500., 900.]], dtype=np.float64)
log_B = np.array([[900.], [-500.]], dtype=np.float64)
print(logmatmulexp_hassan(log_A, log_B)) # prints -inf, while the correct answer is approximately 400.69.
Luthern answered 17/3, 2020 at 23:37 Comment(0)
M
1

You are accessing columns of res and b, which has poor locality of reference. One thing to try is to store these in column-major order.

Mimosaceous answered 13/5, 2014 at 22:23 Comment(1)
I noticed this too, but for larger arrays (size > 1000) the logsumexp operation dominates.Dipterocarpaceous

© 2022 - 2024 — McMap. All rights reserved.