/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.lucene.classification;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.MultiTerms;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.util.BytesRef;

A simplistic Lucene based NaiveBayes classifier, with caching feature, see http://en.wikipedia.org/wiki/Naive_Bayes_classifier

This is NOT an online classifier.

@lucene.experimental
/** * A simplistic Lucene based NaiveBayes classifier, with caching feature, see * <code>http://en.wikipedia.org/wiki/Naive_Bayes_classifier</code> * <p> * This is NOT an online classifier. * * @lucene.experimental */
public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier { //for caching classes this will be the classification class list private final ArrayList<BytesRef> cclasses = new ArrayList<>(); // it's a term-inmap style map, where the inmap contains class-hit pairs to the // upper term private final Map<String, Map<BytesRef, Integer>> termCClassHitCache = new HashMap<>(); // the term frequency in classes private final Map<BytesRef, Double> classTermFreq = new HashMap<>(); private boolean justCachedTerms; private int docsWithClassSize;
Creates a new NaiveBayes classifier with inside caching. If you want less memory usage you could call reInitCache().
Params:
  • indexReader – the reader on the index to be used for classification
  • analyzer – an Analyzer used to analyze unseen text
  • query – a Query to eventually filter the docs used for training the classifier, or null if all the indexed docs should be used
  • classFieldName – the name of the field used as the output for the classifier
  • textFieldNames – the name of the fields used as the inputs for the classifier
/** * Creates a new NaiveBayes classifier with inside caching. If you want less memory usage you could call * {@link #reInitCache(int, boolean) reInitCache()}. * * @param indexReader the reader on the index to be used for classification * @param analyzer an {@link Analyzer} used to analyze unseen text * @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null} * if all the indexed docs should be used * @param classFieldName the name of the field used as the output for the classifier * @param textFieldNames the name of the fields used as the inputs for the classifier */
public CachingNaiveBayesClassifier(IndexReader indexReader, Analyzer analyzer, Query query, String classFieldName, String... textFieldNames) { super(indexReader, analyzer, query, classFieldName, textFieldNames); // building the cache try { reInitCache(0, true); } catch (IOException e) { throw new RuntimeException(e); } } protected List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException { String[] tokenizedText = tokenize(inputDocument); List<ClassificationResult<BytesRef>> assignedClasses = calculateLogLikelihood(tokenizedText); // normalization // The values transforms to a 0-1 range ArrayList<ClassificationResult<BytesRef>> asignedClassesNorm = super.normClassificationResults(assignedClasses); return asignedClassesNorm; } private List<ClassificationResult<BytesRef>> calculateLogLikelihood(String[] tokenizedText) throws IOException { // initialize the return List ArrayList<ClassificationResult<BytesRef>> ret = new ArrayList<>(); for (BytesRef cclass : cclasses) { ClassificationResult<BytesRef> cr = new ClassificationResult<>(cclass, 0d); ret.add(cr); } // for each word for (String word : tokenizedText) { // search with text:word for all class:c Map<BytesRef, Integer> hitsInClasses = getWordFreqForClassess(word); // for each class for (BytesRef cclass : cclasses) { Integer hitsI = hitsInClasses.get(cclass); // if the word is out of scope hitsI could be null int hits = 0; if (hitsI != null) { hits = hitsI; } // num : count the no of times the word appears in documents of class c(+1) double num = hits + 1; // +1 is added because of add 1 smoothing // den : for the whole dictionary, count the no of times a word appears in documents of class c (+|V|) double den = classTermFreq.get(cclass) + docsWithClassSize; // P(w|c) = num/den double wordProbability = num / den; // modify the value in the result list item int removeIdx = -1; int i = 0; for (ClassificationResult<BytesRef> cr : ret) { if (cr.getAssignedClass().equals(cclass)) { removeIdx = i; break; } i++; } if (removeIdx >= 0) { ClassificationResult<BytesRef> toRemove = ret.get(removeIdx); ret.add(new ClassificationResult<>(toRemove.getAssignedClass(), toRemove.getScore() + Math.log(wordProbability))); ret.remove(removeIdx); } } } // log(P(d|c)) = log(P(w1|c))+...+log(P(wn|c)) return ret; } private Map<BytesRef, Integer> getWordFreqForClassess(String word) throws IOException { Map<BytesRef, Integer> insertPoint; insertPoint = termCClassHitCache.get(word); // if we get the answer from the cache if (insertPoint != null) { if (!insertPoint.isEmpty()) { return insertPoint; } } Map<BytesRef, Integer> searched = new ConcurrentHashMap<>(); // if we dont get the answer, but it's relevant we must search it and insert to the cache if (insertPoint != null || !justCachedTerms) { for (BytesRef cclass : cclasses) { BooleanQuery.Builder booleanQuery = new BooleanQuery.Builder(); BooleanQuery.Builder subQuery = new BooleanQuery.Builder(); for (String textFieldName : textFieldNames) { subQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.SHOULD)); } booleanQuery.add(new BooleanClause(subQuery.build(), BooleanClause.Occur.MUST)); booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, cclass)), BooleanClause.Occur.MUST)); if (query != null) { booleanQuery.add(query, BooleanClause.Occur.MUST); } TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector(); indexSearcher.search(booleanQuery.build(), totalHitCountCollector); int ret = totalHitCountCollector.getTotalHits(); if (ret != 0) { searched.put(cclass, ret); } } if (insertPoint != null) { // threadsafe and concurrent write termCClassHitCache.put(word, searched); } } return searched; }
This function is building the frame of the cache. The cache is storing the word occurrences to the memory after those searched once. This cache can made 2-100x speedup in proper use, but can eat lot of memory. There is an option to lower the memory consume, if a word have really low occurrence in the index you could filter it out. The other parameter is switching between the term searching, if it true, just the terms in the skeleton will be searched, but if it false the terms whoes not in the cache will be searched out too (but not cached).
Params:
  • minTermOccurrenceInCache – Lower cache size with higher value.
  • justCachedTerms – The switch for fully exclude low occurrence docs.
Throws:
/** * This function is building the frame of the cache. The cache is storing the * word occurrences to the memory after those searched once. This cache can * made 2-100x speedup in proper use, but can eat lot of memory. There is an * option to lower the memory consume, if a word have really low occurrence in * the index you could filter it out. The other parameter is switching between * the term searching, if it true, just the terms in the skeleton will be * searched, but if it false the terms whoes not in the cache will be searched * out too (but not cached). * * @param minTermOccurrenceInCache Lower cache size with higher value. * @param justCachedTerms The switch for fully exclude low occurrence docs. * @throws IOException If there is a low-level I/O error. */
public void reInitCache(int minTermOccurrenceInCache, boolean justCachedTerms) throws IOException { this.justCachedTerms = justCachedTerms; this.docsWithClassSize = countDocsWithClass(); termCClassHitCache.clear(); cclasses.clear(); classTermFreq.clear(); // build the cache for the word Map<String, Long> frequencyMap = new HashMap<>(); for (String textFieldName : textFieldNames) { TermsEnum termsEnum = MultiTerms.getTerms(indexReader, textFieldName).iterator(); while (termsEnum.next() != null) { BytesRef term = termsEnum.term(); String termText = term.utf8ToString(); long frequency = termsEnum.docFreq(); Long lastfreq = frequencyMap.get(termText); if (lastfreq != null) frequency += lastfreq; frequencyMap.put(termText, frequency); } } for (Map.Entry<String, Long> entry : frequencyMap.entrySet()) { if (entry.getValue() > minTermOccurrenceInCache) { termCClassHitCache.put(entry.getKey(), new ConcurrentHashMap<BytesRef, Integer>()); } } // fill the class list Terms terms = MultiTerms.getTerms(indexReader, classFieldName); TermsEnum termsEnum = terms.iterator(); while ((termsEnum.next()) != null) { cclasses.add(BytesRef.deepCopyOf(termsEnum.term())); } // fill the classTermFreq map for (BytesRef cclass : cclasses) { double avgNumberOfUniqueTerms = 0; for (String textFieldName : textFieldNames) { terms = MultiTerms.getTerms(indexReader, textFieldName); long numPostings = terms.getSumDocFreq(); // number of term/doc pairs avgNumberOfUniqueTerms += numPostings / (double) terms.getDocCount(); } int docsWithC = indexReader.docFreq(new Term(classFieldName, cclass)); classTermFreq.put(cclass, avgNumberOfUniqueTerms * docsWithC); } } }