package org.springframework.data.jpa.support;
import java.io.IOException;
import java.net.URI;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import javax.persistence.Entity;
import javax.persistence.MappedSuperclass;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.context.EnvironmentAware;
import org.springframework.context.ResourceLoaderAware;
import org.springframework.context.annotation.ClassPathScanningCandidateComponentProvider;
import org.springframework.core.env.Environment;
import org.springframework.core.env.StandardEnvironment;
import org.springframework.core.io.DefaultResourceLoader;
import org.springframework.core.io.Resource;
import org.springframework.core.io.ResourceLoader;
import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
import org.springframework.core.io.support.ResourcePatternResolver;
import org.springframework.core.io.support.ResourcePatternUtils;
import org.springframework.core.type.filter.AnnotationTypeFilter;
import org.springframework.lang.Nullable;
import org.springframework.orm.jpa.persistenceunit.MutablePersistenceUnitInfo;
import org.springframework.orm.jpa.persistenceunit.PersistenceUnitPostProcessor;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
public class ClasspathScanningPersistenceUnitPostProcessor
implements PersistenceUnitPostProcessor, ResourceLoaderAware, EnvironmentAware {
private static final Logger LOG = LoggerFactory.getLogger(ClasspathScanningPersistenceUnitPostProcessor.class);
private final String basePackage;
private ResourcePatternResolver mappingFileResolver = new PathMatchingResourcePatternResolver();
private Environment environment = new StandardEnvironment();
private ResourceLoader resourceLoader = new DefaultResourceLoader();
private @Nullable String mappingFileNamePattern;
public ClasspathScanningPersistenceUnitPostProcessor(String basePackage) {
Assert.hasText(basePackage, "Base package must not be null or empty!");
this.basePackage = basePackage;
}
public void setMappingFileNamePattern(String mappingFilePattern) {
Assert.hasText(mappingFilePattern, "Mapping file pattern must not be null or empty!");
this.mappingFileNamePattern = mappingFilePattern;
}
@Override
public void setResourceLoader(ResourceLoader resourceLoader) {
Assert.notNull(resourceLoader, "ResourceLoader must not be null!");
this.mappingFileResolver = ResourcePatternUtils.getResourcePatternResolver(resourceLoader);
this.resourceLoader = resourceLoader;
}
@Override
public void setEnvironment(Environment environment) {
Assert.notNull(environment, "Environment must not be null!");
this.environment = environment;
}
@Override
public void postProcessPersistenceUnitInfo(MutablePersistenceUnitInfo pui) {
ClassPathScanningCandidateComponentProvider provider = new ClassPathScanningCandidateComponentProvider(false);
provider.setEnvironment(environment);
provider.setResourceLoader(resourceLoader);
provider.addIncludeFilter(new AnnotationTypeFilter(Entity.class));
provider.addIncludeFilter(new AnnotationTypeFilter(MappedSuperclass.class));
for (BeanDefinition definition : provider.findCandidateComponents(basePackage)) {
LOG.debug("Registering classpath-scanned entity {} in persistence unit info!", definition.getBeanClassName());
if (definition.getBeanClassName() != null) {
pui.addManagedClassName(definition.getBeanClassName());
}
}
for (String location : scanForMappingFileLocations()) {
LOG.debug("Registering classpath-scanned entity mapping file {} in persistence unit info!", location);
pui.addMappingFileName(location);
}
}
private Set<String> scanForMappingFileLocations() {
if (!StringUtils.hasText(mappingFileNamePattern)) {
return Collections.emptySet();
}
char slash = '/';
String basePackagePathComponent = basePackage.replace('.', slash);
String path = ResourcePatternResolver.CLASSPATH_ALL_URL_PREFIX + basePackagePathComponent + slash
+ mappingFileNamePattern;
Resource[] scannedResources;
try {
scannedResources = mappingFileResolver.getResources(path);
} catch (IOException e) {
throw new IllegalStateException(String.format("Cannot load mapping files from path %s!", path), e);
}
Set<String> mappingFileUris = new HashSet<>();
for (Resource resource : scannedResources) {
try {
String resourcePath = getResourcePath(resource.getURI());
String resourcePathInClasspath = resourcePath.substring(resourcePath.indexOf(basePackagePathComponent));
mappingFileUris.add(resourcePathInClasspath);
} catch (IOException e) {
throw new IllegalStateException(String.format("Couldn't get URI for %s!", resource.toString()), e);
}
}
return mappingFileUris;
}
private static String getResourcePath(URI uri) throws IOException {
if (uri.isOpaque()) {
String rawPath = uri.toString();
if (rawPath != null) {
int exclamationMarkIndex = rawPath.lastIndexOf('!');
if (exclamationMarkIndex > -1) {
return rawPath.substring(exclamationMarkIndex + 1);
}
}
}
return uri.getPath();
}
}