package com.google.crypto.tink.subtle;
import java.io.FilterOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
class StreamingAeadEncryptingStream extends FilterOutputStream {
private StreamSegmentEncrypter encrypter;
private int plaintextSegmentSize;
ByteBuffer ptBuffer;
ByteBuffer ctBuffer;
boolean open;
public StreamingAeadEncryptingStream(
NonceBasedStreamingAead streamAead, OutputStream ciphertextChannel, byte[] associatedData)
throws GeneralSecurityException, IOException {
super(ciphertextChannel);
encrypter = streamAead.newStreamSegmentEncrypter(associatedData);
plaintextSegmentSize = streamAead.getPlaintextSegmentSize();
ptBuffer = ByteBuffer.allocate(plaintextSegmentSize);
ctBuffer = ByteBuffer.allocate(streamAead.getCiphertextSegmentSize());
ptBuffer.limit(plaintextSegmentSize - streamAead.getCiphertextOffset());
ByteBuffer header = encrypter.getHeader();
byte[] headerBytes = new byte[header.remaining()];
header.get(headerBytes);
out.write(headerBytes);
open = true;
}
@Override
public void write(int b) throws IOException {
write(new byte[] {(byte) b});
}
@Override
public void write(byte[] b) throws IOException {
write(b, 0, b.length);
}
@Override
public synchronized void write(byte[] pt, int offset, int length) throws IOException {
if (!open) {
throw new IOException("Trying to write to closed stream");
}
int startPosition = offset;
int remaining = length;
while (remaining > ptBuffer.remaining()) {
int sliceSize = ptBuffer.remaining();
ByteBuffer slice = ByteBuffer.wrap(pt, startPosition, sliceSize);
startPosition += sliceSize;
remaining -= sliceSize;
try {
ptBuffer.flip();
ctBuffer.clear();
encrypter.encryptSegment(ptBuffer, slice, false, ctBuffer);
} catch (GeneralSecurityException ex) {
throw new IOException(ex);
}
ctBuffer.flip();
out.write(ctBuffer.array(), ctBuffer.position(), ctBuffer.remaining());
ptBuffer.clear();
ptBuffer.limit(plaintextSegmentSize);
}
ptBuffer.put(pt, startPosition, remaining);
}
@Override
public synchronized void close() throws IOException {
if (!open) {
return;
}
try {
ptBuffer.flip();
ctBuffer.clear();
encrypter.encryptSegment(ptBuffer, true, ctBuffer);
} catch (GeneralSecurityException ex) {
throw new IOException(
"ptBuffer.remaining():"
+ ptBuffer.remaining()
+ " ctBuffer.remaining():"
+ ctBuffer.remaining(),
ex);
}
ctBuffer.flip();
out.write(ctBuffer.array(), ctBuffer.position(), ctBuffer.remaining());
open = false;
super.close();
}
}