Kotlin suspend function recursive call
Asked Answered
R

2

5

Suddenly discover that recursive call of suspend function takes more time then calling the same function but without suspend modifier, so please consider the code snippet below (basic Fibonacci series calculation):

suspend fun asyncFibonacci(n: Int): Long = when {
    n <= -2 -> asyncFibonacci(n + 2) - asyncFibonacci(n + 1)
    n == -1 -> 1
    n == 0 -> 0
    n == 1 -> 1
    n >= 2 -> asyncFibonacci(n - 1) + asyncFibonacci(n - 2)
    else -> throw IllegalArgumentException()
}

If I call this function and measure its execution time with code below:

fun main(args: Array<String>) {
    val totalElapsedTime = measureTimeMillis {
        val nFibonacci = 40

        val deferredFirstResult: Deferred<Long> = async {
            asyncProfile("fibonacci") { asyncFibonacci(nFibonacci) } as Long
        }
        val deferredSecondResult: Deferred<Long> = async {
            asyncProfile("fibonacci") { asyncFibonacci(nFibonacci) } as Long
        }

        val firstResult: Long = runBlocking { deferredFirstResult.await() }
        val secondResult: Long = runBlocking { deferredSecondResult.await() }
        val superSum = secondResult + firstResult
        println("${thread()} - Sum of two $nFibonacci'th fibonacci numbers: $superSum")
    }
    println("${thread()} - Total elapsed time: $totalElapsedTime millis")
}

I observe further results:

commonPool-worker-2:fibonacci - Start calculation...
commonPool-worker-1:fibonacci - Start calculation...
commonPool-worker-2:fibonacci - Finish calculation...
commonPool-worker-2:fibonacci - Elapsed time: 7704 millis
commonPool-worker-1:fibonacci - Finish calculation...
commonPool-worker-1:fibonacci - Elapsed time: 7741 millis
main - Sum of two 40'th fibonacci numbers: 204668310
main - Total elapsed time: 7816 millis

But if I remove suspend modifier from asyncFibonacci function, I'll have this result:

commonPool-worker-2:fibonacci - Start calculation...
commonPool-worker-1:fibonacci - Start calculation...
commonPool-worker-1:fibonacci - Finish calculation...
commonPool-worker-1:fibonacci - Elapsed time: 1179 millis
commonPool-worker-2:fibonacci - Finish calculation...
commonPool-worker-2:fibonacci - Elapsed time: 1201 millis
main - Sum of two 40'th fibonacci numbers: 204668310
main - Total elapsed time: 1250 millis

I know that's better to rewrite such a function with tailrec it will increase its execution time apx. almost in 100 times, but anyway, what this suspend key word does that decrease execution speed from 1 second to 8 seconds?

Is it totally stupid idea to mark recursive functions with suspend?

Roentgenology answered 6/1, 2018 at 17:21 Comment(2)
I'm trying to reproduce your results, but I can't resolve asyncProfile.Dinothere
@MarkoTopolnik sorry for that, here is gist with full example that has mentioned util functionRoentgenology
D
6

As an introductory comment, your testing code setup is too complex. This much simpler code achieves the same in terms of stressing suspend fun recursion:

fun main(args: Array<String>) {
    launch(Unconfined) {
        val nFibonacci = 37
        var sum = 0L
        (1..1_000).forEach {
            val took = measureTimeMillis {
                sum += suspendFibonacci(nFibonacci)
            }
            println("Sum is $sum, took $took ms")
        }
    }
}

suspend fun suspendFibonacci(n: Int): Long {
    return when {
        n >= 2 -> suspendFibonacci(n - 1) + suspendFibonacci(n - 2)
        n == 0 -> 0
        n == 1 -> 1
        else -> throw IllegalArgumentException()
    }
}

I tried to reproduce its performance by writing a plain function that approximates the kinds of things the suspend function must do to achieve suspendability:

val COROUTINE_SUSPENDED = Any()

fun fakeSuspendFibonacci(n: Int, inCont: Continuation<Unit>): Any? {
    val cont = if (inCont is MyCont && inCont.label and Integer.MIN_VALUE != 0) {
        inCont.label -= Integer.MIN_VALUE
        inCont
    } else MyCont(inCont)
    val suspended = COROUTINE_SUSPENDED
    loop@ while (true) {
        when (cont.label) {
            0 -> {
                when {
                    n >= 2 -> {
                        cont.n = n
                        cont.label = 1
                        val f1 = fakeSuspendFibonacci(n - 1, cont)!!
                        if (f1 === suspended) {
                            return f1
                        }
                        cont.data = f1
                        continue@loop
                    }
                    n == 1 || n == 0 -> return n.toLong()
                    else -> throw IllegalArgumentException("Negative input not allowed")
                }
            }
            1 -> {
                cont.label = 2
                cont.f1 = cont.data as Long
                val f2 = fakeSuspendFibonacci(cont.n - 2, cont)!!
                if (f2 === suspended) {
                    return f2
                }
                cont.data = f2
                continue@loop
            }
            2 -> {
                val f2 = cont.data as Long
                return cont.f1 + f2
            }
            else -> throw AssertionError("Invalid continuation label ${cont.label}")
        }
    }
}

class MyCont(val completion: Continuation<Unit>) : Continuation<Unit> {
    var label = 0
    var data: Any? = null
    var n: Int = 0
    var f1: Long = 0

    override val context: CoroutineContext get() = TODO("not implemented")
    override fun resumeWithException(exception: Throwable) = TODO("not implemented")
    override fun resume(value: Unit) = TODO("not implemented")
}

You have to invoke this one with

sum += fakeSuspendFibonacci(nFibonacci, InitialCont()) as Long

where InitialCont is

class InitialCont : Continuation<Unit> {
    override val context: CoroutineContext get() = TODO("not implemented")
    override fun resumeWithException(exception: Throwable) = TODO("not implemented")
    override fun resume(value: Unit) = TODO("not implemented")
}

Basically, to compile a suspend fun the compiler has to turn its body into a state machine. Each invocation must also create an object to hold the machine's state. When you resume, the state object tells which state handler to go to. The above still isn't all there is to it, the real code is even more complex.

In intepreted mode (java -Xint), I get almost the same performance as the actual suspend fun, and it is less than twice as fast than the real one with JIT enabled. By comparison, the "direct" function implementation is about 10 times as fast. That means that the code shown explains a good part of the overhead of suspendability.

Dinothere answered 7/1, 2018 at 15:38 Comment(2)
Ok, thanks a lot, I have in mind that state machine with kind of switch-case was generated in bytecode, but what I didn't expect that so many continuation objects will be created even there is no other suspend function calls inside recursive fibonacci function.Roentgenology
Recursion is a very narrow special case where the compiler could actually know that the callee will not suspend. In general there is no point to have a suspend fun that won't eventually reach suspendCoroutine so there's no point in optimizing for it.Dinothere
H
2

The problem lies in the Java bytecode generated from the suspend function. While a non-suspend function just generates bytecode like we'd expect it:

public static final long asyncFibonacci(int n) {
  long var10000;
  if (n <= -2) {
     var10000 = asyncFibonacci(n + 2) - asyncFibonacci(n + 1);
  } else if (n == -1) {
     var10000 = 1L;
  } else if (n == 0) {
     var10000 = 0L;
  } else if (n == 1) {
     var10000 = 1L;
  } else {
     if (n < 2) {
        throw (Throwable)(new IllegalArgumentException());
     }

     var10000 = asyncFibonacci(n - 1) + asyncFibonacci(n - 2);
  }

  return var10000;
}

When you add the suspend keyword, the decompiled Java source code is 165 lines - so a lot larger. You can view the bytecode and the decompiled Java code in IntelliJ by going to Tools -> Kotlin -> Show Kotlin bytecode (and then click Decompile on top of the page). While it's not easy to tell what exactly the Kotlin compiler is doing in the function, it looks like it's doing a whole lot of coroutine status checking - which kind of makes sense given that a coroutine can be suspended at any time.

So as a conclusion I'd say that every suspend method call is a lot more heavy than a non-suspend call. This does not only apply to recursive functions, but probably has the worst result on them.

Is it totally stupid idea to mark recursive functions with suspend?

Unless you have a very good reason to do so - Yes

Hesychast answered 6/1, 2018 at 21:19 Comment(5)
The suspend fun checks whether the continuation object is its own of from its caller. If from the caller, it creates a new one, if its own, it jumps to the location in its body where the resumption should take place. asyncFibonacci(40) makes a total of 331,160,281 suspend fun calls so it creates that many continuation objects. It never suspends/resumes so the other checks aren't exercised.Dinothere
Actually, self-recursion is the worst case, it must make all the checks to realize that, even though the continuation is of its own type, it is not a resuming invocation.Dinothere
@Hesychast Thanks for explanation, I've also firstly tried to decompile the code and read generated bytecode, but as Marko explained problem with suspendable functions lies a bit deeper ("asyncFibonacci(40) makes a total of 331,160,281 suspend fun calls so it creates that many continuation objects"), in any way, you guys helped me a lot and clarify the case.Roentgenology
You'd be surprised what little effect the object allocation itself has---In an early version i had just some mock object created and there was almost no difference in speed. The allocation of short-lived objects is an extremely cheap operation in a language with a generational GC.Dinothere
@MarkoTopolnik interesting, most of the time I hear people blaming the JVM GC for every performance problem they have. Nice to see that the GC also has some advantagesHesychast

© 2022 - 2024 — McMap. All rights reserved.