package org.apache.lucene.classification.utils;
import java.io.IOException;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.grouping.GroupDocs;
import org.apache.lucene.search.grouping.GroupingSearch;
import org.apache.lucene.search.grouping.TopGroups;
import org.apache.lucene.store.Directory;
public class DatasetSplitter {
private final double crossValidationRatio;
private final double testRatio;
public DatasetSplitter(double testRatio, double crossValidationRatio) {
this.crossValidationRatio = crossValidationRatio;
this.testRatio = testRatio;
}
public void split(IndexReader originalIndex, Directory trainingIndex, Directory testIndex, Directory crossValidationIndex,
Analyzer analyzer, boolean termVectors, String classFieldName, String... fieldNames) throws IOException {
IndexWriter testWriter = new IndexWriter(testIndex, new IndexWriterConfig(analyzer));
IndexWriter cvWriter = new IndexWriter(crossValidationIndex, new IndexWriterConfig(analyzer));
IndexWriter trainingWriter = new IndexWriter(trainingIndex, new IndexWriterConfig(analyzer));
int noOfClasses = 0;
for (LeafReaderContext leave : originalIndex.leaves()) {
long valueCount = 0;
SortedDocValues classValues = leave.reader().getSortedDocValues(classFieldName);
if (classValues != null) {
valueCount = classValues.getValueCount();
} else {
SortedSetDocValues sortedSetDocValues = leave.reader().getSortedSetDocValues(classFieldName);
if (sortedSetDocValues != null) {
valueCount = sortedSetDocValues.getValueCount();
}
}
if (classValues == null) {
noOfClasses += leave.reader().terms(classFieldName).size();
}
noOfClasses += valueCount;
}
try {
IndexSearcher indexSearcher = new IndexSearcher(originalIndex);
GroupingSearch gs = new GroupingSearch(classFieldName);
gs.setGroupSort(Sort.INDEXORDER);
gs.setSortWithinGroup(Sort.INDEXORDER);
gs.setAllGroups(true);
gs.setGroupDocsLimit(originalIndex.maxDoc());
TopGroups<Object> topGroups = gs.search(indexSearcher, new MatchAllDocsQuery(), 0, noOfClasses);
FieldType ft = new FieldType(TextField.TYPE_STORED);
if (termVectors) {
ft.setStoreTermVectors(true);
ft.setStoreTermVectorOffsets(true);
ft.setStoreTermVectorPositions(true);
}
int b = 0;
for (GroupDocs<Object> group : topGroups.groups) {
assert group.totalHits.relation == TotalHits.Relation.EQUAL_TO;
long totalHits = group.totalHits.value;
double testSize = totalHits * testRatio;
int tc = 0;
double cvSize = totalHits * crossValidationRatio;
int cvc = 0;
for (ScoreDoc scoreDoc : group.scoreDocs) {
Document doc = createNewDoc(originalIndex, ft, scoreDoc, fieldNames);
if (b % 2 == 0 && tc < testSize) {
testWriter.addDocument(doc);
tc++;
} else if (cvc < cvSize) {
cvWriter.addDocument(doc);
cvc++;
} else {
trainingWriter.addDocument(doc);
}
b++;
}
}
testWriter.commit();
cvWriter.commit();
trainingWriter.commit();
testWriter.forceMerge(3);
cvWriter.forceMerge(3);
trainingWriter.forceMerge(3);
} catch (Exception e) {
throw new IOException(e);
} finally {
testWriter.close();
cvWriter.close();
trainingWriter.close();
originalIndex.close();
}
}
private Document createNewDoc(IndexReader originalIndex, FieldType ft, ScoreDoc scoreDoc, String[] fieldNames) throws IOException {
Document doc = new Document();
Document document = originalIndex.document(scoreDoc.doc);
if (fieldNames != null && fieldNames.length > 0) {
for (String fieldName : fieldNames) {
IndexableField field = document.getField(fieldName);
if (field != null) {
doc.add(new Field(fieldName, field.stringValue(), ft));
}
}
} else {
for (IndexableField field : document.getFields()) {
if (field.readerValue() != null) {
doc.add(new Field(field.name(), field.readerValue(), ft));
} else if (field.binaryValue() != null) {
doc.add(new Field(field.name(), field.binaryValue(), ft));
} else if (field.stringValue() != null) {
doc.add(new Field(field.name(), field.stringValue(), ft));
} else if (field.numericValue() != null) {
doc.add(new Field(field.name(), field.numericValue().toString(), ft));
}
}
}
return doc;
}
}