package com.google.crypto.tink.subtle;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.NonWritableChannelException;
import java.nio.channels.SeekableByteChannel;
import java.security.GeneralSecurityException;
import java.util.Arrays;
class StreamingAeadSeekableDecryptingChannel implements SeekableByteChannel {
private static final int = 16;
private final SeekableByteChannel ciphertextChannel;
private final ByteBuffer ciphertextSegment;
private final ByteBuffer plaintextSegment;
private final ByteBuffer ;
private final long ciphertextChannelSize;
private final int numberOfSegments;
private final int lastCiphertextSegmentSize;
private final byte[] aad;
private final StreamSegmentDecrypter decrypter;
private long plaintextPosition;
private long plaintextSize;
private boolean ;
private boolean isCurrentSegmentDecrypted;
private int currentSegmentNr;
private boolean isopen;
private final int plaintextSegmentSize;
private final int ciphertextSegmentSize;
private final int ciphertextOffset;
private final int firstSegmentOffset;
public StreamingAeadSeekableDecryptingChannel(
NonceBasedStreamingAead streamAead,
SeekableByteChannel ciphertext,
byte[] associatedData) throws IOException, GeneralSecurityException {
decrypter = streamAead.newStreamSegmentDecrypter();
ciphertextChannel = ciphertext;
header = ByteBuffer.allocate(streamAead.getHeaderLength());
ciphertextSegmentSize = streamAead.getCiphertextSegmentSize();
ciphertextSegment = ByteBuffer.allocate(ciphertextSegmentSize);
plaintextSegmentSize = streamAead.getPlaintextSegmentSize();
plaintextSegment = ByteBuffer.allocate(plaintextSegmentSize + PLAINTEXT_SEGMENT_EXTRA_SIZE);
plaintextPosition = 0;
headerRead = false;
currentSegmentNr = -1;
isCurrentSegmentDecrypted = false;
ciphertextChannelSize = ciphertextChannel.size();
aad = Arrays.copyOf(associatedData, associatedData.length);
isopen = ciphertextChannel.isOpen();
int fullSegments = (int) (ciphertextChannelSize / ciphertextSegmentSize);
int remainder = (int) (ciphertextChannelSize % ciphertextSegmentSize);
int ciphertextOverhead = streamAead.getCiphertextOverhead();
if (remainder > 0) {
numberOfSegments = fullSegments + 1;
if (remainder < ciphertextOverhead) {
throw new IOException("Invalid ciphertext size");
}
lastCiphertextSegmentSize = remainder;
} else {
numberOfSegments = fullSegments;
lastCiphertextSegmentSize = ciphertextSegmentSize;
}
ciphertextOffset = streamAead.getCiphertextOffset();
firstSegmentOffset = ciphertextOffset - streamAead.getHeaderLength();
if (firstSegmentOffset < 0) {
throw new IOException("Invalid ciphertext offset or header length");
}
long overhead = (long) numberOfSegments * ciphertextOverhead + ciphertextOffset;
if (overhead > ciphertextChannelSize) {
throw new IOException("Ciphertext is too short");
}
plaintextSize = ciphertextChannelSize - overhead;
}
@Override
public synchronized String toString() {
StringBuilder res =
new StringBuilder();
String ctChannel;
try {
ctChannel = "position:" + ciphertextChannel.position();
} catch (IOException ex) {
ctChannel = "position: n/a";
}
res.append("StreamingAeadSeekableDecryptingChannel")
.append("\nciphertextChannel").append(ctChannel)
.append("\nciphertextChannelSize:").append(ciphertextChannelSize)
.append("\nplaintextSize:").append(plaintextSize)
.append("\nciphertextSegmentSize:").append(ciphertextSegmentSize)
.append("\nnumberOfSegments:").append(numberOfSegments)
.append("\nheaderRead:").append(headerRead)
.append("\nplaintextPosition:").append(plaintextPosition)
.append("\nHeader")
.append(" position:").append(header.position())
.append(" limit:").append(header.position())
.append("\ncurrentSegmentNr:").append(currentSegmentNr)
.append("\nciphertextSgement")
.append(" position:").append(ciphertextSegment.position())
.append(" limit:").append(ciphertextSegment.limit())
.append("\nisCurrentSegmentDecrypted:").append(isCurrentSegmentDecrypted)
.append("\nplaintextSegment")
.append(" position:").append(plaintextSegment.position())
.append(" limit:").append(plaintextSegment.limit());
return res.toString();
}
@Override
public synchronized long position() {
return plaintextPosition;
}
@Override
public synchronized SeekableByteChannel position(long newPosition) {
plaintextPosition = newPosition;
return this;
}
private boolean () throws IOException {
ciphertextChannel.position(header.position() + firstSegmentOffset);
ciphertextChannel.read(header);
if (header.remaining() > 0) {
return false;
} else {
header.flip();
try {
decrypter.init(header, aad);
headerRead = true;
} catch (GeneralSecurityException ex) {
throw new IOException(ex);
}
return true;
}
}
private int getSegmentNr(long plaintextPosition) {
return (int) ((plaintextPosition + ciphertextOffset) / plaintextSegmentSize);
}
private boolean tryLoadSegment(int segmentNr) throws IOException {
if (segmentNr < 0 || segmentNr >= numberOfSegments) {
throw new IOException("Invalid position");
}
boolean isLast = segmentNr == numberOfSegments - 1;
if (segmentNr == currentSegmentNr) {
if (isCurrentSegmentDecrypted) {
return true;
}
} else {
long ciphertextPosition = (long) segmentNr * ciphertextSegmentSize;
int segmentSize = ciphertextSegmentSize;
if (isLast) {
segmentSize = lastCiphertextSegmentSize;
}
if (segmentNr == 0) {
segmentSize -= ciphertextOffset;
ciphertextPosition = ciphertextOffset;
}
ciphertextChannel.position(ciphertextPosition);
ciphertextSegment.clear();
ciphertextSegment.limit(segmentSize);
currentSegmentNr = segmentNr;
isCurrentSegmentDecrypted = false;
}
if (ciphertextSegment.remaining() > 0) {
ciphertextChannel.read(ciphertextSegment);
}
if (ciphertextSegment.remaining() > 0) {
return false;
}
ciphertextSegment.flip();
plaintextSegment.clear();
try {
decrypter.decryptSegment(ciphertextSegment, segmentNr, isLast, plaintextSegment);
} catch (GeneralSecurityException ex) {
currentSegmentNr = -1;
throw new IOException("Failed to decrypt", ex);
}
plaintextSegment.flip();
isCurrentSegmentDecrypted = true;
return true;
}
private boolean reachedEnd() {
return (isCurrentSegmentDecrypted
&& currentSegmentNr == numberOfSegments - 1
&& plaintextSegment.remaining() == 0);
}
public synchronized int read(ByteBuffer dst, long start) throws IOException {
long oldPosition = position();
try {
position(start);
return read(dst);
} finally {
position(oldPosition);
}
}
@Override
public synchronized int read(ByteBuffer dst) throws IOException {
if (!isopen) {
throw new ClosedChannelException();
}
if (!headerRead) {
if (!tryReadHeader()) {
return 0;
}
}
int startPos = dst.position();
while (dst.remaining() > 0 && plaintextPosition < plaintextSize) {
int segmentNr = getSegmentNr(plaintextPosition);
int segmentOffset;
if (segmentNr == 0) {
segmentOffset = (int) plaintextPosition;
} else {
segmentOffset = (int) ((plaintextPosition + ciphertextOffset) % plaintextSegmentSize);
}
if (tryLoadSegment(segmentNr)) {
plaintextSegment.position(segmentOffset);
if (plaintextSegment.remaining() <= dst.remaining()) {
plaintextPosition += plaintextSegment.remaining();
dst.put(plaintextSegment);
} else {
int sliceSize = dst.remaining();
ByteBuffer slice = plaintextSegment.duplicate();
slice.limit(slice.position() + sliceSize);
dst.put(slice);
plaintextPosition += sliceSize;
plaintextSegment.position(plaintextSegment.position() + sliceSize);
}
} else {
break;
}
}
int read = dst.position() - startPos;
if (read == 0 && reachedEnd()) {
return -1;
}
return read;
}
@Override
public long size() {
return plaintextSize;
}
public synchronized long verifiedSize() throws IOException {
if (tryLoadSegment(numberOfSegments - 1)) {
return plaintextSize;
} else {
throw new IOException("could not verify the size");
}
}
@Override
public SeekableByteChannel truncate(long size) throws NonWritableChannelException {
throw new NonWritableChannelException();
}
@Override
public int write(ByteBuffer src) throws NonWritableChannelException {
throw new NonWritableChannelException();
}
@Override
public synchronized void close() throws IOException {
ciphertextChannel.close();
isopen = false;
}
@Override
public synchronized boolean isOpen() {
return isopen;
}
}