package com.datastax.oss.driver.internal.core.type.codec;
import com.datastax.oss.driver.api.core.ProtocolVersion;
import com.datastax.oss.driver.api.core.type.DataType;
import com.datastax.oss.driver.api.core.type.codec.TypeCodec;
import com.datastax.oss.driver.api.core.type.reflect.GenericType;
import com.datastax.oss.driver.shaded.guava.common.collect.Maps;
import edu.umd.cs.findbugs.annotations.NonNull;
import edu.umd.cs.findbugs.annotations.Nullable;
import java.nio.ByteBuffer;
import java.util.LinkedHashMap;
import java.util.Map;
import net.jcip.annotations.ThreadSafe;
@ThreadSafe
public class MapCodec<KeyT, ValueT> implements TypeCodec<Map<KeyT, ValueT>> {
private final DataType cqlType;
private final GenericType<Map<KeyT, ValueT>> javaType;
private final TypeCodec<KeyT> keyCodec;
private final TypeCodec<ValueT> valueCodec;
public MapCodec(DataType cqlType, TypeCodec<KeyT> keyCodec, TypeCodec<ValueT> valueCodec) {
this.cqlType = cqlType;
this.keyCodec = keyCodec;
this.valueCodec = valueCodec;
this.javaType = GenericType.mapOf(keyCodec.getJavaType(), valueCodec.getJavaType());
}
@NonNull
@Override
public GenericType<Map<KeyT, ValueT>> getJavaType() {
return javaType;
}
@NonNull
@Override
public DataType getCqlType() {
return cqlType;
}
@Override
public boolean accepts(@NonNull Object value) {
if (value instanceof Map) {
Map<?, ?> map = (Map<?, ?>) value;
if (map.isEmpty()) {
return true;
}
Map.Entry<?, ?> entry = map.entrySet().iterator().next();
return keyCodec.accepts(entry.getKey()) && valueCodec.accepts(entry.getValue());
}
return false;
}
@Override
@Nullable
public ByteBuffer encode(
@Nullable Map<KeyT, ValueT> value, @NonNull ProtocolVersion protocolVersion) {
if (value == null) {
return null;
} else {
int i = 0;
ByteBuffer[] encodedElements = new ByteBuffer[value.size() * 2];
int toAllocate = 4;
for (Map.Entry<KeyT, ValueT> entry : value.entrySet()) {
if (entry.getKey() == null) {
throw new NullPointerException("Map keys cannot be null");
}
if (entry.getValue() == null) {
throw new NullPointerException("Map values cannot be null");
}
ByteBuffer encodedKey;
try {
encodedKey = keyCodec.encode(entry.getKey(), protocolVersion);
} catch (ClassCastException e) {
throw new IllegalArgumentException("Invalid type for key: " + entry.getKey().getClass());
}
if (encodedKey == null) {
throw new NullPointerException("Map keys cannot encode to CQL NULL");
}
encodedElements[i++] = encodedKey;
toAllocate += 4 + encodedKey.remaining();
ByteBuffer encodedValue;
try {
encodedValue = valueCodec.encode(entry.getValue(), protocolVersion);
} catch (ClassCastException e) {
throw new IllegalArgumentException(
"Invalid type for value: " + entry.getValue().getClass());
}
if (encodedValue == null) {
throw new NullPointerException("Map values cannot encode to CQL NULL");
}
encodedElements[i++] = encodedValue;
toAllocate += 4 + encodedValue.remaining();
}
ByteBuffer result = ByteBuffer.allocate(toAllocate);
result.putInt(value.size());
for (ByteBuffer encodedElement : encodedElements) {
result.putInt(encodedElement.remaining());
result.put(encodedElement);
}
result.flip();
return result;
}
}
@Nullable
@Override
public Map<KeyT, ValueT> decode(
@Nullable ByteBuffer bytes, @NonNull ProtocolVersion protocolVersion) {
if (bytes == null || bytes.remaining() == 0) {
return new LinkedHashMap<>(0);
} else {
ByteBuffer input = bytes.duplicate();
int size = input.getInt();
Map<KeyT, ValueT> result = Maps.newLinkedHashMapWithExpectedSize(size);
for (int i = 0; i < size; i++) {
KeyT key;
int keySize = input.getInt();
if (keySize < 0) {
key = null;
} else {
ByteBuffer encodedKey = input.slice();
encodedKey.limit(keySize);
key = keyCodec.decode(encodedKey, protocolVersion);
input.position(input.position() + keySize);
}
ValueT value;
int valueSize = input.getInt();
if (valueSize < 0) {
value = null;
} else {
ByteBuffer encodedValue = input.slice();
encodedValue.limit(valueSize);
value = valueCodec.decode(encodedValue, protocolVersion);
input.position(input.position() + valueSize);
}
result.put(key, value);
}
return result;
}
}
@NonNull
@Override
public String format(@Nullable Map<KeyT, ValueT> value) {
if (value == null) {
return "NULL";
}
StringBuilder sb = new StringBuilder();
sb.append("{");
boolean first = true;
for (Map.Entry<KeyT, ValueT> e : value.entrySet()) {
if (first) {
first = false;
} else {
sb.append(",");
}
sb.append(keyCodec.format(e.getKey()));
sb.append(":");
sb.append(valueCodec.format(e.getValue()));
}
sb.append("}");
return sb.toString();
}
@Nullable
@Override
public Map<KeyT, ValueT> parse(@Nullable String value) {
if (value == null || value.isEmpty() || value.equalsIgnoreCase("NULL")) {
return null;
}
int idx = ParseUtils.skipSpaces(value, 0);
if (value.charAt(idx++) != '{') {
throw new IllegalArgumentException(
String.format(
"cannot parse map value from \"%s\", at character %d expecting '{' but got '%c'",
value, idx, value.charAt(idx)));
}
idx = ParseUtils.skipSpaces(value, idx);
if (value.charAt(idx) == '}') {
return new LinkedHashMap<>(0);
}
Map<KeyT, ValueT> map = new LinkedHashMap<>();
while (idx < value.length()) {
int n;
try {
n = ParseUtils.skipCQLValue(value, idx);
} catch (IllegalArgumentException e) {
throw new IllegalArgumentException(
String.format(
"Cannot parse map value from \"%s\", invalid CQL value at character %d",
value, idx),
e);
}
KeyT k = keyCodec.parse(value.substring(idx, n));
idx = n;
idx = ParseUtils.skipSpaces(value, idx);
if (value.charAt(idx++) != ':') {
throw new IllegalArgumentException(
String.format(
"Cannot parse map value from \"%s\", at character %d expecting ':' but got '%c'",
value, idx, value.charAt(idx)));
}
idx = ParseUtils.skipSpaces(value, idx);
try {
n = ParseUtils.skipCQLValue(value, idx);
} catch (IllegalArgumentException e) {
throw new IllegalArgumentException(
String.format(
"Cannot parse map value from \"%s\", invalid CQL value at character %d",
value, idx),
e);
}
ValueT v = valueCodec.parse(value.substring(idx, n));
idx = n;
map.put(k, v);
idx = ParseUtils.skipSpaces(value, idx);
if (value.charAt(idx) == '}') {
return map;
}
if (value.charAt(idx++) != ',') {
throw new IllegalArgumentException(
String.format(
"Cannot parse map value from \"%s\", at character %d expecting ',' but got '%c'",
value, idx, value.charAt(idx)));
}
idx = ParseUtils.skipSpaces(value, idx);
}
throw new IllegalArgumentException(
String.format("Malformed map value \"%s\", missing closing '}'", value));
}
}