package com.google.crypto.tink.subtle;
import com.google.crypto.tink.prf.Prf;
import com.google.errorprone.annotations.Immutable;
import java.security.GeneralSecurityException;
import java.security.InvalidAlgorithmParameterException;
import java.util.Arrays;
import javax.crypto.Cipher;
import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;
@Immutable
public final class PrfAesCmac implements Prf {
@SuppressWarnings("Immutable")
private final SecretKey keySpec;
@SuppressWarnings("Immutable")
private byte[] subKey1;
@SuppressWarnings("Immutable")
private byte[] subKey2;
private static Cipher instance() throws GeneralSecurityException {
return EngineFactory.CIPHER.getInstance("AES/ECB/NoPadding");
}
public PrfAesCmac(final byte[] key) throws GeneralSecurityException {
Validators.validateAesKeySize(key.length);
keySpec = new SecretKeySpec(key, "AES");
generateSubKeys();
}
@Override
public byte[] compute(final byte[] data, int outputLength) throws GeneralSecurityException {
if (outputLength > AesUtil.BLOCK_SIZE) {
throw new InvalidAlgorithmParameterException(
"outputLength too large, max is " + AesUtil.BLOCK_SIZE + " bytes");
}
Cipher aes = instance();
aes.init(Cipher.ENCRYPT_MODE, keySpec);
int n = Math.max(1, (int) Math.ceil((double) data.length / AesUtil.BLOCK_SIZE));
boolean flag = (n * AesUtil.BLOCK_SIZE == data.length);
byte[] mLast;
if (flag) {
mLast = Bytes.xor(data, (n - 1) * AesUtil.BLOCK_SIZE, subKey1, 0, AesUtil.BLOCK_SIZE);
} else {
mLast =
Bytes.xor(
AesUtil.cmacPad(Arrays.copyOfRange(data, (n - 1) * AesUtil.BLOCK_SIZE, data.length)),
subKey2);
}
byte[] x = new byte[AesUtil.BLOCK_SIZE];
byte[] y;
for (int i = 0; i < n - 1; i++) {
y = Bytes.xor(x, 0, data, i * AesUtil.BLOCK_SIZE, AesUtil.BLOCK_SIZE);
x = aes.doFinal(y);
}
y = Bytes.xor(mLast, x);
byte[] output = Arrays.copyOf(aes.doFinal(y), outputLength);
return output;
}
private void generateSubKeys() throws GeneralSecurityException {
Cipher aes = instance();
aes.init(Cipher.ENCRYPT_MODE, keySpec);
byte[] zeroes = new byte[AesUtil.BLOCK_SIZE];
byte[] l = aes.doFinal(zeroes);
subKey1 = AesUtil.dbl(l);
subKey2 = AesUtil.dbl(subKey1);
}
}