/*
 * Decompiled with CFR 0.152.
 */
package org.shawn.games.Serendipity.NNUE;

import java.io.DataInputStream;
import java.io.IOException;
import java.util.Objects;
import org.shawn.games.Serendipity.Chess.AccumulatorDiff;
import org.shawn.games.Serendipity.Chess.Board;
import org.shawn.games.Serendipity.Chess.Piece;
import org.shawn.games.Serendipity.Chess.PieceType;
import org.shawn.games.Serendipity.Chess.Side;
import org.shawn.games.Serendipity.Chess.Square;
import org.shawn.games.Serendipity.NNUE.AccumulatorStack;
import org.shawn.games.Serendipity.NNUE.Inference;
import org.shawn.games.Serendipity.NNUE.InferenceChooser;

public class NNUE {
    private static final int COLOR_STRIDE = 384;
    private static final int PIECE_STRIDE = 64;
    static final int HIDDEN_SIZE = 1536;
    static final int FEATURE_SIZE = 768;
    static final int OUTPUT_BUCKETS = 8;
    private static final int DIVISOR = 4;
    static final int INPUT_BUCKET_SIZE = 7;
    private static final int[] INPUT_BUCKETS = new int[]{0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6};
    public static final int SCALE = 400;
    public static final int QA = 255;
    public static final int QB = 64;
    final short[][] L1Weights;
    final short[] L1Biases;
    private final short[][] L2Weights;
    private final short[] outputBiases;
    private static final Inference INFERENCE = InferenceChooser.chooseInference();

    private short toLittleEndian(short input) {
        return (short)((input & 0xFF) << 8 | (input & 0xFF00) >> 8);
    }

    public NNUE(String filePath) throws IOException {
        int j;
        int i;
        DataInputStream networkData = new DataInputStream(Objects.requireNonNull(this.getClass().getResourceAsStream(filePath)));
        this.L1Weights = new short[5376][1536];
        for (i = 0; i < 5376; ++i) {
            for (j = 0; j < 1536; ++j) {
                this.L1Weights[i][j] = this.toLittleEndian(networkData.readShort());
            }
        }
        this.L1Biases = new short[1536];
        for (i = 0; i < 1536; ++i) {
            this.L1Biases[i] = this.toLittleEndian(networkData.readShort());
        }
        this.L2Weights = new short[8][3072];
        for (i = 0; i < 3072; ++i) {
            for (j = 0; j < 8; ++j) {
                this.L2Weights[j][i] = this.toLittleEndian(networkData.readShort());
            }
        }
        this.outputBiases = new short[8];
        for (i = 0; i < 8; ++i) {
            this.outputBiases[i] = this.toLittleEndian(networkData.readShort());
        }
        networkData.close();
    }

    public static int chooseOutputBucket(Board board) {
        return (Long.bitCount(board.getBitboard()) - 2) / 4;
    }

    public static int evaluate(Board board, NNUE network, AccumulatorStack accumulators) {
        int chosenBucket = NNUE.chooseOutputBucket(board);
        return INFERENCE.forward(accumulators.refreshAndGet(board), board.getSideToMove(), network.L2Weights[chosenBucket], network.outputBiases[chosenBucket]);
    }

    public static int chooseInputBucket(Board board, Side side) {
        return side.equals((Object)Side.WHITE) ? INPUT_BUCKETS[board.getKingSquare(side).ordinal()] : INPUT_BUCKETS[board.getKingSquare(side).ordinal() ^ 0x38];
    }

    public static int chooseInputBucket(Square square, Side side) {
        return side.equals((Object)Side.WHITE) ? INPUT_BUCKETS[square.ordinal()] : INPUT_BUCKETS[square.ordinal() ^ 0x38];
    }

    public static int getIndex(Square square, Piece piece, Side perspective) {
        return Side.WHITE.equals((Object)perspective) ? piece.getPieceSide().ordinal() * 384 + piece.getPieceType().ordinal() * 64 + square.ordinal() : (piece.getPieceSide().ordinal() ^ 1) * 384 + piece.getPieceType().ordinal() * 64 + (square.ordinal() ^ 0x38);
    }

    public static int getIndex(AccumulatorDiff.DiffInfo diff, Side perspective) {
        return NNUE.getIndex(diff.square, diff.piece, perspective);
    }

    public static boolean requiresRefresh(AccumulatorDiff diff, Side perspective) {
        assert (diff.getAddedCount() <= 1 || !diff.getAdded((int)1).piece.getPieceType().equals((Object)PieceType.KING));
        if (!diff.getAdded((int)0).piece.getPieceType().equals((Object)PieceType.KING)) {
            return false;
        }
        if (!diff.getAdded((int)0).piece.getPieceSide().equals((Object)perspective)) {
            return false;
        }
        assert (diff.getRemoved((int)0).piece.getPieceType().equals((Object)PieceType.KING));
        Square prevKing = diff.getRemoved((int)0).square;
        Square currKing = diff.getAdded((int)0).square;
        return NNUE.chooseInputBucket(prevKing, perspective) != NNUE.chooseInputBucket(currKing, perspective);
    }
}

