Calculating log-sum-exp function in c++
Asked Answered
P

2

6

Are there any functions in the c++ standard library that calculate the log of the sum of exponentials? If not, how should I go about writing my own? Do you have any suggestions not mentioned in this wiki article?

I am concerned specifically with the possibility that the summands will underflow. This is the situation where you are exponentiating negative numbers that have a large absolute value. I am using c++11.

Paterson answered 29/8, 2017 at 16:9 Comment(2)
There isn't something for that particular application in the standard library, but Boost.Multiprecision provides arbitrary-precision arithmetic with a very fully featured of transcendental and special function support.Suffragette
If you want to write your own, then why did you ask about library support? I don't have any implementation recommendations, no.Suffragette
I
9

(C++11 variant of shuvro's code, using the Standard Library as per the question.)

template <typename Iter>
std::iterator_traits<Iter>::value_type
log_sum_exp(Iter begin, Iter end)
{
  using VT = std::iterator_traits<Iter>::value_type{};
  if (begin==end) return VT{};
  using std::exp;
  using std::log;
  auto max_elem = *std::max_element(begin, end);
  auto sum = std::accumulate(begin, end, VT{}, 
     [max_elem](VT a, VT b) { return a + exp(b - max_elem); });
  return max_elem + log(sum);
}

This version is more generic - it will work on any type of value, in any type of container, as long as it has the relevant operators. In particular, it will use std::exp and std::log unless the value type has its own overloads.

To be really robust against underflow, even for unknown numeric types, it probably would be beneficial to sort the values. If you sort the inputs, the very first value will be max_elem, so the first term of sum will be exp(VT{0} which presumably is VT{1}. That's clearly free from underflow.

Inconvenient answered 30/8, 2017 at 12:46 Comment(0)
A
5

If you want smaller code, this implementation could do the job:

double log_sum_exp(double arr[], int count) 
{
   if(count > 0 ){
      double maxVal = arr[0];
      double sum = 0;

      for (int i = 1 ; i < count ; i++){
         if (arr[i] > maxVal){
            maxVal = arr[i];
         }
      }

      for (int i = 0; i < count ; i++){
         sum += exp(arr[i] - maxVal);
      }
      return log(sum) + maxVal;

   }
   else
   {
      return 0.0;
   }
}

You can see a more robust implementation on Takeda 25's blog (Japanese) (or, see it in English, via Google Translate).

Astarte answered 29/8, 2017 at 17:11 Comment(1)
log(exp(x1)+exp(x2)+...+exp(xn))=log(exp(xmax)exp(x1-xmax)+exp(xmax)exp(x2-xmax)+...+exp(xmax)exp(xn-xmax))=log(exp(xmax)*(exp(x1-xmax)+exp(x2-xmax)+...+exp(xn-xmax))) = xmax + log(exp(x1-xmax)+exp(x2-xmax)+...+exp(xn-xmax)). - If I'm not wrong there is no count multiplying the maxVal.Tila

© 2022 - 2024 — McMap. All rights reserved.