ASP.NET Core搭建多层网站架构【4.1-工作单元和仓储设计】

2020/01/28, ASP.NET Core 3.1, VS2019, UnitOfWork, Repository

摘要:基于ASP.NET Core 3.1 WebApi搭建后端多层网站架构【4.1-工作单元和仓储设计】
使用泛型仓储和工作单元模式封装数据访问层基础的增删改查等方法

文章目录
此分支项目代码
关于本章节的工作单元模式:
泛型仓储封装了通用的增删改查方法,由工作单元统一管理仓储以保证数据库上下文一致性。
要获取仓储,都从工作单元中获取,通过仓储改动数据库后,由工作单元进行提交。
代码参考Arch的UnitOfWork设计,大部分都是参考他的,然后做了一些中文注释,去除了分布式多库支持

添加包引用

MS.UnitOfWork项目添加对Microsoft.EntityFrameworkCore.Relational包的引用:

<ItemGroup>
  <PackageReference Include="Microsoft.EntityFrameworkCore.Relational" Version="3.1.1" />
</ItemGroup>

分页处理封装

MS.UnitOfWork项目中添加Collections文件夹,在该文件夹下添加IPagedList.csPagedList.csIEnumerablePagedListExtensions.csIQueryablePageListExtensions.cs类。

IPagedList.cs

using System.Collections.Generic;

namespace MS.UnitOfWork.Collections
{
    /// <summary>
    /// 提供任何类型的分页接口
    /// </summary>
    /// <typeparam name="T">需要分页的数据类型</typeparam>
    public interface IPagedList<T>
    {
        /// <summary>
        /// 起始页 值
        /// </summary>
        int IndexFrom { get; }
        /// <summary>
        /// 当前页 值
        /// </summary>
        int PageIndex { get; }
        /// <summary>
        /// 每页大小
        /// </summary>
        int PageSize { get; }
        /// <summary>
        /// 数据总数
        /// </summary>
        int TotalCount { get; }
        /// <summary>
        /// 总页数
        /// </summary>
        int TotalPages { get; }
        /// <summary>
        /// 当前页数据
        /// </summary>
        IList<T> Items { get; }
        /// <summary>
        /// 是否有上一页
        /// </summary>
        bool HasPreviousPage { get; }
        /// <summary>
        /// 是否有下一页
        /// </summary>
        bool HasNextPage { get; }
    }
}

PagedList.cs

using System;
using System.Collections.Generic;
using System.Linq;

namespace MS.UnitOfWork.Collections
{
    /// <summary>
    /// 提供数据的分页,<see cref="IPagedList{T}"/>的默认实现
    /// </summary>
    /// <typeparam name="T"></typeparam>
    public class PagedList<T> : IPagedList<T>
    {
        /// <summary>
        /// 当前页 值
        /// </summary>
        public int PageIndex { get; set; }
        /// <summary>
        /// 每页大小
        /// </summary>
        public int PageSize { get; set; }
        /// <summary>
        /// 数据总数
        /// </summary>
        public int TotalCount { get; set; }
        /// <summary>
        /// 总页数
        /// </summary>
        public int TotalPages { get; set; }
        /// <summary>
        /// 起始页 值
        /// </summary>
        public int IndexFrom { get; set; }
        /// <summary>
        /// 当前页数据
        /// </summary>
        public IList<T> Items { get; set; }
        /// <summary>
        /// 是否有上一页
        /// </summary>
        public bool HasPreviousPage => PageIndex - IndexFrom > 0;
        /// <summary>
        /// 是否有下一页
        /// </summary>
        public bool HasNextPage => PageIndex - IndexFrom + 1 < TotalPages;

        /// <summary>
        /// 初始化实例
        /// </summary>
        /// <param name="source">The source.</param>
        /// <param name="pageIndex">The index of the page.</param>
        /// <param name="pageSize">The size of the page.</param>
        /// <param name="indexFrom">The index from.</param>
        internal PagedList(IEnumerable<T> source, int pageIndex, int pageSize, int indexFrom)
        {
            if (indexFrom > pageIndex)
            {
                throw new ArgumentException($"indexFrom: {indexFrom} > pageIndex: {pageIndex},起始页必须小于等于当前页");
            }

            if (source is IQueryable<T> querable)
            {
                PageIndex = pageIndex;
                PageSize = pageSize;
                IndexFrom = indexFrom;
                TotalCount = querable.Count();
                TotalPages = (int)Math.Ceiling(TotalCount / (double)PageSize);

                Items = querable.Skip((PageIndex - IndexFrom) * PageSize).Take(PageSize).ToList();
            }
            else
            {
                PageIndex = pageIndex;
                PageSize = pageSize;
                IndexFrom = indexFrom;
                TotalCount = source.Count();
                TotalPages = (int)Math.Ceiling(TotalCount / (double)PageSize);

                Items = source.Skip((PageIndex - IndexFrom) * PageSize).Take(PageSize).ToList();
            }
        }

        /// <summary>
        /// Initializes a new instance of the <see cref="PagedList{T}" /> class.
        /// </summary>
        internal PagedList() => Items = new T[0];
    }

    /// <summary>
    /// 提供数据的分页,并支持数据类型转换
    /// </summary>
    /// <typeparam name="TSource">数据源类型</typeparam>
    /// <typeparam name="TResult">输出数据类型</typeparam>
    internal class PagedList<TSource, TResult> : IPagedList<TResult>
    {
        /// <summary>
        /// 当前页 值
        /// </summary>
        public int PageIndex { get; set; }
        /// <summary>
        /// 每页大小
        /// </summary>
        public int PageSize { get; set; }
        /// <summary>
        /// 数据总数
        /// </summary>
        public int TotalCount { get; set; }
        /// <summary>
        /// 总页数
        /// </summary>
        public int TotalPages { get; set; }
        /// <summary>
        /// 起始页 值
        /// </summary>
        public int IndexFrom { get; set; }
        /// <summary>
        /// 当前页数据
        /// </summary>
        public IList<TResult> Items { get; set; }
        /// <summary>
        /// 是否有上一页
        /// </summary>
        public bool HasPreviousPage => PageIndex - IndexFrom > 0;
        /// <summary>
        /// 是否有下一页
        /// </summary>
        public bool HasNextPage => PageIndex - IndexFrom + 1 < TotalPages;

        /// <summary>
        /// 初始化实例
        /// </summary>
        /// <param name="source">The source.</param>
        /// <param name="converter">The converter.</param>
        /// <param name="pageIndex">The index of the page.</param>
        /// <param name="pageSize">The size of the page.</param>
        /// <param name="indexFrom">The index from.</param>
        public PagedList(IEnumerable<TSource> source, Func<IEnumerable<TSource>, IEnumerable<TResult>> converter, int pageIndex, int pageSize, int indexFrom)
        {
            if (indexFrom > pageIndex)
            {
                throw new ArgumentException($"indexFrom: {indexFrom} > pageIndex: {pageIndex},起始页必须小于等于当前页");
            }

            if (source is IQueryable<TSource> querable)
            {
                PageIndex = pageIndex;
                PageSize = pageSize;
                IndexFrom = indexFrom;
                TotalCount = querable.Count();
                TotalPages = (int)Math.Ceiling(TotalCount / (double)PageSize);

                var items = querable.Skip((PageIndex - IndexFrom) * PageSize).Take(PageSize).ToArray();

                Items = new List<TResult>(converter(items));
            }
            else
            {
                PageIndex = pageIndex;
                PageSize = pageSize;
                IndexFrom = indexFrom;
                TotalCount = source.Count();
                TotalPages = (int)Math.Ceiling(TotalCount / (double)PageSize);

                var items = source.Skip((PageIndex - IndexFrom) * PageSize).Take(PageSize).ToArray();

                Items = new List<TResult>(converter(items));
            }
        }

        /// <summary>
        /// Initializes a new instance of the <see cref="PagedList{TSource, TResult}" /> class.
        /// </summary>
        /// <param name="source">The source.</param>
        /// <param name="converter">The converter.</param>
        public PagedList(IPagedList<TSource> source, Func<IEnumerable<TSource>, IEnumerable<TResult>> converter)
        {
            PageIndex = source.PageIndex;
            PageSize = source.PageSize;
            IndexFrom = source.IndexFrom;
            TotalCount = source.TotalCount;
            TotalPages = source.TotalPages;

            Items = new List<TResult>(converter(source.Items));
        }
    }

    /// <summary>
    /// Provides some help methods for <see cref="IPagedList{T}"/> interface.
    /// </summary>
    public static class PagedList
    {
        /// <summary>
        /// Creates an empty of <see cref="IPagedList{T}"/>.
        /// </summary>
        /// <typeparam name="T">The type for paging </typeparam>
        /// <returns>An empty instance of <see cref="IPagedList{T}"/>.</returns>
        public static IPagedList<T> Empty<T>() => new PagedList<T>();
        /// <summary>
        /// Creates a new instance of <see cref="IPagedList{TResult}"/> from source of <see cref="IPagedList{TSource}"/> instance.
        /// </summary>
        /// <typeparam name="TResult">The type of the result.</typeparam>
        /// <typeparam name="TSource">The type of the source.</typeparam>
        /// <param name="source">The source.</param>
        /// <param name="converter">The converter.</param>
        /// <returns>An instance of <see cref="IPagedList{TResult}"/>.</returns>
        public static IPagedList<TResult> From<TResult, TSource>(IPagedList<TSource> source, Func<IEnumerable<TSource>, IEnumerable<TResult>> converter) => new PagedList<TSource, TResult>(source, converter);
    }
}

IEnumerablePagedListExtensions.cs

using System;
using System.Collections.Generic;

namespace MS.UnitOfWork.Collections
{
    /// <summary>
    /// 给<see cref="IEnumerable{T}"/>添加扩展方法来支持分页
    /// </summary>
    public static class IEnumerablePagedListExtensions
    {
        /// <summary>
        /// 在数据中取得固定页的数据
        /// </summary>
        /// <typeparam name="T">数据类型</typeparam>
        /// <param name="source">数据源</param>
        /// <param name="pageIndex">当前页</param>
        /// <param name="pageSize">页大小</param>
        /// <param name="indexFrom">起始页</param>
        /// <returns></returns>
        public static IPagedList<T> ToPagedList<T>(this IEnumerable<T> source, int pageIndex, int pageSize, int indexFrom = 1) => new PagedList<T>(source, pageIndex, pageSize, indexFrom);

        /// <summary>
        /// 在数据中取得固定页数据,并转换为指定数据类型
        /// </summary>
        /// <typeparam name="TSource">数据源类型</typeparam>
        /// <typeparam name="TResult">输出数据类型</typeparam>
        /// <param name="source">数据源</param>
        /// <param name="converter"></param>
        /// <param name="pageIndex">当前页</param>
        /// <param name="pageSize">页大小</param>
        /// <param name="indexFrom">起始页</param>
        /// <returns></returns>
        public static IPagedList<TResult> ToPagedList<TSource, TResult>(this IEnumerable<TSource> source, Func<IEnumerable<TSource>, IEnumerable<TResult>> converter, int pageIndex, int pageSize, int indexFrom = 1) => new PagedList<TSource, TResult>(source, converter, pageIndex, pageSize, indexFrom);
    }
}

IQueryablePageListExtensions.cs

using Microsoft.EntityFrameworkCore;
using System;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

namespace MS.UnitOfWork.Collections
{
    public static class IQueryablePageListExtensions
    {
        /// <summary>
        /// 在数据中取得固定页的数据(异步操作)
        /// </summary>
        /// <typeparam name="T">数据类型</typeparam>
        /// <param name="source">数据源</param>
        /// <param name="pageIndex">当前页</param>
        /// <param name="pageSize">页大小</param>
        /// <param name="indexFrom">起始页</param>
        /// <param name="cancellationToken">异步观察参数</param>
        /// <returns></returns>
        public static async Task<IPagedList<T>> ToPagedListAsync<T>(this IQueryable<T> source, int pageIndex, int pageSize, int indexFrom = 1, CancellationToken cancellationToken = default(CancellationToken))
        {
            if (indexFrom > pageIndex)
            {
                throw new ArgumentException($"indexFrom: {indexFrom} > pageIndex: {pageIndex}, must indexFrom <= pageIndex");
            }

            var count = await source.CountAsync(cancellationToken).ConfigureAwait(false);
            var items = await source.Skip((pageIndex - indexFrom) * pageSize)
                                    .Take(pageSize).ToListAsync(cancellationToken).ConfigureAwait(false);

            var pagedList = new PagedList<T>()
            {
                PageIndex = pageIndex,
                PageSize = pageSize,
                IndexFrom = indexFrom,
                TotalCount = count,
                Items = items,
                TotalPages = (int)Math.Ceiling(count / (double)pageSize)
            };

            return pagedList;
        }
    }
}

针对IQueryable、IEnumerable类型的数据做了分页扩展方法封装,主要用于向数据库获取数据时进行分页筛选

泛型仓储

MS.UnitOfWork项目中添加Repository文件夹,在该文件夹下添加IRepository.csRepository.cs类。

IRepository.cs

using MS.UnitOfWork.Collections;
using Microsoft.EntityFrameworkCore.ChangeTracking;
using Microsoft.EntityFrameworkCore.Query;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Threading;
using System.Threading.Tasks;

namespace MS.UnitOfWork
{
    /// <summary>
    /// 通用仓储接口
    /// </summary>
    /// <typeparam name="TEntity"></typeparam>
    public interface IRepository<TEntity> where TEntity : class
    {
        #region GetAll
        /// <summary>
        ///获取所有实体
        ///注意性能!
        /// </summary>
        /// <returns>The <see cref="IQueryable{TEntity}"/>.</returns>
        IQueryable<TEntity> GetAll();

        /// <summary>
        /// 获取所有实体
        /// </summary>
        /// <param name="predicate">条件表达式</param>
        /// <param name="orderBy">排序</param>
        /// <param name="include">包含的导航属性</param>
        /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
        /// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param>
        /// <returns></returns>
        IQueryable<TEntity> GetAll(
            Expression<Func<TEntity, bool>> predicate = null,
            Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
            Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
            bool disableTracking = true,
            bool ignoreQueryFilters = false);

        /// <summary>
        /// 获取所有实体,必须提供筛选谓词
        /// </summary>
        /// <typeparam name="TResult">输出数据类型</typeparam>
        /// <param name="selector">投影选择器</param>
        /// <param name="predicate">筛选谓词</param>
        /// <param name="orderBy">排序</param>
        /// <param name="include">包含的导航属性</param>
        /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
        /// <returns></returns>
        IQueryable<TResult> GetAll<TResult>(
            Expression<Func<TEntity, TResult>> selector,
            Expression<Func<TEntity, bool>> predicate = null,
            Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
            Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
            bool disableTracking = true,
            bool ignoreQueryFilters = false
            ) where TResult : class;

        /// <summary>
        /// 获取所有实体
        /// </summary>
        /// <param name="predicate">条件表达式</param>
        /// <param name="orderBy">排序</param>
        /// <param name="include">包含的导航属性</param>
        /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
        /// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param>
        /// <returns></returns>
        Task<IList<TEntity>> GetAllAsync(
            Expression<Func<TEntity, bool>> predicate = null,
            Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
            Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
            bool disableTracking = true,
            bool ignoreQueryFilters = false);
        #endregion

        #region GetPagedList
        /// <summary>
        /// 获取分页数据
        /// 默认是关闭追踪查询的(拿到的数据默认只读)
        /// 默认开启全局查询筛选过滤
        /// </summary>
        /// <param name="predicate">条件表达式</param>
        /// <param name="orderBy">排序</param>
        /// <param name="include">包含的导航属性</param>
        /// <param name="pageIndex">当前页。默认第一页</param>
        /// <param name="pageSize">页大小。默认20笔数据</param>
        /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
        /// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param>
        /// <returns></returns>
        IPagedList<TEntity> GetPagedList(
            Expression<Func<TEntity, bool>> predicate = null,
            Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
            Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
            int pageIndex = 1,
            int pageSize = 20,
            bool disableTracking = true,
            bool ignoreQueryFilters = false);

        /// <summary>
        /// 获取分页数据
        /// 默认是关闭追踪查询的(拿到的数据默认只读)
        /// 默认开启全局查询筛选过滤
        /// </summary>
        /// <param name="predicate">条件表达式</param>
        /// <param name="orderBy">排序</param>
        /// <param name="include">包含的导航属性</param>
        /// <param name="pageIndex">当前页。默认第一页</param>
        /// <param name="pageSize">页大小。默认20笔数据</param>
        /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
        /// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param>
        /// <param name="cancellationToken">异步token</param>
        /// <returns></returns>
        Task<IPagedList<TEntity>> GetPagedListAsync(
            Expression<Func<TEntity, bool>> predicate = null,
            Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
            Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
            int pageIndex = 1,
            int pageSize = 20,
            bool disableTracking = true,
            bool ignoreQueryFilters = false,
            CancellationToken cancellationToken = default);

        /// <summary>
        /// 获取分页数据
        /// 默认是关闭追踪查询的(拿到的数据默认只读)
        /// 默认开启全局查询筛选过滤
        /// </summary>
        /// <typeparam name="TResult">输出数据类型</typeparam>
        /// <param name="selector">投影选择器</param>
        /// <param name="predicate">条件表达式</param>
        /// <param name="orderBy">排序</param>
        /// <param name="include">包含的导航属性</param>
        /// <param name="pageIndex">当前页。默认第一页</param>
        /// <param name="pageSize">页大小。默认20笔数据</param>
        /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
        /// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param>
        /// <returns></returns>
        IPagedList<TResult> GetPagedList<TResult>(
            Expression<Func<TEntity, TResult>> selector,
            Expression<Func<TEntity, bool>> predicate = null,
            Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
            Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
            int pageIndex = 1,
            int pageSize = 20,
            bool disableTracking = true,
            bool ignoreQueryFilters = false
            ) where TResult : class;

        /// <summary>
        /// 获取分页数据
        /// 默认是关闭追踪查询的(拿到的数据默认只读)
        /// 默认开启全局查询筛选过滤
        /// </summary>
        /// <typeparam name="TResult">输出数据类型</typeparam>
        /// <param name="selector">投影选择器</param>
        /// <param name="predicate">条件表达式</param>
        /// <param name="orderBy">排序</param>
        /// <param name="include">包含的导航属性</param>
        /// <param name="pageIndex">当前页。默认第一页</param>
        /// <param name="pageSize">页大小。默认20笔数据</param>
        /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
        /// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param>
        /// <param name="cancellationToken">异步token</param>
        /// <returns></returns>
        Task<IPagedList<TResult>> GetPagedListAsync<TResult>(
            Expression<Func<TEntity, TResult>> selector,
            Expression<Func<TEntity, bool>> predicate = null,
            Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
            Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
            int pageIndex = 1,
            int pageSize = 20,
            bool disableTracking = true,
            bool ignoreQueryFilters = false,
            CancellationToken cancellationToken = default) where TResult : class;

        #endregion

        #region GetFirstOrDefault
        /// <summary>
        /// 获取满足条件的序列中的第一个元素
        /// 如果没有元素满足条件,则返回默认值
        /// 默认是关闭追踪查询的(拿到的数据默认只读)
        /// 默认开启全局查询筛选过滤
        /// </summary>
        /// <param name="predicate">条件表达式</param>
        /// <param name="orderBy">排序</param>
        /// <param name="include">包含的导航属性</param>
        /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
        /// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param>
        /// <returns></returns>
        TEntity GetFirstOrDefault(
            Expression<Func<TEntity, bool>> predicate = null,
            Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
            Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
            bool disableTracking = true,
            bool ignoreQueryFilters = false);

        /// <summary>
        /// 获取满足条件的序列中的第一个元素
        /// 如果没有元素满足条件,则返回默认值
        /// 默认是关闭追踪查询的(拿到的数据默认只读)
        /// 默认开启全局查询筛选过滤
        /// </summary>
        /// <param name="predicate">条件表达式</param>
        /// <param name="orderBy">排序</param>
        /// <param name="include">包含的导航属性</param>
        /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
        /// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param>
        /// <param name="cancellationToken">异步token</param>
        /// <returns></returns>
        Task<TEntity> GetFirstOrDefaultAsync(
            Expression<Func<TEntity, bool>> predicate = null,
            Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
            Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
            bool disableTracking = true,
            bool ignoreQueryFilters = false,
            CancellationToken cancellationToken = default);

        /// <summary>
        /// 获取满足条件的序列中的第一个元素
        /// 如果没有元素满足条件,则返回默认值
        /// 默认是关闭追踪查询的(拿到的数据默认只读)
        /// 默认开启全局查询筛选过滤
        /// </summary>
        /// <typeparam name="TResult">输出数据类型</typeparam>
        /// <param name="selector">投影选择器</param>
        /// <param name="predicate">条件表达式</param>
        /// <param name="orderBy">排序</param>
        /// <param name="include">包含的导航属性</param>
        /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
        /// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param>
        /// <returns></returns>
        TResult GetFirstOrDefault<TResult>(
            Expression<Func<TEntity, TResult>> selector,
            Expression<Func<TEntity, bool>> predicate = null,
            Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
            Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
            bool disableTracking = true,
            bool ignoreQueryFilters = false);

        /// <summary>
        /// 获取满足条件的序列中的第一个元素
        /// 如果没有元素满足条件,则返回默认值
        /// 默认是关闭追踪查询的(拿到的数据默认只读)
        /// 默认开启全局查询筛选过滤
        /// </summary>
        /// <typeparam name="TResult">输出数据类型</typeparam>
        /// <param name="selector">投影选择器</param>
        /// <param name="predicate">条件表达式</param>
        /// <param name="orderBy">排序</param>
        /// <param name="include">包含的导航属性</param>
        /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
        /// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param>
        /// <param name="cancellationToken">异步token</param>
        /// <returns></returns>
        Task<TResult> GetFirstOrDefaultAsync<TResult>(
            Expression<Func<TEntity, TResult>> selector,
            Expression<Func<TEntity, bool>> predicate = null,
            Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
            Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
            bool disableTracking = true,
            bool ignoreQueryFilters = false,
            CancellationToken cancellationToken = default);

        #endregion

        #region Find
        /// <summary>
        /// Finds an entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned.
        /// </summary>
        /// <param name="keyValues">The values of the primary key for the entity to be found.</param>
        /// <returns>The found entity or null.</returns>
        TEntity Find(params object[] keyValues);

        /// <summary>
        /// Finds an entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned.
        /// </summary>
        /// <param name="keyValues">The values of the primary key for the entity to be found.</param>
        /// <returns>A <see cref="Task{TEntity}"/> that represents the asynchronous find operation. The task result contains the found entity or null.</returns>
        ValueTask<TEntity> FindAsync(params object[] keyValues);

        /// <summary>
        /// Finds an entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned.
        /// </summary>
        /// <param name="keyValues">The values of the primary key for the entity to be found.</param>
        /// <param name="cancellationToken">A <see cref="CancellationToken"/> to observe while waiting for the task to complete.</param>
        /// <returns>A <see cref="Task{TEntity}"/> that represents the asynchronous find operation. The task result contains the found entity or null.</returns>
        ValueTask<TEntity> FindAsync(object[] keyValues, CancellationToken cancellationToken);
        #endregion

        #region sql、count、exist
        /// <summary>
        /// 使用原生sql查询来获取指定数据
        /// </summary>
        /// <param name="sql"></param>
        /// <param name="parameters"></param>
        /// <returns></returns>
        IQueryable<TEntity> FromSql(string sql, params object[] parameters);

        /// <summary>
        /// 查询数量
        /// </summary>
        /// <param name="predicate"></param>
        /// <returns></returns>
        int Count(Expression<Func<TEntity, bool>> predicate = null);

        /// <summary>
        /// 查询数量
        /// </summary>
        /// <param name="predicate"></param>
        /// <returns></returns>
        Task<int> CountAsync(Expression<Func<TEntity, bool>> predicate = null);

        /// <summary>
        /// 按指定条件元素是否存在
        /// </summary>
        /// <param name="predicate"></param>
        /// <returns></returns>
        bool Exists(Expression<Func<TEntity, bool>> predicate = null);
        #endregion

        #region Insert
        /// <summary>
        /// Inserts a new entity synchronously.
        /// </summary>
        /// <param name="entity"></param>
        /// <returns></returns>
        TEntity Insert(TEntity entity);

        /// <summary>
        /// Inserts a range of entities synchronously.
        /// </summary>
        /// <param name="entities">The entities to insert.</param>
        void Insert(params TEntity[] entities);

        /// <summary>
        /// Inserts a range of entities synchronously.
        /// </summary>
        /// <param name="entities">The entities to insert.</param>
        void Insert(IEnumerable<TEntity> entities);

        /// <summary>
        /// Inserts a new entity asynchronously.
        /// </summary>
        /// <param name="entity">The entity to insert.</param>
        /// <param name="cancellationToken">A <see cref="CancellationToken"/> to observe while waiting for the task to complete.</param>
        /// <returns>A <see cref="Task"/> that represents the asynchronous insert operation.</returns>
        ValueTask<EntityEntry<TEntity>> InsertAsync(TEntity entity, CancellationToken cancellationToken = default);

        /// <summary>
        /// Inserts a range of entities asynchronously.
        /// </summary>
        /// <param name="entities">The entities to insert.</param>
        /// <returns>A <see cref="Task"/> that represents the asynchronous insert operation.</returns>
        Task InsertAsync(params TEntity[] entities);

        /// <summary>
        /// Inserts a range of entities asynchronously.
        /// </summary>
        /// <param name="entities">The entities to insert.</param>
        /// <param name="cancellationToken">A <see cref="CancellationToken"/> to observe while waiting for the task to complete.</param>
        /// <returns>A <see cref="Task"/> that represents the asynchronous insert operation.</returns>
        Task InsertAsync(IEnumerable<TEntity> entities, CancellationToken cancellationToken = default);
        #endregion

        #region Update
        /// <summary>
        /// Updates the specified entity.
        /// </summary>
        /// <param name="entity">The entity.</param>
        void Update(TEntity entity);

        /// <summary>
        /// Updates the specified entities.
        /// </summary>
        /// <param name="entities">The entities.</param>
        void Update(params TEntity[] entities);

        /// <summary>
        /// Updates the specified entities.
        /// </summary>
        /// <param name="entities">The entities.</param>
        void Update(IEnumerable<TEntity> entities);
        #endregion

        #region Delete
        /// <summary>
        /// Deletes the entity by the specified primary key.
        /// </summary>
        /// <param name="id">The primary key value.</param>
        void Delete(object id);

        /// <summary>
        /// Deletes the specified entity.
        /// </summary>
        /// <param name="entity">The entity to delete.</param>
        void Delete(TEntity entity);

        /// <summary>
        /// Deletes the specified entities.
        /// </summary>
        /// <param name="entities">The entities.</param>
        void Delete(params TEntity[] entities);

        /// <summary>
        /// Deletes the specified entities.
        /// </summary>
        /// <param name="entities">The entities.</param>
        void Delete(IEnumerable<TEntity> entities);
        #endregion
    }
}

Repository.cs

using MS.UnitOfWork.Collections;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.ChangeTracking;
using Microsoft.EntityFrameworkCore.Query;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;

namespace MS.UnitOfWork
{
    /// <summary>
    /// 通用仓储的默认实现
    /// </summary>
    /// <typeparam name="TEntity"></typeparam>
    public class Repository<TEntity> : IRepository<TEntity> where TEntity : class
    {
        protected readonly DbContext _dbContext;
        protected readonly DbSet<TEntity> _dbSet;

        public Repository(DbContext dbContext)
        {
            _dbContext = dbContext ?? throw new ArgumentNullException(nameof(dbContext));
            _dbSet = _dbContext.Set<TEntity>();
        }

        #region GetAll
        public IQueryable<TEntity> GetAll() => _dbSet;

        public IQueryable<TEntity> GetAll(
            Expression<Func<TEntity, bool>> predicate = null,
            Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
            Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
            bool disableTracking = true,
            bool ignoreQueryFilters = false)
        {
            IQueryable<TEntity> query = _dbSet;
            if (disableTracking)
            {
                query = query.AsNoTracking();
            }

            if (include != null)
            {
                query = include(query);
            }

            if (predicate != null)
            {
                query = query.Where(predicate);
            }

            if (ignoreQueryFilters)
            {
                query = query.IgnoreQueryFilters();
            }

            if (orderBy != null)
            {
                return orderBy(query);
            }
            else
            {
                return query;
            }
        }

        public IQueryable<TResult> GetAll<TResult>(
            Expression<Func<TEntity, TResult>> selector,
            Expression<Func<TEntity, bool>> predicate,
            Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
            Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
            bool disableTracking = true,
            bool ignoreQueryFilters = false) where TResult : class
        {
            IQueryable<TEntity> query = _dbSet;
            if (disableTracking)
            {
                query = query.AsNoTracking();
            }

            if (include != null)
            {
                query = include(query);
            }

            if (predicate != null)
            {
                query = query.Where(predicate);
            }

            if (ignoreQueryFilters)
            {
                query = query.IgnoreQueryFilters();
            }

            if (orderBy != null)
            {
                return orderBy(query).Select(selector);
            }
            else
            {
                return query.Select(selector);
            }
        }

        public async Task<IList<TEntity>> GetAllAsync(Expression<Func<TEntity, bool>> predicate = null, Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null, Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null, bool disableTracking = true, bool ignoreQueryFilters = false)
        {
            IQueryable<TEntity> query = _dbSet;

            if (disableTracking)
            {
                query = query.AsNoTracking();
            }

            if (include != null)
            {
                query = include(query);
            }

            if (predicate != null)
            {
                query = query.Where(predicate);
            }

            if (ignoreQueryFilters)
            {
                query = query.IgnoreQueryFilters();
            }

            if (orderBy != null)
            {
                return await orderBy(query).ToListAsync();
            }
            else
            {
                return await query.ToListAsync();
            }
        }
        #endregion

        #region GetPagedList
        public virtual IPagedList<TEntity> GetPagedList(
            Expression<Func<TEntity, bool>> predicate = null,
            Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
            Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
            int pageIndex = 1,
            int pageSize = 20,
            bool disableTracking = true,
            bool ignoreQueryFilters = false)
        {
            IQueryable<TEntity> query = _dbSet;
            if (disableTracking)
            {
                query = query.AsNoTracking();
            }

            if (include != null)
            {
                query = include(query);
            }

            if (predicate != null)
            {
                query = query.Where(predicate);
            }

            if (ignoreQueryFilters)
            {
                query = query.IgnoreQueryFilters();
            }

            if (orderBy != null)
            {
                return orderBy(query).ToPagedList(pageIndex, pageSize);
            }
            else
            {
                return query.ToPagedList(pageIndex, pageSize);
            }
        }

        public virtual async Task<IPagedList<TEntity>> GetPagedListAsync(
            Expression<Func<TEntity, bool>> predicate = null,
            Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
            Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
            int pageIndex = 1,
            int pageSize = 20,
            bool disableTracking = true,
            bool ignoreQueryFilters = false,
            CancellationToken cancellationToken = default)
        {
            IQueryable<TEntity> query = _dbSet;
            if (disableTracking)
            {
                query = query.AsNoTracking();
            }

            if (include != null)
            {
                query = include(query);
            }

            if (predicate != null)
            {
                query = query.Where(predicate);
            }

            if (ignoreQueryFilters)
            {
                query = query.IgnoreQueryFilters();
            }

            if (orderBy != null)
            {
                return await orderBy(query).ToPagedListAsync(pageIndex, pageSize, 1, cancellationToken);
            }
            else
            {
                return await query.ToPagedListAsync(pageIndex, pageSize, 1, cancellationToken);
            }
        }

        public virtual IPagedList<TResult> GetPagedList<TResult>(
            Expression<Func<TEntity, TResult>> selector,
            Expression<Func<TEntity, bool>> predicate = null,
            Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
            Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
            int pageIndex = 1,
            int pageSize = 20,
            bool disableTracking = true,
            bool ignoreQueryFilters = false)
            where TResult : class
        {
            IQueryable<TEntity> query = _dbSet;
            if (disableTracking)
            {
                query = query.AsNoTracking();
            }

            if (include != null)
            {
                query = include(query);
            }

            if (predicate != null)
            {
                query = query.Where(predicate);
            }

            if (ignoreQueryFilters)
            {
                query = query.IgnoreQueryFilters();
            }

            if (orderBy != null)
            {
                return orderBy(query).Select(selector).ToPagedList(pageIndex, pageSize);
            }
            else
            {
                return query.Select(selector).ToPagedList(pageIndex, pageSize);
            }
        }

        public virtual async Task<IPagedList<TResult>> GetPagedListAsync<TResult>(
            Expression<Func<TEntity, TResult>> selector,
            Expression<Func<TEntity, bool>> predicate = null,
            Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
            Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
            int pageIndex = 1,
            int pageSize = 20,
            bool disableTracking = true,
            bool ignoreQueryFilters = false,
            CancellationToken cancellationToken = default)
            where TResult : class
        {
            IQueryable<TEntity> query = _dbSet;
            if (disableTracking)
            {
                query = query.AsNoTracking();
            }

            if (include != null)
            {
                query = include(query);
            }

            if (predicate != null)
            {
                query = query.Where(predicate);
            }

            if (ignoreQueryFilters)
            {
                query = query.IgnoreQueryFilters();
            }

            if (orderBy != null)
            {
                return await orderBy(query).Select(selector).ToPagedListAsync(pageIndex, pageSize, 1, cancellationToken);
            }
            else
            {
                return await query.Select(selector).ToPagedListAsync(pageIndex, pageSize, 1, cancellationToken);
            }
        }
        #endregion

        #region GetFirstOrDefault 

        public virtual TEntity GetFirstOrDefault(
            Expression<Func<TEntity, bool>> predicate = null,
            Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
            Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
            bool disableTracking = true,
            bool ignoreQueryFilters = false)
        {
            IQueryable<TEntity> query = _dbSet;
            if (disableTracking)
            {
                query = query.AsNoTracking();
            }

            if (include != null)
            {
                query = include(query);
            }

            if (predicate != null)
            {
                query = query.Where(predicate);
            }

            if (ignoreQueryFilters)
            {
                query = query.IgnoreQueryFilters();
            }

            if (orderBy != null)
            {
                return orderBy(query).FirstOrDefault();
            }
            else
            {
                return query.FirstOrDefault();
            }
        }

        public virtual async Task<TEntity> GetFirstOrDefaultAsync(
            Expression<Func<TEntity, bool>> predicate = null,
            Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
            Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
            bool disableTracking = true,
            bool ignoreQueryFilters = false,
            CancellationToken cancellationToken = default)
        {
            IQueryable<TEntity> query = _dbSet;
            if (disableTracking)
            {
                query = query.AsNoTracking();
            }

            if (include != null)
            {
                query = include(query);
            }

            if (predicate != null)
            {
                query = query.Where(predicate);
            }

            if (ignoreQueryFilters)
            {
                query = query.IgnoreQueryFilters();
            }

            if (orderBy != null)
            {
                return await orderBy(query).FirstOrDefaultAsync(cancellationToken);
            }
            else
            {
                return await query.FirstOrDefaultAsync(cancellationToken);
            }
        }

        public virtual TResult GetFirstOrDefault<TResult>(
            Expression<Func<TEntity, TResult>> selector,
            Expression<Func<TEntity, bool>> predicate = null,
            Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
            Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
            bool disableTracking = true,
            bool ignoreQueryFilters = false)
        {
            IQueryable<TEntity> query = _dbSet;
            if (disableTracking)
            {
                query = query.AsNoTracking();
            }

            if (include != null)
            {
                query = include(query);
            }

            if (predicate != null)
            {
                query = query.Where(predicate);
            }

            if (ignoreQueryFilters)
            {
                query = query.IgnoreQueryFilters();
            }

            if (orderBy != null)
            {
                return orderBy(query).Select(selector).FirstOrDefault();
            }
            else
            {
                return query.Select(selector).FirstOrDefault();
            }
        }

        public virtual async Task<TResult> GetFirstOrDefaultAsync<TResult>(
            Expression<Func<TEntity, TResult>> selector,
            Expression<Func<TEntity, bool>> predicate = null,
            Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
            Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
            bool disableTracking = true,
            bool ignoreQueryFilters = false,
            CancellationToken cancellationToken = default)
        {
            IQueryable<TEntity> query = _dbSet;
            if (disableTracking)
            {
                query = query.AsNoTracking();
            }

            if (include != null)
            {
                query = include(query);
            }

            if (predicate != null)
            {
                query = query.Where(predicate);
            }

            if (ignoreQueryFilters)
            {
                query = query.IgnoreQueryFilters();
            }

            if (orderBy != null)
            {
                return await orderBy(query).Select(selector).FirstOrDefaultAsync(cancellationToken);
            }
            else
            {
                return await query.Select(selector).FirstOrDefaultAsync(cancellationToken);
            }
        }
        #endregion

        #region Find

        public virtual TEntity Find(params object[] keyValues) => _dbSet.Find(keyValues);

        public virtual ValueTask<TEntity> FindAsync(params object[] keyValues) => _dbSet.FindAsync(keyValues);

        public virtual ValueTask<TEntity> FindAsync(object[] keyValues, CancellationToken cancellationToken) => _dbSet.FindAsync(keyValues, cancellationToken);
        #endregion

        #region sql、count、exist
        public virtual IQueryable<TEntity> FromSql(string sql, params object[] parameters) => _dbSet.FromSqlRaw(sql, parameters);

        public virtual int Count(Expression<Func<TEntity, bool>> predicate = null)
        {
            if (predicate == null)
            {
                return _dbSet.Count();
            }
            else
            {
                return _dbSet.Count(predicate);
            }
        }

        public virtual async Task<int> CountAsync(Expression<Func<TEntity, bool>> predicate = null)
        {
            if (predicate == null)
            {
                return await _dbSet.CountAsync();
            }
            else
            {
                return await _dbSet.CountAsync(predicate);
            }
        }
        public virtual bool Exists(Expression<Func<TEntity, bool>> predicate = null)
        {
            if (predicate == null)
            {
                return _dbSet.Any();
            }
            else
            {
                return _dbSet.Any(predicate);
            }
        }
        #endregion

        #region Insert
        public virtual TEntity Insert(TEntity entity)
        {
            return _dbSet.Add(entity).Entity;
        }

        public virtual void Insert(params TEntity[] entities) => _dbSet.AddRange(entities);

        public virtual void Insert(IEnumerable<TEntity> entities) => _dbSet.AddRange(entities);

        public virtual ValueTask<EntityEntry<TEntity>> InsertAsync(TEntity entity, CancellationToken cancellationToken = default(CancellationToken))
        {
            return _dbSet.AddAsync(entity, cancellationToken);

            // Shadow properties?
            //var property = _dbContext.Entry(entity).Property("Created");
            //if (property != null) {
            //property.CurrentValue = DateTime.Now;
            //}
        }

        public virtual Task InsertAsync(params TEntity[] entities) => _dbSet.AddRangeAsync(entities);

        public virtual Task InsertAsync(IEnumerable<TEntity> entities, CancellationToken cancellationToken = default(CancellationToken)) => _dbSet.AddRangeAsync(entities, cancellationToken);

        #endregion

        #region Update
        public virtual void Update(TEntity entity)
        {
            _dbSet.Update(entity);
        }

        public virtual void UpdateAsync(TEntity entity)
        {
            _dbSet.Update(entity);

        }

        public virtual void Update(params TEntity[] entities) => _dbSet.UpdateRange(entities);

        public virtual void Update(IEnumerable<TEntity> entities) => _dbSet.UpdateRange(entities);
        #endregion

        #region Delete

        public virtual void Delete(TEntity entity) => _dbSet.Remove(entity);

        public virtual void Delete(object id)
        {
            var entity = _dbSet.Find(id);
            if (entity != null)
            {
                Delete(entity);
            }
        }

        public virtual void Delete(params TEntity[] entities) => _dbSet.RemoveRange(entities);

        public virtual void Delete(IEnumerable<TEntity> entities) => _dbSet.RemoveRange(entities);

        #endregion

    }
}

说明

  • 封装了通用的增删改查操作
  • 以Async方法名结尾的是异步操作
  • 方法注释都在接口中
  • 查询:
    • GetAll查询所有满足条件的实体(注意性能)
    • GetPagedList分页查询
    • GetFirstOrDefault获取满足条件的第一个元素
    • Find根据主键查找元素,比如给一个Id值
    • FromSql原生sql查询
    • Count查询数量
    • Exists查询是否存在
  • 查询中包含了很多条件:
    • 分页查询默认每页20笔数据
    • 默认关闭了追踪查询
    • 默认开启了全局查询过滤
    • selector参数可以转换查询出来的数据为其他类型

工作单元

MS.UnitOfWork项目中添加UnitOfWork文件夹,在该文件夹下添加IUnitOfWork.csUnitOfWork.cs类。

IUnitOfWork.cs

using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Storage;
using System;
using System.Linq;
using System.Threading.Tasks;

namespace MS.UnitOfWork
{
    /// <summary>
    /// 定义工作单元接口
    /// </summary>
    public interface IUnitOfWork<TContext> : IDisposable where TContext : DbContext
    {
        /// <summary>
        /// 获取DBContext
        /// </summary>
        /// <returns></returns>
        TContext DbContext { get; }
        /// <summary>
        /// 开始一个事务
        /// </summary>
        /// <returns></returns>
        IDbContextTransaction BeginTransaction();

        /// <summary>
        /// 获取指定仓储
        /// </summary>
        /// <typeparam name="TEntity"></typeparam>
        /// <param name="hasCustomRepository">如有自定义仓储设为True</param>
        /// <returns></returns>
        IRepository<TEntity> GetRepository<TEntity>(bool hasCustomRepository = false) where TEntity : class;

        /// <summary>
        /// DbContext提交修改
        /// </summary>
        /// <returns></returns>
        int SaveChanges();

        /// <summary>
        /// DbContext提交修改(异步)
        /// </summary>
        /// <returns></returns>
        Task<int> SaveChangesAsync();

        /// <summary>
        /// 执行原生sql语句
        /// </summary>
        /// <param name="sql">sql语句</param>
        /// <param name="parameters">参数</param>
        /// <returns></returns>
        int ExecuteSqlCommand(string sql, params object[] parameters);

        /// <summary>
        /// 使用原生sql查询来获取指定数据
        /// </summary>
        /// <typeparam name="TEntity"></typeparam>
        /// <param name="sql"></param>
        /// <param name="parameters">参数</param>
        /// <returns></returns>
        IQueryable<TEntity> FromSql<TEntity>(string sql, params object[] parameters) where TEntity : class;
    }
}

UnitOfWork.cs

using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Storage;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;

namespace MS.UnitOfWork
{
    /// <summary>
    /// 工作单元的默认实现.
    /// </summary>
    /// <typeparam name="TContext"></typeparam>
    public class UnitOfWork<TContext> : IUnitOfWork<TContext> where TContext : DbContext
    {
        protected readonly TContext _context;
        protected bool _disposed = false;
        protected Dictionary<Type, object> _repositories;

        public UnitOfWork(TContext context)
        {
            _context = context ?? throw new ArgumentNullException(nameof(context));
        }

        /// <summary>
        /// 获取DbContext
        /// </summary>
        public TContext DbContext => _context;
        /// <summary>
        /// 开始一个事务
        /// </summary>
        /// <returns></returns>
        public IDbContextTransaction BeginTransaction()
        {
            return _context.Database.BeginTransaction();
        }

        /// <summary>
        /// 获取指定仓储
        /// </summary>
        /// <typeparam name="TEntity"></typeparam>
        /// <param name="hasCustomRepository"></param>
        /// <returns></returns>
        public IRepository<TEntity> GetRepository<TEntity>(bool hasCustomRepository = false) where TEntity : class
        {
            if (_repositories == null)
            {
                _repositories = new Dictionary<Type, object>();
            }

            Type type = typeof(IRepository<TEntity>);
            if (!_repositories.TryGetValue(type, out object repo))
            {
                IRepository<TEntity> newRepo = new Repository<TEntity>(_context);
                _repositories.Add(type, newRepo);
                return newRepo;
            }
            return (IRepository<TEntity>)repo;
        }

        /// <summary>
        /// 执行原生sql语句
        /// </summary>
        /// <param name="sql">sql语句</param>
        /// <param name="parameters">参数</param>
        /// <returns></returns>
        public int ExecuteSqlCommand(string sql, params object[] parameters) => _context.Database.ExecuteSqlRaw(sql, parameters);

        /// <summary>
        /// 使用原生sql查询来获取指定数据
        /// </summary>
        /// <typeparam name="TEntity"></typeparam>
        /// <param name="sql"></param>
        /// <param name="parameters">参数</param>
        /// <returns></returns>
        public IQueryable<TEntity> FromSql<TEntity>(string sql, params object[] parameters) where TEntity : class => _context.Set<TEntity>().FromSqlRaw(sql, parameters);

        /// <summary>
        /// DbContext提交修改
        /// </summary>
        /// <returns></returns>
        public int SaveChanges()
        {
            return _context.SaveChanges();
        }

        /// <summary>
        /// DbContext提交修改(异步)
        /// </summary>
        /// <returns></returns>
        public async Task<int> SaveChangesAsync()
        {
            return await _context.SaveChangesAsync();
        }

        public void Dispose()
        {
            Dispose(true);

            GC.SuppressFinalize(this);
        }
        protected virtual void Dispose(bool disposing)
        {
            if (!_disposed)
            {
                if (disposing)
                {
                    // clear repositories
                    if (_repositories != null)
                    {
                        _repositories.Clear();
                    }

                    // dispose the db context.
                    _context.Dispose();
                }
            }

            _disposed = true;
        }
    }
}

说明

  • 从工作单元中获取仓储或DbContext数据库上下文
  • 如果要使用Transaction事务,也是从工作单元中开启
  • 通过仓储修改数据后,使用工作单元SaveChanges提交修改

封装Ioc注册

MS.UnitOfWork项目中添加UnitOfWorkServiceExtensions.cs类:

using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.DependencyInjection;

namespace MS.UnitOfWork
{
    /// <summary>
    ///在 <see cref="IServiceCollection"/>中安装工作单元依赖注入的扩展方法
    /// </summary>
    public static class UnitOfWorkServiceExtensions
    {
        /// <summary>
        /// 在<see cref ="IServiceCollection"/>中注册给定上下文作为服务的工作单元。
        /// 同时注册了dbcontext
        /// </summary>
        /// <typeparam name="TContext"></typeparam>
        /// <param name="services"></param>
        /// <remarks>此方法仅支持一个db上下文,如果多次调用,将抛出异常。</remarks>
        /// <returns></returns>
        public static IServiceCollection AddUnitOfWorkService<TContext>(this IServiceCollection services, System.Action<DbContextOptionsBuilder> action) where TContext : DbContext
        {
            //注册dbcontext
            services.AddDbContext<TContext>(action);
            //注册工作单元
            services.AddScoped<IUnitOfWork<TContext>, UnitOfWork<TContext>>();
            return services;
        }
    }
}

这样一来,如果项目要使用该工作单元,直接在Startup中调用AddUnitOfWorkService注册即可

项目完成后,如下图所示:

使用方法展示

待补充。。。

原文地址:https://www.cnblogs.com/kasnti/p/12238521.html

时间: 2024-10-25 20:56:45

ASP.NET Core搭建多层网站架构【4.1-工作单元和仓储设计】的相关文章

ASP.NET Core搭建多层网站架构【1-项目结构分层建立】

2020/01/26, ASP.NET Core 3.1, VS2019 摘要:基于ASP.NET Core 3.1 WebApi搭建后端多层网站架构[1-项目结构分层建立] 文章目录 此分支项目代码 此章节介绍了本项目的分层结构,建立了各层级的类库,修改网站项目的启动配置 新建解决方案 新建空白解决方案MSDemo 建立以下解决方案文件夹: 一个解决方案文件夹相当于一个层级,解决方案下是src和tests.而src下分了七个层级,并且每个层级向上依赖,不会出现2.WebCore中的内容依赖5.

ASP.NET Core搭建多层网站架构【2-公共基础库】

2020/01/28, ASP.NET Core 3.1, VS2019, Snowflake雪花算法ID, Enum枚举方法扩展, Lambda方法扩展, Json方法封装 摘要:基于ASP.NET Core 3.1 WebApi搭建后端多层网站架构[2-公共基础库] Snowflake雪花算法ID.Enum枚举方法扩展.Lambda方法扩展.Json方法封装 文章目录 此分支项目代码 本章节介绍了MS.Common类库中一些常用公共方法,可以自行添加自己积累的一些库 添加包引用 向MS.Co

ASP.NET Core搭建多层网站架构【4.1-网站数据库实体设计及映射配置】

2020/01/28, ASP.NET Core 3.1, VS2019 摘要:基于ASP.NET Core 3.1 WebApi搭建后端多层网站架构[4.1-网站数据库实体设计及映射配置] 文章目录 此分支项目代码 本章节介绍后台管理的网站数据库实体设计 需求分析 首先要实现的功能有用户登录.角色管理.日志记录 大概有四张表:用户表.密码表.角色表.日志表 日志表: 用户表: 密码表: 角色表: 好像博客园md不支持表格功能?所以只能截图展示,excel表格上传至项目docs文件夹中 字段设计

ASP.NET Core搭建多层网站架构【3-使用xUnit编写单元测试之简单方法测试】

2020/01/28, ASP.NET Core 3.1, VS2019, xUnit 摘要:基于ASP.NET Core 3.1 WebApi搭建后端多层网站架构[3-使用xUnit编写单元测试之简单方法测试] 文章目录 此分支项目代码 上一章节已经建立了Common公共类库,本章节介绍编写简单的单元测试,对上一章节的公共类库中EnumExtension方法编写单元测试,同时也是介绍上一章节中公共类库EnumExtension的使用方法 新建测试项目 在tests解决方案文件夹下,新建xUni

ASP.NET Core搭建多层网站架构【5.1-WebCore网站核心配置】

2020/01/29, ASP.NET Core 3.1, VS2019 摘要:基于ASP.NET Core 3.1 WebApi搭建后端多层网站架构[5.1-WebCore网站核心配置] 统一封装网站核心配置,注册跨域策略,实例化雪花算法,后期可扩展添加多语言支持 文章目录 此分支项目代码 本章节介绍了统一封装网站核心配置,注册跨域策略,实例化雪花算法,后期可扩展添加多语言支持 添加网站配置及跨域配置 在MS.WebApi应用程序appsettings.json中添加以下节点: "SiteSe

ASP.NET Core搭建多层网站架构【6.2-使用AutoMapper映射实体对象】

2020/01/29, ASP.NET Core 3.1, VS2019, AutoMapper.Extensions.Microsoft.DependencyInjection 7.0.0 摘要:基于ASP.NET Core 3.1 WebApi搭建后端多层网站架构[6.2-使用AutoMapper映射实体对象] 用依赖注入的方法使用AutoMapper映射 文章目录 此分支项目代码 本章节介绍了使用AutoMapper映射实体对象的注册部分,用依赖注入的方法使用AutoMapper映射,具体

ASP.NET Core搭建多层网站架构【8-使用AOP动态拦截器进行服务层日志记录】

2020/01/29, ASP.NET Core 3.1, VS2019 摘要:基于ASP.NET Core 3.1 WebApi搭建后端多层网站架构[7-编写角色业务的增删改查] 编写最简单的增删改业务,涉及到DI依赖注入的使用.AutoMapper的使用.工作单元与仓储的使用.雪花Id的生成 文章目录 此分支项目代码 本章节介绍了编写最简单的增删改查业务,涉及到DI依赖注入的使用.AutoMapper的使用.工作单元与仓储的使用 原文地址:https://www.cnblogs.com/ka

ASP.NET Core搭建多层网站架构【8.2-使用Castle.Core实现动态代理拦截器】

2020/01/30, ASP.NET Core 3.1, VS2019, 摘要:基于ASP.NET Core 3.1 WebApi搭建后端多层网站架构[8.2-使用Castle.Core实现动态代理拦截器] 使用Autofac替换原生的依赖注入 文章目录 此分支项目代码 本章节介绍了使用Autofac代替原生的依赖注入,使用Autofac为的是后面配合Castle.Core做AOP动态代理 原文地址:https://www.cnblogs.com/kasnti/p/12244544.html

ASP.NET Core搭建多层网站架构【10-xUnit单元测试之集成测试】

2020/01/31, ASP.NET Core 3.1, VS2019, 摘要:基于ASP.NET Core 3.1 WebApi搭建后端多层网站架构[10-xUnit单元测试之集成测试] 文章目录 此分支项目代码 本章节介绍了 原文地址:https://www.cnblogs.com/kasnti/p/12246180.html