/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.lucene.classification;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import java.util.concurrent.ConcurrentSkipListMap;

import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.MultiTerms;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.IntsRefBuilder;
import org.apache.lucene.util.fst.Builder;
import org.apache.lucene.util.fst.FST;
import org.apache.lucene.util.fst.PositiveIntOutputs;
import org.apache.lucene.util.fst.Util;

A perceptron (see http://en.wikipedia.org/wiki/Perceptron) based Boolean Classifier. The weights are calculated using TermsEnum.totalTermFreq both on a per field and a per document basis and then a corresponding FST is used for class assignment.
@lucene.experimental
/** * A perceptron (see <code>http://en.wikipedia.org/wiki/Perceptron</code>) based * <code>Boolean</code> {@link org.apache.lucene.classification.Classifier}. The * weights are calculated using * {@link org.apache.lucene.index.TermsEnum#totalTermFreq} both on a per field * and a per document basis and then a corresponding * {@link org.apache.lucene.util.fst.FST} is used for class assignment. * * @lucene.experimental */
public class BooleanPerceptronClassifier implements Classifier<Boolean> { private final Double bias; private final Terms textTerms; private final Analyzer analyzer; private final String textFieldName; private FST<Long> fst;
Params:
  • indexReader – the reader on the index to be used for classification
  • analyzer – an Analyzer used to analyze unseen text
  • query – a Query to eventually filter the docs used for training the classifier, or null if all the indexed docs should be used
  • batchSize – the size of the batch of docs to use for updating the perceptron weights
  • bias – the bias used for class separation
  • classFieldName – the name of the field used as the output for the classifier
  • textFieldName – the name of the field used as input for the classifier
Throws:
/** * Creates a {@link BooleanPerceptronClassifier} * * @param indexReader the reader on the index to be used for classification * @param analyzer an {@link Analyzer} used to analyze unseen text * @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null} * if all the indexed docs should be used * @param batchSize the size of the batch of docs to use for updating the perceptron weights * @param bias the bias used for class separation * @param classFieldName the name of the field used as the output for the classifier * @param textFieldName the name of the field used as input for the classifier * @throws IOException if the building of the underlying {@link FST} fails and / or {@link TermsEnum} for the text field * cannot be found */
public BooleanPerceptronClassifier(IndexReader indexReader, Analyzer analyzer, Query query, Integer batchSize, Double bias, String classFieldName, String textFieldName) throws IOException { this.textTerms = MultiTerms.getTerms(indexReader, textFieldName); if (textTerms == null) { throw new IOException("term vectors need to be available for field " + textFieldName); } this.analyzer = analyzer; this.textFieldName = textFieldName; if (bias == null || bias == 0d) { // automatic assign the bias to be the average total term freq double t = (double) indexReader.getSumTotalTermFreq(textFieldName) / (double) indexReader.getDocCount(textFieldName); if (t != -1) { this.bias = t; } else { throw new IOException( "bias cannot be assigned since term vectors for field " + textFieldName + " do not exist"); } } else { this.bias = bias; } // TODO : remove this map as soon as we have a writable FST SortedMap<String, Double> weights = new ConcurrentSkipListMap<>(); TermsEnum termsEnum = textTerms.iterator(); BytesRef textTerm; while ((textTerm = termsEnum.next()) != null) { weights.put(textTerm.utf8ToString(), (double) termsEnum.totalTermFreq()); } updateFST(weights); IndexSearcher indexSearcher = new IndexSearcher(indexReader); int batchCount = 0; BooleanQuery.Builder q = new BooleanQuery.Builder(); q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, "*")), BooleanClause.Occur.MUST)); if (query != null) { q.add(new BooleanClause(query, BooleanClause.Occur.MUST)); } // run the search and use stored field values for (ScoreDoc scoreDoc : indexSearcher.search(q.build(), Integer.MAX_VALUE).scoreDocs) { Document doc = indexSearcher.doc(scoreDoc.doc); IndexableField textField = doc.getField(textFieldName); // get the expected result IndexableField classField = doc.getField(classFieldName); if (textField != null && classField != null) { // assign class to the doc ClassificationResult<Boolean> classificationResult = assignClass(textField.stringValue()); Boolean assignedClass = classificationResult.getAssignedClass(); Boolean correctClass = Boolean.valueOf(classField.stringValue()); long modifier = correctClass.compareTo(assignedClass); if (modifier != 0) { updateWeights(indexReader, scoreDoc.doc, assignedClass, weights, modifier, batchCount % batchSize == 0); } batchCount++; } } weights.clear(); // free memory while waiting for GC } private void updateWeights(IndexReader indexReader, int docId, Boolean assignedClass, SortedMap<String, Double> weights, double modifier, boolean updateFST) throws IOException { TermsEnum cte = textTerms.iterator(); // get the doc term vectors Terms terms = indexReader.getTermVector(docId, textFieldName); if (terms == null) { throw new IOException("term vectors must be stored for field " + textFieldName); } TermsEnum termsEnum = terms.iterator(); BytesRef term; while ((term = termsEnum.next()) != null) { cte.seekExact(term); if (assignedClass != null) { long termFreqLocal = termsEnum.totalTermFreq(); // update weights Long previousValue = Util.get(fst, term); String termString = term.utf8ToString(); weights.put(termString, previousValue == null ? 0 : Math.max(0, previousValue + modifier * termFreqLocal)); } } if (updateFST) { updateFST(weights); } } private void updateFST(SortedMap<String, Double> weights) throws IOException { PositiveIntOutputs outputs = PositiveIntOutputs.getSingleton(); Builder<Long> fstBuilder = new Builder<>(FST.INPUT_TYPE.BYTE1, outputs); BytesRefBuilder scratchBytes = new BytesRefBuilder(); IntsRefBuilder scratchInts = new IntsRefBuilder(); for (Map.Entry<String, Double> entry : weights.entrySet()) { scratchBytes.copyChars(entry.getKey()); fstBuilder.add(Util.toIntsRef(scratchBytes.get(), scratchInts), entry .getValue().longValue()); } fst = fstBuilder.finish(); } @Override public ClassificationResult<Boolean> assignClass(String text) throws IOException { Long output = 0L; try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) { CharTermAttribute charTermAttribute = tokenStream .addAttribute(CharTermAttribute.class); tokenStream.reset(); while (tokenStream.incrementToken()) { String s = charTermAttribute.toString(); Long d = Util.get(fst, new BytesRef(s)); if (d != null) { output += d; } } tokenStream.end(); } double score = 1 - Math.exp(-1 * Math.abs(bias - output.doubleValue()) / bias); return new ClassificationResult<>(output >= bias, score); } @Override public List<ClassificationResult<Boolean>> getClasses(String text) throws IOException { return null; } @Override public List<ClassificationResult<Boolean>> getClasses(String text, int max) throws IOException { return null; } }