How to optimize for-comprehensions and loops in Scala?
Asked Answered
S

8

129

So Scala is supposed to be as fast as Java. I'm revisiting some Project Euler problems in Scala that I originally tackled in Java. Specifically Problem 5: "What is the smallest positive number that is evenly divisible by all of the numbers from 1 to 20?"

Here's my Java solution, which takes 0.7 seconds to complete on my machine:

public class P005_evenly_divisible implements Runnable{
    final int t = 20;

    public void run() {
        int i = 10;
        while(!isEvenlyDivisible(i, t)){
            i += 2;
        }
        System.out.println(i);
    }

    boolean isEvenlyDivisible(int a, int b){
        for (int i = 2; i <= b; i++) {
            if (a % i != 0) 
                return false;
        }
        return true;
    }

    public static void main(String[] args) {
        new P005_evenly_divisible().run();
    }
}

Here's my "direct translation" into Scala, which takes 103 seconds (147 times longer!)

object P005_JavaStyle {
    val t:Int = 20;
    def run {
        var i = 10
        while(!isEvenlyDivisible(i,t))
            i += 2
        println(i)
    }
    def isEvenlyDivisible(a:Int, b:Int):Boolean = {
        for (i <- 2 to b)
            if (a % i != 0)
                return false
        return true
    }
    def main(args : Array[String]) {
        run
    }
}

Finally here's my attempt at functional programming, which takes 39 seconds (55 times longer)

object P005 extends App{
    def isDivis(x:Int) = (1 to 20) forall {x % _ == 0}
    def find(n:Int):Int = if (isDivis(n)) n else find (n+2)
    println (find (2))
}

Using Scala 2.9.0.1 on Windows 7 64-bit. How do I improve performance? Am I doing something wrong? Or is Java just a lot faster?

Superbomb answered 26/5, 2011 at 23:18 Comment(10)
do you compile or interpret using scala shell?Hyohyoid
There is a better way to do this than using trial division (Hint).Mayday
you don't show how you're timing this. Did you try just timing the run method?Charlinecharlock
@ahmet This is compiled, not shell.Superbomb
@Aaron - I timed just the run method using System.nanotime() in Java. A physical stopwatch for the Scala versionsSuperbomb
@Mayday - yep, just did it the pen & paper way: write down the prime factors for each number starting with high, then cross out the factors that you already have for higher numbers, so you finish with (5*2*2)*(19)*(3*3)*(17)*(2*2)*()*(7)*(13)*()*(11) = 232792560Superbomb
+1 This is the most interesting question I've seen in weeks on SO (that also has the best answer I've seen in quite a while).Huggermugger
+1 for saying "as fast as Java".Arrange
@matt ball - implementing runnable isn't the same as "spawning a new thread." new Thread(new Runnable() { public void run() { ... } }), however, is the same as spawning a new thread.Sure
@Matt, @Andrew, I usually implement Runnable in my Java classes that are meant to be run: it makes more conceptual sense that a static "main" method, and I can easily launch them in a new thread from elsewhere (e.g. a Swing GUI). But I should probably have left it out for this discussion because it's irrelevant when we use a "main".Superbomb
E
112

The problem in this particular case is that you return from within the for-expression. That in turn gets translated into a throw of a NonLocalReturnException, which is caught at the enclosing method. The optimizer can eliminate the foreach but cannot yet eliminate the throw/catch. And throw/catch is expensive. But since such nested returns are rare in Scala programs, the optimizer did not yet address this case. There is work going on to improve the optimizer which hopefully will solve this issue soon.

Everhart answered 16/6, 2011 at 11:3 Comment(7)
Pretty heavy that a return becomes an exception. I'm sure it's documented somewhere, but it has the reek of ununderstandable hidden magic. Is that really the only way?Rumilly
If the return happens from inside a closure, it seems to be the best available option. Returns from outside closures are of course translated directly to return instructions in the bytecode.Everhart
I'm sure I'm overlooking something, but why not instead compile the return from inside a closure to set an enclosed boolean flag and return-value, and check that after the closure-call returns?Bigamy
Why is his functional algorithm still 55 times slower? It doesn't look like it should suffer from such horrible performanceForesight
@Elijah, it looks like the Range literal is a lot of the problem. See my new post below.Superbomb
Martin's book speaks to this case.Minimum
Now, in 2014, I tested this again and for me performance is the following: java -> 0.3s; scala -> 3.6s; scala optimized -> 3.5s; scala functional -> 4s; Looks much better than 3 years ago, but... Still the difference is too big. Can we expect more performance improvements? In other words, Martin, is there anything, in theory, left for possible optimizations?Triecious
H
80

The problem is most likely the use of a for comprehension in the method isEvenlyDivisible. Replacing for by an equivalent while loop should eliminate the performance difference with Java.

As opposed to Java's for loops, Scala's for comprehensions are actually syntactic sugar for higher-order methods; in this case, you're calling the foreach method on a Range object. Scala's for is very general, but sometimes leads to painful performance.

You might want to try the -optimize flag in Scala version 2.9. Observed performance may depend on the particular JVM in use, and the JIT optimizer having sufficient "warm up" time to identify and optimize hot-spots.

Recent discussions on the mailing list indicate that the Scala team is working on improving for performance in simple cases:

Here is the issue in the bug tracker: https://issues.scala-lang.org/browse/SI-4633

Update 5/28:

  • As a short term solution, the ScalaCL plugin (alpha) will transform simple Scala loops into the equivalent of while loops.
  • As a potential longer term solution, teams from the EPFL and Stanford are collaborating on a project enabling run-time compilation of "virtual" Scala for very high performance. For example, multiple idiomatic functional loops can be fused at run-time into optimal JVM bytecode, or to another target such as a GPU. The system is extensible, allowing user defined DSLs and transformations. Check out the publications and Stanford course notes. Preliminary code is available on Github, with a release intended in the coming months.
Hime answered 27/5, 2011 at 0:39 Comment(7)
Great, I replaced the for comprehension with a while loop and it runs exactly the same speed (+/- < 1%) as the Java version. Thanks... I nearly lost faith in Scala for a minute! Now just gotta work on a good functional algorithm... :)Superbomb
It's worth noting that tail-recursive functions are also as fast as while loops (since both are converted to very similar or identical bytecode).Revolutionist
This got me once, too. Had to translate an algorithm from using collection functions to nested while loops (level 6!) because of incredible slow-down. This is something that needs to be heavily targeted, imho; of what use is a nice programming style if I can not use it when I need decent (note: not blazing fast) performance?Cita
When is for suitable then?Cressida
@Cressida - a for in scala behaves as the for ( : ) in java, for the most part.Dwanadwane
Why is his functional algorithm still 55 times slower? It doesn't look like it should suffer from such horrible performance.Foresight
I tried out ScalaCL, and it gets the functional version above down to 1.89 s on my computer. That's over 20x speed increase! Not quite as good as tail recursion but not far off, and more concise.Superbomb
S
32

As a follow-up, I tried the -optimize flag and it reduced running time from 103 to 76 seconds, but that's still 107x slower than Java or a while loop.

Then I was looking at the "functional" version:

object P005 extends App{
  def isDivis(x:Int) = (1 to 20) forall {x % _ == 0}
  def find(n:Int):Int = if (isDivis(n)) n else find (n+2)
  println (find (2))
}

and trying to figure out how to get rid of the "forall" in a concise manner. I failed miserably and came up with

object P005_V2 extends App {
  def isDivis(x:Int):Boolean = {
    var i = 1
    while(i <= 20) {
      if (x % i != 0) return false
      i += 1
    }
    return true
  }
  def find(n:Int):Int = if (isDivis(n)) n else find (n+2)
  println (find (2))
}

whereby my cunning 5-line solution has balooned to 12 lines. However, this version runs in 0.71 seconds, the same speed as the original Java version, and 56 times faster than the version above using "forall" (40.2 s)! (see EDIT below for why this is faster than Java)

Obviously my next step was to translate the above back into Java, but Java can't handle it and throws a StackOverflowError with n around the 22000 mark.

I then scratched my head for a bit and replaced the "while" with a bit more tail recursion, which saves a couple of lines, runs just as fast, but let's face it, is more confusing to read:

object P005_V3 extends App {
  def isDivis(x:Int, i:Int):Boolean = 
    if(i > 20) true
    else if(x % i != 0) false
    else isDivis(x, i+1)

  def find(n:Int):Int = if (isDivis(n, 2)) n else find (n+2)
  println (find (2))
}

So Scala's tail recursion wins the day, but I'm surprised that something as simple as a "for" loop (and the "forall" method) is essentially broken and has to be replaced by inelegant and verbose "whiles", or tail recursion. A lot of the reason I'm trying Scala is because of the concise syntax, but it's no good if my code is going to run 100 times slower!

EDIT: (deleted)

EDIT OF EDIT: Former discrepancies between run times of 2.5s and 0.7s were entirely due to whether the 32-bit or 64-bit JVMs were being used. Scala from the command line uses whatever is set by JAVA_HOME, while Java uses 64-bit if available regardless. IDEs have their own settings. Some measurements here: Scala execution times in Eclipse

Superbomb answered 28/5, 2011 at 6:35 Comment(4)
the isDivis-method can be written as: def isDivis(x: Int, i: Int): Boolean = if (i > 20) true else if (x % i != 0) false else isDivis(x, i+1). Notice that in Scala if-else is an expression which always return a value. No need for the return-keyword here.Kolinsky
Your last version (P005_V3) can be made shorter, more declarative and IMHO clearer by writing: def isDivis(x: Int, i: Int): Boolean = (i > 20) || (x % i == 0) && isDivis(x, i+1)Barnyard
@Barnyard No. This would break the tail-recursiveness, which is required to translate to a while-loop in bytecode, which in turn makes the execution fast.Brody
I see your point, but my example is still tail-recursive since && and || use short-circuit evaluation, as confirmed by using @tailrec: gist.github.com/Blaisorblade/5672562Barnyard
C
8

The answer about for comprehension is right, but it's not the whole story. You should note note that the use of return in isEvenlyDivisible is not free. The use of return inside the for, forces the scala compiler to generate a non-local return (i.e. to return outside it's function).

This is done through the use of an exception to exit the loop. The same happens if you build your own control abstractions, for example:

def loop[T](times: Int, default: T)(body: ()=>T) : T = {
    var count = 0
    var result: T = default
    while(count < times) {
        result = body()
        count += 1
    }
    result
}

def foo() : Int= {
    loop(5, 0) {
        println("Hi")
        return 5
    }
}

foo()

This prints "Hi" only once.

Note that the return in foo exits foo (which is what you would expect). Since the bracketed expression is a function literal, which you can see in the signature of loop this forces the compiler to generate a non local return, that is, the return forces you to exit foo, not just the body.

In Java (i.e. the JVM) the only way to implement such behavior is to throw an exception.

Going back to isEvenlyDivisible:

def isEvenlyDivisible(a:Int, b:Int):Boolean = {
  for (i <- 2 to b) 
    if (a % i != 0) return false
  return true
}

The if (a % i != 0) return false is a function literal that has a return, so each time the return is hit, the runtime has to throw and catch an exception, which causes quite a bit of GC overhead.

Contestation answered 16/6, 2011 at 13:2 Comment(0)
S
7

Some ways to speed up the forall method I discovered:

The original: 41.3 s

def isDivis(x:Int) = (1 to 20) forall {x % _ == 0}

Pre-instantiating the range, so we don't create a new range every time: 9.0 s

val r = (1 to 20)
def isDivis(x:Int) = r forall {x % _ == 0}

Converting to a List instead of a Range: 4.8 s

val rl = (1 to 20).toList
def isDivis(x:Int) = rl forall {x % _ == 0}

I tried a few other collections but List was fastest (although still 7x slower than if we avoid the Range and higher-order function altogether).

While I am new to Scala, I'd guess the compiler could easily implement a quick and significant performance gain by simply automatically replacing Range literals in methods (as above) with Range constants in the outermost scope. Or better, intern them like Strings literals in Java.


footnote: Arrays were about the same as Range, but interestingly, pimping a new forall method (shown below) resulted in 24% faster execution on 64-bit, and 8% faster on 32-bit. When I reduced the calculation size by reducing the number of factors from 20 to 15 the difference disappeared, so maybe it's a garbage collection effect. Whatever the cause, it's significant when operating under full load for extended periods.

A similar pimp for List also resulted in about 10% better performance.

  val ra = (1 to 20).toArray
  def isDivis(x:Int) = ra forall2 {x % _ == 0}

  case class PimpedSeq[A](s: IndexedSeq[A]) {
    def forall2 (p: A => Boolean): Boolean = {      
      var i = 0
      while (i < s.length) {
        if (!p(s(i))) return false
        i += 1
      }
      true
    }    
  }  
  implicit def arrayToPimpedSeq[A](in: Array[A]): PimpedSeq[A] = PimpedSeq(in)  
Superbomb answered 19/6, 2011 at 6:1 Comment(0)
F
3

I just wanted to comment for people who might lose faith in Scala over issues like this that these kinds of issues come up in the performance of just about all functional languages. If you are optimizing a fold in Haskell, you will often have to re-write it as a recursive tail-call-optimized loop, or else you'll have performance and memory issues to contend with.

I know it's unfortunate that FPs aren't yet optimized to the point where we don't have to think about things like this, but this is not at all a problem particular to Scala.

Friedlander answered 5/7, 2011 at 18:7 Comment(0)
B
2

Problems specific to Scala have already been discussed, but the main problem is that using a brute-force algorithm is not very cool. Consider this (much faster than the original Java code):

def gcd(a: Int, b: Int): Int = {
    if (a == 0)
        b
    else
        gcd(b % a, a)
}
print (1 to 20 reduce ((a, b) => {
  a / gcd(a, b) * b
}))
Bicollateral answered 6/10, 2012 at 19:53 Comment(1)
The questions compares performance of a specific logic across languages. Whether the algorithm is optimal for the problem is immaterial.Iny
V
1

Try the one-liner given in the solution Scala for Project Euler

The time given is at least faster than yours, though far from the while loop.. :)

Vapor answered 5/7, 2011 at 19:12 Comment(1)
It's pretty similar to my functional version. You could write mine as def r(n:Int):Int = if ((1 to 20) forall {n % _ == 0}) n else r (n+2); r(2), which is 4 characters shorter than Pavel's. :) However I don't pretend my code is any good - when I posted this question I had coded a total of about 30 lines of Scala.Superbomb

© 2022 - 2024 — McMap. All rights reserved.