package com.google.crypto.tink.subtle;
import com.google.crypto.tink.PublicKeyVerify;
import com.google.crypto.tink.subtle.Enums.HashType;
import java.math.BigInteger;
import java.security.GeneralSecurityException;
import java.security.MessageDigest;
import java.security.interfaces.RSAPublicKey;
import java.util.Arrays;
public final class RsaSsaPssVerifyJce implements PublicKeyVerify {
private final RSAPublicKey publicKey;
private final HashType sigHash;
private final HashType mgf1Hash;
private final int saltLength;
public RsaSsaPssVerifyJce(
final RSAPublicKey pubKey, HashType sigHash, HashType mgf1Hash, int saltLength)
throws GeneralSecurityException {
Validators.validateSignatureHash(sigHash);
Validators.validateRsaModulusSize(pubKey.getModulus().bitLength());
this.publicKey = pubKey;
this.sigHash = sigHash;
this.mgf1Hash = mgf1Hash;
this.saltLength = saltLength;
}
@Override
public void verify(final byte[] signature, final byte[] data) throws GeneralSecurityException {
BigInteger e = publicKey.getPublicExponent();
BigInteger n = publicKey.getModulus();
int nLengthInBytes = (n.bitLength() + 7) / 8;
int mLen = (n.bitLength() - 1 + 7) / 8;
if (nLengthInBytes != signature.length) {
throw new GeneralSecurityException("invalid signature's length");
}
BigInteger s = SubtleUtil.bytes2Integer(signature);
if (s.compareTo(n) >= 0) {
throw new GeneralSecurityException("signature out of range");
}
BigInteger m = s.modPow(e, n);
byte[] em = SubtleUtil.integer2Bytes(m, mLen);
emsaPssVerify(data, em, n.bitLength() - 1);
}
private void emsaPssVerify(byte[] m, byte[] em, int emBits) throws GeneralSecurityException {
Validators.validateSignatureHash(sigHash);
MessageDigest digest =
EngineFactory.MESSAGE_DIGEST.getInstance(SubtleUtil.toDigestAlgo(this.sigHash));
byte[] mHash = digest.digest(m);
int hLen = digest.getDigestLength();
int emLen = em.length;
if (emLen < hLen + this.saltLength + 2) {
throw new GeneralSecurityException("inconsistent");
}
if (em[em.length - 1] != (byte) 0xbc) {
throw new GeneralSecurityException("inconsistent");
}
byte[] maskedDb = Arrays.copyOf(em, emLen - hLen - 1);
byte[] h = Arrays.copyOfRange(em, maskedDb.length, maskedDb.length + hLen);
for (int i = 0; i < (long) emLen * 8 - emBits; i++) {
int bytePos = i / 8;
int bitPos = 7 - i % 8;
if (((maskedDb[bytePos] >> bitPos) & 1) != 0) {
throw new GeneralSecurityException("inconsistent");
}
}
byte[] dbMask = SubtleUtil.mgf1(h, emLen - hLen - 1, mgf1Hash);
byte[] db = new byte[dbMask.length];
for (int i = 0; i < db.length; i++) {
db[i] = (byte) (dbMask[i] ^ maskedDb[i]);
}
for (int i = 0; i <= (long) emLen * 8 - emBits; i++) {
int bytePos = i / 8;
int bitPos = 7 - i % 8;
db[bytePos] = (byte) (db[bytePos] & ~(1 << bitPos));
}
for (int i = 0; i < emLen - hLen - this.saltLength - 2; i++) {
if (db[i] != 0) {
throw new GeneralSecurityException("inconsistent");
}
}
if (db[emLen - hLen - this.saltLength - 2] != (byte) 0x01) {
throw new GeneralSecurityException("inconsistent");
}
byte[] salt = Arrays.copyOfRange(db, db.length - this.saltLength, db.length);
byte[] mPrime = new byte[8 + hLen + this.saltLength];
System.arraycopy(mHash, 0, mPrime, 8, mHash.length);
System.arraycopy(salt, 0, mPrime, 8 + hLen, salt.length);
byte[] hPrime = digest.digest(mPrime);
if (!Bytes.equal(hPrime, h)) {
throw new GeneralSecurityException("inconsistent");
}
}
}