package org.apache.lucene.util.bkd;
import java.io.IOException;
import java.util.Arrays;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FutureArrays;
import org.apache.lucene.util.IntroSelector;
import org.apache.lucene.util.IntroSorter;
import org.apache.lucene.util.MSBRadixSorter;
import org.apache.lucene.util.RadixSelector;
import org.apache.lucene.util.Selector;
import org.apache.lucene.util.Sorter;
public final class BKDRadixSelector {
private static final int HISTOGRAM_SIZE = 256;
private static final int MAX_SIZE_OFFLINE_BUFFER = 1024 * 8;
private final long[] histogram;
private final int bytesPerDim;
private final int bytesSorted;
private final int packedBytesLength;
private final int packedBytesDocIDLength;
private final int maxPointsSortInHeap;
private final byte[] offlineBuffer;
private final int[] partitionBucket;
private final byte[] scratch;
private final Directory tempDir;
private final String tempFileNamePrefix;
private final int numDataDims, numIndexDims;
public BKDRadixSelector(int numDataDims, int numIndexDims, int bytesPerDim, int maxPointsSortInHeap, Directory tempDir, String tempFileNamePrefix) {
this.bytesPerDim = bytesPerDim;
this.numDataDims = numDataDims;
this.numIndexDims = numIndexDims;
this.packedBytesLength = numDataDims * bytesPerDim;
this.packedBytesDocIDLength = packedBytesLength + Integer.BYTES;
this.bytesSorted = bytesPerDim + (numDataDims - numIndexDims) * bytesPerDim + Integer.BYTES;
this.maxPointsSortInHeap = maxPointsSortInHeap;
int numberOfPointsOffline = MAX_SIZE_OFFLINE_BUFFER / packedBytesDocIDLength;
this.offlineBuffer = new byte[numberOfPointsOffline * packedBytesDocIDLength];
this.partitionBucket = new int[bytesSorted];
this.histogram = new long[HISTOGRAM_SIZE];
this.scratch = new byte[bytesSorted];
this.tempDir = tempDir;
this.tempFileNamePrefix = tempFileNamePrefix;
}
public byte[] select(PathSlice points, PathSlice[] partitionSlices, long from, long to, long partitionPoint, int dim, int dimCommonPrefix) throws IOException {
checkArgs(from, to, partitionPoint);
assert partitionSlices.length > 1 : "[partition alices] must be > 1, got " + partitionSlices.length;
if (points.writer instanceof HeapPointWriter) {
byte[] partition = heapRadixSelect((HeapPointWriter) points.writer, dim, Math.toIntExact(from), Math.toIntExact(to), Math.toIntExact(partitionPoint), dimCommonPrefix);
partitionSlices[0] = new PathSlice(points.writer, from, partitionPoint - from);
partitionSlices[1] = new PathSlice(points.writer, partitionPoint, to - partitionPoint);
return partition;
}
OfflinePointWriter offlinePointWriter = (OfflinePointWriter) points.writer;
try (PointWriter left = getPointWriter(partitionPoint - from, "left" + dim);
PointWriter right = getPointWriter(to - partitionPoint, "right" + dim)) {
partitionSlices[0] = new PathSlice(left, 0, partitionPoint - from);
partitionSlices[1] = new PathSlice(right, 0, to - partitionPoint);
return buildHistogramAndPartition(offlinePointWriter, left, right, from, to, partitionPoint, 0, dimCommonPrefix, dim);
}
}
void checkArgs(long from, long to, long partitionPoint) {
if (partitionPoint < from) {
throw new IllegalArgumentException("partitionPoint must be >= from");
}
if (partitionPoint >= to) {
throw new IllegalArgumentException("partitionPoint must be < to");
}
}
private int findCommonPrefixAndHistogram(OfflinePointWriter points, long from, long to, int dim, int dimCommonPrefix) throws IOException{
int commonPrefixPosition = bytesSorted;
final int offset = dim * bytesPerDim;
try (OfflinePointReader reader = points.getReader(from, to - from, offlineBuffer)) {
assert commonPrefixPosition > dimCommonPrefix;
reader.next();
PointValue pointValue = reader.pointValue();
BytesRef packedValueDocID = pointValue.packedValueDocIDBytes();
System.arraycopy(packedValueDocID.bytes, packedValueDocID.offset + offset, scratch, 0, bytesPerDim);
System.arraycopy(packedValueDocID.bytes, packedValueDocID.offset + numIndexDims * bytesPerDim, scratch, bytesPerDim, (numDataDims - numIndexDims) * bytesPerDim + Integer.BYTES);
for (long i = from + 1; i < to; i++) {
reader.next();
pointValue = reader.pointValue();
if (commonPrefixPosition == dimCommonPrefix) {
histogram[getBucket(offset, commonPrefixPosition, pointValue)]++;
for (long j = i + 1; j < to; j++) {
reader.next();
pointValue = reader.pointValue();
histogram[getBucket(offset, commonPrefixPosition, pointValue)]++;
}
break;
} else {
final int startIndex = (dimCommonPrefix > bytesPerDim) ? bytesPerDim : dimCommonPrefix;
final int endIndex = (commonPrefixPosition > bytesPerDim) ? bytesPerDim : commonPrefixPosition;
packedValueDocID = pointValue.packedValueDocIDBytes();
int j = FutureArrays.mismatch(scratch, startIndex, endIndex, packedValueDocID.bytes, packedValueDocID.offset + offset + startIndex, packedValueDocID.offset + offset + endIndex);
if (j == -1) {
if (commonPrefixPosition > bytesPerDim) {
final int startTieBreak = numIndexDims * bytesPerDim;
final int endTieBreak = startTieBreak + commonPrefixPosition - bytesPerDim;
int k = FutureArrays.mismatch(scratch, bytesPerDim, commonPrefixPosition,
packedValueDocID.bytes, packedValueDocID.offset + startTieBreak, packedValueDocID.offset + endTieBreak);
if (k != -1) {
commonPrefixPosition = bytesPerDim + k;
Arrays.fill(histogram, 0);
histogram[scratch[commonPrefixPosition] & 0xff] = i - from;
}
}
} else {
commonPrefixPosition = dimCommonPrefix + j;
Arrays.fill(histogram, 0);
histogram[scratch[commonPrefixPosition] & 0xff] = i - from;
}
if (commonPrefixPosition != bytesSorted) {
histogram[getBucket(offset, commonPrefixPosition, pointValue)]++;
}
}
}
}
for (int i = 0; i < commonPrefixPosition; i++) {
partitionBucket[i] = scratch[i] & 0xff;
}
return commonPrefixPosition;
}
private int getBucket(int offset, int commonPrefixPosition, PointValue pointValue) {
int bucket;
if (commonPrefixPosition < bytesPerDim) {
BytesRef packedValue = pointValue.packedValue();
bucket = packedValue.bytes[packedValue.offset + offset + commonPrefixPosition] & 0xff;
} else {
BytesRef packedValueDocID = pointValue.packedValueDocIDBytes();
bucket = packedValueDocID.bytes[packedValueDocID.offset + numIndexDims * bytesPerDim + commonPrefixPosition - bytesPerDim] & 0xff;
}
return bucket;
}
private byte[] buildHistogramAndPartition(OfflinePointWriter points, PointWriter left, PointWriter right,
long from, long to, long partitionPoint, int iteration, int baseCommonPrefix, int dim) throws IOException {
int commonPrefix = findCommonPrefixAndHistogram(points, from, to, dim, baseCommonPrefix);
if (commonPrefix == bytesSorted) {
offlinePartition(points, left, right, null, from, to, dim, commonPrefix - 1, partitionPoint);
return partitionPointFromCommonPrefix();
}
long leftCount = 0;
long rightCount = 0;
for(int i = 0; i < HISTOGRAM_SIZE; i++) {
long size = histogram[i];
if (leftCount + size > partitionPoint - from) {
partitionBucket[commonPrefix] = i;
break;
}
leftCount += size;
}
for(int i = partitionBucket[commonPrefix] + 1; i < HISTOGRAM_SIZE; i++) {
rightCount += histogram[i];
}
long delta = histogram[partitionBucket[commonPrefix]];
assert leftCount + rightCount + delta == to - from : (leftCount + rightCount + delta) + " / " + (to - from);
if (commonPrefix == bytesSorted - 1) {
long tieBreakCount =(partitionPoint - from - leftCount);
offlinePartition(points, left, right, null, from, to, dim, commonPrefix, tieBreakCount);
return partitionPointFromCommonPrefix();
}
PointWriter deltaPoints;
try (PointWriter tempDeltaPoints = getDeltaPointWriter(left, right, delta, iteration)) {
offlinePartition(points, left, right, tempDeltaPoints, from, to, dim, commonPrefix, 0);
deltaPoints = tempDeltaPoints;
}
long newPartitionPoint = partitionPoint - from - leftCount;
if (deltaPoints instanceof HeapPointWriter) {
return heapPartition((HeapPointWriter) deltaPoints, left, right, dim, 0, (int) deltaPoints.count(), Math.toIntExact(newPartitionPoint), ++commonPrefix);
} else {
return buildHistogramAndPartition((OfflinePointWriter) deltaPoints, left, right, 0, deltaPoints.count(), newPartitionPoint, ++iteration, ++commonPrefix, dim);
}
}
private void offlinePartition(OfflinePointWriter points, PointWriter left, PointWriter right, PointWriter deltaPoints,
long from, long to, int dim, int bytePosition, long numDocsTiebreak) throws IOException {
assert bytePosition == bytesSorted -1 || deltaPoints != null;
int offset = dim * bytesPerDim;
long tiebreakCounter = 0;
try (OfflinePointReader reader = points.getReader(from, to - from, offlineBuffer)) {
while (reader.next()) {
PointValue pointValue = reader.pointValue();
int bucket = getBucket(offset, bytePosition, pointValue);
if (bucket < this.partitionBucket[bytePosition]) {
left.append(pointValue);
} else if (bucket > this.partitionBucket[bytePosition]) {
right.append(pointValue);
} else {
if (bytePosition == bytesSorted - 1) {
if (tiebreakCounter < numDocsTiebreak) {
left.append(pointValue);
tiebreakCounter++;
} else {
right.append(pointValue);
}
} else {
deltaPoints.append(pointValue);
}
}
}
}
points.destroy();
}
private byte[] partitionPointFromCommonPrefix() {
byte[] partition = new byte[bytesPerDim];
for (int i = 0; i < bytesPerDim; i++) {
partition[i] = (byte)partitionBucket[i];
}
return partition;
}
private byte[] heapPartition(HeapPointWriter points, PointWriter left, PointWriter right, int dim, int from, int to, int partitionPoint, int commonPrefix) throws IOException {
byte[] partition = heapRadixSelect(points, dim, from, to, partitionPoint, commonPrefix);
for (int i = from; i < to; i++) {
PointValue value = points.getPackedValueSlice(i);
if (i < partitionPoint) {
left.append(value);
} else {
right.append(value);
}
}
return partition;
}
private byte[] heapRadixSelect(HeapPointWriter points, int dim, int from, int to, int partitionPoint, int commonPrefixLength) {
final int dimOffset = dim * bytesPerDim + commonPrefixLength;
final int dimCmpBytes = bytesPerDim - commonPrefixLength;
final int dataOffset = numIndexDims * bytesPerDim - dimCmpBytes;
new RadixSelector(bytesSorted - commonPrefixLength) {
@Override
protected void swap(int i, int j) {
points.swap(i, j);
}
@Override
protected int byteAt(int i, int k) {
assert k >= 0 : "negative prefix " + k;
if (k < dimCmpBytes) {
return points.block[i * packedBytesDocIDLength + dimOffset + k] & 0xff;
} else {
return points.block[i * packedBytesDocIDLength + dataOffset + k] & 0xff;
}
}
@Override
protected Selector getFallbackSelector(int d) {
final int skypedBytes = d + commonPrefixLength;
final int dimStart = dim * bytesPerDim + skypedBytes;
final int dimEnd = dim * bytesPerDim + bytesPerDim;
final int dataOffset = numIndexDims * bytesPerDim;
final int dataLength = (numDataDims - numIndexDims) * bytesPerDim + Integer.BYTES;
return new IntroSelector() {
@Override
protected void swap(int i, int j) {
points.swap(i, j);
}
@Override
protected void setPivot(int i) {
if (skypedBytes < bytesPerDim) {
System.arraycopy(points.block, i * packedBytesDocIDLength + dim * bytesPerDim, scratch, 0, bytesPerDim);
}
System.arraycopy(points.block, i * packedBytesDocIDLength + dataOffset, scratch, bytesPerDim, dataLength);
}
@Override
protected int compare(int i, int j) {
if (skypedBytes < bytesPerDim) {
int iOffset = i * packedBytesDocIDLength;
int jOffset = j * packedBytesDocIDLength;
int cmp = FutureArrays.compareUnsigned(points.block, iOffset + dimStart, iOffset + dimEnd, points.block, jOffset + dimStart, jOffset + dimEnd);
if (cmp != 0) {
return cmp;
}
}
int iOffset = i * packedBytesDocIDLength + dataOffset;
int jOffset = j * packedBytesDocIDLength + dataOffset;
return FutureArrays.compareUnsigned(points.block, iOffset, iOffset + dataLength, points.block, jOffset, jOffset + dataLength);
}
@Override
protected int comparePivot(int j) {
if (skypedBytes < bytesPerDim) {
int jOffset = j * packedBytesDocIDLength;
int cmp = FutureArrays.compareUnsigned(scratch, skypedBytes, bytesPerDim, points.block, jOffset + dimStart, jOffset + dimEnd);
if (cmp != 0) {
return cmp;
}
}
int jOffset = j * packedBytesDocIDLength + dataOffset;
return FutureArrays.compareUnsigned(scratch, bytesPerDim, bytesPerDim + dataLength, points.block, jOffset, jOffset + dataLength);
}
};
}
}.select(from, to, partitionPoint);
byte[] partition = new byte[bytesPerDim];
PointValue pointValue = points.getPackedValueSlice(partitionPoint);
BytesRef packedValue = pointValue.packedValue();
System.arraycopy(packedValue.bytes, packedValue.offset + dim * bytesPerDim, partition, 0, bytesPerDim);
return partition;
}
public void heapRadixSort(final HeapPointWriter points, int from, int to, int dim, int commonPrefixLength) {
final int dimOffset = dim * bytesPerDim + commonPrefixLength;
final int dimCmpBytes = bytesPerDim - commonPrefixLength;
final int dataOffset = numIndexDims * bytesPerDim - dimCmpBytes;
new MSBRadixSorter(bytesSorted - commonPrefixLength) {
@Override
protected int byteAt(int i, int k) {
assert k >= 0 : "negative prefix " + k;
if (k < dimCmpBytes) {
return points.block[i * packedBytesDocIDLength + dimOffset + k] & 0xff;
} else {
return points.block[i * packedBytesDocIDLength + dataOffset + k] & 0xff;
}
}
@Override
protected void swap(int i, int j) {
points.swap(i, j);
}
@Override
protected Sorter getFallbackSorter(int k) {
final int skypedBytes = k + commonPrefixLength;
final int dimStart = dim * bytesPerDim + skypedBytes;
final int dimEnd = dim * bytesPerDim + bytesPerDim;
final int dataOffset = numIndexDims * bytesPerDim;
final int dataLength = (numDataDims - numIndexDims) * bytesPerDim + Integer.BYTES;
return new IntroSorter() {
@Override
protected void swap(int i, int j) {
points.swap(i, j);
}
@Override
protected void setPivot(int i) {
if (skypedBytes < bytesPerDim) {
System.arraycopy(points.block, i * packedBytesDocIDLength + dim * bytesPerDim, scratch, 0, bytesPerDim);
}
System.arraycopy(points.block, i * packedBytesDocIDLength + dataOffset, scratch, bytesPerDim, dataLength);
}
@Override
protected int compare(int i, int j) {
if (skypedBytes < bytesPerDim) {
int iOffset = i * packedBytesDocIDLength;
int jOffset = j * packedBytesDocIDLength;
int cmp = FutureArrays.compareUnsigned(points.block, iOffset + dimStart, iOffset + dimEnd, points.block, jOffset + dimStart, jOffset + dimEnd);
if (cmp != 0) {
return cmp;
}
}
int iOffset = i * packedBytesDocIDLength + dataOffset;
int jOffset = j * packedBytesDocIDLength + dataOffset;
return FutureArrays.compareUnsigned(points.block, iOffset, iOffset + dataLength, points.block, jOffset, jOffset + dataLength);
}
@Override
protected int comparePivot(int j) {
if (skypedBytes < bytesPerDim) {
int jOffset = j * packedBytesDocIDLength;
int cmp = FutureArrays.compareUnsigned(scratch, skypedBytes, bytesPerDim, points.block, jOffset + dimStart, jOffset + dimEnd);
if (cmp != 0) {
return cmp;
}
}
int jOffset = j * packedBytesDocIDLength + dataOffset;
return FutureArrays.compareUnsigned(scratch, bytesPerDim, bytesPerDim + dataLength, points.block, jOffset, jOffset + dataLength);
}
};
}
}.sort(from, to);
}
private PointWriter getDeltaPointWriter(PointWriter left, PointWriter right, long delta, int iteration) throws IOException {
if (delta <= getMaxPointsSortInHeap(left, right)) {
return new HeapPointWriter(Math.toIntExact(delta), packedBytesLength);
} else {
return new OfflinePointWriter(tempDir, tempFileNamePrefix, packedBytesLength, "delta" + iteration, delta);
}
}
private int getMaxPointsSortInHeap(PointWriter left, PointWriter right) {
int pointsUsed = 0;
if (left instanceof HeapPointWriter) {
pointsUsed += ((HeapPointWriter) left).size;
}
if (right instanceof HeapPointWriter) {
pointsUsed += ((HeapPointWriter) right).size;
}
assert maxPointsSortInHeap >= pointsUsed;
return maxPointsSortInHeap - pointsUsed;
}
PointWriter getPointWriter(long count, String desc) throws IOException {
if (count <= maxPointsSortInHeap / 2) {
int size = Math.toIntExact(count);
return new HeapPointWriter(size, packedBytesLength);
} else {
return new OfflinePointWriter(tempDir, tempFileNamePrefix, packedBytesLength, desc, count);
}
}
public static final class PathSlice {
public final PointWriter writer;
public final long start;
public final long count;
public PathSlice(PointWriter writer, long start, long count) {
this.writer = writer;
this.start = start;
this.count = count;
}
@Override
public String toString() {
return "PathSlice(start=" + start + " count=" + count + " writer=" + writer + ")";
}
}
}