How to use C# async/await as a stand-alone CPS transform
Asked Answered
S

4

9

Note 1 : Here CPS stands for "continuation passing style"

I would be very interested in understanding how to hook into C# async machinery. Basically as I understand C# async/await feature, the compiler is performing a CPS transform and then passes the transformed code to a contextual object that manages the scheduling of tasks on various threads.

Do you think it is possible to leverage that compiler feature to create powerful combinators while leaving aside the default threading aspect ?

An example would be something that can derecursify and memoize a method like

async MyTask<BigInteger> Fib(int n)     // hypothetical example
{
    if (n <= 1) return n;
    return await Fib(n-1) + await Fib(n-2);
}

I managed to do it with something like:

void Fib(int n, Action<BigInteger> Ret, Action<int, Action<BigInteger>> Rec)
{
    if (n <= 1) Ret(n);
    else Rec(n-1, x => Rec(n-2, y => Ret(x + y)));
}

(no use of async, very kludgy...)

or using a monad (While<X> = Either<X, While<X>>)

While<X> Fib(int n) => n <= 1 ?
    While.Return((BigInteger) n) :
    from x in Fib(n-1)
    from y in Fib(n-2)
    select x + y;

a bit better but not as cute looking as the async syntax :)


I have asked this question on the blog of E. Lippert and he was kind enough to let me know it is indeed possible.


The need for me arose when implementing a ZBDD library: (a special kind of DAG)

  • lots of complex mutually recursive operations

  • stack overflows constantly on real examples

  • only practical if fully memoized

Manual CPS and derecursification was very tedious and error prone.


The acid test for what I am after (stack safety) would be something like:

async MyTask<BigInteger> Fib(int n, BigInteger a, BigInteger b)
{
    if (n == 0) return b;
    if (n == 1) return a;
    return await Fib(n - 1, a + b, a);
}

which produces a stack overflow on Fib(10000, 1, 0) with the default behavior. Or even better, using the code at the beginning with memoization to compute Fib(10000).

Sheath answered 16/5, 2019 at 14:55 Comment(11)
Doesn't IEnumerable<T>/IEnumerator<T> coupled with yield give you what you need? It is effectively the decoupled machinery of async/await.Pecos
It might be possible : IEnumerator<T> is conceptually similar to Maybe[(T, IEnumerator<T>)] albeit stateful. I also mentionned a monad construction While<X> = Either<X, While<X>> that does the trick, but really my question is about hijacking the CPS transform performed by the compiler on the async/await statements.Sheath
It's already possible to extend this mechanism without getting into the compiler level by using custom awaiters. As for the memoization I have serious doubts because async semantic doesn't imply it.Cryoscope
@DmytroMukalov reading that link seems promising, especially that one : Await task1.OnCoroutine(crm1)Sheath
Yes, and you can take a look at this one as well. But again - the memoization is problem.Cryoscope
@DmytroMukalov, You are right that IEnumerable<T> can be used to model (lazy) tree of computation. The classic paper by Moggi mentions something similar (Example 1.1 p. 3 nondeterminism) This idea can be applied to model minmax game tree search or to model monadic parsers as explained in this paper. It's a truly great idea.Sheath
Not to belittle the question or its answers, but wouldn't code like this be easier to write in F# (even if you needed to introduce transformations there as well)? This sounds like the sort of thing you could achieve with a custom computation expression (even if it might need considerable explicit machinery, at least the end result would look passable). These are functional concepts, so why not use a functional language?Murmansk
@JeroenMostert C# happens to enable a very confortable style of functional programming too. I mention using C# query comprehension style on a monad in my question (similar to F# computation expression). It is really a question about hijacking the CPS transform performed by the C# compiler.Sheath
I know, but do note that F#'s computation expressions are more flexible than what C# can do with query comprehension, as those patterns are fixed, whereas F# allows adding your own -- plus F#'s typing and generics make "doing things" to functions easier in general. I'm not disputing that you can do this sort of thing with C#, and the question is valid on its own, but taking a step back to consider if you're even using the right tool is never a bad thing either, beyond purely intellectual exercises. I specifically mentioned F# because C# and F# can easily mix, having .NET in common.Murmansk
That's a very valuable comment thanks. But you are right about the intellectual curiosity/speculative aspect. Please note that I couldn't care less about the Fibonacci example, but your remark can turn out to be a life saver for what I am really after (The ZBDD lib I am writing)Sheath
@JeroenMostert I completely forgot the other day, but would you like to post an answer using F# to demonstrate how it is used ?Sheath
C
1

Here's my version of solution. It's stack safe and doesn't utilize thread pool but has specific limitation. In particular it requires tail-recursive style of method, so constructions like Fib(n-1) + Fib(n-2) won't work. From other hand the tail recursive nature which actually is executed in iterative manner doesn't require a memoization as each iteration is called once. It has no edge cases protection but it's rather a prototype than a final solution:

public class RecursiveTask<T>
{
    private T _result;

    private Func<RecursiveTask<T>> _function;

    public T Result
    {
        get
        {
            var current = this;
            var last = current;

            do
            {
                last = current;
                current = current._function?.Invoke();
            } while (current != null);

            return last._result;
        }
    }

    private RecursiveTask(Func<RecursiveTask<T>> function)
    {
        _function = function;
    }

    private RecursiveTask(T result)
    {
        _result = result;
    }

    public static implicit operator RecursiveTask<T>(T result)
    {
        return new RecursiveTask<T>(result);
    }

    public static RecursiveTask<T> FromFunc(Func<RecursiveTask<T>> func) => new RecursiveTask<T>(func);
}

And the usage:

class Program
{
    static RecursiveTask<int> Fib(int n, int a, int b)
    {
        if (n == 0) return a;
        if (n == 1) return b;

        return RecursiveTask<int>.FromFunc(() => Fib(n - 1, b, a + b));
    }

    static RecursiveTask<int> Factorial(int n, int a)
    {
        if (n == 0) return a;

        return RecursiveTask<int>.FromFunc(() => Factorial(n - 1, n * a));
    }


    static void Main(string[] args)
    {
        Console.WriteLine(Factorial(5, 1).Result);
        Console.WriteLine(Fib(100000, 0, 1).Result);
    }
}

Note that it's important to return a function which wraps the recurrent call, not a call itself in order to avoid real recursion.

UPDATE Below is another implementation which still doesn't utilize CPS transform but allows to use semantic close to algebraic recursion, that is it supports multiple recursive-like calls inside a function and doesn't require function to be tail-recursive.

public class RecursiveTask<T1, T2>
{
    private readonly Func<RecursiveTask<T1, T2>, T1, T2> _func;
    private readonly Dictionary<T1, RecursiveTask<T1, T2>> _allTasks;
    private readonly List<RecursiveTask<T1, T2>> _subTasks;
    private readonly RecursiveTask<T1, T2> _rootTask;
    private T1 _arg;
    private T2 _result;
    private int _runsCount;
    private bool _isCompleted;
    private bool _isEvaluating;

    private RecursiveTask(Func<RecursiveTask<T1, T2>, T1, T2> func)
    {
        _func = func;
        _allTasks = new Dictionary<T1, RecursiveTask<T1, T2>>();
        _subTasks = new List<RecursiveTask<T1, T2>>();
        _rootTask = this;
    }

    private RecursiveTask(Func<RecursiveTask<T1, T2>, T1, T2> func, T1 arg, RecursiveTask<T1, T2> rootTask) : this(func)
    {
        _arg = arg;
        _rootTask = rootTask;
    }

    public T2 Run(T1 arg)
    {
        if (!_isEvaluating)
            BuildTasks(arg);

        if (_isEvaluating)
            return EvaluateTasks(arg);

        return default;
    }

    public static RecursiveTask<T1, T2> Create(Func<RecursiveTask<T1, T2>, T1, T2> func)
    {
        return new RecursiveTask<T1, T2>(func);
    }

    private void AddSubTask(T1 arg)
    {
        if (!_allTasks.TryGetValue(arg, out RecursiveTask<T1, T2> subTask))
        {
            subTask = new RecursiveTask<T1, T2>(_func, arg, this);
            _allTasks.Add(arg, subTask);
            _subTasks.Add(subTask);
        }
    }

    private T2 Run()
    {
        if (!_isCompleted)
        {
            var runsCount = _rootTask._runsCount;
            _result = _func(_rootTask, _arg);
            _isCompleted = runsCount == _rootTask._runsCount;
        }
        return _result;
    }

    private void BuildTasks(T1 arg)
    {
        if (_runsCount++ == 0)
            _arg = arg;

        if (EqualityComparer<T1>.Default.Equals(_arg, arg))
        {
            Run();

            var processed = 0;
            var addedTasksCount = _subTasks.Count;
            while (processed < addedTasksCount)
            {
                for (var i = processed; i < addedTasksCount; i++, processed++)
                    _subTasks[i].Run();
                addedTasksCount = _subTasks.Count;
            }
            _isEvaluating = true;
        }
        else
            AddSubTask(arg);
    }

    private T2 EvaluateTasks(T1 arg)
    {
        if (EqualityComparer<T1>.Default.Equals(_arg, arg))
        {
            foreach (var task in Enumerable.Reverse(_subTasks))
                task.Run();

            return Run();
        }
        else
        {
            if (_allTasks.TryGetValue(arg, out RecursiveTask<T1, T2> task))
                return task._isCompleted ? task._result : task.Run();
            else
                return default;
        }
    }
}

The usage:

class Program
{
    static int Fib(int num)
    {
        return RecursiveTask<int, int>.Create((t, n) =>
        {
            if (n == 0) return 0;
            if (n == 1) return 1;

            return t.Run(n - 1) + t.Run(n - 2);
        }).Run(num);
    }

    static void Main(string[] args)
    {
        Console.WriteLine(Fib(7));
        Console.WriteLine(Fib(100000));
    }
}

As benefits, it's stack-safe, doesn't use thread pool, isn't burdened with async await infrastructure, uses memoization and allows to use more or less readable semantic. Current implementation implies using only with functions with a single argument. To make it applicable to wider range of functions, similar implementations should be provided for different sets of generic arguments:

RecursiveTask<T1, T2, T3>
RecursiveTask<T1, T2, T3, T4>
...
Cryoscope answered 17/5, 2019 at 20:2 Comment(3)
You last remark is also what makes the Task.Run(() => _) stackless in GSerg's solution.Sheath
That's a nice implementation of the trampoline pattern you have here (Eric also mentioned it in a comment). Using the CPS transform all calls become effectively tail calls so your solution will apply. The idea of my question is that the compiler does effectively a CPS transform to implement the async/await syntax, wouldn't it be nice to be able to use it stand alone and apply it to your trampoline. +1 for the research effort an the clean reusable pattern implementation by the way.Sheath
@SamuelVidal, I think it's problematic to combine trampolining and async await mechanism as trampolining implies narrower set of scenarios. However I provided another implementation which allows to use semantic close to what is used in your first solution. I think it's possible to achieve something similar with async await but I'm not sure if it's reasonable as the similar thing can be done without async mechanism burden and complexity.Cryoscope
E
0

The acid test for what I am after (stack safety) would be something like:

async MyTask<BigInteger> Fib(int n, BigInteger a, BigInteger b)
{
    if (n == 0) return b;
    if (n == 1) return a;
    return await Fib(n - 1, a + b, a);
}

Would that not be simply

public static Task<BigInteger> Fib(int n, BigInteger a, BigInteger b)
{
    if (n == 0) return Task.FromResult(b);
    if (n == 1) return Task.FromResult(a);

    return Task.Run(() => Fib(n - 1, a + b, a));
}

?


Or, without using the thread pool,

public static async Task<BigInteger> Fib(int n, BigInteger a, BigInteger b)
{
    if (n == 0) return b;
    if (n == 1) return a;

    return await Task.FromResult(a + b).ContinueWith(t => Fib(n - 1, t.Result, a), TaskScheduler.FromCurrentSynchronizationContext()).Unwrap();
}

, unless I grossly misunderstand something.

Elk answered 16/5, 2019 at 16:23 Comment(6)
I believe that Samuel is looking for a solution that does not do work on thread pool worker threads. I think he's looking for something more like a trampoline.Retake
As an alternative the Task.FromResult can be used instead of Task.Run as well.Cryoscope
@DmytroMukalov: How is that then stackless?Retake
@EricLippert The second version works on the context thread for me. For some reason I'm not entirely sure it is correct even though it appears to be.Elk
This is indeed stack safe and has the merit to be fast and clean. But I would like to see if that can be done without changing the initial code.Sheath
I was glad to learn that the Task.Run(() => _); construct produce a 'stackless' behavior hence the +1.Sheath
H
0

Without looking into your MyTask<T> and looking at the stack trace of that exception it's impossible to know what's happening.

Looks like what you're looking for is Generalized async return types.

You can browse the source to see how it's done for ValueTask and ValueTask<T>.

Heptastich answered 16/5, 2019 at 17:44 Comment(3)
It is very easy to tell what is happening, the function is recursive and it calls itself 10000 times which results in stack overflow.Elk
Nice suggestion I have looked at the source. (it's heavy stuff ^^=)Sheath
It's what the compiler requires.Heptastich
S
0

A solution closer to what I am after but not yet totally satisfactory is the following. It is based on insight from GSerg proposed solution for stack safety with memoization added.

Pro The core of the algorithm (FibAux method is using the clean async/await syntax).

Cons It is still using the thread pool for execution.

    // Core algorithm using the cute async/await syntax
    // (n.b. this would be exponential without memoization.)
    private static async Task<BigInteger> FibAux(int n)
    {
        if (n <= 1) return n;
        return await Rec(n - 1) + await Rec(n - 2);
    }

    public static Func<int, Task<BigInteger>> Rec { get; }
        = Utils.StackSafeMemoize<int, BigInteger>(FibAux);

    public static BigInteger Fib(int n)
        => FibAux(n).Result;

    [Test]
    public void Test()
    {
        Console.WriteLine(Fib(100000));
    }

    public static class Utils
    {
        // the combinator (still using the thread pool for execution)
        public static Func<X, Task<Y>> StackSafeMemoize<X, Y>(Func<X, Task<Y>> func)
        {
            var memo = new Dictionary<X, Y>();
            return x =>
            {
                Y result;
                if (!memo.TryGetValue(x, out result))
                {
                    return Task.Run(() => func(x).ContinueWith(task =>
                    {
                        var y = task.Result;
                        memo[x] = y;
                        return y;
                    }));
                }

                return Task.FromResult(result);
            };
        }
    } 

For comparison, this is the cps version not using async/await.


    public static BigInteger Fib(int n)
    {
        var fib = Memo<int, BigInteger>((m, rec, cont) =>
        {
            if (m <= 1) cont(m);
            else rec(m - 1, x => rec(m - 2, y => cont(x + y)));
        });

        return fib(n);
    }

    [Test]
    public void Test()
    {
        Console.WriteLine(Fib(100000));
    }

    // ---------

    public static Func<X, Y> Memo<X, Y>(Action<X, Action<X, Action<Y>>, Action<Y>> func)
    {
        var memo = new Dictionary<X, Y>(); // can be a Lru cache
        var stack = new Stack<Action>();

        Action<X, Action<Y>> rec = null;
        rec = (x, cont) =>
        {
            stack.Push(() =>
            {
                Y res;
                if (memo.TryGetValue(x, out res))
                {
                    cont(res);
                }
                else
                {
                    func(x, rec, y =>
                    {
                        memo[x] = y;
                        cont(y);
                    });
                }
            });
        };

        return x =>
        {
            var res = default(Y);
            rec(x, y => res = y);
            while (stack.Count > 0)
            {
                var next = stack.Pop();
                next();
            }

            return res;
        };
    }

Sheath answered 17/5, 2019 at 1:19 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.