How to write tail-recursive functions when working inside monads
Asked Answered
S

1

18

In general I have problems figuring out how to write tailrecursive functions when working 'inside' monads. Here is a quick example:

This is from a small example application that I am writing to better understand FP in Scala. First of all the user is prompted to enter a Team consisting of 7 Players. This function recursively reads the input:

import cats.effect.{ExitCode, IO, IOApp}
import cats.implicits._

case class Player (name: String)
case class Team (players: List[Player])

/**
  * Reads a team of 7 players from the command line.
  * @return
  */
def readTeam: IO[Team] = {
  def go(team: Team): IO[Team] = { // here I'd like to add @tailrec
    if(team.players.size >= 7){
      IO(println("Enough players!!")) >>= (_ => IO(team))
    } else {
      for {
        player <- readPlayer
        team   <- go(Team(team.players :+ player))
      } yield team
    }
  }
  go(Team(Nil))
}

private def readPlayer: IO[Player] = ???

Now what I'd like to achieve (mainly for educational purposes) is to be able to write a @tailrec notation in front of def go(team: Team). But I don't see a possibility to have the recursive call as my last statement because the last statement as far as I can see always has to 'lift' my Team into the IO monad.

Any hint would be greatly appreciated.

Snowfall answered 2/2, 2019 at 10:28 Comment(0)
L
27

First of all, this isn't necessary, because IO is specifically designed to support stack-safe monadic recursion. From the docs:

IO is trampolined in its flatMap evaluation. This means that you can safely call flatMap in a recursive function of arbitrary depth, without fear of blowing the stack…

So your implementation will work just fine in terms of stack safety, even if instead of seven players you needed 70,000 players (although at that point you might need to worry about the heap).

This doesn't really answer your question, though, and of course even @tailrec is never necessary, since all it does is verify that the compiler is doing what you think it should be doing.

While it's not possible to write this method in such a way that it can be annotated with @tailrec, you can get a similar kind of assurance by using Cats's tailRecM. For example, the following is equivalent to your implementation:

import cats.effect.IO
import cats.syntax.functor._

case class Player (name: String)
case class Team (players: List[Player])

// For the sake of example.
def readPlayer: IO[Player] = IO(Player("foo"))

/**
  * Reads a team of 7 players from the command line.
  * @return
  */
def readTeam: IO[Team] = cats.Monad[IO].tailRecM(Team(Nil)) {
  case team if team.players.size >= 7 =>
    IO(println("Enough players!!")).as(Right(team))
  case team =>
    readPlayer.map(player => Left(Team(team.players :+ player)))
}

This says "start with an empty team and repeatedly add players until we have the necessary number", but without any explicit recursive calls. As long as the monad instance is lawful (according to Cats's definition—there's some question about whether tailRecM even belongs on Monad), you don't have to worry about stack safety.

As a side note, fa.as(b) is equivalent to fa >>= (_ => IO(b)) but more idiomatic.

Also as a side note (but maybe a more interesting one), you can write this method even more concisely (and to my eye more clearly) as follows:

import cats.effect.IO
import cats.syntax.monad._

case class Player (name: String)
case class Team (players: List[Player])

// For the sake of example.
def readPlayer: IO[Player] = IO(Player("foo"))

/**
  * Reads a team of 7 players from the command line.
  * @return
  */
def readTeam: IO[Team] = Team(Nil).iterateUntilM(team =>
  readPlayer.map(player => Team(team.players :+ player))
)(_.players.size >= 7)

Again there are no explicit recursive calls, and it's even more declarative than the tailRecM version—it's just "perform this action iteratively until the given condition holds".


One postscript: you might wonder why you'd ever use tailRecM when IO#flatMap is stack safe, and one reason is that you may someday decide to make your program generic in the effect context (e.g. via the finally tagless pattern). In this case you should not assume that flatMap behaves the way you want it to, since lawfulness for cats.Monad doesn't require flatMap to be stack safe. In that case it would be best to avoid explicit recursive calls through flatMap and choose tailRecM or iterateUntilM, etc. instead, since these are guaranteed to be stack safe for any lawful monadic context.

Leonaleonanie answered 2/2, 2019 at 12:2 Comment(1)
Wow thank you very much for this detailed answer. I'll definetly use iterateUntilM in the future - much easier to read.Snowfall

© 2022 - 2024 — McMap. All rights reserved.