/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.cassandra.cql3.functions;

import java.nio.ByteBuffer;
import java.util.*;

import com.google.common.base.Objects;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.datastax.driver.core.TypeCodec;
import org.apache.cassandra.db.marshal.AbstractType;
import org.apache.cassandra.exceptions.InvalidRequestException;
import org.apache.cassandra.schema.Functions;
import org.apache.cassandra.tracing.Tracing;
import org.apache.cassandra.transport.ProtocolVersion;

Base class for user-defined-aggregates.
/** * Base class for user-defined-aggregates. */
public class UDAggregate extends AbstractFunction implements AggregateFunction { protected static final Logger logger = LoggerFactory.getLogger(UDAggregate.class); private final AbstractType<?> stateType; private final TypeCodec stateTypeCodec; private final TypeCodec returnTypeCodec; protected final ByteBuffer initcond; private final ScalarFunction stateFunction; private final ScalarFunction finalFunction; public UDAggregate(FunctionName name, List<AbstractType<?>> argTypes, AbstractType<?> returnType, ScalarFunction stateFunc, ScalarFunction finalFunc, ByteBuffer initcond) { super(name, argTypes, returnType); this.stateFunction = stateFunc; this.finalFunction = finalFunc; this.stateType = stateFunc != null ? stateFunc.returnType() : null; this.stateTypeCodec = stateType != null ? UDHelper.codecFor(UDHelper.driverType(stateType)) : null; this.returnTypeCodec = returnType != null ? UDHelper.codecFor(UDHelper.driverType(returnType)) : null; this.initcond = initcond; } public static UDAggregate create(Functions functions, FunctionName name, List<AbstractType<?>> argTypes, AbstractType<?> returnType, FunctionName stateFunc, FunctionName finalFunc, AbstractType<?> stateType, ByteBuffer initcond) throws InvalidRequestException { List<AbstractType<?>> stateTypes = new ArrayList<>(argTypes.size() + 1); stateTypes.add(stateType); stateTypes.addAll(argTypes); List<AbstractType<?>> finalTypes = Collections.singletonList(stateType); return new UDAggregate(name, argTypes, returnType, resolveScalar(functions, name, stateFunc, stateTypes), finalFunc != null ? resolveScalar(functions, name, finalFunc, finalTypes) : null, initcond); } public static UDAggregate createBroken(FunctionName name, List<AbstractType<?>> argTypes, AbstractType<?> returnType, ByteBuffer initcond, InvalidRequestException reason) { return new UDAggregate(name, argTypes, returnType, null, null, initcond) { public Aggregate newAggregate() throws InvalidRequestException { throw new InvalidRequestException(String.format("Aggregate '%s' exists but hasn't been loaded successfully for the following reason: %s. " + "Please see the server log for more details", this, reason.getMessage())); } }; } public boolean hasReferenceTo(Function function) { return stateFunction == function || finalFunction == function; } @Override public void addFunctionsTo(List<Function> functions) { functions.add(this); if (stateFunction != null) { stateFunction.addFunctionsTo(functions); if (finalFunction != null) finalFunction.addFunctionsTo(functions); } } public boolean isAggregate() { return true; } public boolean isNative() { return false; } public ScalarFunction stateFunction() { return stateFunction; } public ScalarFunction finalFunction() { return finalFunction; } public ByteBuffer initialCondition() { return initcond; } public AbstractType<?> stateType() { return stateType; } public Aggregate newAggregate() throws InvalidRequestException { return new Aggregate() { private long stateFunctionCount; private long stateFunctionDuration; private Object state; private boolean needsInit = true; public void addInput(ProtocolVersion protocolVersion, List<ByteBuffer> values) throws InvalidRequestException { maybeInit(protocolVersion); long startTime = System.nanoTime(); stateFunctionCount++; if (stateFunction instanceof UDFunction) { UDFunction udf = (UDFunction)stateFunction; if (udf.isCallableWrtNullable(values)) state = udf.executeForAggregate(protocolVersion, state, values); } else { throw new UnsupportedOperationException("UDAs only support UDFs"); } stateFunctionDuration += (System.nanoTime() - startTime) / 1000; } private void maybeInit(ProtocolVersion protocolVersion) { if (needsInit) { state = initcond != null ? UDHelper.deserialize(stateTypeCodec, protocolVersion, initcond.duplicate()) : null; stateFunctionDuration = 0; stateFunctionCount = 0; needsInit = false; } } public ByteBuffer compute(ProtocolVersion protocolVersion) throws InvalidRequestException { maybeInit(protocolVersion); // final function is traced in UDFunction Tracing.trace("Executed UDA {}: {} call(s) to state function {} in {}\u03bcs", name(), stateFunctionCount, stateFunction.name(), stateFunctionDuration); if (finalFunction == null) return UDFunction.decompose(stateTypeCodec, protocolVersion, state); if (finalFunction instanceof UDFunction) { UDFunction udf = (UDFunction)finalFunction; Object result = udf.executeForAggregate(protocolVersion, state, Collections.emptyList()); return UDFunction.decompose(returnTypeCodec, protocolVersion, result); } throw new UnsupportedOperationException("UDAs only support UDFs"); } public void reset() { needsInit = true; } }; } private static ScalarFunction resolveScalar(Functions functions, FunctionName aName, FunctionName fName, List<AbstractType<?>> argTypes) throws InvalidRequestException { Optional<Function> fun = functions.find(fName, argTypes); if (!fun.isPresent()) throw new InvalidRequestException(String.format("Referenced state function '%s %s' for aggregate '%s' does not exist", fName, Arrays.toString(UDHelper.driverTypes(argTypes)), aName)); if (!(fun.get() instanceof ScalarFunction)) throw new InvalidRequestException(String.format("Referenced state function '%s %s' for aggregate '%s' is not a scalar function", fName, Arrays.toString(UDHelper.driverTypes(argTypes)), aName)); return (ScalarFunction) fun.get(); } @Override public boolean equals(Object o) { if (!(o instanceof UDAggregate)) return false; UDAggregate that = (UDAggregate) o; return Objects.equal(name, that.name) && Functions.typesMatch(argTypes, that.argTypes) && Functions.typesMatch(returnType, that.returnType) && Objects.equal(stateFunction, that.stateFunction) && Objects.equal(finalFunction, that.finalFunction) && ((stateType == that.stateType) || ((stateType != null) && stateType.equals(that.stateType, true))) // ignore freezing && Objects.equal(initcond, that.initcond); } @Override public int hashCode() { return Objects.hashCode(name, Functions.typeHashCode(argTypes), Functions.typeHashCode(returnType), stateFunction, finalFunction, stateType, initcond); } }