package com.datastax.oss.driver.internal.core.type.codec;
import com.datastax.oss.driver.api.core.CqlIdentifier;
import com.datastax.oss.driver.api.core.ProtocolVersion;
import com.datastax.oss.driver.api.core.data.UdtValue;
import com.datastax.oss.driver.api.core.type.DataType;
import com.datastax.oss.driver.api.core.type.UserDefinedType;
import com.datastax.oss.driver.api.core.type.codec.TypeCodec;
import com.datastax.oss.driver.api.core.type.codec.registry.CodecRegistry;
import com.datastax.oss.driver.api.core.type.reflect.GenericType;
import edu.umd.cs.findbugs.annotations.NonNull;
import edu.umd.cs.findbugs.annotations.Nullable;
import java.nio.BufferUnderflowException;
import java.nio.ByteBuffer;
import net.jcip.annotations.ThreadSafe;
@ThreadSafe
public class UdtCodec implements TypeCodec<UdtValue> {
private final UserDefinedType cqlType;
public UdtCodec(@NonNull UserDefinedType cqlType) {
this.cqlType = cqlType;
}
@NonNull
@Override
public GenericType<UdtValue> getJavaType() {
return GenericType.UDT_VALUE;
}
@NonNull
@Override
public DataType getCqlType() {
return cqlType;
}
@Override
public boolean accepts(@NonNull Object value) {
return value instanceof UdtValue && ((UdtValue) value).getType().equals(cqlType);
}
@Override
public boolean accepts(@NonNull Class<?> javaClass) {
return UdtValue.class.equals(javaClass);
}
@Nullable
@Override
public ByteBuffer encode(@Nullable UdtValue value, @NonNull ProtocolVersion protocolVersion) {
if (value == null) {
return null;
}
if (!value.getType().equals(cqlType)) {
throw new IllegalArgumentException(
String.format(
"Invalid user defined type, expected %s but got %s", cqlType, value.getType()));
}
int toAllocate = 0;
int size = cqlType.getFieldTypes().size();
for (int i = 0; i < size; i++) {
ByteBuffer field = value.getBytesUnsafe(i);
toAllocate += 4 + (field == null ? 0 : field.remaining());
}
ByteBuffer result = ByteBuffer.allocate(toAllocate);
for (int i = 0; i < value.size(); i++) {
ByteBuffer field = value.getBytesUnsafe(i);
if (field == null) {
result.putInt(-1);
} else {
result.putInt(field.remaining());
result.put(field.duplicate());
}
}
return (ByteBuffer) result.flip();
}
@Nullable
@Override
public UdtValue decode(@Nullable ByteBuffer bytes, @NonNull ProtocolVersion protocolVersion) {
if (bytes == null) {
return null;
}
try {
ByteBuffer input = bytes.duplicate();
UdtValue value = cqlType.newValue();
int i = 0;
while (input.hasRemaining()) {
if (i > cqlType.getFieldTypes().size()) {
throw new IllegalArgumentException(
String.format(
"Too many fields in encoded UDT value, expected %d",
cqlType.getFieldTypes().size()));
}
int elementSize = input.getInt();
ByteBuffer element;
if (elementSize == -1) {
element = null;
} else {
element = input.slice();
element.limit(elementSize);
input.position(input.position() + elementSize);
}
value = value.setBytesUnsafe(i, element);
i += 1;
}
return value;
} catch (BufferUnderflowException e) {
throw new IllegalArgumentException("Not enough bytes to deserialize a UDT value", e);
}
}
@NonNull
@Override
public String format(@Nullable UdtValue value) {
if (value == null) {
return "NULL";
}
CodecRegistry registry = cqlType.getAttachmentPoint().getCodecRegistry();
StringBuilder sb = new StringBuilder("{");
int size = cqlType.getFieldTypes().size();
boolean first = true;
for (int i = 0; i < size; i++) {
if (first) {
first = false;
} else {
sb.append(",");
}
CqlIdentifier elementName = cqlType.getFieldNames().get(i);
sb.append(elementName.asCql(true));
sb.append(":");
DataType elementType = cqlType.getFieldTypes().get(i);
TypeCodec<Object> codec = registry.codecFor(elementType);
sb.append(codec.format(value.get(i, codec)));
}
sb.append("}");
return sb.toString();
}
@Nullable
@Override
public UdtValue parse(@Nullable String value) {
if (value == null || value.isEmpty() || value.equalsIgnoreCase("NULL")) {
return null;
}
UdtValue udt = cqlType.newValue();
int position = ParseUtils.skipSpaces(value, 0);
if (value.charAt(position++) != '{') {
throw new IllegalArgumentException(
String.format(
"Cannot parse UDT value from \"%s\", at character %d expecting '{' but got '%c'",
value, position, value.charAt(position)));
}
position = ParseUtils.skipSpaces(value, position);
if (value.charAt(position) == '}') {
return udt;
}
CodecRegistry registry = cqlType.getAttachmentPoint().getCodecRegistry();
while (position < value.length()) {
int n;
try {
n = ParseUtils.skipCQLId(value, position);
} catch (IllegalArgumentException e) {
throw new IllegalArgumentException(
String.format(
"Cannot parse UDT value from \"%s\", cannot parse a CQL identifier at character %d",
value, position),
e);
}
CqlIdentifier id = CqlIdentifier.fromInternal(value.substring(position, n));
position = n;
if (!cqlType.contains(id)) {
throw new IllegalArgumentException(
String.format("Unknown field %s in value \"%s\"", id, value));
}
position = ParseUtils.skipSpaces(value, position);
if (value.charAt(position++) != ':') {
throw new IllegalArgumentException(
String.format(
"Cannot parse UDT value from \"%s\", at character %d expecting ':' but got '%c'",
value, position, value.charAt(position)));
}
position = ParseUtils.skipSpaces(value, position);
try {
n = ParseUtils.skipCQLValue(value, position);
} catch (IllegalArgumentException e) {
throw new IllegalArgumentException(
String.format(
"Cannot parse UDT value from \"%s\", invalid CQL value at character %d",
value, position),
e);
}
String fieldValue = value.substring(position, n);
DataType fieldType = cqlType.getFieldTypes().get(cqlType.firstIndexOf(id));
TypeCodec<Object> codec = registry.codecFor(fieldType);
udt = udt.set(id, codec.parse(fieldValue), codec);
position = n;
position = ParseUtils.skipSpaces(value, position);
if (value.charAt(position) == '}') {
return udt;
}
if (value.charAt(position) != ',') {
throw new IllegalArgumentException(
String.format(
"Cannot parse UDT value from \"%s\", at character %d expecting ',' but got '%c'",
value, position, value.charAt(position)));
}
++position;
position = ParseUtils.skipSpaces(value, position);
}
throw new IllegalArgumentException(
String.format("Malformed UDT value \"%s\", missing closing '}'", value));
}
}