package org.springframework.data.jpa.repository.support;
import java.util.List;
import java.util.Map.Entry;
import java.util.Optional;
import javax.persistence.EntityManager;
import javax.persistence.LockModeType;
import org.springframework.dao.IncorrectResultSizeDataAccessException;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import org.springframework.data.jpa.repository.EntityGraph;
import org.springframework.data.querydsl.EntityPathResolver;
import org.springframework.data.querydsl.QSort;
import org.springframework.data.querydsl.QuerydslPredicateExecutor;
import org.springframework.data.repository.support.PageableExecutionUtils;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import com.querydsl.core.NonUniqueResultException;
import com.querydsl.core.types.EntityPath;
import com.querydsl.core.types.OrderSpecifier;
import com.querydsl.core.types.Predicate;
import com.querydsl.core.types.dsl.PathBuilder;
import com.querydsl.jpa.JPQLQuery;
import com.querydsl.jpa.impl.AbstractJPAQuery;
public class QuerydslJpaPredicateExecutor<T> implements QuerydslPredicateExecutor<T> {
private final JpaEntityInformation<T, ?> entityInformation;
private final EntityPath<T> path;
private final Querydsl querydsl;
private final EntityManager entityManager;
private final CrudMethodMetadata metadata;
public QuerydslJpaPredicateExecutor(JpaEntityInformation<T, ?> entityInformation, EntityManager entityManager,
EntityPathResolver resolver, @Nullable CrudMethodMetadata metadata) {
this.entityInformation = entityInformation;
this.metadata = metadata;
this.path = resolver.createPath(entityInformation.getJavaType());
this.querydsl = new Querydsl(entityManager, new PathBuilder<T>(path.getType(), path.getMetadata()));
this.entityManager = entityManager;
}
@Override
public Optional<T> findOne(Predicate predicate) {
Assert.notNull(predicate, "Predicate must not be null!");
try {
return Optional.ofNullable(createQuery(predicate).select(path).fetchOne());
} catch (NonUniqueResultException ex) {
throw new IncorrectResultSizeDataAccessException(ex.getMessage(), 1, ex);
}
}
@Override
public List<T> findAll(Predicate predicate) {
Assert.notNull(predicate, "Predicate must not be null!");
return createQuery(predicate).select(path).fetch();
}
@Override
public List<T> findAll(Predicate predicate, OrderSpecifier<?>... orders) {
Assert.notNull(predicate, "Predicate must not be null!");
Assert.notNull(orders, "Order specifiers must not be null!");
return executeSorted(createQuery(predicate).select(path), orders);
}
@Override
public List<T> findAll(Predicate predicate, Sort sort) {
Assert.notNull(predicate, "Predicate must not be null!");
Assert.notNull(sort, "Sort must not be null!");
return executeSorted(createQuery(predicate).select(path), sort);
}
@Override
public List<T> findAll(OrderSpecifier<?>... orders) {
Assert.notNull(orders, "Order specifiers must not be null!");
return executeSorted(createQuery(new Predicate[0]).select(path), orders);
}
@Override
public Page<T> findAll(Predicate predicate, Pageable pageable) {
Assert.notNull(predicate, "Predicate must not be null!");
Assert.notNull(pageable, "Pageable must not be null!");
final JPQLQuery<?> countQuery = createCountQuery(predicate);
JPQLQuery<T> query = querydsl.applyPagination(pageable, createQuery(predicate).select(path));
return PageableExecutionUtils.getPage(query.fetch(), pageable, countQuery::fetchCount);
}
@Override
public long count(Predicate predicate) {
return createQuery(predicate).fetchCount();
}
@Override
public boolean exists(Predicate predicate) {
return createQuery(predicate).fetchCount() > 0;
}
protected JPQLQuery<?> createQuery(Predicate... predicate) {
Assert.notNull(predicate, "Predicate must not be null!");
AbstractJPAQuery<?, ?> query = doCreateQuery(getQueryHints().withFetchGraphs(entityManager), predicate);
CrudMethodMetadata metadata = getRepositoryMethodMetadata();
if (metadata == null) {
return query;
}
LockModeType type = metadata.getLockModeType();
return type == null ? query : query.setLockMode(type);
}
protected JPQLQuery<?> createCountQuery(@Nullable Predicate... predicate) {
return doCreateQuery(getQueryHintsForCount(), predicate);
}
@Nullable
private CrudMethodMetadata getRepositoryMethodMetadata() {
return metadata;
}
private QueryHints getQueryHints() {
return metadata == null ? QueryHints.NoHints.INSTANCE : DefaultQueryHints.of(entityInformation, metadata);
}
private QueryHints getQueryHintsForCount() {
return metadata == null ? QueryHints.NoHints.INSTANCE
: DefaultQueryHints.of(entityInformation, metadata).forCounts();
}
private AbstractJPAQuery<?, ?> doCreateQuery(QueryHints hints, @Nullable Predicate... predicate) {
AbstractJPAQuery<?, ?> query = querydsl.createQuery(path);
if (predicate != null) {
query = query.where(predicate);
}
for (Entry<String, Object> hint : hints) {
query.setHint(hint.getKey(), hint.getValue());
}
return query;
}
private List<T> executeSorted(JPQLQuery<T> query, OrderSpecifier<?>... orders) {
return executeSorted(query, new QSort(orders));
}
private List<T> executeSorted(JPQLQuery<T> query, Sort sort) {
return querydsl.applySorting(sort, query).fetch();
}
}