/*
 * Decompiled with CFR 0.152.
 */
package org.shawn.games.Serendipity.NNUE;

import jdk.incubator.vector.IntVector;
import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.Vector;
import jdk.incubator.vector.VectorOperators;
import jdk.incubator.vector.VectorSpecies;
import org.shawn.games.Serendipity.Chess.Side;
import org.shawn.games.Serendipity.NNUE.AccumulatorStack;
import org.shawn.games.Serendipity.NNUE.Inference;

public class SIMDInference
implements Inference {
    private static final VectorSpecies<Short> SHORT_SPECIES = ShortVector.SPECIES_PREFERRED;
    private static final int UPPERBOUND = SHORT_SPECIES.loopBound(1536);

    @Override
    public int forward(AccumulatorStack.AccumulatorPair accumulators, Side side, short[] weights, short bias) {
        AccumulatorStack.Accumulator us = accumulators.get(side);
        AccumulatorStack.Accumulator them = accumulators.get(side.flip());
        IntVector sum = IntVector.zero((VectorSpecies)SHORT_SPECIES.vectorShape().withLanes(Integer.TYPE));
        for (int i = 0; i < UPPERBOUND; i += SHORT_SPECIES.length()) {
            ShortVector usInputs = ShortVector.fromArray(SHORT_SPECIES, (short[])us.values, (int)i);
            ShortVector themInputs = ShortVector.fromArray(SHORT_SPECIES, (short[])them.values, (int)i);
            ShortVector usWeights = ShortVector.fromArray(SHORT_SPECIES, (short[])weights, (int)i);
            ShortVector themWeights = ShortVector.fromArray(SHORT_SPECIES, (short[])weights, (int)(i + 1536));
            usInputs = usInputs.max((Vector)ShortVector.zero(SHORT_SPECIES)).min((Vector)ShortVector.broadcast(SHORT_SPECIES, (long)255L));
            themInputs = themInputs.max((Vector)ShortVector.zero(SHORT_SPECIES)).min((Vector)ShortVector.broadcast(SHORT_SPECIES, (long)255L));
            ShortVector usWeightedTerms = usInputs.mul((Vector)usWeights);
            ShortVector themWeightedTerms = themInputs.mul((Vector)themWeights);
            Vector usInputsLo = usInputs.convert(VectorOperators.S2I, 0);
            Vector usInputsHi = usInputs.convert(VectorOperators.S2I, 1);
            Vector themInputsLo = themInputs.convert(VectorOperators.S2I, 0);
            Vector themInputsHi = themInputs.convert(VectorOperators.S2I, 1);
            Vector usWeightedTermsLo = usWeightedTerms.convert(VectorOperators.S2I, 0);
            Vector usWeightedTermsHi = usWeightedTerms.convert(VectorOperators.S2I, 1);
            Vector themWeightedTermsLo = themWeightedTerms.convert(VectorOperators.S2I, 0);
            Vector themWeightedTermsHi = themWeightedTerms.convert(VectorOperators.S2I, 1);
            sum = sum.add(usInputsLo.mul(usWeightedTermsLo)).add(usInputsHi.mul(usWeightedTermsHi)).add(themInputsLo.mul(themWeightedTermsLo)).add(themInputsHi.mul(themWeightedTermsHi));
        }
        int eval = sum.reduceLanes(VectorOperators.ADD);
        eval /= 255;
        eval += bias;
        eval *= 400;
        return eval /= 16320;
    }

    @Override
    public void add(short[] to, short[] from, short[] added) {
        for (int i = 0; i < UPPERBOUND; i += SHORT_SPECIES.length()) {
            ShortVector fromVector = ShortVector.fromArray(SHORT_SPECIES, (short[])from, (int)i);
            ShortVector addVector = ShortVector.fromArray(SHORT_SPECIES, (short[])added, (int)i);
            fromVector.add((Vector)addVector).intoArray(to, i);
        }
    }

    @Override
    public void sub(short[] to, short[] from, short[] removed) {
        for (int i = 0; i < UPPERBOUND; i += SHORT_SPECIES.length()) {
            ShortVector fromVector = ShortVector.fromArray(SHORT_SPECIES, (short[])from, (int)i);
            ShortVector subVector = ShortVector.fromArray(SHORT_SPECIES, (short[])removed, (int)i);
            fromVector.sub((Vector)subVector).intoArray(to, i);
        }
    }

    @Override
    public void addSub(short[] to, short[] from, short[] added, short[] subtracted) {
        for (int i = 0; i < UPPERBOUND; i += SHORT_SPECIES.length()) {
            ShortVector fromVector = ShortVector.fromArray(SHORT_SPECIES, (short[])from, (int)i);
            ShortVector addVector = ShortVector.fromArray(SHORT_SPECIES, (short[])added, (int)i);
            ShortVector subVector = ShortVector.fromArray(SHORT_SPECIES, (short[])subtracted, (int)i);
            fromVector.add((Vector)addVector).sub((Vector)subVector).intoArray(to, i);
        }
    }

    @Override
    public void addSubSub(short[] to, short[] from, short[] added, short[] subtracted1, short[] subtracted2) {
        for (int i = 0; i < UPPERBOUND; i += SHORT_SPECIES.length()) {
            ShortVector fromVector = ShortVector.fromArray(SHORT_SPECIES, (short[])from, (int)i);
            ShortVector addVector = ShortVector.fromArray(SHORT_SPECIES, (short[])added, (int)i);
            ShortVector subVector1 = ShortVector.fromArray(SHORT_SPECIES, (short[])subtracted1, (int)i);
            ShortVector subVector2 = ShortVector.fromArray(SHORT_SPECIES, (short[])subtracted2, (int)i);
            fromVector.add((Vector)addVector).sub((Vector)subVector1).sub((Vector)subVector2).intoArray(to, i);
        }
    }

    @Override
    public void addAddSubSub(short[] to, short[] from, short[] added1, short[] added2, short[] subtracted1, short[] subtracted2) {
        for (int i = 0; i < UPPERBOUND; i += SHORT_SPECIES.length()) {
            ShortVector fromVector = ShortVector.fromArray(SHORT_SPECIES, (short[])from, (int)i);
            ShortVector addVector1 = ShortVector.fromArray(SHORT_SPECIES, (short[])added1, (int)i);
            ShortVector addVector2 = ShortVector.fromArray(SHORT_SPECIES, (short[])added2, (int)i);
            ShortVector subVector1 = ShortVector.fromArray(SHORT_SPECIES, (short[])subtracted1, (int)i);
            ShortVector subVector2 = ShortVector.fromArray(SHORT_SPECIES, (short[])subtracted2, (int)i);
            fromVector.add((Vector)addVector1).add((Vector)addVector2).sub((Vector)subVector1).sub((Vector)subVector2).intoArray(to, i);
        }
    }
}

