/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.lucene.search.join;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.function.BiFunction;

import org.apache.lucene.document.DoublePoint;
import org.apache.lucene.document.FloatPoint;
import org.apache.lucene.document.IntPoint;
import org.apache.lucene.document.LongPoint;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.index.PointValues.IntersectVisitor;
import org.apache.lucene.index.PointValues.Relation;
import org.apache.lucene.index.PrefixCodedTerms;
import org.apache.lucene.index.PrefixCodedTerms.TermIterator;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.PointInSetQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.RamUsageEstimator;

// A TermsIncludingScoreQuery variant for point values:
abstract class PointInSetIncludingScoreQuery extends Query implements Accountable {
  protected static final long BASE_RAM_BYTES = RamUsageEstimator.shallowSizeOfInstance(PointInSetIncludingScoreQuery.class);

  static BiFunction<byte[], Class<? extends Number>, String> toString = (value, numericType) -> {
    if (Integer.class.equals(numericType)) {
      return Integer.toString(IntPoint.decodeDimension(value, 0));
    } else if (Long.class.equals(numericType)) {
      return Long.toString(LongPoint.decodeDimension(value, 0));
    } else if (Float.class.equals(numericType)) {
      return Float.toString(FloatPoint.decodeDimension(value, 0));
    } else if (Double.class.equals(numericType)) {
      return Double.toString(DoublePoint.decodeDimension(value, 0));
    } else {
      return "unsupported";
    }
  };

  final ScoreMode scoreMode;
  final Query originalQuery;
  final boolean multipleValuesPerDocument;
  final PrefixCodedTerms sortedPackedPoints;
  final int sortedPackedPointsHashCode;
  final String field;
  final int bytesPerDim;

  final List<Float> aggregatedJoinScores;

  private final long ramBytesUsed; // cache

  static abstract class Stream extends PointInSetQuery.Stream {

    float score;

  }

  PointInSetIncludingScoreQuery(ScoreMode scoreMode, Query originalQuery, boolean multipleValuesPerDocument,
                                String field, int bytesPerDim, Stream packedPoints) {
    this.scoreMode = scoreMode;
    this.originalQuery = originalQuery;
    this.multipleValuesPerDocument = multipleValuesPerDocument;
    this.field = field;
    if (bytesPerDim < 1 || bytesPerDim > PointValues.MAX_NUM_BYTES) {
      throw new IllegalArgumentException("bytesPerDim must be > 0 and <= " + PointValues.MAX_NUM_BYTES + "; got " + bytesPerDim);
    }
    this.bytesPerDim = bytesPerDim;

    aggregatedJoinScores = new ArrayList<>();
    PrefixCodedTerms.Builder builder = new PrefixCodedTerms.Builder();
    BytesRefBuilder previous = null;
    BytesRef current;
    while ((current = packedPoints.next()) != null) {
      if (current.length != bytesPerDim) {
        throw new IllegalArgumentException("packed point length should be " + (bytesPerDim) + " but got " + current.length + "; field=\"" + field + "\"bytesPerDim=" + bytesPerDim);
      }
      if (previous == null) {
        previous = new BytesRefBuilder();
      } else {
        int cmp = previous.get().compareTo(current);
        if (cmp == 0) {
          throw new IllegalArgumentException("unexpected duplicated value: " + current);
        } else if (cmp >= 0) {
          throw new IllegalArgumentException("values are out of order: saw " + previous + " before " + current);
        }
      }
      builder.add(field, current);
      aggregatedJoinScores.add(packedPoints.score);
      previous.copyBytes(current);
    }
    sortedPackedPoints = builder.finish();
    sortedPackedPointsHashCode = sortedPackedPoints.hashCode();

    this.ramBytesUsed = BASE_RAM_BYTES +
        RamUsageEstimator.sizeOfObject(this.field) +
        RamUsageEstimator.sizeOfObject(this.originalQuery, RamUsageEstimator.QUERY_DEFAULT_RAM_BYTES_USED) +
        RamUsageEstimator.sizeOfObject(this.sortedPackedPoints);
  }

  @Override
  public void visit(QueryVisitor visitor) {
    if (visitor.acceptField(field)) {
      visitor.visitLeaf(this);
    }
  }

  @Override
  public final Weight createWeight(IndexSearcher searcher, org.apache.lucene.search.ScoreMode scoreMode, float boost) throws IOException {
    return new Weight(this) {

      @Override
      public void extractTerms(Set<Term> terms) {
      }

      @Override
      public Explanation explain(LeafReaderContext context, int doc) throws IOException {
        Scorer scorer = scorer(context);
        if (scorer != null) {
          int target = scorer.iterator().advance(doc);
          if (doc == target) {
            return Explanation.match(scorer.score(), "A match");
          }
        }
        return Explanation.noMatch("Not a match");
      }

      @Override
      public Scorer scorer(LeafReaderContext context) throws IOException {
        LeafReader reader = context.reader();
        FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(field);
        if (fieldInfo == null) {
          return null;
        }
        if (fieldInfo.getPointDataDimensionCount() != 1) {
          throw new IllegalArgumentException("field=\"" + field + "\" was indexed with numDims=" + fieldInfo.getPointDataDimensionCount() + " but this query has numDims=1");
        }
        if (fieldInfo.getPointNumBytes() != bytesPerDim) {
          throw new IllegalArgumentException("field=\"" + field + "\" was indexed with bytesPerDim=" + fieldInfo.getPointNumBytes() + " but this query has bytesPerDim=" + bytesPerDim);
        }
        PointValues values = reader.getPointValues(field);
        if (values == null) {
          return null;
        }

        FixedBitSet result = new FixedBitSet(reader.maxDoc());
        float[] scores = new float[reader.maxDoc()];
        values.intersect(new MergePointVisitor(sortedPackedPoints, result, scores));
        return new Scorer(this) {

          DocIdSetIterator disi = new BitSetIterator(result, 10L);

          @Override
          public float score() throws IOException {
            return scores[docID()];
          }

          @Override
          public float getMaxScore(int upTo) throws IOException {
            return Float.POSITIVE_INFINITY;
          }

          @Override
          public int docID() {
            return disi.docID();
          }

          @Override
          public DocIdSetIterator iterator() {
            return disi;
          }

        };
      }

      @Override
      public boolean isCacheable(LeafReaderContext ctx) {
        return true;
      }

    };
  }

  private class MergePointVisitor implements IntersectVisitor {

    private final FixedBitSet result;
    private final float[] scores;

    private TermIterator iterator;
    private Iterator<Float> scoreIterator;
    private BytesRef nextQueryPoint;
    float nextScore;
    private final BytesRef scratch = new BytesRef();

    private MergePointVisitor(PrefixCodedTerms sortedPackedPoints, FixedBitSet result, float[] scores) throws IOException {
      this.result = result;
      this.scores = scores;
      scratch.length = bytesPerDim;
      this.iterator = sortedPackedPoints.iterator();
      this.scoreIterator = aggregatedJoinScores.iterator();
      nextQueryPoint = iterator.next();
      if (scoreIterator.hasNext()) {
        nextScore = scoreIterator.next();
      }
    }

    @Override
    public void visit(int docID) {
      throw new IllegalStateException("shouldn't get here, since CELL_INSIDE_QUERY isn't emitted");
    }

    @Override
    public void visit(int docID, byte[] packedValue) {
      scratch.bytes = packedValue;
      while (nextQueryPoint != null) {
        int cmp = nextQueryPoint.compareTo(scratch);
        if (cmp == 0) {
          // Query point equals index point, so collect and return
          if (multipleValuesPerDocument) {
            if (result.get(docID) == false) {
              result.set(docID);
              scores[docID] = nextScore;
            }
          } else {
            result.set(docID);
            scores[docID] = nextScore;
          }
          break;
        } else if (cmp < 0) {
          // Query point is before index point, so we move to next query point
          nextQueryPoint = iterator.next();
          if (scoreIterator.hasNext()) {
            nextScore = scoreIterator.next();
          }
        } else {
          // Query point is after index point, so we don't collect and we return:
          break;
        }
      }
    }

    @Override
    public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
      while (nextQueryPoint != null) {
        scratch.bytes = minPackedValue;
        int cmpMin = nextQueryPoint.compareTo(scratch);
        if (cmpMin < 0) {
          // query point is before the start of this cell
          nextQueryPoint = iterator.next();
          if (scoreIterator.hasNext()) {
            nextScore = scoreIterator.next();
          }
          continue;
        }
        scratch.bytes = maxPackedValue;
        int cmpMax = nextQueryPoint.compareTo(scratch);
        if (cmpMax > 0) {
          // query point is after the end of this cell
          return Relation.CELL_OUTSIDE_QUERY;
        }

        return Relation.CELL_CROSSES_QUERY;
      }

      // We exhausted all points in the query:
      return Relation.CELL_OUTSIDE_QUERY;
    }
  }

  @Override
  public final int hashCode() {
    int hash = classHash();
    hash = 31 * hash + scoreMode.hashCode();
    hash = 31 * hash + field.hashCode();
    hash = 31 * hash + originalQuery.hashCode();
    hash = 31 * hash + sortedPackedPointsHashCode;
    hash = 31 * hash + bytesPerDim;
    return hash;
  }

  @Override
  public final boolean equals(Object other) {
    return sameClassAs(other) &&
           equalsTo(getClass().cast(other));
  }

  private boolean equalsTo(PointInSetIncludingScoreQuery other) {
    return other.scoreMode.equals(scoreMode) &&
           other.field.equals(field) &&
           other.originalQuery.equals(originalQuery) &&
           other.bytesPerDim == bytesPerDim &&
           other.sortedPackedPointsHashCode == sortedPackedPointsHashCode &&
           other.sortedPackedPoints.equals(sortedPackedPoints);
  }

  @Override
  public final String toString(String field) {
    final StringBuilder sb = new StringBuilder();
    if (this.field.equals(field) == false) {
      sb.append(this.field);
      sb.append(':');
    }

    sb.append("{");

    TermIterator iterator = sortedPackedPoints.iterator();
    byte[] pointBytes = new byte[bytesPerDim];
    boolean first = true;
    for (BytesRef point = iterator.next(); point != null; point = iterator.next()) {
      if (first == false) {
        sb.append(" ");
      }
      first = false;
      System.arraycopy(point.bytes, point.offset, pointBytes, 0, pointBytes.length);
      sb.append(toString(pointBytes));
    }
    sb.append("}");
    return sb.toString();
  }

  protected abstract String toString(byte[] value);

  @Override
  public long ramBytesUsed() {
    return ramBytesUsed;
  }
}