/*
 * Decompiled with CFR 0.152.
 */
package org.shawn.games.Serendipity.NNUE;

import org.shawn.games.Serendipity.Chess.Side;
import org.shawn.games.Serendipity.NNUE.AccumulatorStack;
import org.shawn.games.Serendipity.NNUE.Inference;

public class ScalarInference
implements Inference {
    private static final int[] screlu = new int[65536];

    private static int screlu(short i) {
        int v = Math.max(0, Math.min(i, 255));
        return v * v;
    }

    @Override
    public int forward(AccumulatorStack.AccumulatorPair accumulators, Side side, short[] weights, short bias) {
        int eval = 0;
        AccumulatorStack.Accumulator us = accumulators.get(side);
        AccumulatorStack.Accumulator them = accumulators.get(side.flip());
        for (int i = 0; i < 1536; ++i) {
            eval += screlu[us.values[i] - Short.MIN_VALUE] * weights[i] + screlu[them.values[i] - Short.MIN_VALUE] * weights[i + 1536];
        }
        eval /= 255;
        eval += bias;
        eval *= 400;
        return eval /= 16320;
    }

    @Override
    public void add(short[] to, short[] from, short[] added) {
        for (int i = 0; i < 1536; ++i) {
            to[i] = (short)(from[i] + added[i]);
        }
    }

    @Override
    public void sub(short[] to, short[] from, short[] removed) {
        for (int i = 0; i < 1536; ++i) {
            to[i] = (short)(from[i] - removed[i]);
        }
    }

    @Override
    public void addSub(short[] to, short[] from, short[] added, short[] subtracted) {
        for (int i = 0; i < 1536; ++i) {
            to[i] = (short)(from[i] + added[i] - subtracted[i]);
        }
    }

    @Override
    public void addSubSub(short[] to, short[] from, short[] added, short[] subtracted1, short[] subtracted2) {
        for (int i = 0; i < 1536; ++i) {
            to[i] = (short)(from[i] + added[i] - subtracted1[i] - subtracted2[i]);
        }
    }

    @Override
    public void addAddSubSub(short[] to, short[] from, short[] added1, short[] added2, short[] subtracted1, short[] subtracted2) {
        for (int i = 0; i < 1536; ++i) {
            to[i] = (short)(from[i] + added1[i] + added2[i] - subtracted1[i] - subtracted2[i]);
        }
    }

    static {
        for (int i = Short.MIN_VALUE; i <= Short.MAX_VALUE; ++i) {
            ScalarInference.screlu[i - Short.MIN_VALUE] = ScalarInference.screlu((short)i);
        }
    }
}

