package org.springframework.data.jpa.repository.query;
import static java.util.regex.Pattern.*;
import static javax.persistence.metamodel.Attribute.PersistentAttributeType.*;
import java.lang.annotation.Annotation;
import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Member;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import javax.persistence.EntityManager;
import javax.persistence.ManyToOne;
import javax.persistence.OneToOne;
import javax.persistence.Parameter;
import javax.persistence.Query;
import javax.persistence.criteria.CriteriaBuilder;
import javax.persistence.criteria.Expression;
import javax.persistence.criteria.Fetch;
import javax.persistence.criteria.From;
import javax.persistence.criteria.Join;
import javax.persistence.criteria.JoinType;
import javax.persistence.criteria.Path;
import javax.persistence.metamodel.Attribute;
import javax.persistence.metamodel.Attribute.PersistentAttributeType;
import javax.persistence.metamodel.Bindable;
import javax.persistence.metamodel.ManagedType;
import javax.persistence.metamodel.PluralAttribute;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.dao.InvalidDataAccessApiUsageException;
import org.springframework.data.domain.Sort;
import org.springframework.data.domain.Sort.Order;
import org.springframework.data.jpa.domain.JpaSort.JpaOrder;
import org.springframework.data.mapping.PropertyPath;
import org.springframework.data.util.Streamable;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
public abstract class QueryUtils {
public static final String COUNT_QUERY_STRING = "select count(%s) from %s x";
public static final String DELETE_ALL_QUERY_STRING = "delete from %s x";
private static final String IDENTIFIER = "[._$[\\P{Z}&&\\P{Cc}&&\\P{Cf}&&\\P{Punct}]]+";
static final String COLON_NO_DOUBLE_COLON = "(?<![:\\\\]):";
static final String IDENTIFIER_GROUP = String.format("(%s)", IDENTIFIER);
private static final String COUNT_REPLACEMENT_TEMPLATE = "select count(%s) $5$6$7";
private static final String SIMPLE_COUNT_VALUE = "$2";
private static final String COMPLEX_COUNT_VALUE = "$3 $6";
private static final String COMPLEX_COUNT_LAST_VALUE = "$6";
private static final String ORDER_BY_PART = "(?iu)\\s+order\\s+by\\s+.*";
private static final Pattern ALIAS_MATCH;
private static final Pattern COUNT_MATCH;
private static final Pattern PROJECTION_CLAUSE = Pattern.compile("select\\s+(?:distinct\\s+)?(.+)\\s+from", Pattern.CASE_INSENSITIVE);
private static final Pattern NO_DIGITS = Pattern.compile("\\D+");
private static final String JOIN = "join\\s+(fetch\\s+)?" + IDENTIFIER + "\\s+(as\\s+)?" + IDENTIFIER_GROUP;
private static final Pattern JOIN_PATTERN = Pattern.compile(JOIN, Pattern.CASE_INSENSITIVE);
private static final String EQUALS_CONDITION_STRING = "%s.%s = :%s";
private static final Pattern ORDER_BY = Pattern.compile(".*order\\s+by\\s+.*", CASE_INSENSITIVE);
private static final Pattern NAMED_PARAMETER = Pattern
.compile(COLON_NO_DOUBLE_COLON + IDENTIFIER + "|#" + IDENTIFIER, CASE_INSENSITIVE);
private static final Pattern CONSTRUCTOR_EXPRESSION;
private static final Map<PersistentAttributeType, Class<? extends Annotation>> ASSOCIATION_TYPES;
private static final int QUERY_JOIN_ALIAS_GROUP_INDEX = 3;
private static final int VARIABLE_NAME_GROUP_INDEX = 4;
private static final int COMPLEX_COUNT_FIRST_INDEX = 3;
private static final Pattern PUNCTATION_PATTERN = Pattern.compile(".*((?![._])[\\p{Punct}|\\s])");
private static final Pattern FUNCTION_PATTERN;
private static final Pattern FIELD_ALIAS_PATTERN;
private static final String UNSAFE_PROPERTY_REFERENCE = "Sort expression '%s' must only contain property references or "
+ "aliases used in the select clause. If you really want to use something other than that for sorting, please use "
+ "JpaSort.unsafe(…)!";
static {
StringBuilder builder = new StringBuilder();
builder.append("(?<=from)");
builder.append("(?:\\s)+");
builder.append(IDENTIFIER_GROUP);
builder.append("(?:\\sas)*");
builder.append("(?:\\s)+");
builder.append("(?!(?:where|group\\s*by|order\\s*by))(\\w+)");
ALIAS_MATCH = compile(builder.toString(), CASE_INSENSITIVE);
builder = new StringBuilder();
builder.append("(select\\s+((distinct)?((?s).+?)?)\\s+)?(from\\s+");
builder.append(IDENTIFIER);
builder.append("(?:\\s+as)?\\s+)");
builder.append(IDENTIFIER_GROUP);
builder.append("(.*)");
COUNT_MATCH = compile(builder.toString(), CASE_INSENSITIVE);
Map<PersistentAttributeType, Class<? extends Annotation>> persistentAttributeTypes = new HashMap<>();
persistentAttributeTypes.put(ONE_TO_ONE, OneToOne.class);
persistentAttributeTypes.put(ONE_TO_MANY, null);
persistentAttributeTypes.put(MANY_TO_ONE, ManyToOne.class);
persistentAttributeTypes.put(MANY_TO_MANY, null);
persistentAttributeTypes.put(ELEMENT_COLLECTION, null);
ASSOCIATION_TYPES = Collections.unmodifiableMap(persistentAttributeTypes);
builder = new StringBuilder();
builder.append("select");
builder.append("\\s+");
builder.append("(.*\\s+)?");
builder.append("new");
builder.append("\\s+");
builder.append(IDENTIFIER);
builder.append("\\s*");
builder.append("\\(");
builder.append(".*");
builder.append("\\)");
CONSTRUCTOR_EXPRESSION = compile(builder.toString(), CASE_INSENSITIVE + DOTALL);
builder = new StringBuilder();
builder.append("\\w+\\s*\\([\\w\\.,\\s'=]+\\)");
builder.append("\\s+[as|AS]+\\s+(([\\w\\.]+))");
FUNCTION_PATTERN = compile(builder.toString());
builder = new StringBuilder();
builder.append("\\s+");
builder.append("[^\\s\\(\\)]+");
builder.append("\\s+[as|AS]+\\s+(([\\w\\.]+))");
FIELD_ALIAS_PATTERN = compile(builder.toString());
}
private QueryUtils() {
}
public static String getExistsQueryString(String entityName, String countQueryPlaceHolder,
Iterable<String> idAttributes) {
String whereClause = Streamable.of(idAttributes).stream()
.map(idAttribute -> String.format(EQUALS_CONDITION_STRING, "x", idAttribute, idAttribute))
.collect(Collectors.joining(" AND ", " WHERE ", ""));
return String.format(COUNT_QUERY_STRING, countQueryPlaceHolder, entityName) + whereClause;
}
public static String getQueryString(String template, String entityName) {
Assert.hasText(entityName, "Entity name must not be null or empty!");
return String.format(template, entityName);
}
public static String applySorting(String query, Sort sort) {
return applySorting(query, sort, detectAlias(query));
}
public static String applySorting(String query, Sort sort, @Nullable String alias) {
Assert.hasText(query, "Query must not be null or empty!");
if (sort.isUnsorted()) {
return query;
}
StringBuilder builder = new StringBuilder(query);
if (!ORDER_BY.matcher(query).matches()) {
builder.append(" order by ");
} else {
builder.append(", ");
}
Set<String> joinAliases = getOuterJoinAliases(query);
Set<String> selectionAliases = getFunctionAliases(query);
selectionAliases.addAll(getFieldAliases(query));
for (Order order : sort) {
builder.append(getOrderClause(joinAliases, selectionAliases, alias, order)).append(", ");
}
builder.delete(builder.length() - 2, builder.length());
return builder.toString();
}
private static String getOrderClause(Set<String> joinAliases, Set<String> selectionAlias, @Nullable String alias,
Order order) {
String property = order.getProperty();
checkSortExpression(order);
if (selectionAlias.contains(property)) {
return String.format("%s %s", property, toJpaDirection(order));
}
boolean qualifyReference = !property.contains("(");
for (String joinAlias : joinAliases) {
if (property.startsWith(joinAlias.concat("."))) {
qualifyReference = false;
break;
}
}
String reference = qualifyReference && StringUtils.hasText(alias) ? String.format("%s.%s", alias, property)
: property;
String wrapped = order.isIgnoreCase() ? String.format("lower(%s)", reference) : reference;
return String.format("%s %s", wrapped, toJpaDirection(order));
}
static Set<String> getOuterJoinAliases(String query) {
Set<String> result = new HashSet<>();
Matcher matcher = JOIN_PATTERN.matcher(query);
while (matcher.find()) {
String alias = matcher.group(QUERY_JOIN_ALIAS_GROUP_INDEX);
if (StringUtils.hasText(alias)) {
result.add(alias);
}
}
return result;
}
private static Set<String> getFieldAliases(String query) {
Set<String> result = new HashSet<>();
Matcher matcher = FIELD_ALIAS_PATTERN.matcher(query);
while (matcher.find()) {
String alias = matcher.group(1);
if (StringUtils.hasText(alias)) {
result.add(alias);
}
}
return result;
}
static Set<String> getFunctionAliases(String query) {
Set<String> result = new HashSet<>();
Matcher matcher = FUNCTION_PATTERN.matcher(query);
while (matcher.find()) {
String alias = matcher.group(1);
if (StringUtils.hasText(alias)) {
result.add(alias);
}
}
return result;
}
private static String toJpaDirection(Order order) {
return order.getDirection().name().toLowerCase(Locale.US);
}
@Nullable
@Deprecated
public static String detectAlias(String query) {
Matcher matcher = ALIAS_MATCH.matcher(query);
return matcher.find() ? matcher.group(2) : null;
}
public static <T> Query applyAndBind(String queryString, Iterable<T> entities, EntityManager entityManager) {
Assert.notNull(queryString, "Querystring must not be null!");
Assert.notNull(entities, "Iterable of entities must not be null!");
Assert.notNull(entityManager, "EntityManager must not be null!");
Iterator<T> iterator = entities.iterator();
if (!iterator.hasNext()) {
return entityManager.createQuery(queryString);
}
String alias = detectAlias(queryString);
StringBuilder builder = new StringBuilder(queryString);
builder.append(" where");
int i = 0;
while (iterator.hasNext()) {
iterator.next();
builder.append(String.format(" %s = ?%d", alias, ++i));
if (iterator.hasNext()) {
builder.append(" or");
}
}
Query query = entityManager.createQuery(builder.toString());
iterator = entities.iterator();
i = 0;
while (iterator.hasNext()) {
query.setParameter(++i, iterator.next());
}
return query;
}
@Deprecated
public static String createCountQueryFor(String originalQuery) {
return createCountQueryFor(originalQuery, null);
}
@Deprecated
public static String createCountQueryFor(String originalQuery, @Nullable String countProjection) {
Assert.hasText(originalQuery, "OriginalQuery must not be null or empty!");
Matcher matcher = COUNT_MATCH.matcher(originalQuery);
String countQuery;
if (countProjection == null) {
String variable = matcher.matches() ? matcher.group(VARIABLE_NAME_GROUP_INDEX) : null;
boolean useVariable = StringUtils.hasText(variable)
&& !variable.startsWith(" new")
&& !variable.startsWith("count(")
&& !variable.contains(",");
String complexCountValue = matcher.matches() &&
StringUtils.hasText(matcher.group(COMPLEX_COUNT_FIRST_INDEX)) ?
COMPLEX_COUNT_VALUE : COMPLEX_COUNT_LAST_VALUE;
String replacement = useVariable ? SIMPLE_COUNT_VALUE : complexCountValue;
countQuery = matcher.replaceFirst(String.format(COUNT_REPLACEMENT_TEMPLATE, replacement));
} else {
countQuery = matcher.replaceFirst(String.format(COUNT_REPLACEMENT_TEMPLATE, countProjection));
}
return countQuery.replaceFirst(ORDER_BY_PART, "");
}
public static boolean hasNamedParameter(Query query) {
Assert.notNull(query, "Query must not be null!");
for (Parameter<?> parameter : query.getParameters()) {
String name = parameter.getName();
if (name != null && NO_DIGITS.matcher(name).find()) {
return true;
}
}
return false;
}
@Deprecated
static boolean hasNamedParameter(@Nullable String query) {
return StringUtils.hasText(query) && NAMED_PARAMETER.matcher(query).find();
}
public static List<javax.persistence.criteria.Order> toOrders(Sort sort, From<?, ?> from, CriteriaBuilder cb) {
if (sort.isUnsorted()) {
return Collections.emptyList();
}
Assert.notNull(from, "From must not be null!");
Assert.notNull(cb, "CriteriaBuilder must not be null!");
List<javax.persistence.criteria.Order> orders = new ArrayList<>();
for (org.springframework.data.domain.Sort.Order order : sort) {
orders.add(toJpaOrder(order, from, cb));
}
return orders;
}
public static boolean hasConstructorExpression(String query) {
Assert.hasText(query, "Query must not be null or empty!");
return CONSTRUCTOR_EXPRESSION.matcher(query).find();
}
public static String getProjection(String query) {
Assert.hasText(query, "Query must not be null or empty!");
Matcher matcher = PROJECTION_CLAUSE.matcher(query);
String projection = matcher.find() ? matcher.group(1) : "";
return projection.trim();
}
@SuppressWarnings("unchecked")
private static javax.persistence.criteria.Order toJpaOrder(Order order, From<?, ?> from, CriteriaBuilder cb) {
PropertyPath property = PropertyPath.from(order.getProperty(), from.getJavaType());
Expression<?> expression = toExpressionRecursively(from, property);
if (order.isIgnoreCase() && String.class.equals(expression.getJavaType())) {
Expression<String> lower = cb.lower((Expression<String>) expression);
return order.isAscending() ? cb.asc(lower) : cb.desc(lower);
} else {
return order.isAscending() ? cb.asc(expression) : cb.desc(expression);
}
}
static <T> Expression<T> toExpressionRecursively(From<?, ?> from, PropertyPath property) {
return toExpressionRecursively(from, property, false);
}
@SuppressWarnings("unchecked")
static <T> Expression<T> toExpressionRecursively(From<?, ?> from, PropertyPath property, boolean isForSelection) {
Bindable<?> propertyPathModel;
Bindable<?> model = from.getModel();
String segment = property.getSegment();
if (model instanceof ManagedType) {
propertyPathModel = (Bindable<?>) ((ManagedType<?>) model).getAttribute(segment);
} else {
propertyPathModel = from.get(segment).getModel();
}
if (requiresOuterJoin(propertyPathModel, model instanceof PluralAttribute, !property.hasNext(), isForSelection)
&& !isAlreadyFetched(from, segment)) {
Join<?, ?> join = getOrCreateJoin(from, segment);
return (Expression<T>) (property.hasNext() ? toExpressionRecursively(join, property.next(), isForSelection)
: join);
} else {
Path<Object> path = from.get(segment);
return (Expression<T>) (property.hasNext() ? toExpressionRecursively(path, property.next()) : path);
}
}
private static boolean requiresOuterJoin(@Nullable Bindable<?> propertyPathModel, boolean isPluralAttribute,
boolean isLeafProperty, boolean isForSelection) {
if (propertyPathModel == null && isPluralAttribute) {
return true;
}
if (!(propertyPathModel instanceof Attribute)) {
return false;
}
Attribute<?, ?> attribute = (Attribute<?, ?>) propertyPathModel;
if (!ASSOCIATION_TYPES.containsKey(attribute.getPersistentAttributeType())) {
return false;
}
boolean isInverseOptionalOneToOne = PersistentAttributeType.ONE_TO_ONE == attribute.getPersistentAttributeType()
&& StringUtils.hasText(getAnnotationProperty(attribute, "mappedBy", ""));
if (isLeafProperty && !isForSelection && !attribute.isCollection() && !isInverseOptionalOneToOne) {
return false;
}
return getAnnotationProperty(attribute, "optional", true);
}
private static <T> T getAnnotationProperty(Attribute<?, ?> attribute, String propertyName, T defaultValue) {
Class<? extends Annotation> associationAnnotation = ASSOCIATION_TYPES.get(attribute.getPersistentAttributeType());
if (associationAnnotation == null) {
return defaultValue;
}
Member member = attribute.getJavaMember();
if (!(member instanceof AnnotatedElement)) {
return defaultValue;
}
Annotation annotation = AnnotationUtils.getAnnotation((AnnotatedElement) member, associationAnnotation);
return annotation == null ? defaultValue : (T) AnnotationUtils.getValue(annotation, propertyName);
}
static Expression<Object> toExpressionRecursively(Path<Object> path, PropertyPath property) {
Path<Object> result = path.get(property.getSegment());
return property.hasNext() ? toExpressionRecursively(result, property.next()) : result;
}
private static Join<?, ?> getOrCreateJoin(From<?, ?> from, String attribute) {
for (Join<?, ?> join : from.getJoins()) {
boolean sameName = join.getAttribute().getName().equals(attribute);
if (sameName && join.getJoinType().equals(JoinType.LEFT)) {
return join;
}
}
return from.join(attribute, JoinType.LEFT);
}
private static boolean isAlreadyFetched(From<?, ?> from, String attribute) {
for (Fetch<?, ?> fetch : from.getFetches()) {
boolean sameName = fetch.getAttribute().getName().equals(attribute);
if (sameName && fetch.getJoinType().equals(JoinType.LEFT)) {
return true;
}
}
return false;
}
private static void checkSortExpression(Order order) {
if (order instanceof JpaOrder && ((JpaOrder) order).isUnsafe()) {
return;
}
if (PUNCTATION_PATTERN.matcher(order.getProperty()).find()) {
throw new InvalidDataAccessApiUsageException(String.format(UNSAFE_PROPERTY_REFERENCE, order));
}
}
}