package org.apache.lucene.classification;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.classification.utils.NearestFuzzyQuery;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.util.BytesRef;
public class KNearestFuzzyClassifier implements Classifier<BytesRef> {
private final String[] textFieldNames;
private final String classFieldName;
private final IndexSearcher indexSearcher;
private final int k;
private final Query query;
private final Analyzer analyzer;
public KNearestFuzzyClassifier(IndexReader indexReader, Similarity similarity, Analyzer analyzer, Query query, int k,
String classFieldName, String... textFieldNames) {
this.textFieldNames = textFieldNames;
this.classFieldName = classFieldName;
this.analyzer = analyzer;
this.indexSearcher = new IndexSearcher(indexReader);
if (similarity != null) {
this.indexSearcher.setSimilarity(similarity);
} else {
this.indexSearcher.setSimilarity(new BM25Similarity());
}
this.query = query;
this.k = k;
}
@Override
public ClassificationResult<BytesRef> assignClass(String text) throws IOException {
TopDocs knnResults = knnSearch(text);
List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
ClassificationResult<BytesRef> assignedClass = null;
double maxscore = -Double.MAX_VALUE;
for (ClassificationResult<BytesRef> cl : assignedClasses) {
if (cl.getScore() > maxscore) {
assignedClass = cl;
maxscore = cl.getScore();
}
}
return assignedClass;
}
@Override
public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException {
TopDocs knnResults = knnSearch(text);
List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
Collections.sort(assignedClasses);
return assignedClasses;
}
@Override
public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException {
TopDocs knnResults = knnSearch(text);
List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
Collections.sort(assignedClasses);
return assignedClasses.subList(0, max);
}
private TopDocs knnSearch(String text) throws IOException {
BooleanQuery.Builder bq = new BooleanQuery.Builder();
NearestFuzzyQuery nearestFuzzyQuery = new NearestFuzzyQuery(analyzer);
for (String fieldName : textFieldNames) {
nearestFuzzyQuery.addTerms(text, fieldName);
}
bq.add(nearestFuzzyQuery, BooleanClause.Occur.MUST);
Query classFieldQuery = new WildcardQuery(new Term(classFieldName, "*"));
bq.add(new BooleanClause(classFieldQuery, BooleanClause.Occur.MUST));
if (query != null) {
bq.add(query, BooleanClause.Occur.MUST);
}
return indexSearcher.search(bq.build(), k);
}
private List<ClassificationResult<BytesRef>> buildListFromTopDocs(TopDocs topDocs) throws IOException {
Map<BytesRef, Integer> classCounts = new HashMap<>();
Map<BytesRef, Double> classBoosts = new HashMap<>();
float maxScore = topDocs.totalHits.value == 0 ? Float.NaN : topDocs.scoreDocs[0].score;
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
IndexableField storableField = indexSearcher.doc(scoreDoc.doc).getField(classFieldName);
if (storableField != null) {
BytesRef cl = new BytesRef(storableField.stringValue());
classCounts.merge(cl, 1, (a, b) -> a + b);
Double totalBoost = classBoosts.get(cl);
double singleBoost = scoreDoc.score / maxScore;
if (totalBoost != null) {
classBoosts.put(cl, totalBoost + singleBoost);
} else {
classBoosts.put(cl, singleBoost);
}
}
}
List<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
List<ClassificationResult<BytesRef>> temporaryList = new ArrayList<>();
int sumdoc = 0;
for (Map.Entry<BytesRef, Integer> entry : classCounts.entrySet()) {
Integer count = entry.getValue();
Double normBoost = classBoosts.get(entry.getKey()) / count;
temporaryList.add(new ClassificationResult<>(entry.getKey().clone(), (count * normBoost) / (double) k));
sumdoc += count;
}
if (sumdoc < k) {
for (ClassificationResult<BytesRef> cr : temporaryList) {
returnList.add(new ClassificationResult<>(cr.getAssignedClass(), cr.getScore() * k / (double) sumdoc));
}
} else {
returnList = temporaryList;
}
return returnList;
}
@Override
public String toString() {
return "KNearestFuzzyClassifier{" +
"textFieldNames=" + Arrays.toString(textFieldNames) +
", classFieldName='" + classFieldName + '\'' +
", k=" + k +
", query=" + query +
", similarity=" + indexSearcher.getSimilarity() +
'}';
}
}