Making wheels by hand: realizing a simple dependency injection

Making wheels by hand: realizing a simple dependency injection (1)

Intro

In the previous article, I mainly introduced the overall design and general programming experience of dependency injection. This article will start to write code and start to implement its own dependency injection framework.

Class diagram

Let's review the class diagram mentioned last time.

Service life cycle

Service life cycle definition:

public enum ServiceLifetime : sbyte
{
    /// <summary>
    /// Specifies that a single instance of the service will be created.
    /// </summary>
    Singleton = 0,

    /// <summary>
    /// Specifies that a new instance of the service will be created for each scope.
    /// </summary>
    Scoped = 1,

    /// <summary>
    /// Specifies that a new instance of the service will be created every time it is requested.
    /// </summary>
    Transient = 2,
}

Service definition

Service registration definition:

public class ServiceDefinition
{
    // Service life cycle
    public ServiceLifetime ServiceLifetime { get; }
    // Implementation type
    public Type ImplementType { get; }
    // Service type
    public Type ServiceType { get; }
    // Implementation example
    public object ImplementationInstance { get; }
    // Factory implementation
    public Func<IServiceProvider, object> ImplementationFactory { get; }

    // Get the real implementation type
    public Type GetImplementType()
    {
        if (ImplementationInstance != null)
            return ImplementationInstance.GetType();

        if (ImplementationFactory != null)
            return ImplementationFactory.Method.DeclaringType;

        if (ImplementType != null)
            return ImplementType;

        return ServiceType;
    }

    public ServiceDefinition(object instance, Type serviceType)
    {
        ImplementationInstance = instance;
        ServiceType = serviceType;
        ServiceLifetime = ServiceLifetime.Singleton;
    }

    public ServiceDefinition(Type serviceType, ServiceLifetime serviceLifetime) : this(serviceType, serviceType, serviceLifetime)
    {
    }

    public ServiceDefinition(Type serviceType, Type implementType, ServiceLifetime serviceLifetime)
    {
        ServiceType = serviceType;
        ImplementType = implementType ?? serviceType;
        ServiceLifetime = serviceLifetime;
    }

    public ServiceDefinition(Type serviceType, Func<IServiceProvider, object> factory, ServiceLifetime serviceLifetime)
    {
        ServiceType = serviceType;
        ImplementationFactory = factory;
        ServiceLifetime = serviceLifetime;
    }
}

Some static methods are added for the convenience of use.

public static ServiceDefinition Singleton<TService>(Func<IServiceProvider, object> factory)
{
    return new ServiceDefinition(typeof(TService), factory, ServiceLifetime.Singleton);
}

public static ServiceDefinition Scoped<TService>(Func<IServiceProvider, object> factory)
{
    return new ServiceDefinition(typeof(TService), factory, ServiceLifetime.Scoped);
}

public static ServiceDefinition Transient<TService>(Func<IServiceProvider, object> factory)
{
    return new ServiceDefinition(typeof(TService), factory, ServiceLifetime.Transient);
}

public static ServiceDefinition Singleton<TService>()
{
    return new ServiceDefinition(typeof(TService), ServiceLifetime.Singleton);
}

public static ServiceDefinition Scoped<TService>()
{
    return new ServiceDefinition(typeof(TService), ServiceLifetime.Scoped);
}

public static ServiceDefinition Transient<TService>()
{
    return new ServiceDefinition(typeof(TService), ServiceLifetime.Transient);
}

public static ServiceDefinition Singleton<TService, TServiceImplement>() where TServiceImplement : TService
{
    return new ServiceDefinition(typeof(TService), typeof(TServiceImplement), ServiceLifetime.Singleton);
}

public static ServiceDefinition Scoped<TService, TServiceImplement>() where TServiceImplement : TService
{
    return new ServiceDefinition(typeof(TService), typeof(TServiceImplement), ServiceLifetime.Scoped);
}

public static ServiceDefinition Transient<TService, TServiceImplement>() where TServiceImplement : TService
{
    return new ServiceDefinition(typeof(TService), typeof(TServiceImplement), ServiceLifetime.Transient);
}

ServiceContainer

serviceContainer v1

public class ServiceContainer : IServiceContainer
{
    internal readonly List<ServiceDefinition> _services;

    private readonly ConcurrentDictionary<Type, object> _singletonInstances;

    private readonly ConcurrentDictionary<Type, object> _scopedInstances;
    
    private readonly List<object> _transientDisposables = new List<object>();

    private readonly bool _isRootScope;

    public ServiceContainer()
    {
        _isRootScope = true;
        _singletonInstances = new ConcurrentDictionary<Type, object>();
        _services = new List<ServiceDefinition>();
    }

    internal ServiceContainer(ServiceContainer serviceContainer)
    {
        _isRootScope = false;
        _singletonInstances = serviceContainer._singletonInstances;
        _services = serviceContainer._services;
        _scopedInstances = new ConcurrentDictionary<Type, object>();
    }

    public void Add(ServiceDefinition item)
    {
        _services.Add(item);
    }

    public IServiceContainer CreateScope()
    {
        return new ServiceContainer(this);
    }

    private bool _disposed;

    public void Dispose()
    {
        if (_disposed)
        {
            return;
        }

        if (_isRootScope)
        {
            lock (_singletonInstances)
            {
                if (_disposed)
                {
                    return;
                }

                _disposed = true;
                foreach (var instance in _singletonInstances.Values)
                {
                    (instance as IDisposable)?.Dispose();
                }

                foreach (var o in _transientDisposables)
                {
                    (o as IDisposable)?.Dispose();
                }
            }
        }
        else
        {
            lock (_scopedInstances)
            {
                if (_disposed)
                {
                    return;
                }

                _disposed = true;
                foreach (var instance in _scopedInstances.Values)
                {
                    (instance as IDisposable)?.Dispose();
                }

                foreach (var o in _transientDisposables)
                {
                    (o as IDisposable)?.Dispose();
                }
            }
        }
    }

    private object GetServiceInstance(Type serviceType, ServiceDefinition serviceDefinition)
    {
        if (serviceDefinition.ImplementationInstance != null)
            return serviceDefinition.ImplementationInstance;

        if (serviceDefinition.ImplementationFactory != null)
            return serviceDefinition.ImplementationFactory.Invoke(this);

        var implementType = (serviceDefinition.ImplementType ?? serviceType);

        if (implementType.IsInterface || implementType.IsAbstract)
        {
            throw new InvalidOperationException($"invalid service registered, serviceType: {serviceType.FullName}, implementType: {serviceDefinition.ImplementType}");
        }

        var ctorInfos = implementType.GetConstructors(BindingFlags.Instance | BindingFlags.Public);
        if (ctorInfos.Length == 0)
        {
            throw new InvalidOperationException($"service {serviceType.FullName} does not have any public constructors");
        }

        ConstructorInfo ctor;
        if (ctorInfos.Length == 1)
        {
            ctor = ctorInfos[0];
        }
        else
        {
            // try find best ctor
            ctor = ctorInfos
                .OrderBy(_ => _.GetParameters().Length)
                .First();
        }

        var parameters = ctor.GetParameters();
        if (parameters.Length == 0)
        {
            // TODO: cache New Func
            return Expression.Lambda<Func<object>>(Expression.New(ctor)).Compile().Invoke();
        }
        else
        {
            var ctorParams = new object[parameters.Length];
            for (var index = 0; index < parameters.Length; index++)
            {
                var parameter = parameters[index];
                var param = GetService(parameter.ParameterType);
                if (param == null && parameter.HasDefaultValue)
                {
                    param = parameter.DefaultValue;
                }

                ctorParams[index] = param;
            }
            return Expression.Lambda<Func<object>>(Expression.New(ctor, ctorParams.Select(Expression.Constant))).Compile().Invoke();
        }
    }

    public object GetService(Type serviceType)
    {
        var serviceDefinition = _services.LastOrDefault(_ => _.ServiceType == serviceType);
        if (null == serviceDefinition)
        {
            return null;
        }

        if (_isRootScope && serviceDefinition.ServiceLifetime == ServiceLifetime.Scoped)
        {
            throw new InvalidOperationException($"can not get scope service from the root scope, serviceType: {serviceType.FullName}");
        }

        if (serviceDefinition.ServiceLifetime == ServiceLifetime.Singleton)
        {
            var svc = _singletonInstances.GetOrAdd(serviceType, (t) => GetServiceInstance(t, serviceDefinition));
            return svc;
        }
        else if (serviceDefinition.ServiceLifetime == ServiceLifetime.Scoped)
        {
            var svc = _scopedInstances.GetOrAdd(serviceType, (t) => GetServiceInstance(t, serviceDefinition));
            return svc;
        }
        else
        {
            var svc = GetServiceInstance(serviceType, serviceDefinition);
            if (svc is IDisposable)
            {
                _transientDisposables.Add(svc);
            }
            return svc;
        }
    }
}

To make service registration more convenient, you can write some extension methods to facilitate registration:

public static IServiceContainer AddSingleton<TService>([NotNull]this IServiceContainer serviceContainer, [NotNull]TService service)
{
    serviceContainer.Add(new ServiceDefinition(service, typeof(TService)));
    return serviceContainer;
}

public static IServiceContainer AddSingleton([NotNull]this IServiceContainer serviceContainer, [NotNull]Type serviceType)
{
    serviceContainer.Add(new ServiceDefinition(serviceType, ServiceLifetime.Singleton));
    return serviceContainer;
}

public static IServiceContainer AddSingleton([NotNull]this IServiceContainer serviceContainer, [NotNull]Type serviceType, [NotNull]Type implementType)
{
    serviceContainer.Add(new ServiceDefinition(serviceType, implementType, ServiceLifetime.Singleton));
    return serviceContainer;
}

public static IServiceContainer AddSingleton<TService>([NotNull]this IServiceContainer serviceContainer, [NotNull]Func<IServiceProvider, object> func)
{
    serviceContainer.Add(ServiceDefinition.Singleton<TService>(func));
    return serviceContainer;
}


public static IServiceContainer AddSingleton<TService>([NotNull]this IServiceContainer serviceContainer)
{
    serviceContainer.Add(ServiceDefinition.Singleton<TService>());
    return serviceContainer;
}


public static IServiceContainer AddSingleton<TService, TServiceImplement>([NotNull]this IServiceContainer serviceContainer) where TServiceImplement : TService
{
    serviceContainer.Add(ServiceDefinition.Singleton<TService, TServiceImplement>());
    return serviceContainer;
}

public static IServiceContainer AddScoped([NotNull]this IServiceContainer serviceContainer, [NotNull]Type serviceType)
{
    serviceContainer.Add(new ServiceDefinition(serviceType, ServiceLifetime.Scoped));
    return serviceContainer;
}

public static IServiceContainer AddScoped([NotNull]this IServiceContainer serviceContainer, [NotNull]Type serviceType, [NotNull]Type implementType)
{
    serviceContainer.Add(new ServiceDefinition(serviceType, implementType, ServiceLifetime.Scoped));
    return serviceContainer;
}

public static IServiceContainer AddScoped<TService>([NotNull]this IServiceContainer serviceContainer, [NotNull]Func<IServiceProvider, object> func)
{
    serviceContainer.Add(ServiceDefinition.Scoped<TService>(func));
    return serviceContainer;
}


public static IServiceContainer AddScoped<TService>([NotNull]this IServiceContainer serviceContainer)
{
    serviceContainer.Add(ServiceDefinition.Scoped<TService>());
    return serviceContainer;
}


public static IServiceContainer AddScoped<TService, TServiceImplement>([NotNull]this IServiceContainer serviceContainer) where TServiceImplement : TService
{
    serviceContainer.Add(ServiceDefinition.Scoped<TService, TServiceImplement>());
    return serviceContainer;
}

public static IServiceContainer AddTransient([NotNull]this IServiceContainer serviceContainer, [NotNull]Type serviceType)
{
    serviceContainer.Add(new ServiceDefinition(serviceType, ServiceLifetime.Transient));
    return serviceContainer;
}

public static IServiceContainer AddTransient([NotNull]this IServiceContainer serviceContainer, [NotNull]Type serviceType, [NotNull]Type implementType)
{
    serviceContainer.Add(new ServiceDefinition(serviceType, implementType, ServiceLifetime.Transient));
    return serviceContainer;
}

public static IServiceContainer AddTransient<TService>([NotNull]this IServiceContainer serviceContainer, [NotNull]Func<IServiceProvider, object> func)
{
    serviceContainer.Add(ServiceDefinition.Transient<TService>(func));
    return serviceContainer;
}


public static IServiceContainer AddTransient<TService>([NotNull]this IServiceContainer serviceContainer)
{
    serviceContainer.Add(ServiceDefinition.Transient<TService>());
    return serviceContainer;
}


public static IServiceContainer AddTransient<TService, TServiceImplement>([NotNull]this IServiceContainer serviceContainer) where TServiceImplement : TService
{
    serviceContainer.Add(ServiceDefinition.Transient<TService, TServiceImplement>());
    return serviceContainer;
}

Basic dependency injection can be realized through the above code, but in terms of function, the above code only supports to obtain the Instance of a single service, does not support to register multiple implementations of an interface, and obtain all the implementations of the interface. Therefore, the key of the concurrenctdictionary of the Instance in the ServiceContainer can be modified so that the interface type and the Implement type union as key, so there is the second version of ServiceContainer

ServiceContainer v2

A ServiceKey type is defined for this purpose. Please note that the GetHashCode method must be overridden here:

private class ServiceKey : IEquatable<ServiceKey>
{
    public Type ServiceType { get; }

    public Type ImplementType { get; }

    public ServiceKey(Type serviceType, ServiceDefinition definition)
    {
        ServiceType = serviceType;
        ImplementType = definition.GetImplementType();
    }

    public bool Equals(ServiceKey other)
    {
        return ServiceType == other?.ServiceType && ImplementType == other?.ImplementType;
    }

    public override bool Equals(object obj)
    {
        return Equals((ServiceKey)obj);
    }

    public override int GetHashCode()
    {
        var key = $"{ServiceType.FullName}_{ImplementType.FullName}";
        return key.GetHashCode();
    }
}

The second version of ServiceContainer:

public class ServiceContainer : IServiceContainer
{
    internal readonly ConcurrentBag<ServiceDefinition> _services;

    private readonly ConcurrentDictionary<ServiceKey, object> _singletonInstances;

    private readonly ConcurrentDictionary<ServiceKey, object> _scopedInstances;
    private ConcurrentBag<object> _transientDisposables = new ConcurrentBag<object>();

    private class ServiceKey : IEquatable<ServiceKey>
    {
        public Type ServiceType { get; }

        public Type ImplementType { get; }

        public ServiceKey(Type serviceType, ServiceDefinition definition)
        {
            ServiceType = serviceType;
            ImplementType = definition.GetImplementType();
        }

        public bool Equals(ServiceKey other)
        {
            return ServiceType == other?.ServiceType && ImplementType == other?.ImplementType;
        }

        public override bool Equals(object obj)
        {
            return Equals((ServiceKey)obj);
        }

        public override int GetHashCode()
        {
            var key = $"{ServiceType.FullName}_{ImplementType.FullName}";
            return key.GetHashCode();
        }
    }

    private readonly bool _isRootScope;

    public ServiceContainer()
    {
        _isRootScope = true;
        _singletonInstances = new ConcurrentDictionary<ServiceKey, object>();
        _services = new ConcurrentBag<ServiceDefinition>();
    }

    private ServiceContainer(ServiceContainer serviceContainer)
    {
        _isRootScope = false;
        _singletonInstances = serviceContainer._singletonInstances;
        _services = serviceContainer._services;
        _scopedInstances = new ConcurrentDictionary<ServiceKey, object>();
    }

    public IServiceContainer Add(ServiceDefinition item)
    {
        if (_disposed)
        {
            throw new InvalidOperationException("the service container had been disposed");
        }
        if (_services.Any(_ => _.ServiceType == item.ServiceType && _.GetImplementType() == item.GetImplementType()))
        {
            return this;
        }

        _services.Add(item);
        return this;
    }

    public IServiceContainer TryAdd(ServiceDefinition item)
    {
        if (_disposed)
        {
            throw new InvalidOperationException("the service container had been disposed");
        }
        if (_services.Any(_ => _.ServiceType == item.ServiceType))
        {
            return this;
        }
        _services.Add(item);
        return this;
    }

    public IServiceContainer CreateScope()
    {
        return new ServiceContainer(this);
    }

    private bool _disposed;

    public void Dispose()
    {
        if (_disposed)
        {
            return;
        }

        if (_isRootScope)
        {
            lock (_singletonInstances)
            {
                if (_disposed)
                {
                    return;
                }

                _disposed = true;
                foreach (var instance in _singletonInstances.Values)
                {
                    (instance as IDisposable)?.Dispose();
                }

                foreach (var o in _transientDisposables)
                {
                    (o as IDisposable)?.Dispose();
                }

                _singletonInstances.Clear();
                _transientDisposables = null;
            }
        }
        else
        {
            lock (_scopedInstances)
            {
                if (_disposed)
                {
                    return;
                }

                _disposed = true;
                foreach (var instance in _scopedInstances.Values)
                {
                    (instance as IDisposable)?.Dispose();
                }

                foreach (var o in _transientDisposables)
                {
                    (o as IDisposable)?.Dispose();
                }

                _scopedInstances.Clear();
                _transientDisposables = null;
            }
        }
    }

    private object GetServiceInstance(Type serviceType, ServiceDefinition serviceDefinition)
    {
        if (serviceDefinition.ImplementationInstance != null)
            return serviceDefinition.ImplementationInstance;

        if (serviceDefinition.ImplementationFactory != null)
            return serviceDefinition.ImplementationFactory.Invoke(this);

        var implementType = (serviceDefinition.ImplementType ?? serviceType);

        if (implementType.IsInterface || implementType.IsAbstract)
        {
            throw new InvalidOperationException($"invalid service registered, serviceType: {serviceType.FullName}, implementType: {serviceDefinition.ImplementType}");
        }

        if (implementType.IsGenericType)
        {
            implementType = implementType.MakeGenericType(serviceType.GetGenericArguments());
        }

        var ctorInfos = implementType.GetConstructors(BindingFlags.Instance | BindingFlags.Public);
        if (ctorInfos.Length == 0)
        {
            throw new InvalidOperationException($"service {serviceType.FullName} does not have any public constructors");
        }

        ConstructorInfo ctor;
        if (ctorInfos.Length == 1)
        {
            ctor = ctorInfos[0];
        }
        else
        {
            // TODO: try find best ctor
            ctor = ctorInfos
                .OrderBy(_ => _.GetParameters().Length)
                .First();
        }

        var parameters = ctor.GetParameters();
        if (parameters.Length == 0)
        {
            // TODO: cache New Func
            return Expression.Lambda<Func<object>>(Expression.New(ctor)).Compile().Invoke();
        }
        else
        {
            var ctorParams = new object[parameters.Length];
            for (var index = 0; index < parameters.Length; index++)
            {
                var parameter = parameters[index];
                var param = GetService(parameter.ParameterType);
                if (param == null && parameter.HasDefaultValue)
                {
                    param = parameter.DefaultValue;
                }

                ctorParams[index] = param;
            }
            return Expression.Lambda<Func<object>>(Expression.New(ctor, ctorParams.Select(Expression.Constant))).Compile().Invoke();
        }
    }

    public object GetService(Type serviceType)
    {
        if (_disposed)
        {
            throw new InvalidOperationException($"can not get scope service from a disposed scope, serviceType: {serviceType.FullName}");
        }

        var serviceDefinition = _services.LastOrDefault(_ => _.ServiceType == serviceType);
        if (null == serviceDefinition)
        {
            if (serviceType.IsGenericType)
            {
                var genericType = serviceType.GetGenericTypeDefinition();
                serviceDefinition = _services.LastOrDefault(_ => _.ServiceType == genericType);
                if (null == serviceDefinition)
                {
                    var innerServiceType = serviceType.GetGenericArguments().First();
                    if (typeof(IEnumerable<>).MakeGenericType(innerServiceType)
                        .IsAssignableFrom(serviceType))
                    {
                        var innerRegType = innerServiceType;
                        if (innerServiceType.IsGenericType)
                        {
                            innerRegType = innerServiceType.GetGenericTypeDefinition();
                        }
                        //
                        var list = new List<object>(4);
                        foreach (var def in _services.Where(_ => _.ServiceType == innerRegType))
                        {
                            object svc;
                            if (def.ServiceLifetime == ServiceLifetime.Singleton)
                            {
                                svc = _singletonInstances.GetOrAdd(new ServiceKey(innerServiceType, def), (t) => GetServiceInstance(innerServiceType, def));
                            }
                            else if (def.ServiceLifetime == ServiceLifetime.Scoped)
                            {
                                svc = _scopedInstances.GetOrAdd(new ServiceKey(innerServiceType, def), (t) => GetServiceInstance(innerServiceType, def));
                            }
                            else
                            {
                                svc = GetServiceInstance(innerServiceType, def);
                                if (svc is IDisposable)
                                {
                                    _transientDisposables.Add(svc);
                                }
                            }
                            if (null != svc)
                            {
                                list.Add(svc);
                            }
                        }

                        var methodInfo = typeof(Enumerable)
                            .GetMethod("Cast", BindingFlags.Static | BindingFlags.Public);
                        if (methodInfo != null)
                        {
                            var genericMethod = methodInfo.MakeGenericMethod(innerServiceType);
                            var castedValue = genericMethod.Invoke(null, new object[] { list });
                            if (typeof(IEnumerable<>).MakeGenericType(innerServiceType) == serviceType)
                            {
                                return castedValue;
                            }
                            var toArrayMethod = typeof(Enumerable).GetMethod("ToArray", BindingFlags.Static | BindingFlags.Public)
                                .MakeGenericMethod(innerServiceType);

                            return toArrayMethod.Invoke(null, new object[] { castedValue });
                        }
                        return list;
                    }

                    return null;
                }
            }
            else
            {
                return null;
            }
        }

        if (_isRootScope && serviceDefinition.ServiceLifetime == ServiceLifetime.Scoped)
        {
            throw new InvalidOperationException($"can not get scope service from the root scope, serviceType: {serviceType.FullName}");
        }

        if (serviceDefinition.ServiceLifetime == ServiceLifetime.Singleton)
        {
            var svc = _singletonInstances.GetOrAdd(new ServiceKey(serviceType, serviceDefinition), (t) => GetServiceInstance(t.ServiceType, serviceDefinition));
            return svc;
        }
        else if (serviceDefinition.ServiceLifetime == ServiceLifetime.Scoped)
        {
            var svc = _scopedInstances.GetOrAdd(new ServiceKey(serviceType, serviceDefinition), (t) => GetServiceInstance(t.ServiceType, serviceDefinition));
            return svc;
        }
        else
        {
            var svc = GetServiceInstance(serviceType, serviceDefinition);
            if (svc is IDisposable)
            {
                _transientDisposables.Add(svc);
            }
            return svc;
        }
    }
}

In this way, we not only support the registration of IEnumerable < tservice > but also the registration of ireadonlylist < tservice > / ireadonlycollection < tservice >.

Because GetService return is an object, not a strong type, in order to use it conveniently, several extension methods are defined, similar to GetService < tservice > () / getservices < tservice > () / getrequiredservice < tservice > () in Microsoft's dependency injection framework.

/// <summary>
/// ResolveService
/// </summary>
/// <typeparam name="TService">TService</typeparam>
/// <param name="serviceProvider">serviceProvider</param>
/// <returns></returns>
public static TService ResolveService<TService>([NotNull]this IServiceProvider serviceProvider)
    => (TService)serviceProvider.GetService(typeof(TService));

/// <summary>
/// ResolveRequiredService
/// throw exception if can not get a service instance
/// </summary>
/// <typeparam name="TService">TService</typeparam>
/// <param name="serviceProvider">serviceProvider</param>
/// <returns></returns>
public static TService ResolveRequiredService<TService>([NotNull] this IServiceProvider serviceProvider)
{
    var serviceType = typeof(TService);
    var svc = serviceProvider.GetService(serviceType);
    if (null == svc)
    {
        throw new InvalidOperationException($"service had not been registered, serviceType: {serviceType}");
    }
    return (TService)svc;
}

/// <summary>
/// Resolve services
/// </summary>
/// <typeparam name="TService">TService</typeparam>
/// <param name="serviceProvider">serviceProvider</param>
/// <returns></returns>
public static IEnumerable<TService> ResolveServices<TService>([NotNull]this IServiceProvider serviceProvider)
    => serviceProvider.ResolveService<IEnumerable<TService>>();

More

There is also a new version later. It mainly optimizes performance. I'm not very satisfied at present. I won't mention it here for the moment.

Reference

Keywords: C# Lambda Programming github

Added by kunalk on Tue, 29 Oct 2019 19:52:02 +0200