/*
 * Copyright (C) 2017-2017 DataStax Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.datastax.oss.protocol.internal;

import com.datastax.oss.protocol.internal.util.Flags;
import com.datastax.oss.protocol.internal.util.IntIntMap;
import java.nio.ByteBuffer;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.UUID;

public class FrameCodec<B> {

  
Builds a new instance with the default codecs for a client (encoding requests, decoding responses).
/** * Builds a new instance with the default codecs for a client (encoding requests, decoding * responses). */
public static <B> FrameCodec<B> defaultClient( PrimitiveCodec<B> primitiveCodec, Compressor<B> compressor) { return new FrameCodec<>( primitiveCodec, compressor, new ProtocolV3ClientCodecs(), new ProtocolV4ClientCodecs(), new ProtocolV5ClientCodecs()); }
Builds a new instance with the default codecs for a server (decoding requests, encoding responses).
/** * Builds a new instance with the default codecs for a server (decoding requests, encoding * responses). */
public static <B> FrameCodec<B> defaultServer( PrimitiveCodec<B> primitiveCodec, Compressor<B> compressor) { return new FrameCodec<>( primitiveCodec, compressor, new ProtocolV3ServerCodecs(), new ProtocolV4ServerCodecs(), new ProtocolV5ServerCodecs()); } private final PrimitiveCodec<B> primitiveCodec; private final Compressor<B> compressor; private final IntIntMap<Message.Codec> encoders; private final IntIntMap<Message.Codec> decoders; public FrameCodec( PrimitiveCodec<B> primitiveCodec, Compressor<B> compressor, CodecGroup... codecGroups) { ProtocolErrors.check(primitiveCodec != null, "primitiveCodec can't be null"); ProtocolErrors.check(compressor != null, "compressor can't be null, use Compressor.none()"); this.primitiveCodec = primitiveCodec; this.compressor = compressor; IntIntMap.Builder<Message.Codec> encodersBuilder = IntIntMap.builder(); IntIntMap.Builder<Message.Codec> decodersBuilder = IntIntMap.builder(); CodecGroup.Registry registry = new CodecGroup.Registry() { @Override public CodecGroup.Registry addCodec(Message.Codec codec) { addEncoder(codec); addDecoder(codec); return this; } @Override public CodecGroup.Registry addEncoder(Message.Codec codec) { encodersBuilder.put(codec.protocolVersion, codec.opcode, codec); return this; } @Override public CodecGroup.Registry addDecoder(Message.Codec codec) { decodersBuilder.put(codec.protocolVersion, codec.opcode, codec); return this; } }; for (CodecGroup codecGroup : codecGroups) { codecGroup.registerCodecs(registry); } this.encoders = encodersBuilder.build(); this.decoders = decodersBuilder.build(); } public B encode(Frame frame) { int protocolVersion = frame.protocolVersion; Message request = frame.message; ProtocolErrors.check( protocolVersion >= ProtocolConstants.Version.V4 || frame.customPayload.isEmpty(), "Custom payload is not supported in protocol v%d", protocolVersion); ProtocolErrors.check( protocolVersion >= ProtocolConstants.Version.V4 || frame.warnings.isEmpty(), "Warnings are not supported in protocol v%d", protocolVersion); int opcode = request.opcode; Message.Codec encoder = encoders.get(protocolVersion, opcode); ProtocolErrors.check( encoder != null, "Unsupported opcode %s in protocol v%d", opcode, protocolVersion); int flags = 0; if (!(compressor instanceof NoopCompressor) && opcode != ProtocolConstants.Opcode.STARTUP) { flags = Flags.add(flags, ProtocolConstants.FrameFlag.COMPRESSED); } if (frame.tracing || frame.tracingId != null) { flags = Flags.add(flags, ProtocolConstants.FrameFlag.TRACING); } if (!frame.customPayload.isEmpty()) { flags = Flags.add(flags, ProtocolConstants.FrameFlag.CUSTOM_PAYLOAD); } if (!frame.warnings.isEmpty()) { flags = Flags.add(flags, ProtocolConstants.FrameFlag.WARNING); } if (protocolVersion == ProtocolConstants.Version.BETA) { flags = Flags.add(flags, ProtocolConstants.FrameFlag.USE_BETA); } int headerSize = headerEncodedSize(); if (!Flags.contains(flags, ProtocolConstants.FrameFlag.COMPRESSED)) { // No compression: we can optimize and do everything with a single allocation int messageSize = encoder.encodedSize(request); if (frame.tracingId != null) { messageSize += PrimitiveSizes.UUID; } if (!frame.customPayload.isEmpty()) { messageSize += PrimitiveSizes.sizeOfBytesMap(frame.customPayload); } if (!frame.warnings.isEmpty()) { messageSize += PrimitiveSizes.sizeOfStringList(frame.warnings); } B dest = primitiveCodec.allocate(headerSize + messageSize); encodeHeader(frame, flags, messageSize, dest); encodeTracingId(frame.tracingId, dest); encodeCustomPayload(frame.customPayload, dest); encodeWarnings(frame.warnings, dest); encoder.encode(dest, request, primitiveCodec); return dest; } else { // We need to compress first in order to know the body size // 1) Encode uncompressed message int uncompressedMessageSize = encoder.encodedSize(request); if (frame.tracingId != null) { uncompressedMessageSize += PrimitiveSizes.UUID; } if (!frame.customPayload.isEmpty()) { uncompressedMessageSize += PrimitiveSizes.sizeOfBytesMap(frame.customPayload); } if (!frame.warnings.isEmpty()) { uncompressedMessageSize += PrimitiveSizes.sizeOfStringList(frame.warnings); } B uncompressedMessage = primitiveCodec.allocate(uncompressedMessageSize); encodeTracingId(frame.tracingId, uncompressedMessage); encodeCustomPayload(frame.customPayload, uncompressedMessage); encodeWarnings(frame.warnings, uncompressedMessage); encoder.encode(uncompressedMessage, request, primitiveCodec); // 2) Compress and measure size, discard uncompressed buffer B compressedMessage = compressor.compress(uncompressedMessage); primitiveCodec.release(uncompressedMessage); int messageSize = primitiveCodec.sizeOf(compressedMessage); // 3) Encode final frame B header = primitiveCodec.allocate(headerSize); encodeHeader(frame, flags, messageSize, header); return primitiveCodec.concat(header, compressedMessage); } } public static int headerEncodedSize() { return 9; } private void encodeHeader(Frame frame, int flags, int messageSize, B dest) { int versionAndDirection = frame.protocolVersion; if (frame.message.isResponse) { versionAndDirection |= 0b1000_0000; } primitiveCodec.writeByte((byte) versionAndDirection, dest); primitiveCodec.writeByte((byte) flags, dest); primitiveCodec.writeUnsignedShort( frame.streamId & 0xFFFF, // see readStreamId() dest); primitiveCodec.writeByte((byte) frame.message.opcode, dest); primitiveCodec.writeInt(messageSize, dest); } private void encodeTracingId(UUID tracingId, B dest) { if (tracingId != null) { primitiveCodec.writeUuid(tracingId, dest); } } private void encodeCustomPayload(Map<String, ByteBuffer> customPayload, B dest) { if (!customPayload.isEmpty()) { primitiveCodec.writeBytesMap(customPayload, dest); } } private void encodeWarnings(List<String> warnings, B dest) { if (!warnings.isEmpty()) { primitiveCodec.writeStringList(warnings, dest); } } public Frame decode(B source) { int directionAndVersion = primitiveCodec.readByte(source); boolean isResponse = (directionAndVersion & 0b1000_0000) == 0b1000_0000; int protocolVersion = directionAndVersion & 0b0111_1111; int flags = primitiveCodec.readByte(source); boolean beta = Flags.contains(flags, ProtocolConstants.FrameFlag.USE_BETA); int streamId = readStreamId(source); int opcode = primitiveCodec.readByte(source); int length = primitiveCodec.readInt(source); int actualLength = primitiveCodec.sizeOf(source); ProtocolErrors.check( length == actualLength, "Declared length in header (%d) does not match actual length (%d)", length, actualLength); boolean decompressed = false; if (Flags.contains(flags, ProtocolConstants.FrameFlag.COMPRESSED)) { B newSource = compressor.decompress(source); // if decompress returns a different object, track this so we know to release it when done. if (newSource != source) { decompressed = true; source = newSource; } } int frameSize; int compressedFrameSize; if (decompressed) { frameSize = headerEncodedSize() + primitiveCodec.sizeOf(source); compressedFrameSize = headerEncodedSize() + length; // what we measured before decompressing } else { frameSize = headerEncodedSize() + length; compressedFrameSize = -1; } boolean isTracing = Flags.contains(flags, ProtocolConstants.FrameFlag.TRACING); UUID tracingId = (isResponse && isTracing) ? primitiveCodec.readUuid(source) : null; Map<String, ByteBuffer> customPayload = (Flags.contains(flags, ProtocolConstants.FrameFlag.CUSTOM_PAYLOAD)) ? primitiveCodec.readBytesMap(source) : Collections.emptyMap(); List<String> warnings = (isResponse && Flags.contains(flags, ProtocolConstants.FrameFlag.WARNING)) ? primitiveCodec.readStringList(source) : Collections.emptyList(); Message.Codec decoder = decoders.get(protocolVersion, opcode); ProtocolErrors.check( decoder != null, "Unsupported request opcode: %s in protocol %d", opcode, protocolVersion); Message response = decoder.decode(source, primitiveCodec); if (decompressed) { primitiveCodec.release(source); } return new Frame( protocolVersion, beta, streamId, isTracing, tracingId, frameSize, compressedFrameSize, customPayload, warnings, response); } private int readStreamId(B source) { int id = primitiveCodec.readUnsignedShort(source); // The protocol spec states that the stream id is a [short], but this is wrong: the stream id // is signed. Rather than adding a `readSignedShort` to PrimitiveCodec for this edge case, // handle the conversion here. return (short) id; }
Intermediary class to pass request/response codecs to the frame codec.

This is just so that we can have the codecs nicely grouped by protocol version.

/** * Intermediary class to pass request/response codecs to the frame codec. * * <p>This is just so that we can have the codecs nicely grouped by protocol version. */
public interface CodecGroup { interface Registry { Registry addCodec(Message.Codec codec);
Add a codec for encoding only; this helps catch programming errors if the client is only supposed to send a subset of the existing messages.
/** * Add a codec for encoding only; this helps catch programming errors if the client is only * supposed to send a subset of the existing messages. */
Registry addEncoder(Message.Codec codec);
Add a codec for decoding only; this helps catch programming errors if the client is only supposed to receive a subset of the existing messages.
/** * Add a codec for decoding only; this helps catch programming errors if the client is only * supposed to receive a subset of the existing messages. */
Registry addDecoder(Message.Codec codec); } void registerCodecs(Registry registry); } }