package com.google.crypto.tink.subtle;
import com.google.crypto.tink.Aead;
import java.security.GeneralSecurityException;
import java.util.Arrays;
import javax.crypto.AEADBadTagException;
import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
public final class AesEaxJce implements Aead {
private static final ThreadLocal<Cipher> localEcbCipher =
new ThreadLocal<Cipher>() {
@Override
protected Cipher initialValue() {
try {
return EngineFactory.CIPHER.getInstance("AES/ECB/NOPADDING");
} catch (GeneralSecurityException ex) {
throw new IllegalStateException(ex);
}
}
};
private static final ThreadLocal<Cipher> localCtrCipher =
new ThreadLocal<Cipher>() {
@Override
protected Cipher initialValue() {
try {
return EngineFactory.CIPHER.getInstance("AES/CTR/NOPADDING");
} catch (GeneralSecurityException ex) {
throw new IllegalStateException(ex);
}
}
};
static final int BLOCK_SIZE_IN_BYTES = 16;
static final int TAG_SIZE_IN_BYTES = 16;
private final byte[] b;
private final byte[] p;
private final SecretKeySpec keySpec;
private final int ivSizeInBytes;
@SuppressWarnings("InsecureCryptoUsage")
public AesEaxJce(final byte[] key, int ivSizeInBytes) throws GeneralSecurityException {
if (ivSizeInBytes != 12 && ivSizeInBytes != 16) {
throw new IllegalArgumentException("IV size should be either 12 or 16 bytes");
}
this.ivSizeInBytes = ivSizeInBytes;
Validators.validateAesKeySize(key.length);
keySpec = new SecretKeySpec(key, "AES");
Cipher ecb = localEcbCipher.get();
ecb.init(Cipher.ENCRYPT_MODE, keySpec);
byte[] block = ecb.doFinal(new byte[BLOCK_SIZE_IN_BYTES]);
b = multiplyByX(block);
p = multiplyByX(b);
}
private static byte[] xor(final byte[] x, final byte[] y) {
assert x.length == y.length;
int len = x.length;
byte[] res = new byte[len];
for (int i = 0; i < len; i++) {
res[i] = (byte) (x[i] ^ y[i]);
}
return res;
}
private static byte[] multiplyByX(final byte[] block) {
byte[] res = new byte[BLOCK_SIZE_IN_BYTES];
for (int i = 0; i < BLOCK_SIZE_IN_BYTES - 1; i++) {
res[i] = (byte) (((block[i] << 1) ^ ((block[i + 1] & 0xff) >>> 7)) & 0xff);
}
res[BLOCK_SIZE_IN_BYTES - 1] =
(byte) ((block[BLOCK_SIZE_IN_BYTES - 1] << 1) ^ ((block[0] & 0x80) == 0 ? 0 : 0x87));
return res;
}
private byte[] pad(final byte[] data) {
if (data.length == BLOCK_SIZE_IN_BYTES) {
return xor(data, b);
} else {
byte[] res = Arrays.copyOf(p, BLOCK_SIZE_IN_BYTES);
for (int i = 0; i < data.length; i++) {
res[i] ^= data[i];
}
res[data.length] = (byte) (res[data.length] ^ 0x80);
return res;
}
}
private byte[] omac(Cipher ecb, int tag, final byte[] data, int offset, int length)
throws IllegalBlockSizeException, BadPaddingException {
assert length >= 0;
assert 0 <= tag && tag <= 3;
byte[] block = new byte[BLOCK_SIZE_IN_BYTES];
block[BLOCK_SIZE_IN_BYTES - 1] = (byte) tag;
if (length == 0) {
return ecb.doFinal(xor(block, b));
}
block = ecb.doFinal(block);
int position = 0;
while (length - position > BLOCK_SIZE_IN_BYTES) {
for (int i = 0; i < BLOCK_SIZE_IN_BYTES; i++) {
block[i] ^= data[offset + position + i];
}
block = ecb.doFinal(block);
position += BLOCK_SIZE_IN_BYTES;
}
byte[] padded = pad(Arrays.copyOfRange(data, offset + position, offset + length));
block = xor(block, padded);
return ecb.doFinal(block);
}
@SuppressWarnings("InsecureCryptoUsage")
@Override
public byte[] encrypt(final byte[] plaintext, final byte[] associatedData)
throws GeneralSecurityException {
if (plaintext.length > Integer.MAX_VALUE - ivSizeInBytes - TAG_SIZE_IN_BYTES) {
throw new GeneralSecurityException("plaintext too long");
}
byte[] ciphertext = new byte[ivSizeInBytes + plaintext.length + TAG_SIZE_IN_BYTES];
byte[] iv = Random.randBytes(ivSizeInBytes);
System.arraycopy(iv, 0, ciphertext, 0, ivSizeInBytes);
Cipher ecb = localEcbCipher.get();
ecb.init(Cipher.ENCRYPT_MODE, keySpec);
byte[] n = omac(ecb, 0, iv, 0, iv.length);
byte[] aad = associatedData;
if (aad == null) {
aad = new byte[0];
}
byte[] h = omac(ecb, 1, aad, 0, aad.length);
Cipher ctr = localCtrCipher.get();
ctr.init(Cipher.ENCRYPT_MODE, keySpec, new IvParameterSpec(n));
ctr.doFinal(plaintext, 0, plaintext.length, ciphertext, ivSizeInBytes);
byte[] t = omac(ecb, 2, ciphertext, ivSizeInBytes, plaintext.length);
int offset = plaintext.length + ivSizeInBytes;
for (int i = 0; i < TAG_SIZE_IN_BYTES; i++) {
ciphertext[offset + i] = (byte) (h[i] ^ n[i] ^ t[i]);
}
return ciphertext;
}
@SuppressWarnings("InsecureCryptoUsage")
@Override
public byte[] decrypt(final byte[] ciphertext, final byte[] associatedData)
throws GeneralSecurityException {
int plaintextLength = ciphertext.length - ivSizeInBytes - TAG_SIZE_IN_BYTES;
if (plaintextLength < 0) {
throw new GeneralSecurityException("ciphertext too short");
}
Cipher ecb = localEcbCipher.get();
ecb.init(Cipher.ENCRYPT_MODE, keySpec);
byte[] n = omac(ecb, 0, ciphertext, 0, ivSizeInBytes);
byte[] aad = associatedData;
if (aad == null) {
aad = new byte[0];
}
byte[] h = omac(ecb, 1, aad, 0, aad.length);
byte[] t = omac(ecb, 2, ciphertext, ivSizeInBytes, plaintextLength);
byte res = 0;
int offset = ciphertext.length - TAG_SIZE_IN_BYTES;
for (int i = 0; i < TAG_SIZE_IN_BYTES; i++) {
res = (byte) (res | (ciphertext[offset + i] ^ h[i] ^ n[i] ^ t[i]));
}
if (res != 0) {
throw new AEADBadTagException("tag mismatch");
}
Cipher ctr = localCtrCipher.get();
ctr.init(Cipher.ENCRYPT_MODE, keySpec, new IvParameterSpec(n));
return ctr.doFinal(ciphertext, ivSizeInBytes, plaintextLength);
}
}