package org.apache.poi.poifs.crypt;
import static org.apache.poi.poifs.crypt.Decryptor.DEFAULT_POIFS_ENTRY;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FilterOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.security.GeneralSecurityException;
import java.util.BitSet;
import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.ShortBufferException;
import org.apache.poi.EncryptedDocumentException;
import org.apache.poi.poifs.filesystem.DirectoryNode;
import org.apache.poi.poifs.filesystem.POIFSWriterEvent;
import org.apache.poi.poifs.filesystem.POIFSWriterListener;
import org.apache.poi.util.IOUtils;
import org.apache.poi.util.Internal;
import org.apache.poi.util.LittleEndian;
import org.apache.poi.util.LittleEndianConsts;
import org.apache.poi.util.POILogFactory;
import org.apache.poi.util.POILogger;
import org.apache.poi.util.TempFile;
@Internal
public abstract class ChunkedCipherOutputStream extends FilterOutputStream {
private static final POILogger LOG = POILogFactory.getLogger(ChunkedCipherOutputStream.class);
private static final int MAX_RECORD_LENGTH = 100_000;
private static final int STREAMING = -1;
private final int chunkSize;
private final int chunkBits;
private final byte[] chunk;
private final BitSet plainByteFlags;
private final File fileOut;
private final DirectoryNode dir;
private long pos;
private long totalPos;
private long written;
private Cipher cipher;
private boolean isClosed;
public ChunkedCipherOutputStream(DirectoryNode dir, int chunkSize) throws IOException, GeneralSecurityException {
super(null);
this.chunkSize = chunkSize;
int cs = chunkSize == STREAMING ? 4096 : chunkSize;
this.chunk = IOUtils.safelyAllocate(cs, MAX_RECORD_LENGTH);
this.plainByteFlags = new BitSet(cs);
this.chunkBits = Integer.bitCount(cs-1);
this.fileOut = TempFile.createTempFile("encrypted_package", "crypt");
this.fileOut.deleteOnExit();
this.out = new FileOutputStream(fileOut);
this.dir = dir;
this.cipher = initCipherForBlock(null, 0, false);
}
public ChunkedCipherOutputStream(OutputStream stream, int chunkSize) throws IOException, GeneralSecurityException {
super(stream);
this.chunkSize = chunkSize;
int cs = chunkSize == STREAMING ? 4096 : chunkSize;
this.chunk = IOUtils.safelyAllocate(cs, MAX_RECORD_LENGTH);
this.plainByteFlags = new BitSet(cs);
this.chunkBits = Integer.bitCount(cs-1);
this.fileOut = null;
this.dir = null;
this.cipher = initCipherForBlock(null, 0, false);
}
public final Cipher initCipherForBlock(int block, boolean lastChunk) throws IOException, GeneralSecurityException {
return initCipherForBlock(cipher, block, lastChunk);
}
@Internal
protected Cipher initCipherForBlockNoFlush(Cipher existing, int block, boolean lastChunk)
throws IOException, GeneralSecurityException {
return initCipherForBlock(cipher, block, lastChunk);
}
protected abstract Cipher initCipherForBlock(Cipher existing, int block, boolean lastChunk)
throws IOException, GeneralSecurityException;
protected abstract void calculateChecksum(File fileOut, int oleStreamSize)
throws GeneralSecurityException, IOException;
protected abstract void createEncryptionInfoEntry(DirectoryNode dir, File tmpFile)
throws IOException, GeneralSecurityException;
@Override
public void write(int b) throws IOException {
write(new byte[]{(byte)b});
}
@Override
public void write(byte[] b) throws IOException {
write(b, 0, b.length);
}
@Override
public void write(byte[] b, int off, int len) throws IOException {
write(b, off, len, false);
}
public void writePlain(byte[] b, int off, int len) throws IOException {
write(b, off, len, true);
}
protected void write(byte[] b, int off, int len, boolean writePlain) throws IOException {
if (len == 0) {
return;
}
if (len < 0 || b.length < off+len) {
throw new IOException("not enough bytes in your input buffer");
}
final int chunkMask = getChunkMask();
while (len > 0) {
int posInChunk = (int)(pos & chunkMask);
int nextLen = Math.min(chunk.length-posInChunk, len);
System.arraycopy(b, off, chunk, posInChunk, nextLen);
if (writePlain) {
plainByteFlags.set(posInChunk, posInChunk+nextLen);
}
pos += nextLen;
totalPos += nextLen;
off += nextLen;
len -= nextLen;
if ((pos & chunkMask) == 0) {
writeChunk(len > 0);
}
}
}
protected int getChunkMask() {
return chunk.length-1;
}
protected void writeChunk(boolean continued) throws IOException {
if (pos == 0 || totalPos == written) {
return;
}
int posInChunk = (int)(pos & getChunkMask());
int index = (int)(pos >> chunkBits);
boolean lastChunk;
if (posInChunk==0) {
index--;
posInChunk = chunk.length;
lastChunk = false;
} else {
lastChunk = true;
}
int ciLen;
try {
boolean doFinal = true;
long oldPos = pos;
pos = 0;
if (chunkSize == STREAMING) {
if (continued) {
doFinal = false;
}
} else {
cipher = initCipherForBlock(cipher, index, lastChunk);
pos = oldPos;
}
ciLen = invokeCipher(posInChunk, doFinal);
} catch (GeneralSecurityException e) {
throw new IOException("can't re-/initialize cipher", e);
}
out.write(chunk, 0, ciLen);
plainByteFlags.clear();
written += ciLen;
}
protected int invokeCipher(int posInChunk, boolean doFinal) throws GeneralSecurityException, IOException {
byte[] plain = (plainByteFlags.isEmpty()) ? null : chunk.clone();
int ciLen = (doFinal)
? cipher.doFinal(chunk, 0, posInChunk, chunk)
: cipher.update(chunk, 0, posInChunk, chunk);
if (doFinal && "IBMJCE".equals(cipher.getProvider().getName()) && "RC4".equals(cipher.getAlgorithm())) {
int index = (int)(pos >> chunkBits);
boolean lastChunk;
if (posInChunk==0) {
index--;
posInChunk = chunk.length;
lastChunk = false;
} else {
lastChunk = true;
}
cipher = initCipherForBlockNoFlush(cipher, index, lastChunk);
}
if (plain != null) {
int i = plainByteFlags.nextSetBit(0);
while (i >= 0 && i < posInChunk) {
chunk[i] = plain[i];
i = plainByteFlags.nextSetBit(i+1);
}
}
return ciLen;
}
@Override
public void close() throws IOException {
if (isClosed) {
LOG.log(POILogger.DEBUG, "ChunkedCipherOutputStream was already closed - ignoring");
return;
}
isClosed = true;
try {
writeChunk(false);
super.close();
if (fileOut != null) {
int oleStreamSize = (int)(fileOut.length()+LittleEndianConsts.LONG_SIZE);
calculateChecksum(fileOut, (int)pos);
dir.createDocument(DEFAULT_POIFS_ENTRY, oleStreamSize, new EncryptedPackageWriter());
createEncryptionInfoEntry(dir, fileOut);
}
} catch (GeneralSecurityException e) {
throw new IOException(e);
}
}
protected byte[] getChunk() {
return chunk;
}
protected BitSet getPlainByteFlags() {
return plainByteFlags;
}
protected long getPos() {
return pos;
}
protected long getTotalPos() {
return totalPos;
}
public void setNextRecordSize(int recordSize, boolean isPlain) {
}
private class EncryptedPackageWriter implements POIFSWriterListener {
@Override
public void processPOIFSWriterEvent(POIFSWriterEvent event) {
try {
try (OutputStream os = event.getStream();
FileInputStream fis = new FileInputStream(fileOut)) {
byte[] buf = new byte[LittleEndianConsts.LONG_SIZE];
LittleEndian.putLong(buf, 0, pos);
os.write(buf);
IOUtils.copy(fis, os);
}
if (!fileOut.delete()) {
LOG.log(POILogger.ERROR, "Can't delete temporary encryption file: "+fileOut);
}
} catch (IOException e) {
throw new EncryptedDocumentException(e);
}
}
}
}