package org.springframework.data.jpa.repository.support;
import static org.springframework.data.jpa.repository.query.QueryUtils.*;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import javax.persistence.EntityManager;
import javax.persistence.LockModeType;
import javax.persistence.NoResultException;
import javax.persistence.Parameter;
import javax.persistence.Query;
import javax.persistence.TypedQuery;
import javax.persistence.criteria.CriteriaBuilder;
import javax.persistence.criteria.CriteriaQuery;
import javax.persistence.criteria.Order;
import javax.persistence.criteria.ParameterExpression;
import javax.persistence.criteria.Path;
import javax.persistence.criteria.Predicate;
import javax.persistence.criteria.Root;
import org.springframework.dao.EmptyResultDataAccessException;
import org.springframework.data.domain.Example;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageImpl;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import org.springframework.data.jpa.convert.QueryByExamplePredicateBuilder;
import org.springframework.data.jpa.domain.Specification;
import org.springframework.data.jpa.provider.PersistenceProvider;
import org.springframework.data.jpa.repository.EntityGraph;
import org.springframework.data.jpa.repository.query.EscapeCharacter;
import org.springframework.data.jpa.repository.query.QueryUtils;
import org.springframework.data.jpa.repository.support.QueryHints.NoHints;
import org.springframework.data.repository.support.PageableExecutionUtils;
import org.springframework.data.util.ProxyUtils;
import org.springframework.lang.Nullable;
import org.springframework.stereotype.Repository;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.Assert;
@Repository
@Transactional(readOnly = true)
public class SimpleJpaRepository<T, ID> implements JpaRepositoryImplementation<T, ID> {
private static final String ID_MUST_NOT_BE_NULL = "The given id must not be null!";
private final JpaEntityInformation<T, ?> entityInformation;
private final EntityManager em;
private final PersistenceProvider provider;
private @Nullable CrudMethodMetadata metadata;
private EscapeCharacter escapeCharacter = EscapeCharacter.DEFAULT;
private static <T> Collection<T> toCollection(Iterable<T> ts) {
if (ts instanceof Collection) {
return (Collection<T>) ts;
}
List<T> tCollection = new ArrayList<T>();
for (T t : ts) {
tCollection.add(t);
}
return tCollection;
}
public SimpleJpaRepository(JpaEntityInformation<T, ?> entityInformation, EntityManager entityManager) {
Assert.notNull(entityInformation, "JpaEntityInformation must not be null!");
Assert.notNull(entityManager, "EntityManager must not be null!");
this.entityInformation = entityInformation;
this.em = entityManager;
this.provider = PersistenceProvider.fromEntityManager(entityManager);
}
public SimpleJpaRepository(Class<T> domainClass, EntityManager em) {
this(JpaEntityInformationSupport.getEntityInformation(domainClass, em), em);
}
@Override
public void setRepositoryMethodMetadata(CrudMethodMetadata crudMethodMetadata) {
this.metadata = crudMethodMetadata;
}
@Override
public void setEscapeCharacter(EscapeCharacter escapeCharacter) {
this.escapeCharacter = escapeCharacter;
}
@Nullable
protected CrudMethodMetadata getRepositoryMethodMetadata() {
return metadata;
}
protected Class<T> getDomainClass() {
return entityInformation.getJavaType();
}
private String getDeleteAllQueryString() {
return getQueryString(DELETE_ALL_QUERY_STRING, entityInformation.getEntityName());
}
private String getCountQueryString() {
String countQuery = String.format(COUNT_QUERY_STRING, provider.getCountQueryPlaceholder(), "%s");
return getQueryString(countQuery, entityInformation.getEntityName());
}
@Transactional
@Override
public void deleteById(ID id) {
Assert.notNull(id, ID_MUST_NOT_BE_NULL);
delete(findById(id).orElseThrow(() -> new EmptyResultDataAccessException(
String.format("No %s entity with id %s exists!", entityInformation.getJavaType(), id), 1)));
}
@Override
@Transactional
@SuppressWarnings("unchecked")
public void delete(T entity) {
Assert.notNull(entity, "Entity must not be null!");
if (entityInformation.isNew(entity)) {
return;
}
Class<?> type = ProxyUtils.getUserClass(entity);
T existing = (T) em.find(type, entityInformation.getId(entity));
if (existing == null) {
return;
}
em.remove(em.contains(entity) ? entity : em.merge(entity));
}
@Transactional
@Override
public void deleteAll(Iterable<? extends T> entities) {
Assert.notNull(entities, "Entities must not be null!");
for (T entity : entities) {
delete(entity);
}
}
@Transactional
@Override
public void deleteInBatch(Iterable<T> entities) {
Assert.notNull(entities, "Entities must not be null!");
if (!entities.iterator().hasNext()) {
return;
}
applyAndBind(getQueryString(DELETE_ALL_QUERY_STRING, entityInformation.getEntityName()), entities, em)
.executeUpdate();
}
@Transactional
@Override
public void deleteAll() {
for (T element : findAll()) {
delete(element);
}
}
@Transactional
@Override
public void deleteAllInBatch() {
em.createQuery(getDeleteAllQueryString()).executeUpdate();
}
@Override
public Optional<T> findById(ID id) {
Assert.notNull(id, ID_MUST_NOT_BE_NULL);
Class<T> domainType = getDomainClass();
if (metadata == null) {
return Optional.ofNullable(em.find(domainType, id));
}
LockModeType type = metadata.getLockModeType();
Map<String, Object> hints = getQueryHints().withFetchGraphs(em).asMap();
return Optional.ofNullable(type == null ? em.find(domainType, id, hints) : em.find(domainType, id, type, hints));
}
protected QueryHints getQueryHints() {
return metadata == null ? NoHints.INSTANCE : DefaultQueryHints.of(entityInformation, metadata);
}
@Override
public T getOne(ID id) {
Assert.notNull(id, ID_MUST_NOT_BE_NULL);
return em.getReference(getDomainClass(), id);
}
@Override
public boolean existsById(ID id) {
Assert.notNull(id, ID_MUST_NOT_BE_NULL);
if (entityInformation.getIdAttribute() == null) {
return findById(id).isPresent();
}
String placeholder = provider.getCountQueryPlaceholder();
String entityName = entityInformation.getEntityName();
Iterable<String> idAttributeNames = entityInformation.getIdAttributeNames();
String existsQuery = QueryUtils.getExistsQueryString(entityName, placeholder, idAttributeNames);
TypedQuery<Long> query = em.createQuery(existsQuery, Long.class);
if (!entityInformation.hasCompositeId()) {
query.setParameter(idAttributeNames.iterator().next(), id);
return query.getSingleResult() == 1L;
}
for (String idAttributeName : idAttributeNames) {
Object idAttributeValue = entityInformation.getCompositeIdAttributeValue(id, idAttributeName);
boolean complexIdParameterValueDiscovered = idAttributeValue != null
&& !query.getParameter(idAttributeName).getParameterType().isAssignableFrom(idAttributeValue.getClass());
if (complexIdParameterValueDiscovered) {
return findById(id).isPresent();
}
query.setParameter(idAttributeName, idAttributeValue);
}
return query.getSingleResult() == 1L;
}
@Override
public List<T> findAll() {
return getQuery(null, Sort.unsorted()).getResultList();
}
@Override
public List<T> findAllById(Iterable<ID> ids) {
Assert.notNull(ids, "Ids must not be null!");
if (!ids.iterator().hasNext()) {
return Collections.emptyList();
}
if (entityInformation.hasCompositeId()) {
List<T> results = new ArrayList<T>();
for (ID id : ids) {
findById(id).ifPresent(results::add);
}
return results;
}
Collection<ID> idCollection = toCollection(ids);
ByIdsSpecification<T> specification = new ByIdsSpecification<T>(entityInformation);
TypedQuery<T> query = getQuery(specification, Sort.unsorted());
return query.setParameter(specification.parameter, idCollection).getResultList();
}
@Override
public List<T> findAll(Sort sort) {
return getQuery(null, sort).getResultList();
}
@Override
public Page<T> findAll(Pageable pageable) {
if (isUnpaged(pageable)) {
return new PageImpl<T>(findAll());
}
return findAll((Specification<T>) null, pageable);
}
@Override
public Optional<T> findOne(@Nullable Specification<T> spec) {
try {
return Optional.of(getQuery(spec, Sort.unsorted()).getSingleResult());
} catch (NoResultException e) {
return Optional.empty();
}
}
@Override
public List<T> findAll(@Nullable Specification<T> spec) {
return getQuery(spec, Sort.unsorted()).getResultList();
}
@Override
public Page<T> findAll(@Nullable Specification<T> spec, Pageable pageable) {
TypedQuery<T> query = getQuery(spec, pageable);
return isUnpaged(pageable) ? new PageImpl<T>(query.getResultList())
: readPage(query, getDomainClass(), pageable, spec);
}
@Override
public List<T> findAll(@Nullable Specification<T> spec, Sort sort) {
return getQuery(spec, sort).getResultList();
}
@Override
public <S extends T> Optional<S> findOne(Example<S> example) {
try {
return Optional
.of(getQuery(new ExampleSpecification<S>(example, escapeCharacter), example.getProbeType(), Sort.unsorted())
.getSingleResult());
} catch (NoResultException e) {
return Optional.empty();
}
}
@Override
public <S extends T> long count(Example<S> example) {
return executeCountQuery(
getCountQuery(new ExampleSpecification<S>(example, escapeCharacter), example.getProbeType()));
}
@Override
public <S extends T> boolean exists(Example<S> example) {
return !getQuery(new ExampleSpecification<S>(example, escapeCharacter), example.getProbeType(), Sort.unsorted())
.getResultList().isEmpty();
}
@Override
public <S extends T> List<S> findAll(Example<S> example) {
return getQuery(new ExampleSpecification<S>(example, escapeCharacter), example.getProbeType(), Sort.unsorted())
.getResultList();
}
@Override
public <S extends T> List<S> findAll(Example<S> example, Sort sort) {
return getQuery(new ExampleSpecification<S>(example, escapeCharacter), example.getProbeType(), sort)
.getResultList();
}
@Override
public <S extends T> Page<S> findAll(Example<S> example, Pageable pageable) {
ExampleSpecification<S> spec = new ExampleSpecification<>(example, escapeCharacter);
Class<S> probeType = example.getProbeType();
TypedQuery<S> query = getQuery(new ExampleSpecification<>(example, escapeCharacter), probeType, pageable);
return isUnpaged(pageable) ? new PageImpl<>(query.getResultList()) : readPage(query, probeType, pageable, spec);
}
@Override
public long count() {
return em.createQuery(getCountQueryString(), Long.class).getSingleResult();
}
@Override
public long count(@Nullable Specification<T> spec) {
return executeCountQuery(getCountQuery(spec, getDomainClass()));
}
@Transactional
@Override
public <S extends T> S save(S entity) {
if (entityInformation.isNew(entity)) {
em.persist(entity);
return entity;
} else {
return em.merge(entity);
}
}
@Transactional
@Override
public <S extends T> S saveAndFlush(S entity) {
S result = save(entity);
flush();
return result;
}
@Transactional
@Override
public <S extends T> List<S> saveAll(Iterable<S> entities) {
Assert.notNull(entities, "Entities must not be null!");
List<S> result = new ArrayList<S>();
for (S entity : entities) {
result.add(save(entity));
}
return result;
}
@Transactional
@Override
public void flush() {
em.flush();
}
@Deprecated
protected Page<T> readPage(TypedQuery<T> query, Pageable pageable, @Nullable Specification<T> spec) {
return readPage(query, getDomainClass(), pageable, spec);
}
protected <S extends T> Page<S> readPage(TypedQuery<S> query, final Class<S> domainClass, Pageable pageable,
@Nullable Specification<S> spec) {
if (pageable.isPaged()) {
query.setFirstResult((int) pageable.getOffset());
query.setMaxResults(pageable.getPageSize());
}
return PageableExecutionUtils.getPage(query.getResultList(), pageable,
() -> executeCountQuery(getCountQuery(spec, domainClass)));
}
protected TypedQuery<T> getQuery(@Nullable Specification<T> spec, Pageable pageable) {
Sort sort = pageable.isPaged() ? pageable.getSort() : Sort.unsorted();
return getQuery(spec, getDomainClass(), sort);
}
protected <S extends T> TypedQuery<S> getQuery(@Nullable Specification<S> spec, Class<S> domainClass,
Pageable pageable) {
Sort sort = pageable.isPaged() ? pageable.getSort() : Sort.unsorted();
return getQuery(spec, domainClass, sort);
}
protected TypedQuery<T> getQuery(@Nullable Specification<T> spec, Sort sort) {
return getQuery(spec, getDomainClass(), sort);
}
protected <S extends T> TypedQuery<S> getQuery(@Nullable Specification<S> spec, Class<S> domainClass, Sort sort) {
CriteriaBuilder builder = em.getCriteriaBuilder();
CriteriaQuery<S> query = builder.createQuery(domainClass);
Root<S> root = applySpecificationToCriteria(spec, domainClass, query);
query.select(root);
if (sort.isSorted()) {
query.orderBy(toOrders(sort, root, builder));
}
return applyRepositoryMethodMetadata(em.createQuery(query));
}
@Deprecated
protected TypedQuery<Long> getCountQuery(@Nullable Specification<T> spec) {
return getCountQuery(spec, getDomainClass());
}
protected <S extends T> TypedQuery<Long> getCountQuery(@Nullable Specification<S> spec, Class<S> domainClass) {
CriteriaBuilder builder = em.getCriteriaBuilder();
CriteriaQuery<Long> query = builder.createQuery(Long.class);
Root<S> root = applySpecificationToCriteria(spec, domainClass, query);
if (query.isDistinct()) {
query.select(builder.countDistinct(root));
} else {
query.select(builder.count(root));
}
query.orderBy(Collections.<Order> emptyList());
return em.createQuery(query);
}
private <S, U extends T> Root<U> applySpecificationToCriteria(@Nullable Specification<U> spec, Class<U> domainClass,
CriteriaQuery<S> query) {
Assert.notNull(domainClass, "Domain class must not be null!");
Assert.notNull(query, "CriteriaQuery must not be null!");
Root<U> root = query.from(domainClass);
if (spec == null) {
return root;
}
CriteriaBuilder builder = em.getCriteriaBuilder();
Predicate predicate = spec.toPredicate(root, query, builder);
if (predicate != null) {
query.where(predicate);
}
return root;
}
private <S> TypedQuery<S> applyRepositoryMethodMetadata(TypedQuery<S> query) {
if (metadata == null) {
return query;
}
LockModeType type = metadata.getLockModeType();
TypedQuery<S> toReturn = type == null ? query : query.setLockMode(type);
applyQueryHints(toReturn);
return toReturn;
}
private void applyQueryHints(Query query) {
for (Entry<String, Object> hint : getQueryHints().withFetchGraphs(em)) {
query.setHint(hint.getKey(), hint.getValue());
}
}
private static long executeCountQuery(TypedQuery<Long> query) {
Assert.notNull(query, "TypedQuery must not be null!");
List<Long> totals = query.getResultList();
long total = 0L;
for (Long element : totals) {
total += element == null ? 0 : element;
}
return total;
}
private static boolean isUnpaged(Pageable pageable) {
return pageable.isUnpaged();
}
@SuppressWarnings("rawtypes")
private static final class ByIdsSpecification<T> implements Specification<T> {
private static final long serialVersionUID = 1L;
private final JpaEntityInformation<T, ?> entityInformation;
@Nullable ParameterExpression<Collection<?>> parameter;
ByIdsSpecification(JpaEntityInformation<T, ?> entityInformation) {
this.entityInformation = entityInformation;
}
@Override
public Predicate toPredicate(Root<T> root, CriteriaQuery<?> query, CriteriaBuilder cb) {
Path<?> path = root.get(entityInformation.getIdAttribute());
parameter = (ParameterExpression<Collection<?>>) (ParameterExpression) cb.parameter(Collection.class);
return path.in(parameter);
}
}
private static class ExampleSpecification<T> implements Specification<T> {
private static final long serialVersionUID = 1L;
private final Example<T> example;
private final EscapeCharacter escapeCharacter;
ExampleSpecification(Example<T> example, EscapeCharacter escapeCharacter) {
Assert.notNull(example, "Example must not be null!");
Assert.notNull(escapeCharacter, "EscapeCharacter must not be null!");
this.example = example;
this.escapeCharacter = escapeCharacter;
}
@Override
public Predicate toPredicate(Root<T> root, CriteriaQuery<?> query, CriteriaBuilder cb) {
return QueryByExamplePredicateBuilder.getPredicate(root, cb, example, escapeCharacter);
}
}
}