package org.springframework.data.repository.core.support;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.Value;
import lombok.extern.slf4j.Slf4j;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import org.aopalliance.intercept.MethodInterceptor;
import org.aopalliance.intercept.MethodInvocation;
import org.springframework.aop.framework.ProxyFactory;
import org.springframework.aop.interceptor.ExposeInvocationInterceptor;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanClassLoaderAware;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryAware;
import org.springframework.core.ResolvableType;
import org.springframework.core.convert.support.DefaultConversionService;
import org.springframework.core.convert.support.GenericConversionService;
import org.springframework.data.projection.DefaultMethodInvokingMethodInterceptor;
import org.springframework.data.projection.ProjectionFactory;
import org.springframework.data.projection.SpelAwareProxyProjectionFactory;
import org.springframework.data.repository.Repository;
import org.springframework.data.repository.core.EntityInformation;
import org.springframework.data.repository.core.NamedQueries;
import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.repository.core.RepositoryMetadata;
import org.springframework.data.repository.core.support.RepositoryComposition.RepositoryFragments;
import org.springframework.data.repository.query.QueryLookupStrategy;
import org.springframework.data.repository.query.QueryLookupStrategy.Key;
import org.springframework.data.repository.query.QueryMethod;
import org.springframework.data.repository.query.QueryMethodEvaluationContextProvider;
import org.springframework.data.repository.query.RepositoryQuery;
import org.springframework.data.repository.util.ClassUtils;
import org.springframework.data.repository.util.QueryExecutionConverters;
import org.springframework.data.repository.util.ReactiveWrapperConverters;
import org.springframework.data.repository.util.ReactiveWrappers;
import org.springframework.data.util.Pair;
import org.springframework.data.util.ReflectionUtils;
import org.springframework.lang.Nullable;
import org.springframework.transaction.interceptor.TransactionalProxy;
import org.springframework.util.Assert;
import org.springframework.util.ConcurrentReferenceHashMap;
import org.springframework.util.ConcurrentReferenceHashMap.ReferenceType;
@Slf4j
public abstract class RepositoryFactorySupport implements BeanClassLoaderAware, BeanFactoryAware {
private static final BiFunction<Method, Object[], Object[]> REACTIVE_ARGS_CONVERTER = (method, o) -> {
if (ReactiveWrappers.isAvailable()) {
Class<?>[] parameterTypes = method.getParameterTypes();
Object[] converted = new Object[o.length];
for (int i = 0; i < parameterTypes.length; i++) {
Class<?> parameterType = parameterTypes[i];
Object value = o[i];
if (value == null) {
continue;
}
if (!parameterType.isAssignableFrom(value.getClass())
&& ReactiveWrapperConverters.canConvert(value.getClass(), parameterType)) {
converted[i] = ReactiveWrapperConverters.toWrapper(value, parameterType);
} else {
converted[i] = value;
}
}
return converted;
}
return o;
};
final static GenericConversionService CONVERSION_SERVICE = new DefaultConversionService();
static {
QueryExecutionConverters.registerConvertersIn(CONVERSION_SERVICE);
CONVERSION_SERVICE.removeConvertible(Object.class, Object.class);
}
private final Map<RepositoryInformationCacheKey, RepositoryInformation> repositoryInformationCache;
private final List<RepositoryProxyPostProcessor> postProcessors;
private Optional<Class<?>> repositoryBaseClass;
private @Nullable QueryLookupStrategy.Key queryLookupStrategyKey;
private List<QueryCreationListener<?>> queryPostProcessors;
private NamedQueries namedQueries;
private ClassLoader classLoader;
private QueryMethodEvaluationContextProvider evaluationContextProvider;
private BeanFactory beanFactory;
private final QueryCollectingQueryCreationListener collectingListener = new QueryCollectingQueryCreationListener();
@SuppressWarnings("null")
public RepositoryFactorySupport() {
this.repositoryInformationCache = new ConcurrentReferenceHashMap<>(16, ReferenceType.WEAK);
this.postProcessors = new ArrayList<>();
this.repositoryBaseClass = Optional.empty();
this.namedQueries = PropertiesBasedNamedQueries.EMPTY;
this.classLoader = org.springframework.util.ClassUtils.getDefaultClassLoader();
this.evaluationContextProvider = QueryMethodEvaluationContextProvider.DEFAULT;
this.queryPostProcessors = new ArrayList<>();
this.queryPostProcessors.add(collectingListener);
}
public void setQueryLookupStrategyKey(Key key) {
this.queryLookupStrategyKey = key;
}
public void setNamedQueries(NamedQueries namedQueries) {
this.namedQueries = namedQueries == null ? PropertiesBasedNamedQueries.EMPTY : namedQueries;
}
@Override
public void setBeanClassLoader(ClassLoader classLoader) {
this.classLoader = classLoader == null ? org.springframework.util.ClassUtils.getDefaultClassLoader() : classLoader;
}
@Override
public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
this.beanFactory = beanFactory;
}
public void setEvaluationContextProvider(QueryMethodEvaluationContextProvider evaluationContextProvider) {
this.evaluationContextProvider = evaluationContextProvider == null ? QueryMethodEvaluationContextProvider.DEFAULT
: evaluationContextProvider;
}
public void setRepositoryBaseClass(Class<?> repositoryBaseClass) {
this.repositoryBaseClass = Optional.ofNullable(repositoryBaseClass);
}
public void addQueryCreationListener(QueryCreationListener<?> listener) {
Assert.notNull(listener, "Listener must not be null!");
this.queryPostProcessors.add(listener);
}
public void addRepositoryProxyPostProcessor(RepositoryProxyPostProcessor processor) {
Assert.notNull(processor, "RepositoryProxyPostProcessor must not be null!");
this.postProcessors.add(processor);
}
protected RepositoryFragments getRepositoryFragments(RepositoryMetadata metadata) {
return RepositoryFragments.empty();
}
private RepositoryComposition getRepositoryComposition(RepositoryMetadata metadata) {
RepositoryComposition composition = RepositoryComposition.empty();
if (metadata.isReactiveRepository()) {
return composition.withMethodLookup(MethodLookups.forReactiveTypes(metadata))
.withArgumentConverter(REACTIVE_ARGS_CONVERTER);
}
return composition.withMethodLookup(MethodLookups.forRepositoryTypes(metadata));
}
public <T> T getRepository(Class<T> repositoryInterface) {
return getRepository(repositoryInterface, RepositoryFragments.empty());
}
public <T> T getRepository(Class<T> repositoryInterface, Object customImplementation) {
return getRepository(repositoryInterface, RepositoryFragments.just(customImplementation));
}
@SuppressWarnings({ "unchecked" })
public <T> T getRepository(Class<T> repositoryInterface, RepositoryFragments fragments) {
if (LOG.isDebugEnabled()) {
LOG.debug("Initializing repository instance for {}…", repositoryInterface.getName());
}
Assert.notNull(repositoryInterface, "Repository interface must not be null!");
Assert.notNull(fragments, "RepositoryFragments must not be null!");
RepositoryMetadata metadata = getRepositoryMetadata(repositoryInterface);
RepositoryComposition composition = getRepositoryComposition(metadata, fragments);
RepositoryInformation information = getRepositoryInformation(metadata, composition);
validate(information, composition);
Object target = getTargetRepository(information);
ProxyFactory result = new ProxyFactory();
result.setTarget(target);
result.setInterfaces(repositoryInterface, Repository.class, TransactionalProxy.class);
if (MethodInvocationValidator.supports(repositoryInterface)) {
result.addAdvice(new MethodInvocationValidator());
}
result.addAdvisor(ExposeInvocationInterceptor.ADVISOR);
postProcessors.forEach(processor -> processor.postProcess(result, information));
if (DefaultMethodInvokingMethodInterceptor.hasDefaultMethods(repositoryInterface)) {
result.addAdvice(new DefaultMethodInvokingMethodInterceptor());
}
ProjectionFactory projectionFactory = getProjectionFactory(classLoader, beanFactory);
result.addAdvice(new QueryExecutorMethodInterceptor(information, projectionFactory));
composition = composition.append(RepositoryFragment.implemented(target));
result.addAdvice(new ImplementationMethodExecutionInterceptor(composition));
T repository = (T) result.getProxy(classLoader);
if (LOG.isDebugEnabled()) {
LOG.debug("Finished creation of repository instance for {}.", repositoryInterface.getName());
}
return repository;
}
protected ProjectionFactory getProjectionFactory(ClassLoader classLoader, BeanFactory beanFactory) {
SpelAwareProxyProjectionFactory factory = new SpelAwareProxyProjectionFactory();
factory.setBeanClassLoader(classLoader);
factory.setBeanFactory(beanFactory);
return factory;
}
protected RepositoryMetadata getRepositoryMetadata(Class<?> repositoryInterface) {
return AbstractRepositoryMetadata.getMetadata(repositoryInterface);
}
protected RepositoryInformation getRepositoryInformation(RepositoryMetadata metadata, RepositoryFragments fragments) {
return getRepositoryInformation(metadata, getRepositoryComposition(metadata, fragments));
}
private RepositoryComposition getRepositoryComposition(RepositoryMetadata metadata, RepositoryFragments fragments) {
Assert.notNull(metadata, "RepositoryMetadata must not be null!");
Assert.notNull(fragments, "RepositoryFragments must not be null!");
RepositoryComposition composition = getRepositoryComposition(metadata);
RepositoryFragments repositoryAspects = getRepositoryFragments(metadata);
return composition.append(fragments).append(repositoryAspects);
}
private RepositoryInformation getRepositoryInformation(RepositoryMetadata metadata,
RepositoryComposition composition) {
RepositoryInformationCacheKey cacheKey = new RepositoryInformationCacheKey(metadata, composition);
return repositoryInformationCache.computeIfAbsent(cacheKey, key -> {
Class<?> baseClass = repositoryBaseClass.orElse(getRepositoryBaseClass(metadata));
return new DefaultRepositoryInformation(metadata, baseClass, composition);
});
}
protected List<QueryMethod> getQueryMethods() {
return collectingListener.getQueryMethods();
}
public abstract <T, ID> EntityInformation<T, ID> getEntityInformation(Class<T> domainClass);
protected abstract Object getTargetRepository(RepositoryInformation metadata);
protected abstract Class<?> getRepositoryBaseClass(RepositoryMetadata metadata);
protected Optional<QueryLookupStrategy> getQueryLookupStrategy(@Nullable Key key,
QueryMethodEvaluationContextProvider evaluationContextProvider) {
return Optional.empty();
}
private void validate(RepositoryInformation repositoryInformation, RepositoryComposition composition) {
if (repositoryInformation.hasCustomMethod()) {
if (composition.isEmpty()) {
throw new IllegalArgumentException(
String.format("You have custom methods in %s but not provided a custom implementation!",
repositoryInformation.getRepositoryInterface()));
}
composition.validateImplementation();
}
validate(repositoryInformation);
}
protected void validate(RepositoryMetadata repositoryMetadata) {
}
protected final <R> R getTargetRepositoryViaReflection(RepositoryInformation information,
Object... constructorArguments) {
Class<?> baseClass = information.getRepositoryBaseClass();
return getTargetRepositoryViaReflection(baseClass, constructorArguments);
}
@SuppressWarnings("unchecked")
protected final <R> R getTargetRepositoryViaReflection(Class<?> baseClass, Object... constructorArguments) {
Optional<Constructor<?>> constructor = ReflectionUtils.findConstructor(baseClass, constructorArguments);
return constructor.map(it -> (R) BeanUtils.instantiateClass(it, constructorArguments))
.orElseThrow(() -> new IllegalStateException(String.format(
"No suitable constructor found on %s to match the given arguments: %s. Make sure you implement a constructor taking these",
baseClass, Arrays.stream(constructorArguments).map(Object::getClass).collect(Collectors.toList()))));
}
public class QueryExecutorMethodInterceptor implements MethodInterceptor {
private final Map<Method, RepositoryQuery> queries;
private final QueryExecutionResultHandler resultHandler;
public QueryExecutorMethodInterceptor(RepositoryInformation repositoryInformation,
ProjectionFactory projectionFactory) {
this.resultHandler = new QueryExecutionResultHandler(CONVERSION_SERVICE);
Optional<QueryLookupStrategy> lookupStrategy = getQueryLookupStrategy(queryLookupStrategyKey,
RepositoryFactorySupport.this.evaluationContextProvider);
if (!lookupStrategy.isPresent() && repositoryInformation.hasQueryMethods()) {
throw new IllegalStateException("You have defined query method in the repository but "
+ "you don't have any query lookup strategy defined. The "
+ "infrastructure apparently does not support query methods!");
}
this.queries = lookupStrategy
.map(it -> mapMethodsToQuery(repositoryInformation, it, projectionFactory))
.orElse(Collections.emptyMap());
}
private Map<Method, RepositoryQuery> mapMethodsToQuery(RepositoryInformation repositoryInformation,
QueryLookupStrategy lookupStrategy, ProjectionFactory projectionFactory) {
return repositoryInformation.getQueryMethods().stream()
.map(method -> lookupQuery(method, repositoryInformation, lookupStrategy, projectionFactory))
.peek(pair -> invokeListeners(pair.getSecond()))
.collect(Pair.toMap());
}
private Pair<Method, RepositoryQuery> lookupQuery(Method method, RepositoryInformation information,
QueryLookupStrategy strategy, ProjectionFactory projectionFactory) {
return Pair.of(method, strategy.resolveQuery(method, information, projectionFactory, namedQueries));
}
@SuppressWarnings({ "rawtypes", "unchecked" })
private void invokeListeners(RepositoryQuery query) {
for (QueryCreationListener listener : queryPostProcessors) {
ResolvableType typeArgument = ResolvableType.forClass(QueryCreationListener.class, listener.getClass())
.getGeneric(0);
if (typeArgument != null && typeArgument.isAssignableFrom(ResolvableType.forClass(query.getClass()))) {
listener.onCreation(query);
}
}
}
@Override
@Nullable
public Object invoke(@SuppressWarnings("null") MethodInvocation invocation) throws Throwable {
Method method = invocation.getMethod();
QueryExecutionConverters.ExecutionAdapter executionAdapter = QueryExecutionConverters
.getExecutionAdapter(method.getReturnType());
if (executionAdapter == null) {
return resultHandler.postProcessInvocationResult(doInvoke(invocation), method);
}
return executionAdapter
.apply(() -> resultHandler.postProcessInvocationResult(doInvoke(invocation), method));
}
@Nullable
private Object doInvoke(MethodInvocation invocation) throws Throwable {
Method method = invocation.getMethod();
if (hasQueryFor(method)) {
return queries.get(method).execute(invocation.getArguments());
}
return invocation.proceed();
}
private boolean hasQueryFor(Method method) {
return queries.containsKey(method);
}
}
@RequiredArgsConstructor
public class ImplementationMethodExecutionInterceptor implements MethodInterceptor {
private final @NonNull RepositoryComposition composition;
@Nullable
@Override
public Object invoke(@SuppressWarnings("null") MethodInvocation invocation) throws Throwable {
Method method = invocation.getMethod();
Object[] arguments = invocation.getArguments();
try {
return composition.invoke(method, arguments);
} catch (Exception e) {
ClassUtils.unwrapReflectionException(e);
}
throw new IllegalStateException("Should not occur!");
}
}
@Getter
private static class QueryCollectingQueryCreationListener implements QueryCreationListener<RepositoryQuery> {
private final List<QueryMethod> queryMethods = new ArrayList<>();
@Override
public void onCreation(RepositoryQuery query) {
this.queryMethods.add(query.getQueryMethod());
}
}
@EqualsAndHashCode
@Value
private static class RepositoryInformationCacheKey {
String repositoryInterfaceName;
final long compositionHash;
public RepositoryInformationCacheKey(RepositoryMetadata metadata, RepositoryComposition composition) {
this.repositoryInterfaceName = metadata.getRepositoryInterface().getName();
this.compositionHash = composition.hashCode();
}
}
}