Hibernate 5 通用基础DAO父类

/*
 * Copyright (c) 2017 西安才多信息技术有限责任公司。
 * 项目名称:dev-admin
 * 文件名称:HibernateBaseDaoImpl.java
 * 日期:17-5-31 下午6:39
 * 作者:yangyan
 *
 */
package cn.firegod.common.hibernate;


import cn.firegod.common.Page;
import cn.firegod.common.PageList;
import cn.firegod.dev.vo.CurrentUser;
import org.apache.commons.lang3.StringUtils;
import org.hibernate.FlushMode;
import org.hibernate.Query;
import org.hibernate.Session;
import org.hibernate.SessionFactory;
import org.hibernate.engine.query.spi.HQLQueryPlan;
import org.hibernate.engine.query.spi.QueryPlanCache;
import org.hibernate.engine.spi.SessionFactoryImplementor;
import org.springframework.beans.BeanUtils;
import org.springframework.orm.hibernate5.SessionHolder;
import org.springframework.stereotype.Repository;
import org.springframework.transaction.support.TransactionSynchronizationManager;

import java.beans.PropertyDescriptor;
import java.io.Serializable;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
/**
 * Hibernate dao 基础类,一个Hibernate的dao类继承此抽象类后,需要实现 setSessionFactory 方法,注入数据源
 *
 * @param <T>  对象类型
 * @param <PK> 对象主键类型
 */
@Repository
public abstract class HibernateBaseDaoImpl<T, PK extends Serializable> implements
        HibernateDao<T, PK> {
    static {
        System.setProperty("org.jboss.logging.provider", "log4j2");
    }


    protected SessionFactory sessionFactory;
    private Class<T> entityClass;

    protected HibernateBaseDaoImpl() {

        Class c = getClass();
        Type type = c.getGenericSuperclass();
        if (type instanceof ParameterizedType) {
            Type[] parameterizedType = ((ParameterizedType) type)
                    .getActualTypeArguments();
            this.entityClass = (Class<T>) parameterizedType[0];
        }
    }

    /**
     * 注入对应的数据源
     *
     * @param sessionFactory
     */
    protected abstract void setSessionFactory(SessionFactory sessionFactory);

    @Override
    public Session getCurrentSession() {
        Session session = null;
        if (TransactionSynchronizationManager.hasResource(sessionFactory)) {
            // Do not modify the Session: just set the participate flag.
            session = sessionFactory.getCurrentSession();
        } else {
            session = sessionFactory.openSession();
            session.setHibernateFlushMode(FlushMode.AUTO);
            SessionHolder sessionHolder = new SessionHolder(session);
            TransactionSynchronizationManager.bindResource(sessionFactory, sessionHolder);
        }
        return session;
    }

    @Override
    public Session openSession() {
        return sessionFactory.openSession();
    }

    @Override
    public void beginTransaction() {
        getCurrentSession().beginTransaction();
    }

    @Override
    public void commitTransaction() {
        getCurrentSession().getTransaction().commit();
    }

    @Override
    public void rollbackTransaction() {
        getCurrentSession().getTransaction().rollback();
    }

    @Override
    public void closeSession(Session session) {
        session.close();
    }

    @Override
    public T getById(PK id) {
        return (T) getCurrentSession().get(entityClass, id);
    }

    @Override
    public T loadById(PK id) {
        return (T) getCurrentSession().load(entityClass, id);
    }

    @Override
    public <T1> T1 getPropertyById(PK id, String propertyName, Class<T1> clazz) {
        return getUniqueResult(clazz, "select " + propertyName + " from " + entityClass.getName() + " where " + sessionFactory.getClassMetadata(this.entityClass).getIdentifierPropertyName() + "=?", id);
    }

    private void preInsert(T model) {
        try {
            Method setCreateDate = this.entityClass.getMethod("setCreateDate", Timestamp.class);
            if (setCreateDate != null) {
                setCreateDate.invoke(model, new Timestamp(System.currentTimeMillis()));
            }
            Method setCreateBy = this.entityClass.getMethod("setCreateBy", Integer.class);
            if (setCreateBy != null) {
                setCreateBy.invoke(model, CurrentUser.getCurrentUserId());
            }
        } catch (NoSuchMethodException e) {
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        } catch (InvocationTargetException e) {
            e.printStackTrace();
        }
        try {
            Method delFlag = this.entityClass.getMethod("setDelFlag", Boolean.class);
            if (delFlag != null) {
                delFlag.invoke(model, false);
            }
        } catch (NoSuchMethodException e) {
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        } catch (InvocationTargetException e) {
            e.printStackTrace();
        }

        this.preUpdate(model);

    }

    private void preUpdate(T model) {
        try {
            Method setUpdateDate = this.entityClass.getMethod("setUpdateDate", Timestamp.class);
            if (setUpdateDate != null) {
                setUpdateDate.invoke(model, new Timestamp(System.currentTimeMillis()));
            }
            Method setUpdateBy = this.entityClass.getMethod("setUpdateBy", Integer.class);
            if (setUpdateBy != null) {
                setUpdateBy.invoke(model, CurrentUser.getCurrentUserId());
            }
        } catch (NoSuchMethodException e) {
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        } catch (InvocationTargetException e) {
            e.printStackTrace();
        }

    }

    @Override
    public void save(T model) {
        this.preInsert(model);
        getCurrentSession().save(model);
    }

    @Override
    public void saveOrUpdate(T model) {
        try {
            Method getId = this.entityClass.getMethod("getId");
            if (getId != null && getId.invoke(model) != null) {
                this.preUpdate(model);

            } else {
                this.preInsert(model);
            }
        } catch (NoSuchMethodException e) {
            e.printStackTrace();
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        } catch (InvocationTargetException e) {
            e.printStackTrace();
        }
        getCurrentSession().saveOrUpdate(model);
    }

    @Override
    public List<T> findListByProperty(String propertyName, Object value) {
        Query query = getCurrentSession().createQuery("from " + entityClass.getName() + " where " + propertyName + "=?");
        query.setParameter(0, value);
        return query.list();
    }

    @Override
    public <T> List<T> findListByProperty(Class<T> resultClass, String getPropertyName, String propertyName, Object value) {
        Query query = getCurrentSession().createQuery("select " + getPropertyName + " from " + entityClass.getName() + " where " + propertyName + "=?");
        query.setParameter(0, value);
        return query.list();
    }

    @Override
    public void delete(T model) {
        getCurrentSession().delete(model);
    }

    @Override
    public void deleteById(PK... id) {
        for (int i = 0; i < id.length; i++) {
            execute("delete from " + entityClass.getName() + " where " + sessionFactory.getClassMetadata(this.entityClass).getIdentifierPropertyName() + "=?", id[i]);
        }
    }

    @Override
    public void update(T model) {
        this.preUpdate(model);
        getCurrentSession().update(model);
    }

    @Override
    public List<T> findAll() {
        return getCurrentSession().createCriteria(entityClass.getName()).list();
    }

    @Override
    public void deleteResultsByProperty(String propertyName, Object value) {
        Query query = getCurrentSession().createQuery("delete from " + entityClass.getName() + " where " + propertyName + "=:p1");
        query.setParameter("p1", value);
        query.executeUpdate();
    }

    @Override
    public void deleteResultsByPropertyInValues(String propertyName, Object... value) {
        String s = StringUtils.leftPad("?", value.length * 2 - 1, "?,");
        Query query = getCurrentSession().createQuery("delete from " + entityClass.getName() + " where " + propertyName + " in (" + s + ") ");
        for (int i = 0; i < value.length; i++) {
            query.setParameter(i, value[i]);
        }
        query.executeUpdate();
    }

    @Override
    public PageList<T> findByPage(Page page, String hql, List<Object> params) {
        Query q = getCurrentSession().createQuery(hql);
        if (page instanceof Page.Offset) {
            q.setFirstResult(((Page.Offset) page).getStart()).setMaxResults(
                    ((Page.Offset) page).getLimit());
        } else {
            q.setFirstResult((page.getPage() - 1) * page.getPageSize())
                    .setMaxResults(page.getPageSize());
        }

        if (params != null) {
            for (int i = 0; i < params.size(); i++) {
                q.setParameter(i, params.get(i));
            }
        }
        page.setTotal(this.getTotalCount(hql, params));
        return new PageList<T>(page, q.list());
    }

    @Override
    public PageList<T> findByPage(Page page) {
        return this.findByPage(page, "from " + this.entityClass.getName(),
                Collections.EMPTY_LIST);
    }

    @Override
    public PageList<T> findByPage(Page page, String hql, Object... params) {
        return this.findByPage(page, hql, Arrays.asList(params));
    }

    @Override
    public void clear() {
        getCurrentSession().clear();
    }

    @Override
    public void flush() {
        getCurrentSession().flush();
    }

    @Override
    public void evict(Object o) {
        getCurrentSession().evict(o);
    }

    @Override
    public boolean isExist(String propertyName, Object value) {
        String hql = "select count(*) from " + entityClass.getName()
                + " where " + propertyName + "=?";
        return (Long) getCurrentSession().createQuery(hql)
                .setParameter(0, value).uniqueResult() > 0;
    }

    @Override
    public long getTotalCount() {
        String hql = "select count(*) from " + entityClass.getName();
        return Long.valueOf(getCurrentSession().createQuery(hql).uniqueResult()
                .toString());
    }

    @Override
    public long getTotalCount(String hql, List<Object> params) {
        Query q = getCurrentSession().createQuery(prepareCountHql(hql));
        if (params != null) {
            for (int i = 0; i < params.size(); i++) {
                q.setParameter(i, params.get(i));
            }
        }
        Object singleResult = q.uniqueResult();
        if (singleResult instanceof Object[]) {
            return Long.valueOf(((Object[]) singleResult)[0].toString()).longValue();
        }
        return Long.valueOf(singleResult.toString()).longValue();
    }

    @Override
    public long getTotalCount(String hql, Object... params) {
        return this.getTotalCount(hql,
                params == null ? null : Arrays.asList(params));
    }

    protected String prepareCountHql(String hql) {

        String fromHql = hql;
        fromHql = " from " + StringUtils.substringAfter(fromHql, "from ");
        fromHql = StringUtils.substringBefore(fromHql, "order by");
//取出查询的字段
        String selectWhat = StringUtils.substringBetween(hql, "select", "from");
//        如果是new ClassName (x.x.x)格式的处理
        if (selectWhat != null && selectWhat.contains("new ") && selectWhat.contains("(") && selectWhat.contains(")")) {
            selectWhat = StringUtils.substringBetween(selectWhat, "(", ")");
        }
//        第一列查询总行数
        String countHql = "select count(*)" + (selectWhat == null ? "" : ", " + selectWhat + " ") + fromHql;
        return countHql;

    }

    protected String getCountSql(String originalHql,
                                 SessionFactory sessionFactory) {

        SessionFactoryImplementor sessionFactoryImplementor = (SessionFactoryImplementor) sessionFactory;

        HQLQueryPlan hqlQueryPlan = sessionFactoryImplementor
                .getQueryPlanCache().getHQLQueryPlan(originalHql, false,
                        Collections.emptyMap());

        String[] sqls = hqlQueryPlan.getSqlStrings();

        String countSql = "select count(*) from (" + sqls[0] + ") count";

        return countSql;

    }

    public SessionFactoryImplementor getSessionFactoryImplementor() {

        return (SessionFactoryImplementor) sessionFactory;

    }

    public QueryPlanCache getQueryPlanCache() {

        return getSessionFactoryImplementor().getQueryPlanCache();

    }

    public HQLQueryPlan getHqlQueryPlan(String hql) {

        return getQueryPlanCache().getHQLQueryPlan(hql, false,
                Collections.emptyMap());

    }

    protected String prepareCountSql(String sql) {
        return getCountSql(sql, sessionFactory);

    }

    @Override
    public <T> List<T> findList(Class<T> clazz, String hql, List<Object> params) {
        Query q = getCurrentSession().createQuery(hql);
        if (params != null) {
            for (int i = 0; i < params.size(); i++) {
                q.setParameter(i, params.get(i));
            }
        }
        return q.list();
    }

    @Override
    public <T> List<T> findListLimit(Class<T> clazz, String hql, int limit,
                                     List<Object> params) {
        Query q = getCurrentSession().createQuery(hql);
        if (params != null) {
            for (int i = 0; i < params.size(); i++) {
                q.setParameter(i, params.get(i));
            }
        }
        q.setFirstResult(0);
        q.setMaxResults(limit);
        return q.list();
    }

    @Override
    public T getUniqueResult(String hql, Object... params) {
        return getUniqueResult(this.entityClass, hql, params);
    }

    @Override
    public <T> T getUniqueResult(Class<T> clazz, String hql, List<Object> params) {
        Query q = getCurrentSession().createQuery(hql);
        if (params != null) {
            for (int i = 0; i < params.size(); i++) {
                q.setParameter(i, params.get(i));
            }
        }
        q.setFirstResult(0);
        q.setMaxResults(1);
        List list = q.list();
        if (list == null || list.isEmpty() || list.get(0) == null) {
            return null;
        }
        return (T) list.get(0);
    }

    @Override
    public List<T> findList(String hql, Object... params) {
        return this.findList(this.entityClass, hql, params);
    }

    @Override
    public <T> List<T> findList(Class<T> clazz, String hql, Object... params) {
        return this.findList(clazz, hql,
                params == null ? null : Arrays.asList(params));
    }

    @Override
    /**
     * 用给定的HQL和参数查询前几条给定类型的数据列表
     * @param clazz
     * @param hql
     * @param limit
     * @param params
     * @return
     */
    public <T> List<T> findListLimit(Class<T> clazz, String hql, int limit,
                                     Object... params) {
        return this.findListLimit(clazz, hql, limit, params == null ? null
                : Arrays.asList(params));
    }

    @Override
    public PageList<T> findPageList(Page page, String hql,
                                    Object... params) {
        return this.findPageList(page, this.entityClass, hql, params);
    }

    @Override
    public <T> PageList<T> findPageList(Page page, Class<T> clazz, String hql,
                                        Object... params) {
        if (params != null && params.length == 1 && params[0] instanceof List) {
            return findPageList(page, clazz, hql, params[0]);
        }
        return findPageList(page, clazz, hql, Arrays.asList(params));
    }

    @Override
    public <T> PageList<T> findPageList(Page page, Class<T> clazz, String hql,
                                        List<Object> params) {
        Query q = getCurrentSession().createQuery(hql);
        if (page instanceof Page.Offset) {
            q.setFirstResult(((Page.Offset) page).getStart()).setMaxResults(
                    ((Page.Offset) page).getLimit());
        } else {
            q.setFirstResult((page.getPage() - 1) * page.getPageSize())
                    .setMaxResults(page.getPageSize());
        }
        if (params != null) {
            for (int i = 0; i < params.size(); i++) {
                q.setParameter(i, params.get(i));
            }
        }
        page.setTotal(this.getTotalCount(hql, params));
        return new PageList(page, q.list());

    }

    @Override
    public T getUniqueResultByProperty(String propertyName, Object value) {
        return this.getUniqueResult("from " + this.entityClass.getName()
                + " where " + propertyName + "=?", value);
    }

    @Override
    public <T> T getUniqueResult(Class<T> clazz, String hql, Object... params) {
        return this.getUniqueResult(clazz, hql,
                params == null ? null : Arrays.asList(params));
    }

    @Override
    public void execute(String hql, List<Object> params) {
        Query q = this.getCurrentSession().createQuery(hql);
        if (params != null) {
            for (int i = 0; i < params.size(); i++) {
                q.setParameter(i, params.get(i));
            }
        }
        q.executeUpdate();
    }

    @Override
    public void execute(String hql, Object... params) {
        this.execute(hql, params == null ? null : Arrays.asList(params));
    }

    @Override
    public void updateProperty(PK id, String propertyName, Object status) {
        String hql = "update " + this.entityClass.getName() + " set " + propertyName + " = ? where " + getIdPropertyName() + "= ?";
        this.execute(hql, status, id);
    }

    @Override
    public void incr(PK id, String propertyName) {
        String hql = "update " + this.entityClass.getName() + " set " + propertyName + " = " + propertyName + "+1 where " + getIdPropertyName() + "= ?";
        this.execute(hql, id);
    }

    @Override
    public void incr(PK id, String propertyName, Integer n) {
        String hql = "update " + this.entityClass.getName() + " set " + propertyName + " = " + propertyName + "+? where " + getIdPropertyName() + "= ?";
        this.execute(hql, n, id);
    }

    @Override
    public void updateNotNullProperties(T obj) {

//        EntityType<T> entity = sessionFactory.getMetamodel().entity(this.entityClass);

        preUpdate(obj);

        String idName = getIdPropertyName();
        Object idValue = null;
        String hql = "update " + entityClass.getName() + " set ";

        List<String> names = new ArrayList<>();
        List<Object> values = new ArrayList<>();
        PropertyDescriptor[] propertyDescriptors = BeanUtils.getPropertyDescriptors(this.entityClass);
        for (PropertyDescriptor propertyDescriptor : propertyDescriptors) {
            try {
                Method readMethod = propertyDescriptor.getReadMethod();
                String name = propertyDescriptor.getName();
                if ("class".equals(name)) {
                    continue;
                }
                Object v = null;
                if (readMethod != null && (v = readMethod.invoke(obj)) != null) {
                    if (name.equals(idName)) {
                        idValue = v;
                        continue;
                    } else {
                        names.add(name);
                        values.add(v);
                    }

                }
            } catch (IllegalAccessException e) {


            } catch (InvocationTargetException e) {
                e.printStackTrace();
            }
        }

        for (int i = 0; i < names.size(); i++) {
            hql += names.get(i) + "= :v" + i;
            if (i + 1 < names.size()) {
                hql += ",";
            }
        }
        hql += " where " + getIdPropertyName() + " =:id";
        Query query = getCurrentSession().createQuery(hql);
        for (int i = 0; i < values.size(); i++) {
            query.setParameter("v" + i, values.get(i));
        }
        query.setParameter("id", idValue);

        query.executeUpdate();
    }

    @Override
    public void updateProperties(PK id, String[] propertyNames, Object[] values) {
        String hql = "update " + entityClass.getName() + " set ";
        for (int i = 0; i < propertyNames.length; i++) {
            hql += propertyNames[i] + "= :v" + i;
            if (i + 1 < propertyNames.length) {
                hql += ",";
            }
        }
        hql += " where " + getIdPropertyName() + " =:id";
        Query query = getCurrentSession().createQuery(hql);
        for (int i = 0; i < values.length; i++) {
            query.setParameter("v" + i, values[i]);
        }
        query.setParameter("id", id).executeUpdate();
    }

    public String getIdPropertyName() {
        return sessionFactory.getClassMetadata(this.entityClass).getIdentifierPropertyName();
    }

    @Override
    public boolean deleteAll() {
        String hql = "delete " + this.entityClass.getName();
        this.execute(hql);
        return true;
    }
}

 

Leave a Comment

此站点使用Akismet来减少垃圾评论。了解我们如何处理您的评论数据