/*
 * Decompiled with CFR 0.152.
 */
package org.apache.cassandra.index.sai.disk.v1.segment;

import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import io.github.jbellis.jvector.util.Bits;
import io.github.jbellis.jvector.util.SparseFixedBitSet;
import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.util.List;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.agrona.collections.IntArrayList;
import org.apache.cassandra.db.PartitionPosition;
import org.apache.cassandra.dht.AbstractBounds;
import org.apache.cassandra.index.sai.QueryContext;
import org.apache.cassandra.index.sai.StorageAttachedIndex;
import org.apache.cassandra.index.sai.VectorQueryContext;
import org.apache.cassandra.index.sai.disk.PrimaryKeyMap;
import org.apache.cassandra.index.sai.disk.v1.PerColumnIndexFiles;
import org.apache.cassandra.index.sai.disk.v1.postings.VectorPostingList;
import org.apache.cassandra.index.sai.disk.v1.segment.IndexSegmentSearcher;
import org.apache.cassandra.index.sai.disk.v1.segment.SegmentMetadata;
import org.apache.cassandra.index.sai.disk.v1.vector.DiskAnn;
import org.apache.cassandra.index.sai.disk.v1.vector.OnDiskOrdinalsMap;
import org.apache.cassandra.index.sai.disk.v1.vector.OptimizeFor;
import org.apache.cassandra.index.sai.iterators.KeyRangeIterator;
import org.apache.cassandra.index.sai.iterators.KeyRangeListIterator;
import org.apache.cassandra.index.sai.memory.VectorMemoryIndex;
import org.apache.cassandra.index.sai.plan.Expression;
import org.apache.cassandra.index.sai.postings.IntArrayPostingList;
import org.apache.cassandra.index.sai.postings.PostingList;
import org.apache.cassandra.index.sai.utils.AtomicRatio;
import org.apache.cassandra.index.sai.utils.PrimaryKey;
import org.apache.cassandra.index.sai.utils.RangeUtil;
import org.apache.cassandra.tracing.Tracing;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class VectorIndexSegmentSearcher
extends IndexSegmentSearcher {
    private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
    private final DiskAnn graph;
    private final int globalBruteForceRows;
    private final AtomicRatio actualExpectedRatio = new AtomicRatio();
    private final ThreadLocal<SparseFixedBitSet> cachedBitSets;
    private final OptimizeFor optimizeFor;

    VectorIndexSegmentSearcher(PrimaryKeyMap.Factory primaryKeyMapFactory, PerColumnIndexFiles perIndexFiles, SegmentMetadata segmentMetadata, StorageAttachedIndex index) throws IOException {
        super(primaryKeyMapFactory, perIndexFiles, segmentMetadata, index);
        this.graph = new DiskAnn(segmentMetadata.componentMetadatas, perIndexFiles, index.indexWriterConfig());
        this.cachedBitSets = ThreadLocal.withInitial(() -> new SparseFixedBitSet(this.graph.size()));
        this.globalBruteForceRows = Integer.MAX_VALUE;
        this.optimizeFor = index.indexWriterConfig().getOptimizeFor();
    }

    @Override
    public long indexFileCacheSize() {
        return this.graph.ramBytesUsed();
    }

    @Override
    public KeyRangeIterator search(Expression exp, AbstractBounds<PartitionPosition> keyRange, QueryContext context) throws IOException {
        int limit = context.vectorContext().limit();
        if (logger.isTraceEnabled()) {
            logger.trace(this.index.identifier().logMessage("Searching on expression '{}'..."), (Object)exp);
        }
        if (exp.getIndexOperator() != Expression.IndexOperator.ANN) {
            throw new IllegalArgumentException(this.index.identifier().logMessage("Unsupported expression during ANN index query: " + exp));
        }
        int topK = this.optimizeFor.topKFor(limit);
        BitsOrPostingList bitsOrPostingList = this.bitsOrPostingListForKeyRange(context.vectorContext(), keyRange, topK);
        if (bitsOrPostingList.skipANN()) {
            return this.toPrimaryKeyIterator(bitsOrPostingList.postingList(), context);
        }
        float[] queryVector = this.index.termType().decomposeVector(exp.lower().value.raw.duplicate());
        VectorPostingList vectorPostings = this.graph.search(queryVector, topK, limit, bitsOrPostingList.getBits());
        if (bitsOrPostingList.expectedNodesVisited >= 0) {
            this.updateExpectedNodes(vectorPostings.getVisitedCount(), bitsOrPostingList.expectedNodesVisited);
        }
        return this.toPrimaryKeyIterator(vectorPostings, context);
    }

    private BitsOrPostingList bitsOrPostingListForKeyRange(VectorQueryContext context, AbstractBounds<PartitionPosition> keyRange, int limit) throws IOException {
        try (PrimaryKeyMap primaryKeyMap = this.primaryKeyMapFactory.newPerSSTablePrimaryKeyMap();){
            BitsOrPostingList bitsOrPostingList;
            if (RangeUtil.coversFullRing(keyRange)) {
                BitsOrPostingList bitsOrPostingList2 = new BitsOrPostingList(context.bitsetForShadowedPrimaryKeys(this.metadata, primaryKeyMap, this.graph));
                return bitsOrPostingList2;
            }
            long minSSTableRowId = primaryKeyMap.ceiling(((PartitionPosition)keyRange.left).getToken());
            if (minSSTableRowId < 0L) {
                BitsOrPostingList bitsOrPostingList3 = new BitsOrPostingList(PostingList.EMPTY);
                return bitsOrPostingList3;
            }
            long maxSSTableRowId = this.getMaxSSTableRowId(primaryKeyMap, (PartitionPosition)keyRange.right);
            if (minSSTableRowId > maxSSTableRowId) {
                BitsOrPostingList bitsOrPostingList4 = new BitsOrPostingList(PostingList.EMPTY);
                return bitsOrPostingList4;
            }
            if (minSSTableRowId <= this.metadata.minSSTableRowId && maxSSTableRowId >= this.metadata.maxSSTableRowId) {
                BitsOrPostingList bitsOrPostingList5 = new BitsOrPostingList(context.bitsetForShadowedPrimaryKeys(this.metadata, primaryKeyMap, this.graph));
                return bitsOrPostingList5;
            }
            minSSTableRowId = Math.max(minSSTableRowId, this.metadata.minSSTableRowId);
            maxSSTableRowId = Math.min(maxSSTableRowId, this.metadata.maxSSTableRowId);
            int nRows = Math.toIntExact(maxSSTableRowId - minSSTableRowId + 1L);
            int maxBruteForceRows = Math.min(this.globalBruteForceRows, this.maxBruteForceRows(limit, nRows, this.graph.size()));
            logger.trace("Search range covers {} rows; max brute force rows is {} for sstable index with {} nodes, LIMIT {}", new Object[]{nRows, maxBruteForceRows, this.graph.size(), limit});
            Tracing.trace("Search range covers {} rows; max brute force rows is {} for sstable index with {} nodes, LIMIT {}", nRows, maxBruteForceRows, this.graph.size(), limit);
            if (nRows <= maxBruteForceRows) {
                IntArrayList postings = new IntArrayList(Math.toIntExact(nRows), -1);
                for (long sstableRowId = minSSTableRowId; sstableRowId <= maxSSTableRowId; ++sstableRowId) {
                    if (!context.shouldInclude(sstableRowId, primaryKeyMap)) continue;
                    postings.addInt(this.metadata.toSegmentRowId(sstableRowId));
                }
                BitsOrPostingList sstableRowId = new BitsOrPostingList(new IntArrayPostingList(postings.toIntArray()));
                return sstableRowId;
            }
            SparseFixedBitSet bits = this.bitSetForSearch();
            boolean hasMatches = false;
            try (OnDiskOrdinalsMap.OrdinalsView ordinalsView = this.graph.getOrdinalsView();){
                for (long sstableRowId = minSSTableRowId; sstableRowId <= maxSSTableRowId; ++sstableRowId) {
                    int segmentRowId;
                    int ordinal;
                    if (!context.shouldInclude(sstableRowId, primaryKeyMap) || (ordinal = ordinalsView.getOrdinalForRowId(segmentRowId = this.metadata.toSegmentRowId(sstableRowId))) < 0) continue;
                    bits.set(ordinal);
                    hasMatches = true;
                }
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
            if (!hasMatches) {
                bitsOrPostingList = new BitsOrPostingList(PostingList.EMPTY);
                return bitsOrPostingList;
            }
            bitsOrPostingList = new BitsOrPostingList(bits, VectorMemoryIndex.expectedNodesVisited(limit, nRows, this.graph.size()));
            return bitsOrPostingList;
        }
    }

    private long getMaxSSTableRowId(PrimaryKeyMap primaryKeyMap, PartitionPosition right) {
        if (right.isMinimum()) {
            return this.metadata.maxSSTableRowId;
        }
        long max = primaryKeyMap.floor(right.getToken());
        if (max < 0L) {
            return this.metadata.maxSSTableRowId;
        }
        return max;
    }

    private SparseFixedBitSet bitSetForSearch() {
        SparseFixedBitSet bits = this.cachedBitSets.get();
        bits.clear();
        return bits;
    }

    @Override
    public KeyRangeIterator limitToTopKResults(QueryContext context, List<PrimaryKey> primaryKeys, Expression expression) throws IOException {
        int limit = context.vectorContext().limit();
        List<PrimaryKey> keysInRange = primaryKeys.stream().dropWhile(k -> k.compareTo(this.metadata.minKey) < 0).takeWhile(k -> k.compareTo(this.metadata.maxKey) <= 0).collect(Collectors.toList());
        if (keysInRange.isEmpty()) {
            return KeyRangeIterator.empty();
        }
        int topK = this.optimizeFor.topKFor(limit);
        if (this.shouldUseBruteForce(topK, limit, keysInRange.size())) {
            return new KeyRangeListIterator(this.metadata.minKey, this.metadata.maxKey, keysInRange);
        }
        try (PrimaryKeyMap primaryKeyMap = this.primaryKeyMapFactory.newPerSSTablePrimaryKeyMap();){
            int maxSegmentRowId = this.metadata.toSegmentRowId(this.metadata.maxSSTableRowId);
            SparseFixedBitSet bits = this.bitSetForSearch();
            IntArrayList rowIds = new IntArrayList();
            try (AutoCloseable ordinalsView = this.graph.getOrdinalsView();){
                for (PrimaryKey primaryKey : keysInRange) {
                    long sstableRowId = primaryKeyMap.rowIdFromPrimaryKey(primaryKey);
                    if (sstableRowId < this.metadata.minSSTableRowId) continue;
                    if (sstableRowId > this.metadata.maxSSTableRowId) {
                        break;
                    }
                    int segmentRowId = this.metadata.toSegmentRowId(sstableRowId);
                    rowIds.add(segmentRowId);
                    int ordinal = ((OnDiskOrdinalsMap.OrdinalsView)ordinalsView).getOrdinalForRowId(segmentRowId);
                    if (ordinal < 0) continue;
                    bits.set(ordinal);
                }
            }
            if (this.shouldUseBruteForce(topK, limit, rowIds.size())) {
                ordinalsView = this.toPrimaryKeyIterator(new IntArrayPostingList(rowIds.toIntArray()), context);
                return ordinalsView;
            }
            float[] queryVector = this.index.termType().decomposeVector(expression.lower().value.raw.duplicate());
            VectorPostingList results = this.graph.search(queryVector, topK, limit, bits);
            this.updateExpectedNodes(results.getVisitedCount(), this.expectedNodesVisited(topK, maxSegmentRowId, this.graph.size()));
            KeyRangeIterator keyRangeIterator = this.toPrimaryKeyIterator(results, context);
            return keyRangeIterator;
        }
    }

    private boolean shouldUseBruteForce(int topK, int limit, int numRows) {
        int maxBruteForceRows = Math.min(this.globalBruteForceRows, this.maxBruteForceRows(topK, numRows, this.graph.size()));
        logger.trace("SAI materialized {} rows; max brute force rows is {} for sstable index with {} nodes, LIMIT {}", new Object[]{numRows, maxBruteForceRows, this.graph.size(), limit});
        Tracing.trace("SAI materialized {} rows; max brute force rows is {} for sstable index with {} nodes, LIMIT {}", numRows, maxBruteForceRows, this.graph.size(), limit);
        return numRows <= maxBruteForceRows;
    }

    private int maxBruteForceRows(int limit, int nPermittedOrdinals, int graphSize) {
        int expectedNodes = this.expectedNodesVisited(limit, nPermittedOrdinals, graphSize);
        return Math.max(limit, expectedNodes);
    }

    private int expectedNodesVisited(int limit, int nPermittedOrdinals, int graphSize) {
        double observedRatio = this.actualExpectedRatio.getUpdateCount() >= 10 ? this.actualExpectedRatio.get() : 1.0;
        return (int)(observedRatio * (double)VectorMemoryIndex.expectedNodesVisited(limit, nPermittedOrdinals, graphSize));
    }

    private void updateExpectedNodes(int actualNodesVisited, int expectedNodesVisited) {
        assert (expectedNodesVisited >= 0) : expectedNodesVisited;
        assert (actualNodesVisited >= 0) : actualNodesVisited;
        if (actualNodesVisited >= 1000 && actualNodesVisited > 2 * expectedNodesVisited || expectedNodesVisited > 2 * actualNodesVisited) {
            logger.warn("Predicted visiting {} nodes, but actually visited {}", (Object)expectedNodesVisited, (Object)actualNodesVisited);
        }
        this.actualExpectedRatio.update(actualNodesVisited, expectedNodesVisited);
    }

    public String toString() {
        return MoreObjects.toStringHelper(this).add("index", this.index).toString();
    }

    @Override
    public void close() throws IOException {
        this.graph.close();
    }

    private static class BitsOrPostingList {
        private final Bits bits;
        private final int expectedNodesVisited;
        private final PostingList postingList;

        public BitsOrPostingList(@Nullable Bits bits, int expectedNodesVisited) {
            this.bits = bits;
            this.expectedNodesVisited = expectedNodesVisited;
            this.postingList = null;
        }

        public BitsOrPostingList(@Nullable Bits bits) {
            this.bits = bits;
            this.postingList = null;
            this.expectedNodesVisited = -1;
        }

        public BitsOrPostingList(PostingList postingList) {
            this.bits = null;
            this.postingList = Preconditions.checkNotNull(postingList);
            this.expectedNodesVisited = -1;
        }

        @Nullable
        public Bits getBits() {
            Preconditions.checkState(!this.skipANN());
            return this.bits;
        }

        public PostingList postingList() {
            Preconditions.checkState(this.skipANN());
            return this.postingList;
        }

        public boolean skipANN() {
            return this.postingList != null;
        }
    }
}

