How to inject the dependency of the next handler in a chain of responsibility?
Asked Answered
M

3

17

In my current project, I'm using quite a few Chain of Responsibility patterns.

However, I find it a bit awkward to configure the chain via dependency injection.

Given this model:

public interface IChainOfResponsibility 
{
    IChainOfResponsibility Next { get; }
    void Handle(Foo foo);
}

public class HandlerOne : IChainOfResponsibility 
{
    private DbContext _dbContext;

    public HandlerOne(IChainOfResponsibility next, DbContext dbContext)
    {
        Next = next;
        _dbContext = dbContext;
    }

    public IChainOfResponsibility Next { get; }

    public void Handle(Foo foo) { /*...*/}
}

public class HandlerTwo : IChainOfResponsibility 
{
    private DbContext _dbContext;

    public HandlerTwo(IChainOfResponsibility next, DbContext dbContext)
    {
        Next = next;
        _dbContext = dbContext;
    }

    public IChainOfResponsibility Next { get; }

    public void Handle(Foo foo) { /*...*/}
}

My Startup becomes:

public void ConfigureServices(IServiceCollection services)
{
    services.AddTransient<IChainOfResponsibility>(x => 
        new HandlerOne(x.GetRequiredService<HandlerTwo>(), x.GetRequiredService<DbContext>())
    );

    services.AddTransient(x => 
        new HandlerTwo(null, x.GetRequiredService<DbContext>())
    );
}

How to configure my chain of responsibility more cleanly?

Matheny answered 2/4, 2019 at 13:45 Comment(0)
M
16

I've hacked a simple solution, as I couldn't find anything that did what I wanted. It's working fine, as it uses IServiceProvider.GetRequiredService to resolve all constructor dependencies of all the handlers of the chain.

My startup class becomes:

public void ConfigureServices(IServiceCollection services)
{
    services.Chain<IChainOfResponsibility>()
        .Add<HandlerOne>()
        .Add<HandlerTwo>()
        .Configure();
}

What I'm doing is generating the lambda in the question dynamically using Expression. This is then compiled and registered in the IServiceCollection.AddTransient.

Because it generates compiled code, in the runtime it should run as fast as the question registration.

Here's the code that does the magic:

public static class ChainConfigurator
{
    public static IChainConfigurator<T> Chain<T>(this IServiceCollection services) where T : class
    {
        return new ChainConfiguratorImpl<T>(services);
    }

    public interface IChainConfigurator<T>
    {
        IChainConfigurator<T> Add<TImplementation>() where TImplementation : T;
        void Configure();
    }

    private class ChainConfiguratorImpl<T> : IChainConfigurator<T> where T : class
    {
        private readonly IServiceCollection _services;
        private List<Type> _types;
        private Type _interfaceType;

        public ChainConfiguratorImpl(IServiceCollection services)
        {
            _services = services;
            _types = new List<Type>();
            _interfaceType = typeof(T);
        }

        public IChainConfigurator<T> Add<TImplementation>() where TImplementation : T
        {
            var type = typeof(TImplementation);

            _types.Add(type);

            return this;
        }

        public void Configure()
        {
            if (_types.Count == 0)
                throw new InvalidOperationException($"No implementation defined for {_interfaceType.Name}");

            foreach (var type in _types)
            {
                ConfigureType(type);
            }
        }

        private void ConfigureType(Type currentType)
        {
            // gets the next type, as that will be injected in the current type
            var nextType = _types.SkipWhile(x => x != currentType).SkipWhile(x => x == currentType).FirstOrDefault();

            // Makes a parameter expression, that is the IServiceProvider x 
            var parameter = Expression.Parameter(typeof(IServiceProvider), "x");

            // get constructor with highest number of parameters. Ideally, there should be only 1 constructor, but better be safe.
            var ctor = currentType.GetConstructors().OrderByDescending(x => x.GetParameters().Count()).First();

            // for each parameter in the constructor
            var ctorParameters = ctor.GetParameters().Select(p =>
            {
                // check if it implements the interface. That's how we find which parameter to inject the next handler.
                if (_interfaceType.IsAssignableFrom(p.ParameterType))
                {
                    if (nextType is null)
                    {
                        // if there's no next type, current type is the last in the chain, so it just receives null
                        return Expression.Constant(null, _interfaceType);
                    }
                    else
                    {
                        // if there is, then we call IServiceProvider.GetRequiredService to resolve next type for us
                        return Expression.Call(typeof(ServiceProviderServiceExtensions), "GetRequiredService", new Type[] { nextType }, parameter);
                    }
                }
                
                // this is a parameter we don't care about, so we just ask GetRequiredService to resolve it for us 
                return (Expression)Expression.Call(typeof(ServiceProviderServiceExtensions), "GetRequiredService", new Type[] { p.ParameterType }, parameter);
            });

            // cool, we have all of our constructors parameters set, so we build a "new" expression to invoke it.
            var body = Expression.New(ctor, ctorParameters.ToArray());
            
            // if current type is the first in our list, then we register it by the interface, otherwise by the concrete type
            var first = _types[0] == currentType;
            var resolveType = first ? _interfaceType : currentType;
            var expressionType = Expression.GetFuncType(typeof(IServiceProvider), resolveType);

            // finally, we can build our expression
            var expression = Expression.Lambda(expressionType, body, parameter);

            // compile it
            var compiledExpression = (Func<IServiceProvider, object>)expression.Compile();

            // and register it in the services collection as transient
            _services.AddTransient(resolveType, compiledExpression );
        }
    }
}

PS.: I'm answering my own question for future reference (myself and hopefully others), but I'd love some feedback on this.

Matheny answered 2/4, 2019 at 13:45 Comment(3)
Func<IServiceProvider, object> x = (Func<IServiceProvider, object>)expression.Compile(); You are using x in linq expression, so change the variable nameCankered
@GouravGarg edited, thanks! In hindsight, as there's over an year that I wrote this, I could improve this code a bit more. Hopefully I'll have some time next weekMatheny
Just used this code in a project I'm working on and first impression is that it works really well, so nice solution and thanks for sharing!Ortrud
P
2

Quick solution working for simplest cases of dependency chains.

    public static IServiceCollection AddChained<TService>(this IServiceCollection services, params Type[] implementationTypes)
    {
        if (implementationTypes.Length == 0)
        {
            throw new ArgumentException("Pass at least one implementation type", nameof(implementationTypes));
        }

        foreach(Type type in implementationTypes)
        {
            services.AddScoped(type);
        }

        int order = 0;
        services.AddTransient(typeof(TService), provider =>
        {
            //starts again
            if (order > implementationTypes.Length - 1)
            {
                order = 0;
            }

            Type type = implementationTypes[order];
            order++;

            return provider.GetService(type);
        });

        return services;
    }

and then

services.AddChained<IService>(typeof(SomeTypeWithIService), typeof(SomeType));

Important notice:

Need to use this solution very carefully as it may not work consistently in multi-thread scenarios. order variable is not thread-safe here. Because of that it can't guarantee that it will always return first implementation in chain for our service.

For instance when we call services.GetService<IService>() we expect to receive instance of SomeTypeWithIService all the time as this is first implementation in chain. But if we do the same call in multiple threads we can sometimes receive SomeType instead because order is not thread-safe.

Parmenides answered 10/12, 2019 at 15:26 Comment(0)
L
1

I developed your idea by introducing the notion of ChainLink(current, next).

public class ItemDecoratorChainLink : IItemDecorator
{
    private readonly IItemDecorator[] _decorators;

    public ItemDecoratorChainLink(
        IItemDecorator current,
        IItemDecorator next)
    {
        if (current == null)
        {
            throw new ArgumentNullException(nameof(current));
        }

        _decorators = next != null
            ? new[] { current, next }
            : new[] { current };
    }

    public bool CanHandle(Item item) =>
        _decorators.Any(d => d.CanHandle(item));

    public void Decorate(Item item)
    {
        var decorators = _decorators.Where(d => d.CanHandle(item)).ToArray();

        foreach (var decorator in decorators)
        {
            decorator.Decorate(item);
        }
    }
}

Thus you don't need to keep a reference to "next" link inside links but burden the chainLink with that. Your links, wherein, become cleaner, relieved from duplication, and can care single responsibility.

Below is a code for the chain builder:

public class ComponentChainBuilder<TInterface> : IChainBuilder<TInterface>
    where TInterface : class
{
    private static readonly Type InterfaceType = typeof(TInterface);

    private readonly List<Type> _chain = new List<Type>();
    private readonly IServiceCollection _container;
    private readonly ConstructorInfo _chainLinkCtor;
    private readonly string _currentImplementationArgName;
    private readonly string _nextImplementationArgName;

    public ComponentChainBuilder(
        IServiceCollection container,
        Type chainLinkType,
        string currentImplementationArgName,
        string nextImplementationArgName)
    {
        _container = container;//.GuardNotNull(nameof(container));
        _chainLinkCtor = chainLinkType.GetConstructors().First();//.GuardNotNull(nameof(chainLinkType));
        _currentImplementationArgName = currentImplementationArgName;//.GuardNeitherNullNorWhitespace(nameof(currentImplementationArgName));
        _nextImplementationArgName = nextImplementationArgName;//.GuardNeitherNullNorWhitespace(nameof(nextImplementationArgName));
    }

    /// <inheritdoc />
    public IChainBuilder<TInterface> Link(Type implementationType)
    {
        _chain.Add(implementationType);
        return this;
    }

    /// <inheritdoc />
    public IChainBuilder<TInterface> Link<TImplementationType>()
      where TImplementationType : class, TInterface
        => Link(typeof(TImplementationType));

    public IServiceCollection Build(ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
    {
        if (_chain.Count == 0)
        {
            throw new InvalidOperationException("At least one link must be registered.");
        }

        var serviceProviderParameter = Expression.Parameter(typeof(IServiceProvider), "x");
        Expression chainLink = null;

        for (var i = _chain.Count - 1; i > 0; i--)
        {
            var currentLink = CreateLinkExpression(_chain[i - 1], serviceProviderParameter);
            var nextLink = chainLink ?? CreateLinkExpression(_chain[i], serviceProviderParameter);
            chainLink = CreateChainLinkExpression(currentLink, nextLink, serviceProviderParameter);
        }

        if (chainLink == null)
        {
            // only one type is defined so we use it to register dependency
            _container.Add(new ServiceDescriptor(InterfaceType, _chain[0], serviceLifetime));
        }
        else
        {
            // chain is built so we use it to register dependency
            var expressionType = Expression.GetFuncType(typeof(IServiceProvider), InterfaceType);
            var createChainLinkLambda = Expression.Lambda(expressionType, chainLink, serviceProviderParameter);
            var createChainLinkFunction = (Func<IServiceProvider, object>)createChainLinkLambda.Compile();

            _container.Add(new ServiceDescriptor(InterfaceType, createChainLinkFunction, serviceLifetime));
        }

        return _container;
    }

    private NewExpression CreateLinkExpression(Type linkType, ParameterExpression serviceProviderParameter)
    {
        var linkCtor = linkType.GetConstructors().First();
        var linkCtorParameters = linkCtor.GetParameters()
            .Select(p => GetServiceProviderDependenciesExpression(p, serviceProviderParameter))
            .ToArray();
        return Expression.New(linkCtor, linkCtorParameters);
    }

    private Expression CreateChainLinkExpression(
        Expression currentLink,
        Expression nextLink,
        ParameterExpression serviceProviderParameter)
    {
        var chainLinkCtorParameters = _chainLinkCtor.GetParameters().Select(p =>
        {
            if (p.Name == _currentImplementationArgName)
            {
                return currentLink;
            }

            if (p.Name == _nextImplementationArgName)
            {
                return nextLink;
            }

            return GetServiceProviderDependenciesExpression(p, serviceProviderParameter);
        }).ToArray();

        return Expression.New(_chainLinkCtor, chainLinkCtorParameters);
    }

    private static Expression GetServiceProviderDependenciesExpression(ParameterInfo parameter, ParameterExpression serviceProviderParameter)
    {
        // this is a parameter we don't care about, so we just ask GetRequiredService to resolve it for us
        return Expression.Call(
            typeof(ServiceProviderServiceExtensions),
            nameof(ServiceProviderServiceExtensions.GetRequiredService),
            new[] { parameter.ParameterType },
            serviceProviderParameter);
    }
}

And its extension:

   public static IChainBuilder<TInterface> Chain<TInterface, TChainLink>(
        this IServiceCollection container,
        string currentImplementationArgumentName = "current",
        string nextImplementationArgumentName = "next")
        where TInterface : class
        where TChainLink : TInterface
        => new ComponentChainBuilder<TInterface>(
            container,
            typeof(TChainLink),
            currentImplementationArgumentName,
            nextImplementationArgumentName);

The code for building chains looks like this:

        serviceProvider.Chain<IItemDecorator, ItemDecoratorChainLink>()
            .Link<ChannelItemDecorator>()
            .Link<CompetitionItemDecorator>()
            .Link<ProgramItemDecorator>()
            .Build(ServiceLifetime.Singleton);

And the full example of this approach can be found on my GitHub:

https://github.com/alex-valchuk/dot-net-expressions/blob/master/NetExpressions/ConsoleApp1/ConsoleApp1/ChainBuilder/ComponentChainBuilder.cs

Languorous answered 29/10, 2021 at 9:1 Comment(5)
Overall, I really like your solution. One catch is that it currently does not allow the link to decide if will call the next handler or break - hence why I was passing the next handler in the chain. This is solvable by having your links receive a next Action - like the approach Microsoft took with Middlewares.Matheny
The second thing is that allowing your current link to have full control of parameters, like my approach, enables it to decorate the original parameter before passing it to the next handler. The use case for this should be very small, it's just something to keep in mind. Thanks for sharing your approach, I really appreciated it.Matheny
My pleasure. Thank you for the example with expression trees. It has really made my day.Languorous
And what concerns the possibility of the link to decide whether to call the next handler or not. First, it's a matter of designs, and second, it looks like a corruption of SRP when every link is burdened with this additional responsibility.Languorous
I think there is a simple solution services.AddTransient<IItemDecorator, FirstDecorator>(); services.AddTransient<IItemDecorator, SecondDecorator>(); var services = _serviceProvider.GetServices<IItemDecorator>() foreach (var decorator in services.Where(d => d.CanHandle(item)) ) { decorator.Decorate(item); }Inositol

© 2022 - 2024 — McMap. All rights reserved.