/*
 * Decompiled with CFR 0.152.
 */
package org.shawn.games.Serendipity.NNUE;

import org.shawn.games.Serendipity.Chess.AccumulatorDiff;
import org.shawn.games.Serendipity.Chess.Bitboard;
import org.shawn.games.Serendipity.Chess.Board;
import org.shawn.games.Serendipity.Chess.Piece;
import org.shawn.games.Serendipity.Chess.PieceType;
import org.shawn.games.Serendipity.Chess.Side;
import org.shawn.games.Serendipity.Chess.Square;
import org.shawn.games.Serendipity.Chess.move.Move;
import org.shawn.games.Serendipity.NNUE.AccumulatorCache;
import org.shawn.games.Serendipity.NNUE.Inference;
import org.shawn.games.Serendipity.NNUE.InferenceChooser;
import org.shawn.games.Serendipity.NNUE.NNUE;

public class AccumulatorStack {
    private static final Inference INFERENCE = InferenceChooser.chooseInference();
    private final NNUE network;
    private final AccumulatorPair[] stack;
    private final AccumulatorCache cache;
    private int top;

    public AccumulatorStack(NNUE network) {
        this.network = network;
        this.stack = new AccumulatorPair[246];
        this.cache = new AccumulatorCache(network);
        for (int i = 0; i < this.stack.length; ++i) {
            this.stack[i] = new AccumulatorPair();
        }
    }

    public void pop() {
        --this.top;
    }

    public void push(Board board, Move move, AccumulatorDiff diff) {
        ++this.top;
        this.stack[this.top].makeMove(this.stack[this.top - 1], board, move, diff);
    }

    public void init(Board board) {
        this.top = 0;
        this.stack[0].init();
        this.stack[0].loadFromBoard(board);
    }

    public void printAccumulator(Side side) {
        for (int i = 0; i < 1536; ++i) {
            System.out.print(this.stack[this.top].accumulators[side.ordinal()].values[i] + ", ");
        }
    }

    private void efficientlyUpdate(int fromIdx, int toIdx, Side side) {
        for (int i = fromIdx; i < toIdx; ++i) {
            Accumulator from = this.stack[i].get(side);
            Accumulator to = this.stack[i + 1].get(side);
            assert (!from.needsRefresh && to.needsRefresh);
            to.efficientlyUpdate(from);
        }
    }

    private void refreshAccumulator(Board board, Side side) {
        Accumulator startingAccumulator = this.stack[this.top].get(side);
        if (!startingAccumulator.needsRefresh) {
            return;
        }
        int startingKingBucket = startingAccumulator.kingBucket;
        for (int currIdx = this.top - 1; currIdx >= 0; --currIdx) {
            Accumulator currAccumulator = this.stack[currIdx].get(side);
            if (currAccumulator.kingBucket != startingKingBucket) break;
            if (currAccumulator.needsRefresh) continue;
            this.efficientlyUpdate(currIdx, this.top, side);
            return;
        }
        startingAccumulator.updateFromCache(board);
    }

    public AccumulatorPair refreshAndGet(Board board) {
        this.refreshAccumulator(board, Side.WHITE);
        this.refreshAccumulator(board, Side.BLACK);
        return this.stack[this.top];
    }

    public class AccumulatorPair {
        Accumulator[] accumulators;

        public AccumulatorPair() {
            this.accumulators = new Accumulator[]{new Accumulator(), new Accumulator()};
        }

        public void init() {
            this.accumulators = new Accumulator[]{new Accumulator(AccumulatorStack.this.network, Side.WHITE), new Accumulator(AccumulatorStack.this.network, Side.BLACK)};
        }

        public void loadFromBoard(Board board) {
            this.accumulators[0].loadFromBoard(board);
            this.accumulators[1].loadFromBoard(board);
        }

        public void makeMove(AccumulatorPair prev, Board board, Move move, AccumulatorDiff diff) {
            this.accumulators[0].makeMove(prev.accumulators[0], board, diff);
            this.accumulators[1].makeMove(prev.accumulators[1], board, diff);
        }

        public Accumulator get(Side side) {
            return this.accumulators[side.ordinal()];
        }
    }

    public class Accumulator {
        short[] values;
        Side color;
        AccumulatorDiff diff;
        int kingBucket;
        boolean needsRefresh;

        public Accumulator() {
            this.values = new short[1536];
        }

        public Accumulator(NNUE network, Side color) {
            this.values = (short[])network.L1Biases.clone();
            this.color = color;
            this.needsRefresh = false;
        }

        private void addSub(Accumulator prev, int featureIndexToAdd, int featureIndexToSubtract) {
            INFERENCE.addSub(this.values, prev.values, AccumulatorStack.this.network.L1Weights[featureIndexToAdd += this.kingBucket * 768], AccumulatorStack.this.network.L1Weights[featureIndexToSubtract += this.kingBucket * 768]);
        }

        private void addSubSub(Accumulator prev, int featureIndexToAdd, int featureIndexToSubtract1, int featureIndexToSubtract2) {
            INFERENCE.addSubSub(this.values, prev.values, AccumulatorStack.this.network.L1Weights[featureIndexToAdd += this.kingBucket * 768], AccumulatorStack.this.network.L1Weights[featureIndexToSubtract1 += this.kingBucket * 768], AccumulatorStack.this.network.L1Weights[featureIndexToSubtract2 += this.kingBucket * 768]);
        }

        private void addAddSubSub(Accumulator prev, int featureIndexToAdd1, int featureIndexToAdd2, int featureIndexToSubtract1, int featureIndexToSubtract2) {
            INFERENCE.addAddSubSub(this.values, prev.values, AccumulatorStack.this.network.L1Weights[featureIndexToAdd1 += this.kingBucket * 768], AccumulatorStack.this.network.L1Weights[featureIndexToAdd2 += this.kingBucket * 768], AccumulatorStack.this.network.L1Weights[featureIndexToSubtract1 += this.kingBucket * 768], AccumulatorStack.this.network.L1Weights[featureIndexToSubtract2 += this.kingBucket * 768]);
        }

        private void efficientlyUpdate(Accumulator prev) {
            int addedCount = this.diff.getAddedCount();
            int removedCount = this.diff.getRemovedCount();
            if (addedCount == 1 && removedCount == 1) {
                int addedIndex = NNUE.getIndex(this.diff.getAdded(0), this.color);
                int removedIndex = NNUE.getIndex(this.diff.getRemoved(0), this.color);
                this.addSub(prev, addedIndex, removedIndex);
            } else if (addedCount == 1 && removedCount == 2) {
                int addedIndex = NNUE.getIndex(this.diff.getAdded(0), this.color);
                int removedIndex0 = NNUE.getIndex(this.diff.getRemoved(0), this.color);
                int removedIndex1 = NNUE.getIndex(this.diff.getRemoved(1), this.color);
                this.addSubSub(prev, addedIndex, removedIndex0, removedIndex1);
            } else {
                assert (addedCount == 2 && removedCount == 2);
                int addedIndex0 = NNUE.getIndex(this.diff.getAdded(0), this.color);
                int addedIndex1 = NNUE.getIndex(this.diff.getAdded(1), this.color);
                int removedIndex0 = NNUE.getIndex(this.diff.getRemoved(0), this.color);
                int removedIndex1 = NNUE.getIndex(this.diff.getRemoved(1), this.color);
                this.addAddSubSub(prev, addedIndex0, addedIndex1, removedIndex0, removedIndex1);
            }
            this.needsRefresh = false;
        }

        private void updateFromCache(Board board) {
            AccumulatorCache.Entry entry = AccumulatorStack.this.cache.get(this.color, NNUE.chooseInputBucket(board, this.color));
            for (Side side : Side.values()) {
                for (PieceType pieceType : PieceType.validValues()) {
                    int featureIndex;
                    Square sq;
                    Piece piece = Piece.make(side, pieceType);
                    long oldBB = entry.getBitboard(side, pieceType);
                    long newBB = board.getBitboard(side, pieceType);
                    long removed = oldBB & (newBB ^ 0xFFFFFFFFFFFFFFFFL);
                    long added = newBB & (oldBB ^ 0xFFFFFFFFFFFFFFFFL);
                    while (added != 0L && removed != 0L) {
                        Square sqRemoved = Square.squareAt(Bitboard.bitScanForward(removed));
                        int removedIndex = NNUE.getIndex(sqRemoved, piece, this.color) + this.kingBucket * 768;
                        Square sqAdded = Square.squareAt(Bitboard.bitScanForward(added));
                        int addedIndex = NNUE.getIndex(sqAdded, piece, this.color) + this.kingBucket * 768;
                        INFERENCE.addSub(entry.storedAccumulator, entry.storedAccumulator, AccumulatorStack.this.network.L1Weights[addedIndex], AccumulatorStack.this.network.L1Weights[removedIndex]);
                        removed = Bitboard.extractLsb(removed);
                        added = Bitboard.extractLsb(added);
                    }
                    while (removed != 0L) {
                        sq = Square.squareAt(Bitboard.bitScanForward(removed));
                        featureIndex = NNUE.getIndex(sq, piece, this.color) + this.kingBucket * 768;
                        INFERENCE.sub(entry.storedAccumulator, entry.storedAccumulator, AccumulatorStack.this.network.L1Weights[featureIndex]);
                        removed = Bitboard.extractLsb(removed);
                    }
                    while (added != 0L) {
                        sq = Square.squareAt(Bitboard.bitScanForward(added));
                        featureIndex = NNUE.getIndex(sq, piece, this.color) + this.kingBucket * 768;
                        INFERENCE.add(entry.storedAccumulator, entry.storedAccumulator, AccumulatorStack.this.network.L1Weights[featureIndex]);
                        added = Bitboard.extractLsb(added);
                    }
                }
            }
            System.arraycopy(entry.storedAccumulator, 0, this.values, 0, 1536);
            entry.update(board);
            this.needsRefresh = false;
        }

        private void makeMove(Accumulator prev, Board board, AccumulatorDiff diff) {
            this.diff = diff;
            this.color = prev.color;
            this.needsRefresh = true;
            this.kingBucket = NNUE.chooseInputBucket(board, this.color);
        }

        private void loadFromBoard(Board board) {
            this.kingBucket = NNUE.chooseInputBucket(board, this.color);
            this.needsRefresh = true;
            this.updateFromCache(board);
        }
    }
}

