package org.apache.poi.poifs.macros;
import static org.apache.poi.util.StringUtil.endsWithIgnoreCase;
import static org.apache.poi.util.StringUtil.startsWithIgnoreCase;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.Closeable;
import java.io.EOFException;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import org.apache.poi.poifs.filesystem.DirectoryNode;
import org.apache.poi.poifs.filesystem.DocumentInputStream;
import org.apache.poi.poifs.filesystem.DocumentNode;
import org.apache.poi.poifs.filesystem.Entry;
import org.apache.poi.poifs.filesystem.FileMagic;
import org.apache.poi.poifs.filesystem.POIFSFileSystem;
import org.apache.poi.poifs.filesystem.OfficeXmlFileException;
import org.apache.poi.poifs.macros.Module.ModuleType;
import org.apache.poi.util.CodePageUtil;
import org.apache.poi.util.HexDump;
import org.apache.poi.util.IOUtils;
import org.apache.poi.util.LittleEndian;
import org.apache.poi.util.POILogFactory;
import org.apache.poi.util.POILogger;
import org.apache.poi.util.RLEDecompressingInputStream;
import org.apache.poi.util.StringUtil;
public class VBAMacroReader implements Closeable {
private static final POILogger LOGGER = POILogFactory.getLogger(VBAMacroReader.class);
private static final int MAX_STRING_LENGTH = 20000;
protected static final String VBA_PROJECT_OOXML = "vbaProject.bin";
protected static final String VBA_PROJECT_POIFS = "VBA";
private POIFSFileSystem fs;
public VBAMacroReader(InputStream rstream) throws IOException {
InputStream is = FileMagic.prepareToCheckMagic(rstream);
FileMagic fm = FileMagic.valueOf(is);
if (fm == FileMagic.OLE2) {
fs = new POIFSFileSystem(is);
} else {
openOOXML(is);
}
}
public VBAMacroReader(File file) throws IOException {
try {
this.fs = new POIFSFileSystem(file);
} catch (OfficeXmlFileException e) {
openOOXML(new FileInputStream(file));
}
}
public VBAMacroReader(POIFSFileSystem fs) {
this.fs = fs;
}
private void openOOXML(InputStream zipFile) throws IOException {
try(ZipInputStream zis = new ZipInputStream(zipFile)) {
ZipEntry zipEntry;
while ((zipEntry = zis.getNextEntry()) != null) {
if (endsWithIgnoreCase(zipEntry.getName(), VBA_PROJECT_OOXML)) {
try {
this.fs = new POIFSFileSystem(zis);
return;
} catch (IOException e) {
zis.close();
throw e;
}
}
}
}
throw new IllegalArgumentException("No VBA project found");
}
public void close() throws IOException {
fs.close();
fs = null;
}
public Map<String, Module> readMacroModules() throws IOException {
final ModuleMap modules = new ModuleMap();
final Map<String, String> moduleNameMap = new LinkedHashMap<>();
findMacros(fs.getRoot(), modules);
findModuleNameMap(fs.getRoot(), moduleNameMap, modules);
findProjectProperties(fs.getRoot(), moduleNameMap, modules);
Map<String, Module> moduleSources = new HashMap<>();
for (Map.Entry<String, ModuleImpl> entry : modules.entrySet()) {
ModuleImpl module = entry.getValue();
module.charset = modules.charset;
moduleSources.put(entry.getKey(), module);
}
return moduleSources;
}
public Map<String, String> readMacros() throws IOException {
Map<String, Module> modules = readMacroModules();
Map<String, String> moduleSources = new HashMap<>();
for (Map.Entry<String, Module> entry : modules.entrySet()) {
moduleSources.put(entry.getKey(), entry.getValue().getContent());
}
return moduleSources;
}
protected static class ModuleImpl implements Module {
Integer offset;
byte[] buf;
ModuleType moduleType;
Charset charset;
void read(InputStream in) throws IOException {
final ByteArrayOutputStream out = new ByteArrayOutputStream();
IOUtils.copy(in, out);
out.close();
buf = out.toByteArray();
}
public String getContent() {
return new String(buf, charset);
}
public ModuleType geModuleType() {
return moduleType;
}
}
protected static class ModuleMap extends HashMap<String, ModuleImpl> {
Charset charset = StringUtil.WIN_1252;
}
protected void findMacros(DirectoryNode dir, ModuleMap modules) throws IOException {
if (VBA_PROJECT_POIFS.equalsIgnoreCase(dir.getName())) {
readMacros(dir, modules);
} else {
for (Entry child : dir) {
if (child instanceof DirectoryNode) {
findMacros((DirectoryNode)child, modules);
}
}
}
}
private static void readModuleMetadataFromDirStream(RLEDecompressingInputStream in, String streamName, ModuleMap modules) throws IOException {
int moduleOffset = in.readInt();
ModuleImpl module = modules.get(streamName);
if (module == null) {
module = new ModuleImpl();
module.offset = moduleOffset;
modules.put(streamName, module);
} else {
InputStream stream = new RLEDecompressingInputStream(
new ByteArrayInputStream(module.buf, moduleOffset, module.buf.length - moduleOffset)
);
module.read(stream);
stream.close();
}
}
private static void readModuleFromDocumentStream(DocumentNode documentNode, String name, ModuleMap modules) throws IOException {
ModuleImpl module = modules.get(name);
if (module == null) {
module = new ModuleImpl();
modules.put(name, module);
try (InputStream dis = new DocumentInputStream(documentNode)) {
module.read(dis);
}
} else if (module.buf == null) {
if (module.offset == null) {
throw new IOException("Module offset for '" + name + "' was never read.");
}
InputStream decompressed = null;
InputStream compressed = new DocumentInputStream(documentNode);
try {
long skippedBytes = compressed.skip(module.offset);
if (skippedBytes != module.offset) {
throw new IOException("tried to skip " + module.offset + " bytes, but actually skipped " + skippedBytes + " bytes");
}
decompressed = new RLEDecompressingInputStream(compressed);
module.read(decompressed);
return;
} catch (IllegalArgumentException | IllegalStateException e) {
} finally {
IOUtils.closeQuietly(compressed);
IOUtils.closeQuietly(decompressed);
}
compressed = new DocumentInputStream(documentNode);
byte[] decompressedBytes;
try {
decompressedBytes = findCompressedStreamWBruteForce(compressed);
} finally {
IOUtils.closeQuietly(compressed);
}
if (decompressedBytes != null) {
module.read(new ByteArrayInputStream(decompressedBytes));
}
}
}
private static void trySkip(InputStream in, long n) throws IOException {
long skippedBytes = IOUtils.skipFully(in, n);
if (skippedBytes != n) {
if (skippedBytes < 0) {
throw new IOException(
"Tried skipping " + n + " bytes, but no bytes were skipped. "
+ "The end of the stream has been reached or the stream is closed.");
} else {
throw new IOException(
"Tried skipping " + n + " bytes, but only " + skippedBytes + " bytes were skipped. "
+ "This should never happen with a non-corrupt file.");
}
}
}
private static final int STREAMNAME_RESERVED = 0x0032;
private static final int PROJECT_CONSTANTS_RESERVED = 0x003C;
private static final int HELP_FILE_PATH_RESERVED = 0x003D;
private static final int REFERENCE_NAME_RESERVED = 0x003E;
private static final int DOC_STRING_RESERVED = 0x0040;
private static final int MODULE_DOCSTRING_RESERVED = 0x0048;
protected void readMacros(DirectoryNode macroDir, ModuleMap modules) throws IOException {
for (String entryName : macroDir.getEntryNames()) {
if ("dir".equalsIgnoreCase(entryName)) {
processDirStream(macroDir.getEntry(entryName), modules);
break;
}
}
for (Entry entry : macroDir) {
if (! (entry instanceof DocumentNode)) { continue; }
String name = entry.getName();
DocumentNode document = (DocumentNode)entry;
if (! "dir".equalsIgnoreCase(name) && !startsWithIgnoreCase(name, "__SRP")
&& !startsWithIgnoreCase(name, "_VBA_PROJECT")) {
readModuleFromDocumentStream(document, name, modules);
}
}
}
protected void findProjectProperties(DirectoryNode node, Map<String, String> moduleNameMap, ModuleMap modules) throws IOException {
for (Entry entry : node) {
if ("project".equalsIgnoreCase(entry.getName())) {
DocumentNode document = (DocumentNode)entry;
try(DocumentInputStream dis = new DocumentInputStream(document)) {
readProjectProperties(dis, moduleNameMap, modules);
return;
}
} else if (entry instanceof DirectoryNode) {
findProjectProperties((DirectoryNode)entry, moduleNameMap, modules);
}
}
}
protected void findModuleNameMap(DirectoryNode node, Map<String, String> moduleNameMap, ModuleMap modules) throws IOException {
for (Entry entry : node) {
if ("projectwm".equalsIgnoreCase(entry.getName())) {
DocumentNode document = (DocumentNode)entry;
try(DocumentInputStream dis = new DocumentInputStream(document)) {
readNameMapRecords(dis, moduleNameMap, modules.charset);
return;
}
} else if (entry.isDirectoryEntry()) {
findModuleNameMap((DirectoryNode)entry, moduleNameMap, modules);
}
}
}
private enum RecordType {
MODULE_OFFSET(0x0031),
PROJECT_SYS_KIND(0x01),
PROJECT_LCID(0x0002),
PROJECT_LCID_INVOKE(0x14),
PROJECT_CODEPAGE(0x0003),
PROJECT_NAME(0x04),
PROJECT_DOC_STRING(0x05),
PROJECT_HELP_FILE_PATH(0x06),
PROJECT_HELP_CONTEXT(0x07, 8),
PROJECT_LIB_FLAGS(0x08),
PROJECT_VERSION(0x09, 10),
PROJECT_CONSTANTS(0x0C),
PROJECT_MODULES(0x0F),
DIR_STREAM_TERMINATOR(0x10),
PROJECT_COOKIE(0x13),
MODULE_NAME(0x19),
MODULE_NAME_UNICODE(0x47),
MODULE_STREAM_NAME(0x1A),
MODULE_DOC_STRING(0x1C),
MODULE_HELP_CONTEXT(0x1E),
MODULE_COOKIE(0x2c),
MODULE_TYPE_PROCEDURAL(0x21, 4),
MODULE_TYPE_OTHER(0x22, 4),
MODULE_PRIVATE(0x28, 4),
REFERENCE_NAME(0x16),
REFERENCE_REGISTERED(0x0D),
REFERENCE_PROJECT(0x0E),
REFERENCE_CONTROL_A(0x2F),
REFERENCE_CONTROL_B(0x33),
MODULE_TERMINATOR(0x002B),
EOF(-1),
UNKNOWN(-2);
private final int VARIABLE_LENGTH = -1;
private final int id;
private final int constantLength;
RecordType(int id) {
this.id = id;
this.constantLength = VARIABLE_LENGTH;
}
RecordType(int id, int constantLength) {
this.id = id;
this.constantLength = constantLength;
}
int getConstantLength() {
return constantLength;
}
static RecordType lookup(int id) {
for (RecordType type : RecordType.values()) {
if (type.id == id) {
return type;
}
}
return UNKNOWN;
}
}
private enum DIR_STATE {
INFORMATION_RECORD,
REFERENCES_RECORD,
MODULES_RECORD
}
private static class ASCIIUnicodeStringPair {
private final String ascii;
private final String unicode;
private final int pushbackRecordId;
ASCIIUnicodeStringPair(String ascii, int pushbackRecordId) {
this.ascii = ascii;
this.unicode = "";
this.pushbackRecordId = pushbackRecordId;
}
ASCIIUnicodeStringPair(String ascii, String unicode) {
this.ascii = ascii;
this.unicode = unicode;
pushbackRecordId = -1;
}
private String getAscii() {
return ascii;
}
private String getUnicode() {
return unicode;
}
private int getPushbackRecordId() {
return pushbackRecordId;
}
}
private void processDirStream(Entry dir, ModuleMap modules) throws IOException {
DocumentNode dirDocumentNode = (DocumentNode)dir;
DIR_STATE dirState = DIR_STATE.INFORMATION_RECORD;
try (DocumentInputStream dis = new DocumentInputStream(dirDocumentNode)) {
String streamName = null;
int recordId = 0;
try (RLEDecompressingInputStream in = new RLEDecompressingInputStream(dis)) {
while (true) {
recordId = in.readShort();
if (recordId == -1) {
break;
}
RecordType type = RecordType.lookup(recordId);
if (type.equals(RecordType.EOF) || type.equals(RecordType.DIR_STREAM_TERMINATOR)) {
break;
}
switch (type) {
case PROJECT_VERSION:
trySkip(in, RecordType.PROJECT_VERSION.getConstantLength());
break;
case PROJECT_CODEPAGE:
in.readInt();
int codepage = in.readShort();
modules.charset = Charset.forName(CodePageUtil.codepageToEncoding(codepage, true));
break;
case MODULE_STREAM_NAME:
ASCIIUnicodeStringPair pair = readStringPair(in, modules.charset, STREAMNAME_RESERVED);
streamName = pair.getAscii();
break;
case PROJECT_DOC_STRING:
readStringPair(in, modules.charset, DOC_STRING_RESERVED);
break;
case PROJECT_HELP_FILE_PATH:
readStringPair(in, modules.charset, HELP_FILE_PATH_RESERVED);
break;
case PROJECT_CONSTANTS:
readStringPair(in, modules.charset, PROJECT_CONSTANTS_RESERVED);
break;
case REFERENCE_NAME:
if (dirState.equals(DIR_STATE.INFORMATION_RECORD)) {
dirState = DIR_STATE.REFERENCES_RECORD;
}
ASCIIUnicodeStringPair stringPair = readStringPair(in,
modules.charset, REFERENCE_NAME_RESERVED, false);
if (stringPair.getPushbackRecordId() == -1) {
break;
}
if (stringPair.getPushbackRecordId() != RecordType.REFERENCE_REGISTERED.id) {
throw new IllegalArgumentException("Unexpected reserved character. "+
"Expected "+Integer.toHexString(REFERENCE_NAME_RESERVED)
+ " or "+Integer.toHexString(RecordType.REFERENCE_REGISTERED.id)+
" not: "+Integer.toHexString(stringPair.getPushbackRecordId()));
}
case REFERENCE_REGISTERED:
int recLength = in.readInt();
trySkip(in, recLength);
break;
case MODULE_DOC_STRING:
int modDocStringLength = in.readInt();
readString(in, modDocStringLength, modules.charset);
int modDocStringReserved = in.readShort();
if (modDocStringReserved != MODULE_DOCSTRING_RESERVED) {
throw new IOException("Expected x003C after stream name before Unicode stream name, but found: " +
Integer.toHexString(modDocStringReserved));
}
int unicodeModDocStringLength = in.readInt();
readUnicodeString(in, unicodeModDocStringLength);
break;
case MODULE_OFFSET:
int modOffsetSz = in.readInt();
readModuleMetadataFromDirStream(in, streamName, modules);
break;
case PROJECT_MODULES:
dirState = DIR_STATE.MODULES_RECORD;
in.readInt();
in.readShort();
break;
case REFERENCE_CONTROL_A:
int szTwiddled = in.readInt();
trySkip(in, szTwiddled);
int nextRecord = in.readShort();
if (nextRecord == RecordType.REFERENCE_NAME.id) {
readStringPair(in, modules.charset, REFERENCE_NAME_RESERVED);
nextRecord = in.readShort();
}
if (nextRecord != 0x30) {
throw new IOException("Expected 0x30 as Reserved3 in a ReferenceControl record");
}
int szExtended = in.readInt();
trySkip(in, szExtended);
break;
case MODULE_TERMINATOR:
int endOfModulesReserved = in.readInt();
break;
default:
if (type.getConstantLength() > -1) {
trySkip(in, type.getConstantLength());
} else {
int recordLength = in.readInt();
trySkip(in, recordLength);
}
break;
}
}
} catch (final IOException e) {
throw new IOException(
"Error occurred while reading macros at section id "
+ recordId + " (" + HexDump.shortToHex(recordId) + ")", e);
}
}
}
private ASCIIUnicodeStringPair readStringPair(RLEDecompressingInputStream in,
Charset charset, int reservedByte) throws IOException {
return readStringPair(in, charset, reservedByte, true);
}
private ASCIIUnicodeStringPair readStringPair(RLEDecompressingInputStream in,
Charset charset, int reservedByte,
boolean throwOnUnexpectedReservedByte) throws IOException {
int nameLength = in.readInt();
String ascii = readString(in, nameLength, charset);
int reserved = in.readShort();
if (reserved != reservedByte) {
if (throwOnUnexpectedReservedByte) {
throw new IOException("Expected " + Integer.toHexString(reservedByte) +
"after name before Unicode name, but found: " +
Integer.toHexString(reserved));
} else {
return new ASCIIUnicodeStringPair(ascii, reserved);
}
}
int unicodeNameRecordLength = in.readInt();
String unicode = readUnicodeString(in, unicodeNameRecordLength);
return new ASCIIUnicodeStringPair(ascii, unicode);
}
protected void readNameMapRecords(InputStream is,
Map<String, String> moduleNames, Charset charset) throws IOException {
String mbcs = null;
String unicode = null;
final int maxNameRecords = 10000;
int records = 0;
while (++records < maxNameRecords) {
try {
int b = IOUtils.readByte(is);
if (b == 0) {
b = IOUtils.readByte(is);
if (b == 0) {
return;
}
}
mbcs = readMBCS(b, is, charset, MAX_STRING_LENGTH);
} catch (EOFException e) {
return;
}
try {
unicode = readUnicode(is, MAX_STRING_LENGTH);
} catch (EOFException e) {
return;
}
if (mbcs.trim().length() > 0 && unicode.trim().length() > 0) {
moduleNames.put(mbcs, unicode);
}
}
if (records >= maxNameRecords) {
LOGGER.log(POILogger.WARN, "Hit max name records to read ("+maxNameRecords+"). Stopped early.");
}
}
private static String readUnicode(InputStream is, int maxLength) throws IOException {
ByteArrayOutputStream bos = new ByteArrayOutputStream();
int b0 = IOUtils.readByte(is);
int b1 = IOUtils.readByte(is);
int read = 2;
while ((b0 + b1) != 0 && read < maxLength) {
bos.write(b0);
bos.write(b1);
b0 = IOUtils.readByte(is);
b1 = IOUtils.readByte(is);
read += 2;
}
if (read >= maxLength) {
LOGGER.log(POILogger.WARN, "stopped reading unicode name after "+read+" bytes");
}
return new String (bos.toByteArray(), StandardCharsets.UTF_16LE);
}
private static String readMBCS(int firstByte, InputStream is, Charset charset, int maxLength) throws IOException {
ByteArrayOutputStream bos = new ByteArrayOutputStream();
int len = 0;
int b = firstByte;
while (b > 0 && len < maxLength) {
++len;
bos.write(b);
b = IOUtils.readByte(is);
}
return new String(bos.toByteArray(), charset);
}
private static String readString(InputStream stream, int length, Charset charset) throws IOException {
byte[] buffer = IOUtils.safelyAllocate(length, MAX_STRING_LENGTH);
int bytesRead = IOUtils.readFully(stream, buffer);
if (bytesRead != length) {
throw new IOException("Tried to read: "+length +
", but could only read: "+bytesRead);
}
return new String(buffer, 0, length, charset);
}
protected void readProjectProperties(DocumentInputStream dis,
Map<String, String> moduleNameMap, ModuleMap modules) throws IOException {
InputStreamReader reader = new InputStreamReader(dis, modules.charset);
StringBuilder builder = new StringBuilder();
char[] buffer = new char[512];
int read;
while ((read = reader.read(buffer)) >= 0) {
builder.append(buffer, 0, read);
}
String properties = builder.toString();
for (String line : properties.split("\r\n|\n\r")) {
if (!line.startsWith("[")) {
String[] tokens = line.split("=");
if (tokens.length > 1 && tokens[1].length() > 1
&& tokens[1].startsWith("\"") && tokens[1].endsWith("\"")) {
tokens[1] = tokens[1].substring(1, tokens[1].length() - 1);
}
if ("Document".equals(tokens[0]) && tokens.length > 1) {
String mn = tokens[1].substring(0, tokens[1].indexOf("/&H"));
ModuleImpl module = getModule(mn, moduleNameMap, modules);
if (module != null) {
module.moduleType = ModuleType.Document;
} else {
LOGGER.log(POILogger.WARN, "couldn't find module with name: "+mn);
}
} else if ("Module".equals(tokens[0]) && tokens.length > 1) {
ModuleImpl module = getModule(tokens[1], moduleNameMap, modules);
if (module != null) {
module.moduleType = ModuleType.Module;
} else {
LOGGER.log(POILogger.WARN, "couldn't find module with name: "+tokens[1]);
}
} else if ("Class".equals(tokens[0]) && tokens.length > 1) {
ModuleImpl module = getModule(tokens[1], moduleNameMap, modules);
if (module != null) {
module.moduleType = ModuleType.Class;
} else {
LOGGER.log(POILogger.WARN, "couldn't find module with name: "+tokens[1]);
}
}
}
}
}
private ModuleImpl getModule(String moduleName, Map<String, String> moduleNameMap, ModuleMap moduleMap) {
if (moduleNameMap.containsKey(moduleName)) {
return moduleMap.get(moduleNameMap.get(moduleName));
}
return moduleMap.get(moduleName);
}
private String readUnicodeString(RLEDecompressingInputStream in, int unicodeNameRecordLength) throws IOException {
byte[] buffer = IOUtils.safelyAllocate(unicodeNameRecordLength, MAX_STRING_LENGTH);
int bytesRead = IOUtils.readFully(in, buffer);
if (bytesRead != unicodeNameRecordLength) {
throw new EOFException();
}
return new String(buffer, StringUtil.UTF16LE);
}
private static byte[] findCompressedStreamWBruteForce(InputStream is) throws IOException {
ByteArrayOutputStream bos = new ByteArrayOutputStream();
IOUtils.copy(is, bos);
byte[] compressed = bos.toByteArray();
byte[] decompressed = null;
for (int i = 0; i < compressed.length; i++) {
if (compressed[i] == 0x01 && i < compressed.length-1) {
int w = LittleEndian.getUShort(compressed, i+1);
if (w <= 0 || (w & 0x7000) != 0x3000) {
continue;
}
decompressed = tryToDecompress(new ByteArrayInputStream(compressed, i, compressed.length - i));
if (decompressed != null) {
if (decompressed.length > 9) {
int firstX = Math.min(20, decompressed.length);
String start = new String(decompressed, 0, firstX, StringUtil.WIN_1252);
if (start.contains("Attribute")) {
return decompressed;
}
}
}
}
}
return decompressed;
}
private static byte[] tryToDecompress(InputStream is) {
ByteArrayOutputStream bos = new ByteArrayOutputStream();
try {
IOUtils.copy(new RLEDecompressingInputStream(is), bos);
} catch (IllegalArgumentException | IOException | IllegalStateException e){
return null;
}
return bos.toByteArray();
}
}