/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.util.hnsw;

import java.io.IOException;
import java.util.HashSet;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.SplittableRandom;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.NeighborArray;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;

public final class HnswGraphBuilder {
    public static final int DEFAULT_MAX_CONN = 16;
    public static final int DEFAULT_BEAM_WIDTH = 100;
    private static final long DEFAULT_RAND_SEED = 42L;
    public static final String HNSW_COMPONENT = "HNSW";
    public static long randSeed = 42L;
    private final int M;
    private final double ml;
    private final NeighborArray scratch;
    private final SplittableRandom random;
    private final RandomVectorScorerSupplier scorerSupplier;
    private final HnswGraphSearcher graphSearcher;
    private final GraphBuilderKnnCollector entryCandidates;
    private final GraphBuilderKnnCollector beamCandidates;
    final OnHeapHnswGraph hnsw;
    private InfoStream infoStream = InfoStream.getDefault();
    private final Set<Integer> initializedNodes;

    public static HnswGraphBuilder create(RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed) throws IOException {
        return new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed);
    }

    public static HnswGraphBuilder create(RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed, HnswGraph initializerGraph, Map<Integer, Integer> oldToNewOrdinalMap) throws IOException {
        HnswGraphBuilder hnswGraphBuilder = new HnswGraphBuilder(scorerSupplier, M, beamWidth, seed);
        hnswGraphBuilder.initializeFromGraph(initializerGraph, oldToNewOrdinalMap);
        return hnswGraphBuilder;
    }

    private HnswGraphBuilder(RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed) throws IOException {
        if (M <= 0) {
            throw new IllegalArgumentException("maxConn must be positive");
        }
        if (beamWidth <= 0) {
            throw new IllegalArgumentException("beamWidth must be positive");
        }
        this.M = M;
        this.scorerSupplier = Objects.requireNonNull(scorerSupplier, "scorer supplier must not be null");
        this.ml = M == 1 ? 1.0 : 1.0 / Math.log(1.0 * (double)M);
        this.random = new SplittableRandom(seed);
        this.hnsw = new OnHeapHnswGraph(M);
        this.graphSearcher = new HnswGraphSearcher(new NeighborQueue(beamWidth, true), new FixedBitSet(this.getGraph().size()));
        this.scratch = new NeighborArray(Math.max(beamWidth, M + 1), false);
        this.entryCandidates = new GraphBuilderKnnCollector(1);
        this.beamCandidates = new GraphBuilderKnnCollector(beamWidth);
        this.initializedNodes = new HashSet<Integer>();
    }

    public OnHeapHnswGraph build(int maxOrd) throws IOException {
        if (this.infoStream.isEnabled(HNSW_COMPONENT)) {
            this.infoStream.message(HNSW_COMPONENT, "build graph from " + maxOrd + " vectors");
        }
        this.addVectors(maxOrd);
        return this.hnsw;
    }

    private void initializeFromGraph(HnswGraph initializerGraph, Map<Integer, Integer> oldToNewOrdinalMap) throws IOException {
        assert (this.hnsw.size() == 0);
        for (int level = 0; level < initializerGraph.numLevels(); ++level) {
            HnswGraph.NodesIterator it = initializerGraph.getNodesOnLevel(level);
            while (it.hasNext()) {
                int oldOrd = it.nextInt();
                int newOrd = oldToNewOrdinalMap.get(oldOrd);
                this.hnsw.addNode(level, newOrd);
                if (level == 0) {
                    this.initializedNodes.add(newOrd);
                }
                NeighborArray newNeighbors = this.hnsw.getNeighbors(level, newOrd);
                initializerGraph.seek(level, oldOrd);
                int oldNeighbor = initializerGraph.nextNeighbor();
                while (oldNeighbor != Integer.MAX_VALUE) {
                    int newNeighbor = oldToNewOrdinalMap.get(oldNeighbor);
                    newNeighbors.addOutOfOrder(newNeighbor, Float.NaN);
                    oldNeighbor = initializerGraph.nextNeighbor();
                }
            }
        }
    }

    public void setInfoStream(InfoStream infoStream) {
        this.infoStream = infoStream;
    }

    public OnHeapHnswGraph getGraph() {
        return this.hnsw;
    }

    private void addVectors(int maxOrd) throws IOException {
        long start;
        long t = start = System.nanoTime();
        for (int node = 0; node < maxOrd; ++node) {
            if (this.initializedNodes.contains(node)) continue;
            this.addGraphNode(node);
            if (node % 10000 != 0 || !this.infoStream.isEnabled(HNSW_COMPONENT)) continue;
            t = this.printGraphBuildStatus(node, start, t);
        }
    }

    public void addGraphNode(int node) throws IOException {
        int level;
        RandomVectorScorer scorer = this.scorerSupplier.scorer(node);
        int nodeLevel = HnswGraphBuilder.getRandomGraphLevel(this.ml, this.random);
        int curMaxLevel = this.hnsw.numLevels() - 1;
        if (this.hnsw.entryNode() == -1) {
            for (int level2 = nodeLevel; level2 >= 0; --level2) {
                this.hnsw.addNode(level2, node);
            }
            return;
        }
        int[] eps = new int[]{this.hnsw.entryNode()};
        for (int level3 = nodeLevel; level3 > curMaxLevel; --level3) {
            this.hnsw.addNode(level3, node);
        }
        GraphBuilderKnnCollector candidates = this.entryCandidates;
        for (level = curMaxLevel; level > nodeLevel; --level) {
            candidates.clear();
            this.graphSearcher.searchLevel(candidates, scorer, level, eps, this.hnsw, null);
            eps = new int[]{candidates.popNode()};
        }
        candidates = this.beamCandidates;
        for (level = Math.min(nodeLevel, curMaxLevel); level >= 0; --level) {
            candidates.clear();
            this.graphSearcher.searchLevel(candidates, scorer, level, eps, this.hnsw, null);
            eps = candidates.popUntilNearestKNodes();
            this.hnsw.addNode(level, node);
            this.addDiverseNeighbors(level, node, candidates);
        }
    }

    private long printGraphBuildStatus(int node, long start, long t) {
        long now = System.nanoTime();
        this.infoStream.message(HNSW_COMPONENT, String.format(Locale.ROOT, "built %d in %d/%d ms", node, TimeUnit.NANOSECONDS.toMillis(now - t), TimeUnit.NANOSECONDS.toMillis(now - start)));
        return now;
    }

    private void addDiverseNeighbors(int level, int node, GraphBuilderKnnCollector candidates) throws IOException {
        NeighborArray neighbors = this.hnsw.getNeighbors(level, node);
        assert (neighbors.size() == 0);
        this.popToScratch(candidates);
        int maxConnOnLevel = level == 0 ? this.M * 2 : this.M;
        this.selectAndLinkDiverse(neighbors, this.scratch, maxConnOnLevel);
        int size = neighbors.size();
        for (int i = 0; i < size; ++i) {
            int nbr = neighbors.node[i];
            NeighborArray nbrsOfNbr = this.hnsw.getNeighbors(level, nbr);
            nbrsOfNbr.addOutOfOrder(node, neighbors.score[i]);
            if (nbrsOfNbr.size() <= maxConnOnLevel) continue;
            int indexToRemove = this.findWorstNonDiverse(nbrsOfNbr, nbr);
            nbrsOfNbr.removeIndex(indexToRemove);
        }
    }

    private void selectAndLinkDiverse(NeighborArray neighbors, NeighborArray candidates, int maxConnOnLevel) throws IOException {
        for (int i = candidates.size() - 1; neighbors.size() < maxConnOnLevel && i >= 0; --i) {
            int cNode = candidates.node[i];
            float cScore = candidates.score[i];
            assert (cNode < this.hnsw.size());
            if (!this.diversityCheck(cNode, cScore, neighbors)) continue;
            neighbors.addInOrder(cNode, cScore);
        }
    }

    private void popToScratch(GraphBuilderKnnCollector candidates) {
        this.scratch.clear();
        int candidateCount = candidates.size();
        for (int i = 0; i < candidateCount; ++i) {
            float maxSimilarity = candidates.minimumScore();
            this.scratch.addInOrder(candidates.popNode(), maxSimilarity);
        }
    }

    private boolean diversityCheck(int candidate, float score, NeighborArray neighbors) throws IOException {
        RandomVectorScorer scorer = this.scorerSupplier.scorer(candidate);
        for (int i = 0; i < neighbors.size(); ++i) {
            float neighborSimilarity = scorer.score(neighbors.node[i]);
            if (!(neighborSimilarity >= score)) continue;
            return false;
        }
        return true;
    }

    private int findWorstNonDiverse(NeighborArray neighbors, int nodeOrd) throws IOException {
        RandomVectorScorer scorer = this.scorerSupplier.scorer(nodeOrd);
        int[] uncheckedIndexes = neighbors.sort(scorer);
        if (uncheckedIndexes == null) {
            return neighbors.size() - 1;
        }
        int uncheckedCursor = uncheckedIndexes.length - 1;
        for (int i = neighbors.size() - 1; i > 0 && uncheckedCursor >= 0; --i) {
            if (this.isWorstNonDiverse(i, neighbors, uncheckedIndexes, uncheckedCursor)) {
                return i;
            }
            if (i != uncheckedIndexes[uncheckedCursor]) continue;
            --uncheckedCursor;
        }
        return neighbors.size() - 1;
    }

    private boolean isWorstNonDiverse(int candidateIndex, NeighborArray neighbors, int[] uncheckedIndexes, int uncheckedCursor) throws IOException {
        float minAcceptedSimilarity = neighbors.score[candidateIndex];
        RandomVectorScorer scorer = this.scorerSupplier.scorer(neighbors.node[candidateIndex]);
        if (candidateIndex == uncheckedIndexes[uncheckedCursor]) {
            for (int i = candidateIndex - 1; i >= 0; --i) {
                float neighborSimilarity = scorer.score(neighbors.node[i]);
                if (!(neighborSimilarity >= minAcceptedSimilarity)) continue;
                return true;
            }
        } else {
            assert (candidateIndex > uncheckedIndexes[uncheckedCursor]);
            for (int i = uncheckedCursor; i >= 0; --i) {
                float neighborSimilarity = scorer.score(neighbors.node[uncheckedIndexes[i]]);
                if (!(neighborSimilarity >= minAcceptedSimilarity)) continue;
                return true;
            }
        }
        return false;
    }

    private static int getRandomGraphLevel(double ml, SplittableRandom random) {
        double randDouble;
        while ((randDouble = random.nextDouble()) == 0.0) {
        }
        return (int)(-Math.log(randDouble) * ml);
    }

    public static final class GraphBuilderKnnCollector
    implements KnnCollector {
        private final NeighborQueue queue;
        private final int k;
        private long visitedCount;

        public GraphBuilderKnnCollector(int k) {
            this.queue = new NeighborQueue(k, false);
            this.k = k;
        }

        public int size() {
            return this.queue.size();
        }

        public int popNode() {
            return this.queue.pop();
        }

        public int[] popUntilNearestKNodes() {
            while (this.size() > this.k()) {
                this.queue.pop();
            }
            return this.queue.nodes();
        }

        float minimumScore() {
            return this.queue.topScore();
        }

        public void clear() {
            this.queue.clear();
            this.visitedCount = 0L;
        }

        @Override
        public boolean earlyTerminated() {
            return false;
        }

        @Override
        public void incVisitedCount(int count) {
            this.visitedCount += (long)count;
        }

        @Override
        public long visitedCount() {
            return this.visitedCount;
        }

        @Override
        public long visitLimit() {
            return Long.MAX_VALUE;
        }

        @Override
        public int k() {
            return this.k;
        }

        @Override
        public boolean collect(int docId, float similarity) {
            return this.queue.insertWithOverflow(docId, similarity);
        }

        @Override
        public float minCompetitiveSimilarity() {
            return this.queue.size() >= this.k() ? this.queue.topScore() : Float.NEGATIVE_INFINITY;
        }

        @Override
        public TopDocs topDocs() {
            throw new IllegalArgumentException();
        }
    }
}

