package com.google.crypto.tink.subtle;
import com.google.crypto.tink.PublicKeySign;
import com.google.crypto.tink.subtle.Enums.HashType;
import java.math.BigInteger;
import java.security.GeneralSecurityException;
import java.security.KeyFactory;
import java.security.MessageDigest;
import java.security.interfaces.RSAPrivateCrtKey;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.RSAPublicKeySpec;
import javax.crypto.Cipher;
public final class RsaSsaPssSignJce implements PublicKeySign {
private final RSAPrivateCrtKey privateKey;
private final RSAPublicKey publicKey;
private final HashType sigHash;
private final HashType mgf1Hash;
private final int saltLength;
private static final String RAW_RSA_ALGORITHM = "RSA/ECB/NOPADDING";
public RsaSsaPssSignJce(
final RSAPrivateCrtKey priv, HashType sigHash, HashType mgf1Hash, int saltLength)
throws GeneralSecurityException {
Validators.validateSignatureHash(sigHash);
Validators.validateRsaModulusSize(priv.getModulus().bitLength());
this.privateKey = priv;
KeyFactory kf = EngineFactory.KEY_FACTORY.getInstance("RSA");
this.publicKey =
(RSAPublicKey)
kf.generatePublic(new RSAPublicKeySpec(priv.getModulus(), priv.getPublicExponent()));
this.sigHash = sigHash;
this.mgf1Hash = mgf1Hash;
this.saltLength = saltLength;
}
@Override
public byte[] sign(final byte[] data) throws GeneralSecurityException {
int modBits = publicKey.getModulus().bitLength();
byte[] em = emsaPssEncode(data, modBits - 1);
return rsasp1(em);
}
private byte[] rsasp1(byte[] m) throws GeneralSecurityException {
Cipher decryptCipher = EngineFactory.CIPHER.getInstance(RAW_RSA_ALGORITHM);
decryptCipher.init(Cipher.DECRYPT_MODE, this.privateKey);
byte[] c = decryptCipher.doFinal(m);
Cipher encryptCipher = EngineFactory.CIPHER.getInstance(RAW_RSA_ALGORITHM);
encryptCipher.init(Cipher.ENCRYPT_MODE, this.publicKey);
byte[] m0 = encryptCipher.doFinal(c);
if (!new BigInteger(1, m).equals(new BigInteger(1, m0))) {
throw new java.lang.RuntimeException("Security bug: RSA signature computation error");
}
return c;
}
private byte[] emsaPssEncode(byte[] m, int emBits) throws GeneralSecurityException {
Validators.validateSignatureHash(sigHash);
MessageDigest digest =
EngineFactory.MESSAGE_DIGEST.getInstance(SubtleUtil.toDigestAlgo(this.sigHash));
byte[] mHash = digest.digest(m);
int hLen = digest.getDigestLength();
int emLen = (emBits - 1) / 8 + 1;
if (emLen < hLen + this.saltLength + 2) {
throw new GeneralSecurityException("encoding error");
}
byte[] salt = Random.randBytes(this.saltLength);
byte[] mPrime = new byte[8 + hLen + this.saltLength];
System.arraycopy(mHash, 0, mPrime, 8, hLen);
System.arraycopy(salt, 0, mPrime, 8 + hLen, salt.length);
byte[] h = digest.digest(mPrime);
byte[] db = new byte[emLen - hLen - 1];
db[emLen - this.saltLength - hLen - 2] = (byte) 0x01;
System.arraycopy(salt, 0, db, emLen - this.saltLength - hLen - 1, salt.length);
byte[] dbMask = SubtleUtil.mgf1(h, emLen - hLen - 1, this.mgf1Hash);
byte[] maskedDb = new byte[emLen - hLen - 1];
for (int i = 0; i < maskedDb.length; i++) {
maskedDb[i] = (byte) (db[i] ^ dbMask[i]);
}
for (int i = 0; i < (long) emLen * 8 - emBits; i++) {
int bytePos = i / 8;
int bitPos = 7 - i % 8;
maskedDb[bytePos] = (byte) (maskedDb[bytePos] & ~(1 << bitPos));
}
byte[] em = new byte[maskedDb.length + hLen + 1];
System.arraycopy(maskedDb, 0, em, 0, maskedDb.length);
System.arraycopy(h, 0, em, maskedDb.length, h.length);
em[maskedDb.length + hLen] = (byte) 0xbc;
return em;
}
}