How does this continuation-passing style Clojure function generator work?
Asked Answered
H

1

12

This is from the Joy of Clojure, 2nd Edition. http://www.manning.com/fogus2/

 (defn mk-cps [accept? kend kont] 
   (fn [n] 
     ((fn [n k] 
        (let [cont (fn [v] (k ((partial kont v) n)))] 
          (if (accept? n) 
            (k 1) 
            (recur (dec n) cont)))) 
      n kend))) 

Then to make a factorial:

(def fac (mk-cps zero? identity #(* %1 %2)))

My understanding:

  • mm-cps generates a function which takes in n, the fn [n]
  • the function inside, fn [n k], is initially called with n and kend
  • the continuation function cont [v] is defined as (calling k with the partial application of kont with v) as the first parameter and n as the second parameter. Why would this be written using partial instead of simply (k (cont v n)) ?
  • if the accept? function passes, then finish the recursion, applying k to 1.
  • otherwise, the recur recurs back to fn [n k] with a decremented n, and with the continuation function.
  • all throughout, kont does not change.

Am I right that k isn't actually executed until the final (k 1)? So, (fac 3) is expanded first to (* 1 (* 2 3)) before being evaluated.

Hitormiss answered 31/1, 2014 at 10:32 Comment(0)
S
15

I don't have the book, but I assume the motivating example is

(defn fact-n [n]
  (if (zero? n)
      1
      (* n (recur (dec n)))))

;=> CompilerException: Can only recur from tail position

And that last form has to be written (* n (fact-n (dec n))) instead, not tail-recursive. The problem is there is something remaining to be done after the recursion, namely multiplication by n.

What continuation passing style does is turn this inside out. Instead of applying what remains of the current context/continuation after the recursive call returns, pass the context/continuation into the recursive call to apply when complete. Instead of implicitly storing continuations on the stack as call frames, we explicitly accumulate them via function composition.

In this case, we add an additional argument k to our factorial, a function that does what we would have done after the recursive call returns.

(defn fact-nk [n k]
  (if (zero? n)
      (k 1)
      (recur (dec n) (comp k (partial * n)))))

The first k in is the last one out. Ultimately here we just want to return the value calculated, so the first k in should be the identity function.

Here's the base case:

(fact-nk 0 identity)
;== (identity 1)
;=> 1

Here's n = 3:

(fact-nk 3 identity)
;== (fact-nk 2 (comp identity (partial * 3)))
;== (fact-nk 1 (comp identity (partial * 3) (partial * 2)))
;== (fact-nk 0 (comp identity (partial * 3) (partial * 2) (partial * 1)))
;== ((comp identity (partial * 3) (partial * 2) (partial * 1)) 1)
;== ((comp identity (partial * 3) (partial * 2)) 1)
;== ((comp identity (partial * 3)) 2)
;== ((comp identity) 6)
;== (identity 6)
;=> 6

Compare to the non-tail recursive version

(fact-n 3)
;== (* 3 (fact-n 2))
;== (* 3 (* 2 (fact-n 1)))
;== (* 3 (* 2 (* 1 (fact-n 0))))
;== (* 3 (* 2 (* 1 1)))
;== (* 3 (* 2 1))
;== (* 3 2)
;=> 6

Now to make this a bit more flexible, we could factor out the zero? and the * and make them variable arguments instead.

A first approach would be

(defn cps-anck [accept? n c k]
  (if (accept? n)
      (k 1)
      (recur accept?, (dec n), c, (comp k (partial c n)))))

But since accept? and c are not changing, we could lift then out and recur to an inner anonymous function instead. Clojure has a special form for this, loop.

(defn cps-anckl [accept? n c k]
  (loop [n n, k k]
    (if (accept? n)
        (k 1)
        (recur (dec n) (comp k (partial c n))))))

And finally we might want to turn this into a function generator that pulls in n.

(defn gen-cps [accept? c k]
  (fn [n]
    (loop [n n, k k]
      (if (accept? n)
          (k 1)
          (recur (dec n) (comp k (partial c n)))))))

And that is how I would write mk-cps (note: last two arguments reversed).

(def factorial (gen-cps zero? * identity))
(factorial 5)
;=> 120

(def triangular-number (gen-cps #{1} + identity))    
(triangular-number 5)
;=> 15
Stravinsky answered 1/2, 2014 at 6:4 Comment(1)
Thanks for spelling it out! The book has too much "magic" on this part, I think.Hitormiss

© 2022 - 2024 — McMap. All rights reserved.