Recursion over a list of s-expressions in Clojure
Asked Answered
P

2

13

To set some context, I'm in the process of learning Clojure, and Lisp development more generally. On my path to Lisp, I'm currently working through the "Little" series in an effort to solidify a foundation in functional programming and recursive-based solution solving. In "The Little Schemer," I've worked through many of the exercises, however, I'm struggling a bit to convert some of them to Clojure. More specifically, I'm struggling to convert them to use "recur" so as to enable TCO. For example, here is a Clojure-based implementation to the "occurs*" function (from Little Schemer) which counts the number of occurrences of an atom appearing within a list of S-expressions:

(defn atom? [l]
  (not (list? l)))

(defn occurs [a lst]
  (cond
   (empty? lst) 0
   (atom? (first lst))
    (cond
     (= a (first lst)) (inc (occurs a (rest lst)))
     true (occurs a (rest lst)))
   true (+ (occurs a (first lst))
           (occurs a (rest lst)))))

Basically, (occurs 'abc '(abc (def abc) (abc (abc def) (def (((((abc))))))))) will evaluate to 5. The obvious problem is that this definition consumes stack frames and will blow the stack if given a list of S-expressions too deep.

Now, I understand the option of refactoring recursive functions to use an accumulator parameter to enable putting the recursive call into the tail position (to allow for TCO), but I'm struggling if this option is even applicable to situations such as this one.

Here's how far I get if I try to refactor this using "recur" along with using an accumulator parameter:

(defn recur-occurs [a lst]
  (letfn [(myoccurs [a lst count]
            (cond
             (empty? lst) 0
             (atom? (first lst))
             (cond
              (= a (first lst)) (recur a (rest lst) (inc count))
              true (recur a (rest lst) count))
             true (+ (recur a (first lst) count)
                     (recur a (rest lst) count))))]
    (myoccurs a lst 0)))

So, I feel like I'm almost there, but not quite. The obvious problem is my "else" clause in which the head of the list is not an atom. Conceptually, I want to sum the result of recurring over the first element in the list with the result of recurring over the rest of the list. I'm struggling in my head on how to refactor this such that the recurs can be moved to the tail position.

Are there additional techniques to the "accumulator" pattern to achieving getting your recursive calls put into the tail position that I should be applying here, or, is the issue simply more "fundamental" and that there isn't a clean Clojure-based solution due to the JVM's lack of TCO? If the latter, generally speaking, what should be the general pattern for Clojure programs to use that need to recur over a list of S-expressions? For what it's worth, I've seen the multi method w/lazy-seq technique used (page 151 of Halloway's "Programming Clojure" for reference) to "Replace Recursion with Laziness" - but I'm not sure how to apply that pattern to this example in which I'm not attempting to build a list, but to compute a single integer value.

Thank you in advance for any guidance on this.

Parthenos answered 8/11, 2011 at 4:9 Comment(1)
Just to be clear, I don't believe the code presented in the Little Schemer for occurs* can be tail call optimized in Scheme.Ronnie
S
11

Firstly, I must advise you to not worry much about implementation snags like stack overflows as you make your way through The Little Schemer. It is good to be conscientious of issues like the lack of tail call optimization when you're programming in anger, but the main point of the book is to teach you to think recursively. Converting the examples accumulator-passing style is certainly good practice, but it's essentially ditching recursion in favor of iteration.

However, and I must preface this with a spoiler warning, there is a way to keep the same recursive algorithm without being subject to the whims of the JVM stack. We can use continuation-passing style to make our own stack in the form of an extra anonymous function argument k:

(defn occurs-cps [a lst k]
  (cond
   (empty? lst) (k 0) 
   (atom? (first lst))
   (cond
    (= a (first lst)) (occurs-cps a (rest lst)
                                  (fn [v] (k (inc v))))
    :else (occurs-cps a (rest lst) k))
   :else (occurs-cps a (first lst)
                     (fn [fst]
                       (occurs-cps a (rest lst)
                                   (fn [rst] (k (+ fst rst))))))))

Instead of the stack being created implicitly by our non-tail function calls, we bundle up "what's left to do" after each call to occurs, and pass it along as the next continuation k. When we invoke it, we start off with a k that represents nothing left to do, the identity function:

scratch.core=> (occurs-cps 'abc 
                           '(abc (def abc) (abc (abc def) (def (((((abc)))))))) 
                           (fn [v] v))
5

I won't go further into the details of how to do CPS, as that's for a later chapter of TLS. However, I will note that this of course doesn't yet work completely:

scratch.core=> (def ls (repeat 20000 'foo))          
#'scratch.core/ls
scratch.core=> (occurs-cps 'foo ls (fn [v] v))       
java.lang.StackOverflowError (NO_SOURCE_FILE:0)

CPS lets us move all of our non-trivial, stack-building calls to tail position, but in Clojure we need to take the extra step of replacing them with recur:

(defn occurs-cps-recur [a lst k]
  (cond
   (empty? lst) (k 0)
   (atom? (first lst))
   (cond
    (= a (first lst)) (recur a (rest lst)
                             (fn [v] (k (inc v))))
    :else (recur a (rest lst) k))
   :else (recur a (first lst)
                (fn [fst]
                  (recur a (rest lst) ;; Problem
                         (fn [rst] (k (+ fst rst))))))))

Alas, this goes wrong: java.lang.IllegalArgumentException: Mismatched argument count to recur, expected: 1 args, got: 3 (core.clj:39). The very last recur actually refers to the fn right above it, the one we're using to represent our continuations! We can get good behavior most of the time by changing just that recur to a call to occurs-cps-recur, but pathologically-nested input will still overflow the stack:

scratch.core=> (occurs-cps-recur 'foo ls (fn [v] v))
20000
scratch.core=> (def nested (reduce (fn [onion _] (list onion)) 
                                   'foo (range 20000)))
#'scratch.core/nested
scratch.core=> (occurs-cps-recur 'foo nested (fn [v] v))
Java.lang.StackOverflowError (NO_SOURCE_FILE:0)

Instead of making the call to occurs-* and expecting it to give back an answer, we can have it return a thunk immediately. When we invoke that thunk, it'll go off and do some work right up until it does a recursive call, which in turn will return another thunk. This is trampolined style, and the function that "bounces" our thunks is trampoline. Returning a thunk each time we make a recursive call bounds our stack size to one call at a time, so our only limit is the heap:

(defn occurs-cps-tramp [a lst k]
  (fn [] 
    (cond
     (empty? lst) (k 0) 
     (atom? (first lst))
     (cond
      (= a (first lst)) (occurs-cps-tramp a (rest lst)
                                          (fn [v] (k (inc v))))
      :else (occurs-cps-tramp a (rest lst) k))
     :else (occurs-cps-tramp a (first lst)
                             (fn [fst]
                               (occurs-cps-tramp a (rest lst)
                                                 (fn [rst] (k (+ fst rst)))))))))

(declare done answer)

(defn my-trampoline [th]
  (if done
    answer
    (recur (th))))

(defn empty-k [v]
  (set! answer v)
  (set! done true))

(defn run []
  (binding [done false answer 'whocares]
    (my-trampoline (occurs-cps-tramp 'foo nested empty-k))))

;; scratch.core=> (run)                             
;; 1

Note that Clojure has a built-in trampoline (with some limitations on the return type). Using that instead, we don't need a specialized empty-k:

scratch.core=> (trampoline (occurs-cps-tramp 'foo nested (fn [v] v)))
1

Trampolining is certainly a cool technique, but the prerequisite to trampoline a program is that it must contain only tail calls; CPS is the real star here. It lets you define your algorithm with the clarity of natural recursion, and through correctness-preserving transformations, express it efficiently on any host that has a single loop and a heap.

Spacecraft answered 8/11, 2011 at 6:32 Comment(2)
Thank you very, very much for such a detailed answer. And yes, agreed 100% on the comment about not worrying about letting the details of the JVM get in the way of learning the concepts (recursion-based thinking) from TLS. At a high-level I understand your solution (er, pattern) - I'll admit though that since I haven't dug into CPS yet (and the trampoline thing), that it'll take me a while to fully wrap my head around your code :) With that said, it's good to know that the capability exists in Clojure to craft solutions that align in-spirit with the lessons from TLS.Parthenos
Happy to help. In many ways, TLS and its siblings are about giving you the tools to use the lessons in any environment. So, your question is quite well-aimed (I shared this page with Dan, incidentally, who described it as "Cool; very cool").Spacecraft
R
7

You can't do this with a fixed amount of memory. You can consume stack, or heap; that's the decision you get to make. If I were writing this in Clojure I would do it with map and reduce rather than with manual recursion:

(defn occurs [x coll]
  (if (coll? coll)
    (reduce + (map #(occurs x %) coll))
    (if (= x coll)
      1, 0)))

Note that shorter solutions exist if you use tree-seq or flatten, but at that point most of the problem is gone so there's not much to learn.

Edit

Here's a version that doesn't use any stack, instead letting its queue get larger and larger (using up heap).

(defn heap-occurs [item coll]
  (loop [count 0, queue coll]
    (if-let [[x & xs] (seq queue)]
      (if (coll? x)
        (recur count (concat x xs))
        (recur (+ (if (= item x) 1, 0)
                  count)
               xs))
      count)))
Rhomboid answered 8/11, 2011 at 4:25 Comment(4)
Ahhh - this solution makes perfect sense, and still captures the essence of the lesson from Little Schemer (the algorithm is essentially the same). So is it fair to say that this bit represents a decision to consume heap (instead of stack)? Thank you again.Parthenos
To be honest lazy sequences are sometimes hard to reason about, and you can be forgiven for getting this wrong. This algorithm actually consumes stack, because each "layer" of map can only fully return when the layer below it is finished. For example, (occurs 'x (nth (iterate list 'x) 1000)) runs out of stack space. Edit: on my machine, anyway.Rhomboid
@PaulEvans I added a version that consumes heap instead - as is often the case it's harder to follow because you're essentially managing your own stack and storing it on the heap.Rhomboid
Thanks for the clarification on the stack vs heap consuming variants. This is very helpful and is sure to help in "converting" the "Little Schemer" solutions to ClojureParthenos

© 2022 - 2024 — McMap. All rights reserved.