package org.jdbi.v3.core.statement;
import java.lang.reflect.Type;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Optional;
import java.util.function.Function;
import java.util.function.Supplier;
import org.jdbi.v3.core.Handle;
import org.jdbi.v3.core.argument.Argument;
import org.jdbi.v3.core.argument.Arguments;
import org.jdbi.v3.core.argument.NamedArgumentFinder;
import org.jdbi.v3.core.argument.internal.NamedArgumentFinderFactory;
import org.jdbi.v3.core.argument.internal.NamedArgumentFinderFactory.PrepareKey;
import org.jdbi.v3.core.qualifier.QualifiedType;
import org.jdbi.v3.core.result.ResultBearing;
import org.jdbi.v3.core.result.ResultIterator;
import org.jdbi.v3.core.result.ResultProducer;
import org.jdbi.v3.core.result.ResultProducers;
import org.jdbi.v3.core.result.ResultSetScanner;
import org.jdbi.v3.core.result.UnableToProduceResultException;
import org.jdbi.v3.core.statement.internal.PreparedBinding;
import static org.jdbi.v3.core.result.ResultProducers.returningGeneratedKeys;
public class PreparedBatch extends SqlStatement<PreparedBatch> implements ResultBearing {
private final List<PreparedBinding> bindings = new ArrayList<>();
final Map<PrepareKey, Function<String, Optional<Function<Object, Argument>>>> preparedFinders = new HashMap<>();
public PreparedBatch(Handle handle, String sql) {
super(handle, sql);
getContext().setBinding(new PreparedBinding(getContext()));
}
@Override
PreparedBatch bindNamedArgumentFinder(NamedArgumentFinderFactory<?> factory, String prefix, Object value, Type type, Supplier<NamedArgumentFinder> backupArgumentFinder) {
PreparedBinding binding = getBinding();
PrepareKey key = factory.keyFor(prefix, value);
preparedFinders.computeIfAbsent(key,
pk -> factory.prepareFor(getConfig(), this::buildArgument, prefix, value, type));
binding.prepareKeys.put(key, value);
binding.backupArgumentFinders.add(backupArgumentFinder);
return this;
}
@Override
protected PreparedBinding getBinding() {
return (PreparedBinding) super.getBinding();
}
Function<Object, Argument> buildArgument(QualifiedType<?> type) {
return getContext().getConfig(Arguments.class)
.prepareFor(type)
.orElse(value ->
(pos, st, ctx) ->
ctx.getConfig(Arguments.class)
.findFor(type, value)
.orElseThrow(() -> new UnableToCreateStatementException("no argument factory for type " + type, ctx))
.apply(pos, st, ctx));
}
@Override
public <R> R scanResultSet(ResultSetScanner<R> mapper) {
return execute(ResultProducers.returningResults()).scanResultSet(mapper);
}
public int[] execute() {
try {
return internalBatchExecute().updateCounts;
} finally {
getContext().close();
}
}
public ResultIterator<Integer> executeAndGetModCount() {
StatementContext ctx = getContext();
final int[] modCount = execute();
return new ResultIterator<Integer>() {
int pos = 0;
@Override
public boolean hasNext() {
return pos < modCount.length;
}
@Override
public Integer next() {
if (!hasNext()) {
throw new NoSuchElementException();
}
return modCount[pos++];
}
@Override
public StatementContext getContext() {
return ctx;
}
@Override
public void close() {
ctx.close();
}
};
}
public ResultBearing executeAndReturnGeneratedKeys(String... columnNames) {
return execute(returningGeneratedKeys(columnNames));
}
public <R> R execute(ResultProducer<R> producer) {
try {
return producer.produce(() -> internalBatchExecute().stmt, getContext());
} catch (SQLException e) {
try {
close();
} catch (Exception e1) {
e.addSuppressed(e1);
}
throw new UnableToProduceResultException("Exception producing batch result", e, getContext());
}
}
private ExecutedBatch internalBatchExecute() {
if (!getBinding().isEmpty()) {
add();
}
beforeTemplating();
final StatementContext ctx = getContext();
ParsedSql parsedSql = parseSql();
String sql = parsedSql.getSql();
ParsedParameters parsedParameters = parsedSql.getParameters();
try {
try {
StatementBuilder statementBuilder = getHandle().getStatementBuilder();
@SuppressWarnings("PMD.CloseResource")
Connection connection = getHandle().getConnection();
stmt = statementBuilder.create(connection, sql, ctx);
addCleanable(() -> statementBuilder.close(connection, sql, stmt));
getConfig(SqlStatements.class).customize(stmt);
} catch (SQLException e) {
throw new UnableToCreateStatementException(e, ctx);
}
if (bindings.isEmpty()) {
return new ExecutedBatch(stmt, new int[0]);
}
beforeBinding();
try {
ArgumentBinder<?> binder = new ArgumentBinder.Prepared(this, parsedParameters, bindings.get(0));
for (Binding binding : bindings) {
ctx.setBinding(binding);
binder.bind(binding);
stmt.addBatch();
}
} catch (SQLException e) {
throw new UnableToExecuteStatementException("Exception while binding parameters", e, ctx);
}
beforeExecution();
try {
final int[] rs = SqlLoggerUtil.wrap(stmt::executeBatch, ctx, getConfig(SqlStatements.class).getSqlLogger());
afterExecution();
ctx.setBinding(new PreparedBinding(ctx));
return new ExecutedBatch(stmt, rs);
} catch (SQLException e) {
throw new UnableToExecuteStatementException(Batch.mungeBatchException(e), ctx);
}
} finally {
bindings.clear();
}
}
public PreparedBatch add() {
final PreparedBinding currentBinding = getBinding();
if (currentBinding.isEmpty()) {
throw new IllegalStateException("Attempt to add() an empty batch, you probably didn't mean to do this "
+ "- call add() *after* setting batch parameters");
}
bindings.add(currentBinding);
getContext().setBinding(new PreparedBinding(getContext()));
return this;
}
public PreparedBatch add(Object... args) {
for (int i = 0; i < args.length; i++) {
bind(i, args[i]);
}
add();
return this;
}
public PreparedBatch add(Map<String, ?> args) {
bindMap(args);
add();
return this;
}
public int size() {
return bindings.size();
}
private static class ExecutedBatch {
final PreparedStatement stmt;
final int[] updateCounts;
ExecutedBatch(PreparedStatement stmt, int[] updateCounts) {
this.stmt = stmt;
this.updateCounts = Arrays.copyOf(updateCounts, updateCounts.length);
}
}
}