/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.lucene.classification;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;

import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.MultiTerms;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.util.BytesRef;

A simplistic Lucene based NaiveBayes classifier, see http://en.wikipedia.org/wiki/Naive_Bayes_classifier
@lucene.experimental
/** * A simplistic Lucene based NaiveBayes classifier, see <code>http://en.wikipedia.org/wiki/Naive_Bayes_classifier</code> * * @lucene.experimental */
public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
IndexReader used to access the Classifier's index
/** * {@link org.apache.lucene.index.IndexReader} used to access the {@link org.apache.lucene.classification.Classifier}'s * index */
protected final IndexReader indexReader;
names of the fields to be used as input text
/** * names of the fields to be used as input text */
protected final String[] textFieldNames;
name of the field to be used as a class / category output
/** * name of the field to be used as a class / category output */
protected final String classFieldName;
Analyzer to be used for tokenizing unseen input text
/** * {@link org.apache.lucene.analysis.Analyzer} to be used for tokenizing unseen input text */
protected final Analyzer analyzer;
IndexSearcher to run searches on the index for retrieving frequencies
/** * {@link org.apache.lucene.search.IndexSearcher} to run searches on the index for retrieving frequencies */
protected final IndexSearcher indexSearcher;
Query used to eventually filter the document set to be used to classify
/** * {@link org.apache.lucene.search.Query} used to eventually filter the document set to be used to classify */
protected final Query query;
Creates a new NaiveBayes classifier.
Params:
  • indexReader – the reader on the index to be used for classification
  • analyzer – an Analyzer used to analyze unseen text
  • query – a Query to eventually filter the docs used for training the classifier, or null if all the indexed docs should be used
  • classFieldName – the name of the field used as the output for the classifier NOTE: must not be havely analyzed as the returned class will be a token indexed for this field
  • textFieldNames – the name of the fields used as the inputs for the classifier, NO boosting supported per field
/** * Creates a new NaiveBayes classifier. * * @param indexReader the reader on the index to be used for classification * @param analyzer an {@link Analyzer} used to analyze unseen text * @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null} * if all the indexed docs should be used * @param classFieldName the name of the field used as the output for the classifier NOTE: must not be havely analyzed * as the returned class will be a token indexed for this field * @param textFieldNames the name of the fields used as the inputs for the classifier, NO boosting supported per field */
public SimpleNaiveBayesClassifier(IndexReader indexReader, Analyzer analyzer, Query query, String classFieldName, String... textFieldNames) { this.indexReader = indexReader; this.indexSearcher = new IndexSearcher(this.indexReader); this.textFieldNames = textFieldNames; this.classFieldName = classFieldName; this.analyzer = analyzer; this.query = query; } @Override public ClassificationResult<BytesRef> assignClass(String inputDocument) throws IOException { List<ClassificationResult<BytesRef>> assignedClasses = assignClassNormalizedList(inputDocument); ClassificationResult<BytesRef> assignedClass = null; double maxscore = -Double.MAX_VALUE; for (ClassificationResult<BytesRef> c : assignedClasses) { if (c.getScore() > maxscore) { assignedClass = c; maxscore = c.getScore(); } } return assignedClass; } @Override public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException { List<ClassificationResult<BytesRef>> assignedClasses = assignClassNormalizedList(text); Collections.sort(assignedClasses); return assignedClasses; } @Override public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException { List<ClassificationResult<BytesRef>> assignedClasses = assignClassNormalizedList(text); Collections.sort(assignedClasses); return assignedClasses.subList(0, max); }
Calculate probabilities for all classes for a given input text
Params:
  • inputDocument – the input text as a String
Throws:
Returns:a List of ClassificationResult, one for each existing class
/** * Calculate probabilities for all classes for a given input text * @param inputDocument the input text as a {@code String} * @return a {@code List} of {@code ClassificationResult}, one for each existing class * @throws IOException if assigning probabilities fails */
protected List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException { List<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<>(); Terms classes = MultiTerms.getTerms(indexReader, classFieldName); if (classes != null) { TermsEnum classesEnum = classes.iterator(); BytesRef next; String[] tokenizedText = tokenize(inputDocument); int docsWithClassSize = countDocsWithClass(); while ((next = classesEnum.next()) != null) { if (next.length > 0) { Term term = new Term(this.classFieldName, next); double clVal = calculateLogPrior(term, docsWithClassSize) + calculateLogLikelihood(tokenizedText, term, docsWithClassSize); assignedClasses.add(new ClassificationResult<>(term.bytes(), clVal)); } } } // normalization; the values transforms to a 0-1 range return normClassificationResults(assignedClasses); }
count the number of documents in the index having at least a value for the 'class' field
Throws:
  • IOException – if accessing to term vectors or search fails
Returns:the no. of documents having a value for the 'class' field
/** * count the number of documents in the index having at least a value for the 'class' field * * @return the no. of documents having a value for the 'class' field * @throws IOException if accessing to term vectors or search fails */
protected int countDocsWithClass() throws IOException { Terms terms = MultiTerms.getTerms(this.indexReader, this.classFieldName); int docCount; if (terms == null || terms.getDocCount() == -1) { // in case codec doesn't support getDocCount TotalHitCountCollector classQueryCountCollector = new TotalHitCountCollector(); BooleanQuery.Builder q = new BooleanQuery.Builder(); q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))), BooleanClause.Occur.MUST)); if (query != null) { q.add(query, BooleanClause.Occur.MUST); } indexSearcher.search(q.build(), classQueryCountCollector); docCount = classQueryCountCollector.getTotalHits(); } else { docCount = terms.getDocCount(); } return docCount; }
tokenize a String on this classifier's text fields and analyzer
Params:
  • text – the String representing an input text (to be classified)
Throws:
Returns:a String array of the resulting tokens
/** * tokenize a <code>String</code> on this classifier's text fields and analyzer * * @param text the <code>String</code> representing an input text (to be classified) * @return a <code>String</code> array of the resulting tokens * @throws IOException if tokenization fails */
protected String[] tokenize(String text) throws IOException { Collection<String> result = new LinkedList<>(); for (String textFieldName : textFieldNames) { try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) { CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class); tokenStream.reset(); while (tokenStream.incrementToken()) { result.add(charTermAttribute.toString()); } tokenStream.end(); } } return result.toArray(new String[result.size()]); } private double calculateLogLikelihood(String[] tokenizedText, Term term, int docsWithClass) throws IOException { // for each word double result = 0d; for (String word : tokenizedText) { // search with text:word AND class:c int hits = getWordFreqForClass(word, term); // num : count the no of times the word appears in documents of class c (+1) double num = hits + 1; // +1 is added because of add 1 smoothing // den : for the whole dictionary, count the no of times a word appears in documents of class c (+|V|) double den = getTextTermFreqForClass(term) + docsWithClass; // P(w|c) = num/den double wordProbability = num / den; result += Math.log(wordProbability); } // log(P(d|c)) = log(P(w1|c))+...+log(P(wn|c)) return result; }
Returns the average number of unique terms times the number of docs belonging to the input class
Params:
  • term – the term representing the class
Throws:
Returns:the average number of unique terms
/** * Returns the average number of unique terms times the number of docs belonging to the input class * @param term the term representing the class * @return the average number of unique terms * @throws IOException if a low level I/O problem happens */
private double getTextTermFreqForClass(Term term) throws IOException { double avgNumberOfUniqueTerms = 0; for (String textFieldName : textFieldNames) { Terms terms = MultiTerms.getTerms(indexReader, textFieldName); long numPostings = terms.getSumDocFreq(); // number of term/doc pairs avgNumberOfUniqueTerms += numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc } int docsWithC = indexReader.docFreq(term); return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c }
Returns the number of documents of the input class ( from the whole index or from a subset) that contains the word ( in a specific field or in all the fields if no one selected)
Params:
  • word – the token produced by the analyzer
  • term – the term representing the class
Throws:
Returns:the number of documents of the input class
/** * Returns the number of documents of the input class ( from the whole index or from a subset) * that contains the word ( in a specific field or in all the fields if no one selected) * @param word the token produced by the analyzer * @param term the term representing the class * @return the number of documents of the input class * @throws IOException if a low level I/O problem happens */
private int getWordFreqForClass(String word, Term term) throws IOException { BooleanQuery.Builder booleanQuery = new BooleanQuery.Builder(); BooleanQuery.Builder subQuery = new BooleanQuery.Builder(); for (String textFieldName : textFieldNames) { subQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.SHOULD)); } booleanQuery.add(new BooleanClause(subQuery.build(), BooleanClause.Occur.MUST)); booleanQuery.add(new BooleanClause(new TermQuery(term), BooleanClause.Occur.MUST)); if (query != null) { booleanQuery.add(query, BooleanClause.Occur.MUST); } TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector(); indexSearcher.search(booleanQuery.build(), totalHitCountCollector); return totalHitCountCollector.getTotalHits(); } private double calculateLogPrior(Term term, int docsWithClassSize) throws IOException { return Math.log((double) docCount(term)) - Math.log(docsWithClassSize); } private int docCount(Term term) throws IOException { return indexReader.docFreq(term); }
Normalize the classification results based on the max score available
Params:
  • assignedClasses – the list of assigned classes
Returns:the normalized results
/** * Normalize the classification results based on the max score available * @param assignedClasses the list of assigned classes * @return the normalized results */
protected ArrayList<ClassificationResult<BytesRef>> normClassificationResults(List<ClassificationResult<BytesRef>> assignedClasses) { // normalization; the values transforms to a 0-1 range ArrayList<ClassificationResult<BytesRef>> returnList = new ArrayList<>(); if (!assignedClasses.isEmpty()) { Collections.sort(assignedClasses); // this is a negative number closest to 0 = a double smax = assignedClasses.get(0).getScore(); double sumLog = 0; // log(sum(exp(x_n-a))) for (ClassificationResult<BytesRef> cr : assignedClasses) { // getScore-smax <=0 (both negative, smax is the smallest abs() sumLog += Math.exp(cr.getScore() - smax); } // loga=a+log(sum(exp(x_n-a))) = log(sum(exp(x_n))) double loga = smax; loga += Math.log(sumLog); // 1/sum*x = exp(log(x))*1/sum = exp(log(x)-log(sum)) for (ClassificationResult<BytesRef> cr : assignedClasses) { double scoreDiff = cr.getScore() - loga; returnList.add(new ClassificationResult<>(cr.getAssignedClass(), Math.exp(scoreDiff))); } } return returnList; } }