/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.search;

import java.io.IOException;
import java.util.Objects;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocAndScoreQuery;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.DoubleValues;
import org.apache.lucene.search.DoubleValuesSource;
import org.apache.lucene.search.FullPrecisionFloatVectorSimilarityValuesSource;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.LateInteractionFloatValuesSource;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.Weight;

public class RescoreTopNQuery
extends Query {
    private final int n;
    private final Query query;
    private final DoubleValuesSource valuesSource;

    public RescoreTopNQuery(Query query, DoubleValuesSource valuesSource, int n) {
        if (n < 1) {
            throw new IllegalArgumentException("n must be >= 1");
        }
        this.query = query;
        this.valuesSource = valuesSource;
        this.n = n;
    }

    @Override
    public Query rewrite(IndexSearcher indexSearcher) throws IOException {
        DoubleValuesSource rewrittenValueSource = this.valuesSource.rewrite(indexSearcher);
        IndexReader reader = indexSearcher.getIndexReader();
        Query rewritten = indexSearcher.rewrite(this.query);
        Weight weight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1.0f);
        HitQueue queue = new HitQueue(this.n, false);
        int originalCount = 0;
        for (LeafReaderContext leaf : reader.leaves()) {
            Scorer innerScorer = weight.scorer(leaf);
            if (innerScorer == null) continue;
            DoubleValues rescores = rewrittenValueSource.getValues(leaf, this.getDoubleValues(innerScorer));
            DocIdSetIterator iterator = innerScorer.iterator();
            while (iterator.nextDoc() != Integer.MAX_VALUE) {
                int docId = iterator.docID();
                if (rescores.advanceExact(docId)) {
                    double v = rescores.doubleValue();
                    queue.insertWithOverflow(new ScoreDoc(leaf.docBase + docId, (float)v));
                } else {
                    queue.insertWithOverflow(new ScoreDoc(leaf.docBase + docId, 0.0f));
                }
                ++originalCount;
            }
        }
        int i = 0;
        ScoreDoc[] scoreDocs = new ScoreDoc[queue.size()];
        for (ScoreDoc topDoc : queue) {
            scoreDocs[i++] = topDoc;
        }
        TopDocs topDocs = new TopDocs(new TotalHits(originalCount, TotalHits.Relation.EQUAL_TO), scoreDocs);
        return DocAndScoreQuery.createDocAndScoreQuery(reader, topDocs);
    }

    private DoubleValues getDoubleValues(Scorer innerScorer) {
        if (!this.valuesSource.needsScores()) {
            return null;
        }
        return DoubleValuesSource.fromScorer(innerScorer);
    }

    @Override
    public int hashCode() {
        int result = this.valuesSource.hashCode();
        result = 31 * result + Objects.hash(this.query, this.n);
        return result;
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        RescoreTopNQuery that = (RescoreTopNQuery)o;
        return Objects.equals(this.query, that.query) && Objects.equals(this.valuesSource, that.valuesSource) && this.n == that.n;
    }

    @Override
    public void visit(QueryVisitor visitor) {
        this.query.visit(visitor);
    }

    @Override
    public String toString(String field) {
        return this.getClass().getSimpleName() + ":" + this.query.toString(field) + ":" + this.valuesSource.toString() + "[" + this.n + "]";
    }

    public static Query createFullPrecisionRescorerQuery(Query in, float[] targetVector, String field, int n) {
        FullPrecisionFloatVectorSimilarityValuesSource valuaSource = new FullPrecisionFloatVectorSimilarityValuesSource(targetVector, field);
        return new RescoreTopNQuery(in, valuaSource, n);
    }

    public static Query createLateInteractionQuery(Query in, int n, String fieldName, float[][] queryVector, VectorSimilarityFunction vectorSimilarityFunction) {
        LateInteractionFloatValuesSource valuesSource = new LateInteractionFloatValuesSource(fieldName, queryVector, vectorSimilarityFunction);
        return new RescoreTopNQuery(in, valuesSource, n);
    }
}

