LINQ Path: LINQ Extension

Original Link: http://www.cnblogs.com/jellochen/p/the-extension-of-linq.html

This article describes LINQ extensions in three ways: extended query operators, custom query operators, and simple simulation of LINQ to SQL.

1. Extended Query Operators

In practice, the extension methods in Enumerable or Queryable sometimes do not meet our needs. We need to extend some query operators to meet our needs.For example, the following example:

var r = Enumerable.Range(1, 10).Zip(Enumerable.Range(11, 5), (s, d) => s + d);
foreach (var i in r)
{
    Console.WriteLine(i);
}
//output:
//12
//14
//16
//18
//20

The Enumerable.ZIP extension is used to apply a specified function to the corresponding elements of two sequences to produce a result sequence, where a new sequence is generated by adding the elements corresponding to the position of the sequence [1...10] and the sequence [11...15].Internal implementation is as follows:

static IEnumerable<TResult> ZipIterator<TFirst, TSecond, TResult>(IEnumerable<TFirst> first, IEnumerable<TSecond> second, Func<TFirst, TSecond, TResult> resultSelector) {
        using (IEnumerator<TFirst> e1 = first.GetEnumerator())
            using (IEnumerator<TSecond> e2 = second.GetEnumerator())
                while (e1.MoveNext() && e2.MoveNext())
                    yield return resultSelector(e1.Current, e2.Current);
    }

Obviously, acquisition is intersection, that is, both sequences need elements at their corresponding locations to be processed, so the output above is natural.Sometimes, however, we want to focus on the first sequence, that is, the length of the resulting sequence is always equal to the length of the first sequence.Let's extend a query operator named LeftZip to implement the following:

    /// <summary>
    /// Merge right sequence into left sequence by using the specified predicate function.
    /// </summary>
    /// <typeparam name="TLeft"></typeparam>
    /// <typeparam name="TRight"></typeparam>
    /// <typeparam name="TResult"></typeparam>
    /// <param name="lefts"></param>
    /// <param name="rights"></param>
    /// <param name="resultSelector"></param>
    /// <returns></returns>
    public static IEnumerable<TResult> LeftZip<TLeft, TRight, TResult>(this IEnumerable<TLeft> lefts,
        IEnumerable<TRight> rights, Func<TLeft, TRight, TResult> resultSelector)
    {
        if(lefts == null)
            throw new ArgumentNullException("lefts");
        if(rights == null)
            throw new ArgumentNullException("rights");
        if (resultSelector == null)
            throw new ArgumentNullException("resultSelector");
        return LeftZipImpl(lefts, rights, resultSelector);
    }
    /// <summary>
    /// The Implementation of LeftZip
    /// </summary>
    /// <typeparam name="TLeft"></typeparam>
    /// <typeparam name="TRight"></typeparam>
    /// <typeparam name="TResult"></typeparam>
    /// <param name="lefts"></param>
    /// <param name="rights"></param>
    /// <param name="resultSelector"></param>
    /// <returns></returns>
    private static IEnumerable<TResult> LeftZipImpl<TLeft, TRight, TResult>(this IEnumerable<TLeft> lefts,
        IEnumerable<TRight> rights, Func<TLeft, TRight, TResult> resultSelector)
    {
        using (var left = lefts.GetEnumerator())
        {
            using (var right = rights.GetEnumerator())
            {
                while (left.MoveNext())
                {
                    if (right.MoveNext())
                    {
                        yield return resultSelector(left.Current, right.Current);
                    }
                    else
                    {
                        do
                        {
                            yield return resultSelector(left.Current, default(TRight));
                        } while (left.MoveNext());
                        yield break;
                    }
                }
            }
        }
    }

Call LeftZip with the following code:

var r = Enumerable.Range(1, 10).LeftZip(Enumerable.Range(11, 5), (s, d) => s + d);
foreach (var i in r)
{
    Console.WriteLine(i);
}
//output:
//12
//14
//16
//18
//20
//6
//7
//8
//9
//10

2. Custom Query Operators

Previously, when implementing enumerators, we had a form of self-implementation that did not inherit the IEnumerable and IEnumerator interfaces, customized a class that implements GetEnumerator(), and a class that implements Current and MoveNext, and used foreach to iterate.We also know that LINQ statements are converted to chain calls to extension methods, and standard query operators are converted to extension methods with the same name (capital letters).So, if we implement the same name extension method for standard query operators ourselves, will it be executed?
Start by creating a static class LinqExtensions that implements the Where extension method as follows:

    /// <summary>
    /// Filters a sequence of values based on a predicate.
    /// </summary>
    /// <typeparam name="TResult"></typeparam>
    /// <param name="source"></param>
    /// <param name="predicate"></param>
    /// <returns></returns>
    public static IEnumerable<TResult> Where<TResult>(this IEnumerable<TResult> source,
        Func<TResult, bool> predicate)
    {
        if (source == null)
            throw new ArgumentNullException("source");
        if (predicate == null)
            throw new ArgumentNullException("predicate");
        return WhereImpl(source, predicate);
    }
    /// <summary>
    /// The implementation of Where
    /// </summary>
    /// <typeparam name="TResult"></typeparam>
    /// <param name="source"></param>
    /// <param name="predicate"></param>
    /// <returns></returns>
    private static IEnumerable<TResult> WhereImpl<TResult>(this IEnumerable<TResult> source,
        Func<TResult, bool> predicate)
    {
        using (var e = source.GetEnumerator())
        {
            while (e.MoveNext())
            {
                if (predicate(e.Current))
                    yield return e.Current;
            }
        }
    }

The calling code is as follows:

var r = from e in Enumerable.Range(1, 10)
            where e%2 == 0
            select e;
        foreach (var i in r)
        {
            Console.WriteLine(i);
        }
//output:
//2
//4
//6
//8
//10

How do I know where the extension method was called?Debugging, selecting Where in VS, then F12 going to Definition, and printing out in Where Extension Method Implementation all make sense.You might have questions?Do method signatures need to be consistent?The answer is No.You can rename the Where extension method Select and change the call to the following:

var r = from e in Enumerable.Range(1, 10)
        //where e % 2 == 0
        select e % 2 == 0;
//output:
//2
//4
//6
//8
//10

Finally, with an enumerator, an example is given:

public class Collection<T>
{
    private T[] items;

    public Collection()
    {

    }

    public Collection(IEnumerable<T> collection)
    {
        if (collection == null)
            throw new ArgumentNullException("collection");
        items = new T[collection.Count()];
        Array.Copy(collection.ToArray(), items, collection.Count());
    }

    public static implicit operator Collection<T>(T[] arr)
    {
        Collection<T> collection = new Collection<T>();
        collection.items = new T[arr.Length];
        Array.Copy(arr, collection.items, arr.Length);
        return collection;
    }

    public ItemEnumerator GetEnumerator()
    {
        return new ItemEnumerator(items);
    }

    #region Item Enumerator
    public class ItemEnumerator : IDisposable
    {
        private T[] items;
        private int index = -1;

        public ItemEnumerator(T[] arr)
        {
            this.items = arr;
        }
        /// <summary>
        /// Current property
        /// </summary>
        public T Current
        {
            get
            {
                if (index < 0 || index > items.Length - 1)
                    throw new InvalidOperationException();
                return items[index];
            }
        }
        /// <summary>
        /// MoveNext method
        /// </summary>
        /// <returns></returns>
        public bool MoveNext()
        {
            if (index < items.Length - 1)
            {
                index++;
                return true;
            }
            else
            {
                return false;
            }
        }

        public void Reset()
        {
            index = -1;
        }
        #region IDisposable member

        public void Dispose()
        {
            index = -1;
        }

        #endregion
    }
    #endregion
}

public static class EnumerableExtensions
{
    public static Collection<T> Where<T>(this Collection<T> source, Func<T, bool> predicate)
    {
        if (source == null)
            throw new ArgumentNullException("source");
        if (predicate == null)
            throw new ArgumentNullException("predicate");
        return WhereImpl(source, predicate).ToCollection();
    }

    private static IEnumerable<T> WhereImpl<T>(this Collection<T> source, Func<T, bool> predicate)
    {
        using (var e = source.GetEnumerator())
        {
            while (e.MoveNext())
            {
                if (predicate(e.Current))
                {
                    yield return e.Current;
                }
            }
        }
    }

    public static Collection<TResult> Select<T, TResult>(this Collection<T> source, Func<T, TResult> selector)
    {
        if (source == null)
            throw new ArgumentNullException("source");
        if (selector == null)
            throw new ArgumentNullException("selector");
        return SelectImpl(source, selector).ToCollection();
    }

    private static IEnumerable<TResult> SelectImpl<T, TResult>(this Collection<T> source, Func<T, TResult> selector)
    {
        using (var e = source.GetEnumerator())
        {
            while (e.MoveNext())
            {
                yield return selector(e.Current);
            }
        }
    }

    public static Collection<T> ToCollection<T>(this IEnumerable<T> source)
    {
        if (source == null)
            throw new ArgumentNullException("source");
        return new Collection<T>(source);
    }
}

There are two classes, one for data source and one for extension, called as follows:

Collection<int> collection = new int[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };
        var r = from c in collection
                where c % 2 == 0
                select c;
        foreach (var i in r)
        {
            Console.WriteLine(i);
        }

//output:
//2
//4
//6
//8
//10

3. Simple simulation LINQ to SQL

In the course of LINQ (2), we briefly introduce the principle of LINQ to SQL.Here, we learn more about LINQ to SQL principles by simply simulating LINQ to SQL.

First create the data source, create the Query class, and implement the IQueryable interface:

public class Query<T> : IQueryable<T>
{
    #region field

    private QueryProvider provider;
    private Expression expression;

    #endregion

    #region attribute

    #endregion

    #region constructor

    public Query(QueryProvider provider)
    {
        if (provider == null)
            throw new ArgumentNullException("provider");
        this.provider = provider;
        this.expression = Expression.Constant(this);
    }

    public Query(QueryProvider provider, Expression expression)
    {
        if (provider == null)
            throw new ArgumentNullException("provider");
        if (expression == null)
            throw new ArgumentNullException("expression");
        if (!typeof(IQueryable<T>).IsAssignableFrom(expression.Type))
            throw new ArgumentOutOfRangeException("expression");
        this.provider = provider;
        this.expression = expression;
    }
    #endregion

    #region method

    public IEnumerator<T> GetEnumerator()
    {
        return ((IEnumerable<T>) this.provider.Execute(this.expression)).GetEnumerator();
    }

    #endregion

    #Region IEnumerable <T>member

    IEnumerator<T> IEnumerable<T>.GetEnumerator()
    {
        return this.GetEnumerator();
    }

    #endregion

    #region IEnumerable member

    System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
    {
        return this.GetEnumerator();
    }

    #endregion

    #region IQueryable Member

    Type IQueryable.ElementType
    {
        get { return typeof(T); }
    }

    Expression IQueryable.Expression
    {
        get { return this.expression; }
    }

    IQueryProvider IQueryable.Provider
    {
        get { return this.provider; }
    }

    #endregion
}

This class is relatively simple, implements interfaces and initializes parameters.
Look at Provider again, create the QueryProvider class, and implement the IQueryProvider interface:

public class QueryProvider:IQueryProvider
{
    #region field

    private IDbConnection dbConnection;

    #endregion

    #region attribute

    #endregion

    #region constructor
    public QueryProvider(IDbConnection dbConnection)
    {
        this.dbConnection = dbConnection;
    }
    #endregion

    #region method

    #endregion

    #region IQueryProvider Member

    public IQueryable<TElement> CreateQuery<TElement>(System.Linq.Expressions.Expression expression)
    {
        return new Query<TElement>(this, expression);
    }

    public IQueryable CreateQuery(System.Linq.Expressions.Expression expression)
    {
        var type = expression.Type;
        try
        {
            return (IQueryable) Activator.CreateInstance(typeof (Query<>).MakeGenericType(type), this, expression);
        }
        catch (TargetInvocationException e)
        {
            throw e.InnerException;
        }
    }

    TResult IQueryProvider.Execute<TResult>(System.Linq.Expressions.Expression expression)
    {
        return (TResult) this.Execute(expression);
    }

    object IQueryProvider.Execute(System.Linq.Expressions.Expression expression)
    {
        return this.Execute(expression);
    }

    public virtual object Execute(Expression expression)
    {
        if(expression == null)
            throw new ArgumentNullException("expression");
        return ExecuteImpl(expression);
    }

    private IEnumerable ExecuteImpl(Expression expression)
    {
        //var type = expression.Type;
        //var entityType = type.GetGenericArguments()[0];
        List<Product> products = new List<Product>();
        QueryTranslator queryTranslator = new QueryTranslator();
        var cmdText = queryTranslator.Translate(expression);
        IDbCommand cmd = dbConnection.CreateCommand();
        cmd.CommandText = cmdText;
        using (IDataReader dataReader = cmd.ExecuteReader())
        {
            while (dataReader.Read())
            {
                Product product = new Product();
                product.ID = dataReader.GetInt32(0);
                product.Name = dataReader.GetString(1);
                product.Type = dataReader.GetInt32(2);
                products.Add(product);
            }
        }
        return products;
    }
    #endregion
}

Let's look at the query translation class again, creating the QueryTranslator class, which inherits from the ExpressionVisitor Abstract class:

public class QueryTranslator:ExpressionVisitor
{
    #region field

    private StringBuilder sb;

    #endregion

    #region attribute

    #endregion

    #region constructor
    public QueryTranslator()
    {

    }

    #endregion

    #region method

    public string Translate(Expression expression)
    {
        this.sb = new StringBuilder();
        this.Visit(expression);
        return this.sb.ToString();
    }

    private static Expression StripQuotes(Expression e)
    {
        while (e.NodeType == ExpressionType.Quote)
        {
            e = ((UnaryExpression) e).Operand;
        }
        return e;
    }

    protected override Expression VisitMethodCall(MethodCallExpression node)
    {
        if (node.Method.DeclaringType == typeof (Queryable) &&
            node.Method.Name == "Where")
        {
            sb.Append("SELECT * FROM (");
            this.Visit(node.Arguments[0]);
            sb.Append(") AS T WHERE ");
            LambdaExpression lambda = (LambdaExpression) StripQuotes(node.Arguments[1]);
            this.Visit(lambda.Body);
            return node;
        }
        throw new NotSupportedException(string.Format("The Method '{0}' is not supported", node.Method.Name));
    }

    protected override Expression VisitBinary(BinaryExpression node)
    {
        sb.Append("(");
        this.Visit(node.Left);
        switch (node.NodeType)
        {
            case ExpressionType.Equal:
                sb.Append(" = ");
                break;
            case ExpressionType.NotEqual:
                sb.Append(" <> ");
                break;
            case ExpressionType.GreaterThan:
                sb.Append(" > ");
                break;
            case ExpressionType.GreaterThanOrEqual:
                sb.Append(" >= ");
                break;
            case ExpressionType.LessThan:
                sb.Append(" < ");
                break;
            case ExpressionType.LessThanOrEqual:
                sb.Append(" <= ");
                break;
            default:
                throw new NotSupportedException(string.Format("The binary operator '{0}' is not supported", node.NodeType));
        }
        this.Visit(node.Right);
        sb.Append(")");
        return node;
    }

    protected override Expression VisitConstant(ConstantExpression node)
    {
        IQueryable q = node.Value as IQueryable;
        if (q != null)
        {
            sb.Append("SELECT * FROM ");
            sb.Append(DataContext.MetaTables.FirstOrDefault(f => f.Type == q.ElementType).TableName);
            return node;
        }
        else if(node.Value == null)
        {
            sb.Append("NULL");
        }
        else
        {
            switch (Type.GetTypeCode(node.Value.GetType()))
            {
                case TypeCode.Boolean:
                    sb.Append(((bool) node.Value) ? 1 : 0);
                    break;
                case TypeCode.String:
                    sb.AppendFormat("'{0}'", node.Value);
                    break;
                case TypeCode.Object:
                    throw new NotSupportedException(string.Format("The constant for '{0}' is not supported", node.Value));
                default:
                    sb.Append(node.Value);
                    break;
            }
        }
        return node;
    }

    protected override Expression VisitMember(MemberExpression node)
    {
        if (node.Expression != null && node.Expression.NodeType == ExpressionType.Parameter)
        {
            sb.Append(node.Member.Name);
            return node;
        }
        throw new NotSupportedException(string.Format("The member '{0}' is not supported", node.Member.Name));
    }

    #endregion
}

Rewrite Visit related methods to parse the expression catalog tree in Visitor mode.
Finally, let's look at the implementation of the DataContext:

public class DataContext : IDisposable
{
    #region field

    private IDbConnection dbConnection;
    private static List<MetaTable> metaTables; 
    #endregion

    #region attribute

    public TextWriter Log { get; set; }

    public IDbConnection DbConnection
    {
        get { return this.dbConnection; }
    }

    public static List<MetaTable> MetaTables
    {
        get { return metaTables; }
    }
    #endregion

    #region constructor
    public DataContext(string connString)
    {
        if (connString == null)
            throw new ArgumentNullException(connString);
        dbConnection = new SqlConnection(connString);
        dbConnection.Open();
        InitTables();
    }
    #endregion

    #region method

    private void InitTables()
    {
        metaTables = new List<MetaTable>();
        var props = this.GetType().GetProperties(BindingFlags.Public | BindingFlags.Instance);
        foreach (var prop in props)
        {
            var propType = prop.PropertyType;
            if (propType.IsGenericType && propType.GetGenericTypeDefinition() == typeof (Query<>))
            {
                var entityType = propType.GetGenericArguments()[0];
                var entityAttr = entityType.GetCustomAttribute<MappingAttribute>(true);
                if (entityAttr != null)
                {
                    var metaTable = new MetaTable();
                    metaTable.Type = entityType;
                    metaTable.TableName = entityAttr.Name;
                    metaTable.MappingAttribute = entityAttr;
                    var columnProps = entityType.GetProperties(BindingFlags.Public | BindingFlags.Instance);
                    foreach (var columnProp in columnProps)
                    {
                        var columnPropAttr = columnProp.GetCustomAttribute<MappingAttribute>(true);
                        if (columnPropAttr != null)
                        {
                            MetaColumn metaColumn = new MetaColumn();
                            metaColumn.MappingAttribute = columnPropAttr;
                            metaColumn.ColumnName = columnPropAttr.Name;
                            metaColumn.PropertyInfo = columnProp;
                            metaTable.MetaColumns.Add(metaColumn);
                        }
                    }
                    metaTables.Add(metaTable);
                }
            }
        }
    }
    #endregion

    #region IDisposable member

    protected virtual void Dispose(bool disposing)
    {
        if (!disposing) return;
        if (dbConnection != null)
            dbConnection.Close();
    }

    public void Dispose()
    {
        Dispose(true);
        GC.SuppressFinalize(this);
    }

    #endregion
}
[Database(Name = "IT_Company")]
public class QueryDataContext : DataContext
{
    public QueryDataContext(string connString)
        : base(connString)
    {
        QueryProvider provider = new QueryProvider(DbConnection);
        Products = new Query<Product>(provider);
    }
    public Query<Product> Products
    {
        get;
        set;
    }
}

The call is as follows:

class Program
{
    private static readonly string connString =
        "Data Source=.;Initial Catalog=IT_Company;Persist Security Info=True;User ID=sa;Password=123456";
    static void Main(string[] args)
    {
        using (var context = new QueryDataContext(connString))
        {
            var query = from product in context.Products
                where product.Type == 1
                select product;
            foreach (var product in query)
            {
                Console.WriteLine(product.Name);
            }
            Console.ReadKey();
        }
    }
}
//output:
//MG500
//MG1000

Reprinted at: https://www.cnblogs.com/jellochen/p/the-extension-of-linq.html

Keywords: SQL Attribute Lambda Database

Added by iamthebugman on Tue, 23 Jul 2019 19:28:42 +0300