 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *     http://www.apache.org/licenses/LICENSE-2.0
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * See the License for the specific language governing permissions and
 * limitations under the License.
package org.apache.lucene.classification.document;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.classification.SimpleNaiveBayesClassifier;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.MultiTerms;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.util.BytesRef;

A simplistic Lucene based NaiveBayes classifier, see http://en.wikipedia.org/wiki/Naive_Bayes_classifier
/** * A simplistic Lucene based NaiveBayes classifier, see {@code http://en.wikipedia.org/wiki/Naive_Bayes_classifier} * * @lucene.experimental */
public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifier implements DocumentClassifier<BytesRef> {
Analyzer to be used for tokenizing document fields
/** * {@link org.apache.lucene.analysis.Analyzer} to be used for tokenizing document fields */
protected Map<String, Analyzer> field2analyzer;
Creates a new NaiveBayes classifier.
  • indexReader – the reader on the index to be used for classification
  • query – a Query to eventually filter the docs used for training the classifier, or null if all the indexed docs should be used
  • classFieldName – the name of the field used as the output for the classifier NOTE: must not be heavely analyzed as the returned class will be a token indexed for this field
  • textFieldNames – the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10
/** * Creates a new NaiveBayes classifier. * * @param indexReader the reader on the index to be used for classification * @param query a {@link org.apache.lucene.search.Query} to eventually filter the docs used for training the classifier, or {@code null} * if all the indexed docs should be used * @param classFieldName the name of the field used as the output for the classifier NOTE: must not be heavely analyzed * as the returned class will be a token indexed for this field * @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10 */
public SimpleNaiveBayesDocumentClassifier(IndexReader indexReader, Query query, String classFieldName, Map<String, Analyzer> field2analyzer, String... textFieldNames) { super(indexReader, null, query, classFieldName, textFieldNames); this.field2analyzer = field2analyzer; } @Override public ClassificationResult<BytesRef> assignClass(Document document) throws IOException { List<ClassificationResult<BytesRef>> assignedClasses = assignNormClasses(document); ClassificationResult<BytesRef> assignedClass = null; double maxscore = -Double.MAX_VALUE; for (ClassificationResult<BytesRef> c : assignedClasses) { if (c.getScore() > maxscore) { assignedClass = c; maxscore = c.getScore(); } } return assignedClass; } @Override public List<ClassificationResult<BytesRef>> getClasses(Document document) throws IOException { List<ClassificationResult<BytesRef>> assignedClasses = assignNormClasses(document); Collections.sort(assignedClasses); return assignedClasses; } @Override public List<ClassificationResult<BytesRef>> getClasses(Document document, int max) throws IOException { List<ClassificationResult<BytesRef>> assignedClasses = assignNormClasses(document); Collections.sort(assignedClasses); return assignedClasses.subList(0, max); } private List<ClassificationResult<BytesRef>> assignNormClasses(Document inputDocument) throws IOException { List<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<>(); Map<String, List<String[]>> fieldName2tokensArray = new LinkedHashMap<>(); Map<String, Float> fieldName2boost = new LinkedHashMap<>(); Terms classes = MultiTerms.getTerms(indexReader, classFieldName); if (classes != null) { TermsEnum classesEnum = classes.iterator(); BytesRef c; analyzeSeedDocument(inputDocument, fieldName2tokensArray, fieldName2boost); int docsWithClassSize = countDocsWithClass(); while ((c = classesEnum.next()) != null) { double classScore = 0; Term term = new Term(this.classFieldName, c); for (String fieldName : textFieldNames) { List<String[]> tokensArrays = fieldName2tokensArray.get(fieldName); double fieldScore = 0; for (String[] fieldTokensArray : tokensArrays) { fieldScore += calculateLogPrior(term, docsWithClassSize) + calculateLogLikelihood(fieldTokensArray, fieldName, term, docsWithClassSize) * fieldName2boost.get(fieldName); } classScore += fieldScore; } assignedClasses.add(new ClassificationResult<>(term.bytes(), classScore)); } } return normClassificationResults(assignedClasses); }
This methods performs the analysis for the seed document and extract the boosts if present. This is done only one time for the Seed Document.
  • inputDocument – the seed unseen document
  • fieldName2tokensArray – a map that associated to a field name the list of token arrays for all its values
  • fieldName2boost – a map that associates the boost to the field
/** * This methods performs the analysis for the seed document and extract the boosts if present. * This is done only one time for the Seed Document. * * @param inputDocument the seed unseen document * @param fieldName2tokensArray a map that associated to a field name the list of token arrays for all its values * @param fieldName2boost a map that associates the boost to the field * @throws IOException If there is a low-level I/O error */
private void analyzeSeedDocument(Document inputDocument, Map<String, List<String[]>> fieldName2tokensArray, Map<String, Float> fieldName2boost) throws IOException { for (int i = 0; i < textFieldNames.length; i++) { String fieldName = textFieldNames[i]; float boost = 1; List<String[]> tokenizedValues = new LinkedList<>(); if (fieldName.contains("^")) { String[] field2boost = fieldName.split("\\^"); fieldName = field2boost[0]; boost = Float.parseFloat(field2boost[1]); } IndexableField[] fieldValues = inputDocument.getFields(fieldName); for (IndexableField fieldValue : fieldValues) { TokenStream fieldTokens = fieldValue.tokenStream(field2analyzer.get(fieldName), null); String[] fieldTokensArray = getTokenArray(fieldTokens); tokenizedValues.add(fieldTokensArray); } fieldName2tokensArray.put(fieldName, tokenizedValues); fieldName2boost.put(fieldName, boost); textFieldNames[i] = fieldName; } }
Returns a token array from the TokenStream in input
  • tokenizedText – the tokenized content of a field
  • IOException – If tokenization fails because there is a low-level I/O error
Returns:a String array of the resulting tokens
/** * Returns a token array from the {@link org.apache.lucene.analysis.TokenStream} in input * * @param tokenizedText the tokenized content of a field * @return a {@code String} array of the resulting tokens * @throws java.io.IOException If tokenization fails because there is a low-level I/O error */
protected String[] getTokenArray(TokenStream tokenizedText) throws IOException { Collection<String> tokens = new LinkedList<>(); CharTermAttribute charTermAttribute = tokenizedText.addAttribute(CharTermAttribute.class); tokenizedText.reset(); while (tokenizedText.incrementToken()) { tokens.add(charTermAttribute.toString()); } tokenizedText.end(); tokenizedText.close(); return tokens.toArray(new String[tokens.size()]); }
  • tokenizedText – the tokenized content of a field
  • fieldName – the input field name
  • term – the Term referring to the class to calculate the score of
  • docsWithClass – the total number of docs that have a class
Returns:a normalized score for the class
/** * @param tokenizedText the tokenized content of a field * @param fieldName the input field name * @param term the {@link Term} referring to the class to calculate the score of * @param docsWithClass the total number of docs that have a class * @return a normalized score for the class * @throws IOException If there is a low-level I/O error */
private double calculateLogLikelihood(String[] tokenizedText, String fieldName, Term term, int docsWithClass) throws IOException { // for each word double result = 0d; for (String word : tokenizedText) { // search with text:word AND class:c int hits = getWordFreqForClass(word, fieldName, term); // num : count the no of times the word appears in documents of class c (+1) double num = hits + 1; // +1 is added because of add 1 smoothing // den : for the whole dictionary, count the no of times a word appears in documents of class c (+|V|) double den = getTextTermFreqForClass(term, fieldName) + docsWithClass; // P(w|c) = num/den double wordProbability = num / den; result += Math.log(wordProbability); } // log(P(d|c)) = log(P(w1|c))+...+log(P(wn|c)) double normScore = result / (tokenizedText.length); // this is normalized because if not, long text fields will always be more important than short fields return normScore; }
Returns the average number of unique terms times the number of docs belonging to the input class
  • term – the class term
Returns:the average number of unique terms
/** * Returns the average number of unique terms times the number of docs belonging to the input class * * @param term the class term * @return the average number of unique terms * @throws java.io.IOException If there is a low-level I/O error */
private double getTextTermFreqForClass(Term term, String fieldName) throws IOException { double avgNumberOfUniqueTerms; Terms terms = MultiTerms.getTerms(indexReader, fieldName); long numPostings = terms.getSumDocFreq(); // number of term/doc pairs avgNumberOfUniqueTerms = numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc int docsWithC = indexReader.docFreq(term); return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c }
Returns the number of documents of the input class ( from the whole index or from a subset) that contains the word ( in a specific field or in all the fields if no one selected)
  • word – the token produced by the analyzer
  • fieldName – the field the word is coming from
  • term – the class term
Returns:number of documents of the input class
/** * Returns the number of documents of the input class ( from the whole index or from a subset) * that contains the word ( in a specific field or in all the fields if no one selected) * * @param word the token produced by the analyzer * @param fieldName the field the word is coming from * @param term the class term * @return number of documents of the input class * @throws java.io.IOException If there is a low-level I/O error */
private int getWordFreqForClass(String word, String fieldName, Term term) throws IOException { BooleanQuery.Builder booleanQuery = new BooleanQuery.Builder(); BooleanQuery.Builder subQuery = new BooleanQuery.Builder(); subQuery.add(new BooleanClause(new TermQuery(new Term(fieldName, word)), BooleanClause.Occur.SHOULD)); booleanQuery.add(new BooleanClause(subQuery.build(), BooleanClause.Occur.MUST)); booleanQuery.add(new BooleanClause(new TermQuery(term), BooleanClause.Occur.MUST)); if (query != null) { booleanQuery.add(query, BooleanClause.Occur.MUST); } TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector(); indexSearcher.search(booleanQuery.build(), totalHitCountCollector); return totalHitCountCollector.getTotalHits(); } private double calculateLogPrior(Term term, int docsWithClassSize) throws IOException { return Math.log((double) docCount(term)) - Math.log(docsWithClassSize); } private int docCount(Term term) throws IOException { return indexReader.docFreq(term); } }