/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.lucene.classification.utils;


import java.io.IOException;

import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.grouping.GroupDocs;
import org.apache.lucene.search.grouping.GroupingSearch;
import org.apache.lucene.search.grouping.TopGroups;
import org.apache.lucene.store.Directory;

Utility class for creating training / test / cross validation indexes from the original index.
/** * Utility class for creating training / test / cross validation indexes from the original index. */
public class DatasetSplitter { private final double crossValidationRatio; private final double testRatio;
Create a DatasetSplitter by giving test and cross validation IDXs sizes
Params:
  • testRatio – the ratio of the original index to be used for the test IDX as a double between 0.0 and 1.0
  • crossValidationRatio – the ratio of the original index to be used for the c.v. IDX as a double between 0.0 and 1.0
/** * Create a {@link DatasetSplitter} by giving test and cross validation IDXs sizes * * @param testRatio the ratio of the original index to be used for the test IDX as a <code>double</code> between 0.0 and 1.0 * @param crossValidationRatio the ratio of the original index to be used for the c.v. IDX as a <code>double</code> between 0.0 and 1.0 */
public DatasetSplitter(double testRatio, double crossValidationRatio) { this.crossValidationRatio = crossValidationRatio; this.testRatio = testRatio; }
Split a given index into 3 indexes for training, test and cross validation tasks respectively
Params:
  • originalIndex – an LeafReader on the source index
  • trainingIndex – a Directory used to write the training index
  • testIndex – a Directory used to write the test index
  • crossValidationIndex – a Directory used to write the cross validation index
  • analyzer – Analyzer used to create the new docs
  • termVectors – true if term vectors should be kept
  • classFieldName – name of the field used as the label for classification; this must be indexed with sorted doc values
  • fieldNames – names of fields that need to be put in the new indexes or null if all should be used
Throws:
  • IOException – if any writing operation fails on any of the indexes
/** * Split a given index into 3 indexes for training, test and cross validation tasks respectively * * @param originalIndex an {@link org.apache.lucene.index.LeafReader} on the source index * @param trainingIndex a {@link Directory} used to write the training index * @param testIndex a {@link Directory} used to write the test index * @param crossValidationIndex a {@link Directory} used to write the cross validation index * @param analyzer {@link Analyzer} used to create the new docs * @param termVectors {@code true} if term vectors should be kept * @param classFieldName name of the field used as the label for classification; this must be indexed with sorted doc values * @param fieldNames names of fields that need to be put in the new indexes or <code>null</code> if all should be used * @throws IOException if any writing operation fails on any of the indexes */
public void split(IndexReader originalIndex, Directory trainingIndex, Directory testIndex, Directory crossValidationIndex, Analyzer analyzer, boolean termVectors, String classFieldName, String... fieldNames) throws IOException { // create IWs for train / test / cv IDXs IndexWriter testWriter = new IndexWriter(testIndex, new IndexWriterConfig(analyzer)); IndexWriter cvWriter = new IndexWriter(crossValidationIndex, new IndexWriterConfig(analyzer)); IndexWriter trainingWriter = new IndexWriter(trainingIndex, new IndexWriterConfig(analyzer)); // get the exact no. of existing classes int noOfClasses = 0; for (LeafReaderContext leave : originalIndex.leaves()) { long valueCount = 0; SortedDocValues classValues = leave.reader().getSortedDocValues(classFieldName); if (classValues != null) { valueCount = classValues.getValueCount(); } else { SortedSetDocValues sortedSetDocValues = leave.reader().getSortedSetDocValues(classFieldName); if (sortedSetDocValues != null) { valueCount = sortedSetDocValues.getValueCount(); } } if (classValues == null) { // approximate with no. of terms noOfClasses += leave.reader().terms(classFieldName).size(); } noOfClasses += valueCount; } try { IndexSearcher indexSearcher = new IndexSearcher(originalIndex); GroupingSearch gs = new GroupingSearch(classFieldName); gs.setGroupSort(Sort.INDEXORDER); gs.setSortWithinGroup(Sort.INDEXORDER); gs.setAllGroups(true); gs.setGroupDocsLimit(originalIndex.maxDoc()); TopGroups<Object> topGroups = gs.search(indexSearcher, new MatchAllDocsQuery(), 0, noOfClasses); // set the type to be indexed, stored, with term vectors FieldType ft = new FieldType(TextField.TYPE_STORED); if (termVectors) { ft.setStoreTermVectors(true); ft.setStoreTermVectorOffsets(true); ft.setStoreTermVectorPositions(true); } int b = 0; // iterate over existing documents for (GroupDocs<Object> group : topGroups.groups) { assert group.totalHits.relation == TotalHits.Relation.EQUAL_TO; long totalHits = group.totalHits.value; double testSize = totalHits * testRatio; int tc = 0; double cvSize = totalHits * crossValidationRatio; int cvc = 0; for (ScoreDoc scoreDoc : group.scoreDocs) { // create a new document for indexing Document doc = createNewDoc(originalIndex, ft, scoreDoc, fieldNames); // add it to one of the IDXs if (b % 2 == 0 && tc < testSize) { testWriter.addDocument(doc); tc++; } else if (cvc < cvSize) { cvWriter.addDocument(doc); cvc++; } else { trainingWriter.addDocument(doc); } b++; } } // commit testWriter.commit(); cvWriter.commit(); trainingWriter.commit(); // merge testWriter.forceMerge(3); cvWriter.forceMerge(3); trainingWriter.forceMerge(3); } catch (Exception e) { throw new IOException(e); } finally { // close IWs testWriter.close(); cvWriter.close(); trainingWriter.close(); originalIndex.close(); } } private Document createNewDoc(IndexReader originalIndex, FieldType ft, ScoreDoc scoreDoc, String[] fieldNames) throws IOException { Document doc = new Document(); Document document = originalIndex.document(scoreDoc.doc); if (fieldNames != null && fieldNames.length > 0) { for (String fieldName : fieldNames) { IndexableField field = document.getField(fieldName); if (field != null) { doc.add(new Field(fieldName, field.stringValue(), ft)); } } } else { for (IndexableField field : document.getFields()) { if (field.readerValue() != null) { doc.add(new Field(field.name(), field.readerValue(), ft)); } else if (field.binaryValue() != null) { doc.add(new Field(field.name(), field.binaryValue(), ft)); } else if (field.stringValue() != null) { doc.add(new Field(field.name(), field.stringValue(), ft)); } else if (field.numericValue() != null) { doc.add(new Field(field.name(), field.numericValue().toString(), ft)); } } } return doc; } }