package org.apache.lucene.search.grouping;
import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.LeafFieldComparator;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.SimpleCollector;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.util.FixedBitSet;
@SuppressWarnings({"unchecked","rawtypes"})
public abstract class AllGroupHeadsCollector<T> extends SimpleCollector {
private final GroupSelector<T> groupSelector;
protected final Sort sort;
protected final int[] reversed;
protected final int compIDXEnd;
protected Map<T, GroupHead<T>> heads = new HashMap<>();
protected LeafReaderContext context;
protected Scorable scorer;
public static <T> AllGroupHeadsCollector<T> newCollector(GroupSelector<T> selector, Sort sort) {
if (sort.equals(Sort.RELEVANCE))
return new ScoringGroupHeadsCollector<>(selector, sort);
return new SortingGroupHeadsCollector<>(selector, sort);
}
private AllGroupHeadsCollector(GroupSelector<T> selector, Sort sort) {
this.groupSelector = selector;
this.sort = sort;
this.reversed = new int[sort.getSort().length];
final SortField[] sortFields = sort.getSort();
for (int i = 0; i < sortFields.length; i++) {
reversed[i] = sortFields[i].getReverse() ? -1 : 1;
}
this.compIDXEnd = this.reversed.length - 1;
}
public FixedBitSet retrieveGroupHeads(int maxDoc) {
FixedBitSet bitSet = new FixedBitSet(maxDoc);
Collection<? extends GroupHead<T>> groupHeads = getCollectedGroupHeads();
for (GroupHead groupHead : groupHeads) {
bitSet.set(groupHead.doc);
}
return bitSet;
}
public int[] retrieveGroupHeads() {
Collection<? extends GroupHead<T>> groupHeads = getCollectedGroupHeads();
int[] docHeads = new int[groupHeads.size()];
int i = 0;
for (GroupHead groupHead : groupHeads) {
docHeads[i++] = groupHead.doc;
}
return docHeads;
}
public int groupHeadsSize() {
return getCollectedGroupHeads().size();
}
protected Collection<? extends GroupHead<T>> getCollectedGroupHeads() {
return heads.values();
}
@Override
public void collect(int doc) throws IOException {
groupSelector.advanceTo(doc);
T groupValue = groupSelector.currentValue();
if (heads.containsKey(groupValue) == false) {
groupValue = groupSelector.copyValue();
heads.put(groupValue, newGroupHead(doc, groupValue, context, scorer));
return;
}
GroupHead<T> groupHead = heads.get(groupValue);
for (int compIDX = 0; ; compIDX++) {
final int c = reversed[compIDX] * groupHead.compare(compIDX, doc);
if (c < 0) {
return;
} else if (c > 0) {
break;
} else if (compIDX == compIDXEnd) {
return;
}
}
groupHead.updateDocHead(doc);
}
@Override
public ScoreMode scoreMode() {
return sort.needsScores() ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
}
@Override
protected void doSetNextReader(LeafReaderContext context) throws IOException {
groupSelector.setNextReader(context);
this.context = context;
for (GroupHead<T> head : heads.values()) {
head.setNextReader(context);
}
}
@Override
public void setScorer(Scorable scorer) throws IOException {
this.scorer = scorer;
for (GroupHead<T> head : heads.values()) {
head.setScorer(scorer);
}
}
protected abstract GroupHead<T> newGroupHead(int doc, T value, LeafReaderContext context, Scorable scorer) throws IOException;
public static abstract class GroupHead<T> {
public final T groupValue;
public int doc;
protected int docBase;
protected GroupHead(T groupValue, int doc, int docBase) {
this.groupValue = groupValue;
this.doc = doc + docBase;
this.docBase = docBase;
}
protected void setNextReader(LeafReaderContext ctx) throws IOException {
this.docBase = ctx.docBase;
}
protected abstract void setScorer(Scorable scorer) throws IOException;
protected abstract int compare(int compIDX, int doc) throws IOException;
protected abstract void updateDocHead(int doc) throws IOException;
}
private static class SortingGroupHeadsCollector<T> extends AllGroupHeadsCollector<T> {
protected SortingGroupHeadsCollector(GroupSelector<T> selector, Sort sort) {
super(selector, sort);
}
@Override
protected GroupHead<T> newGroupHead(int doc, T value, LeafReaderContext ctx, Scorable scorer) throws IOException {
return new SortingGroupHead<>(sort, value, doc, ctx, scorer);
}
}
private static class SortingGroupHead<T> extends GroupHead<T> {
final FieldComparator[] comparators;
final LeafFieldComparator[] leafComparators;
protected SortingGroupHead(Sort sort, T groupValue, int doc, LeafReaderContext context, Scorable scorer) throws IOException {
super(groupValue, doc, context.docBase);
final SortField[] sortFields = sort.getSort();
comparators = new FieldComparator[sortFields.length];
leafComparators = new LeafFieldComparator[sortFields.length];
for (int i = 0; i < sortFields.length; i++) {
comparators[i] = sortFields[i].getComparator(1, i);
leafComparators[i] = comparators[i].getLeafComparator(context);
leafComparators[i].setScorer(scorer);
leafComparators[i].copy(0, doc);
leafComparators[i].setBottom(0);
}
}
@Override
public void setNextReader(LeafReaderContext ctx) throws IOException {
super.setNextReader(ctx);
for (int i = 0; i < comparators.length; i++) {
leafComparators[i] = comparators[i].getLeafComparator(ctx);
}
}
@Override
protected void setScorer(Scorable scorer) throws IOException {
for (LeafFieldComparator c : leafComparators) {
c.setScorer(scorer);
}
}
@Override
public int compare(int compIDX, int doc) throws IOException {
return leafComparators[compIDX].compareBottom(doc);
}
@Override
public void updateDocHead(int doc) throws IOException {
for (LeafFieldComparator comparator : leafComparators) {
comparator.copy(0, doc);
comparator.setBottom(0);
}
this.doc = doc + docBase;
}
}
private static class ScoringGroupHeadsCollector<T> extends AllGroupHeadsCollector<T> {
protected ScoringGroupHeadsCollector(GroupSelector<T> selector, Sort sort) {
super(selector, sort);
}
@Override
protected GroupHead<T> newGroupHead(int doc, T value, LeafReaderContext context, Scorable scorer) throws IOException {
return new ScoringGroupHead<>(scorer, value, doc, context.docBase);
}
}
private static class ScoringGroupHead<T> extends GroupHead<T> {
private Scorable scorer;
private float topScore;
protected ScoringGroupHead(Scorable scorer, T groupValue, int doc, int docBase) throws IOException {
super(groupValue, doc, docBase);
assert scorer.docID() == doc;
this.scorer = scorer;
this.topScore = scorer.score();
}
@Override
protected void setScorer(Scorable scorer) {
this.scorer = scorer;
}
@Override
protected int compare(int compIDX, int doc) throws IOException {
assert scorer.docID() == doc;
assert compIDX == 0;
float score = scorer.score();
int c = Float.compare(score, topScore);
if (c > 0)
topScore = score;
return c;
}
@Override
protected void updateDocHead(int doc) throws IOException {
this.doc = doc + docBase;
}
}
}