/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.ars_nouveau.internal.vectorization;

import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.nio.ByteOrder;
import jdk.incubator.vector.ByteVector;
import jdk.incubator.vector.FloatVector;
import jdk.incubator.vector.IntVector;
import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.Vector;
import jdk.incubator.vector.VectorMask;
import jdk.incubator.vector.VectorOperators;
import jdk.incubator.vector.VectorShape;
import jdk.incubator.vector.VectorSpecies;
import org.apache.lucene.ars_nouveau.internal.vectorization.PanamaVectorConstants;
import org.apache.lucene.ars_nouveau.internal.vectorization.VectorUtilSupport;
import org.apache.lucene.ars_nouveau.util.Constants;
import org.apache.lucene.ars_nouveau.util.SuppressForbidden;

final class PanamaVectorUtilSupport
implements VectorUtilSupport {
    private static final VectorSpecies<Float> FLOAT_SPECIES;
    private static final VectorSpecies<Integer> INT_SPECIES;
    private static final VectorSpecies<Byte> BYTE_SPECIES;
    private static final VectorSpecies<Short> SHORT_SPECIES;
    static final int VECTOR_BITSIZE;
    private static final boolean ENABLE_FIND_NEXT_GEQ_VECTOR_OPTO;

    PanamaVectorUtilSupport() {
    }

    private static FloatVector fma(FloatVector a, FloatVector b, FloatVector c) {
        if (Constants.HAS_FAST_VECTOR_FMA) {
            return a.fma((Vector)b, (Vector)c);
        }
        return a.mul((Vector)b).add((Vector)c);
    }

    @SuppressForbidden(reason="Uses FMA only where fast and carefully contained")
    private static float fma(float a, float b, float c) {
        if (Constants.HAS_FAST_SCALAR_FMA) {
            return Math.fma(a, b, c);
        }
        return a * b + c;
    }

    @Override
    public float dotProduct(float[] a, float[] b) {
        int i = 0;
        float res = 0.0f;
        if (a.length > 2 * FLOAT_SPECIES.length()) {
            res += this.dotProductBody(a, b, i += FLOAT_SPECIES.loopBound(a.length));
        }
        while (i < a.length) {
            res = PanamaVectorUtilSupport.fma(a[i], b[i], res);
            ++i;
        }
        return res;
    }

    private float dotProductBody(float[] a, float[] b, int limit) {
        FloatVector vb;
        FloatVector va;
        int i;
        FloatVector acc1 = FloatVector.zero(FLOAT_SPECIES);
        FloatVector acc2 = FloatVector.zero(FLOAT_SPECIES);
        FloatVector acc3 = FloatVector.zero(FLOAT_SPECIES);
        FloatVector acc4 = FloatVector.zero(FLOAT_SPECIES);
        int unrolledLimit = limit - 3 * FLOAT_SPECIES.length();
        for (i = 0; i < unrolledLimit; i += 4 * FLOAT_SPECIES.length()) {
            va = FloatVector.fromArray(FLOAT_SPECIES, (float[])a, (int)i);
            vb = FloatVector.fromArray(FLOAT_SPECIES, (float[])b, (int)i);
            acc1 = PanamaVectorUtilSupport.fma(va, vb, acc1);
            FloatVector vc = FloatVector.fromArray(FLOAT_SPECIES, (float[])a, (int)(i + FLOAT_SPECIES.length()));
            FloatVector vd = FloatVector.fromArray(FLOAT_SPECIES, (float[])b, (int)(i + FLOAT_SPECIES.length()));
            acc2 = PanamaVectorUtilSupport.fma(vc, vd, acc2);
            FloatVector ve = FloatVector.fromArray(FLOAT_SPECIES, (float[])a, (int)(i + 2 * FLOAT_SPECIES.length()));
            FloatVector vf = FloatVector.fromArray(FLOAT_SPECIES, (float[])b, (int)(i + 2 * FLOAT_SPECIES.length()));
            acc3 = PanamaVectorUtilSupport.fma(ve, vf, acc3);
            FloatVector vg = FloatVector.fromArray(FLOAT_SPECIES, (float[])a, (int)(i + 3 * FLOAT_SPECIES.length()));
            FloatVector vh = FloatVector.fromArray(FLOAT_SPECIES, (float[])b, (int)(i + 3 * FLOAT_SPECIES.length()));
            acc4 = PanamaVectorUtilSupport.fma(vg, vh, acc4);
        }
        while (i < limit) {
            va = FloatVector.fromArray(FLOAT_SPECIES, (float[])a, (int)i);
            vb = FloatVector.fromArray(FLOAT_SPECIES, (float[])b, (int)i);
            acc1 = PanamaVectorUtilSupport.fma(va, vb, acc1);
            i += FLOAT_SPECIES.length();
        }
        FloatVector res1 = acc1.add((Vector)acc2);
        FloatVector res2 = acc3.add((Vector)acc4);
        return res1.add((Vector)res2).reduceLanes(VectorOperators.ADD);
    }

    @Override
    public float cosine(float[] a, float[] b) {
        int i = 0;
        float sum = 0.0f;
        float norm1 = 0.0f;
        float norm2 = 0.0f;
        if (a.length > 2 * FLOAT_SPECIES.length()) {
            float[] ret = this.cosineBody(a, b, i += FLOAT_SPECIES.loopBound(a.length));
            sum += ret[0];
            norm1 += ret[1];
            norm2 += ret[2];
        }
        while (i < a.length) {
            sum = PanamaVectorUtilSupport.fma(a[i], b[i], sum);
            norm1 = PanamaVectorUtilSupport.fma(a[i], a[i], norm1);
            norm2 = PanamaVectorUtilSupport.fma(b[i], b[i], norm2);
            ++i;
        }
        return (float)((double)sum / Math.sqrt((double)norm1 * (double)norm2));
    }

    private float[] cosineBody(float[] a, float[] b, int limit) {
        FloatVector vb;
        FloatVector va;
        int i;
        FloatVector sum1 = FloatVector.zero(FLOAT_SPECIES);
        FloatVector sum2 = FloatVector.zero(FLOAT_SPECIES);
        FloatVector norm1_1 = FloatVector.zero(FLOAT_SPECIES);
        FloatVector norm1_2 = FloatVector.zero(FLOAT_SPECIES);
        FloatVector norm2_1 = FloatVector.zero(FLOAT_SPECIES);
        FloatVector norm2_2 = FloatVector.zero(FLOAT_SPECIES);
        int unrolledLimit = limit - FLOAT_SPECIES.length();
        for (i = 0; i < unrolledLimit; i += 2 * FLOAT_SPECIES.length()) {
            va = FloatVector.fromArray(FLOAT_SPECIES, (float[])a, (int)i);
            vb = FloatVector.fromArray(FLOAT_SPECIES, (float[])b, (int)i);
            sum1 = PanamaVectorUtilSupport.fma(va, vb, sum1);
            norm1_1 = PanamaVectorUtilSupport.fma(va, va, norm1_1);
            norm2_1 = PanamaVectorUtilSupport.fma(vb, vb, norm2_1);
            FloatVector vc = FloatVector.fromArray(FLOAT_SPECIES, (float[])a, (int)(i + FLOAT_SPECIES.length()));
            FloatVector vd = FloatVector.fromArray(FLOAT_SPECIES, (float[])b, (int)(i + FLOAT_SPECIES.length()));
            sum2 = PanamaVectorUtilSupport.fma(vc, vd, sum2);
            norm1_2 = PanamaVectorUtilSupport.fma(vc, vc, norm1_2);
            norm2_2 = PanamaVectorUtilSupport.fma(vd, vd, norm2_2);
        }
        while (i < limit) {
            va = FloatVector.fromArray(FLOAT_SPECIES, (float[])a, (int)i);
            vb = FloatVector.fromArray(FLOAT_SPECIES, (float[])b, (int)i);
            sum1 = PanamaVectorUtilSupport.fma(va, vb, sum1);
            norm1_1 = PanamaVectorUtilSupport.fma(va, va, norm1_1);
            norm2_1 = PanamaVectorUtilSupport.fma(vb, vb, norm2_1);
            i += FLOAT_SPECIES.length();
        }
        return new float[]{sum1.add((Vector)sum2).reduceLanes(VectorOperators.ADD), norm1_1.add((Vector)norm1_2).reduceLanes(VectorOperators.ADD), norm2_1.add((Vector)norm2_2).reduceLanes(VectorOperators.ADD)};
    }

    @Override
    public float squareDistance(float[] a, float[] b) {
        int i = 0;
        float res = 0.0f;
        if (a.length > 2 * FLOAT_SPECIES.length()) {
            res += this.squareDistanceBody(a, b, i += FLOAT_SPECIES.loopBound(a.length));
        }
        while (i < a.length) {
            float diff = a[i] - b[i];
            res = PanamaVectorUtilSupport.fma(diff, diff, res);
            ++i;
        }
        return res;
    }

    private float squareDistanceBody(float[] a, float[] b, int limit) {
        FloatVector vb;
        FloatVector va;
        int i;
        FloatVector acc1 = FloatVector.zero(FLOAT_SPECIES);
        FloatVector acc2 = FloatVector.zero(FLOAT_SPECIES);
        FloatVector acc3 = FloatVector.zero(FLOAT_SPECIES);
        FloatVector acc4 = FloatVector.zero(FLOAT_SPECIES);
        int unrolledLimit = limit - 3 * FLOAT_SPECIES.length();
        for (i = 0; i < unrolledLimit; i += 4 * FLOAT_SPECIES.length()) {
            va = FloatVector.fromArray(FLOAT_SPECIES, (float[])a, (int)i);
            vb = FloatVector.fromArray(FLOAT_SPECIES, (float[])b, (int)i);
            FloatVector diff1 = va.sub((Vector)vb);
            acc1 = PanamaVectorUtilSupport.fma(diff1, diff1, acc1);
            FloatVector vc = FloatVector.fromArray(FLOAT_SPECIES, (float[])a, (int)(i + FLOAT_SPECIES.length()));
            FloatVector vd = FloatVector.fromArray(FLOAT_SPECIES, (float[])b, (int)(i + FLOAT_SPECIES.length()));
            FloatVector diff2 = vc.sub((Vector)vd);
            acc2 = PanamaVectorUtilSupport.fma(diff2, diff2, acc2);
            FloatVector ve = FloatVector.fromArray(FLOAT_SPECIES, (float[])a, (int)(i + 2 * FLOAT_SPECIES.length()));
            FloatVector vf = FloatVector.fromArray(FLOAT_SPECIES, (float[])b, (int)(i + 2 * FLOAT_SPECIES.length()));
            FloatVector diff3 = ve.sub((Vector)vf);
            acc3 = PanamaVectorUtilSupport.fma(diff3, diff3, acc3);
            FloatVector vg = FloatVector.fromArray(FLOAT_SPECIES, (float[])a, (int)(i + 3 * FLOAT_SPECIES.length()));
            FloatVector vh = FloatVector.fromArray(FLOAT_SPECIES, (float[])b, (int)(i + 3 * FLOAT_SPECIES.length()));
            FloatVector diff4 = vg.sub((Vector)vh);
            acc4 = PanamaVectorUtilSupport.fma(diff4, diff4, acc4);
        }
        while (i < limit) {
            va = FloatVector.fromArray(FLOAT_SPECIES, (float[])a, (int)i);
            vb = FloatVector.fromArray(FLOAT_SPECIES, (float[])b, (int)i);
            FloatVector diff = va.sub((Vector)vb);
            acc1 = PanamaVectorUtilSupport.fma(diff, diff, acc1);
            i += FLOAT_SPECIES.length();
        }
        FloatVector res1 = acc1.add((Vector)acc2);
        FloatVector res2 = acc3.add((Vector)acc4);
        return res1.add((Vector)res2).reduceLanes(VectorOperators.ADD);
    }

    @Override
    public int dotProduct(byte[] a, byte[] b) {
        return PanamaVectorUtilSupport.dotProduct(MemorySegment.ofArray(a), MemorySegment.ofArray(b));
    }

    public static int dotProduct(MemorySegment a, MemorySegment b) {
        assert (a.byteSize() == b.byteSize());
        int i = 0;
        int res = 0;
        if (a.byteSize() >= 16L && PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) {
            if (VECTOR_BITSIZE >= 512) {
                i = (int)((long)i + BYTE_SPECIES.loopBound(a.byteSize()));
                res += PanamaVectorUtilSupport.dotProductBody512(a, b, i);
            } else if (VECTOR_BITSIZE == 256) {
                i = (int)((long)i + BYTE_SPECIES.loopBound(a.byteSize()));
                res += PanamaVectorUtilSupport.dotProductBody256(a, b, i);
            } else {
                i = (int)((long)i + ByteVector.SPECIES_64.loopBound(a.byteSize() - (long)ByteVector.SPECIES_64.length()));
                res += PanamaVectorUtilSupport.dotProductBody128(a, b, i);
            }
        }
        while ((long)i < a.byteSize()) {
            res += b.get(ValueLayout.JAVA_BYTE, (long)i) * a.get(ValueLayout.JAVA_BYTE, (long)i);
            ++i;
        }
        return res;
    }

    private static int dotProductBody512(MemorySegment a, MemorySegment b, int limit) {
        IntVector acc = IntVector.zero(INT_SPECIES);
        for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
            ByteVector va8 = ByteVector.fromMemorySegment(BYTE_SPECIES, (MemorySegment)a, (long)i, (ByteOrder)ByteOrder.LITTLE_ENDIAN);
            ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES, (MemorySegment)b, (long)i, (ByteOrder)ByteOrder.LITTLE_ENDIAN);
            Vector va16 = va8.convertShape(VectorOperators.B2S, SHORT_SPECIES, 0);
            Vector vb16 = vb8.convertShape(VectorOperators.B2S, SHORT_SPECIES, 0);
            Vector prod16 = va16.mul(vb16);
            Vector prod32 = prod16.convertShape(VectorOperators.S2I, INT_SPECIES, 0);
            acc = acc.add(prod32);
        }
        return acc.reduceLanes(VectorOperators.ADD);
    }

    private static int dotProductBody256(MemorySegment a, MemorySegment b, int limit) {
        IntVector acc = IntVector.zero((VectorSpecies)IntVector.SPECIES_256);
        for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length()) {
            ByteVector va8 = ByteVector.fromMemorySegment((VectorSpecies)ByteVector.SPECIES_64, (MemorySegment)a, (long)i, (ByteOrder)ByteOrder.LITTLE_ENDIAN);
            ByteVector vb8 = ByteVector.fromMemorySegment((VectorSpecies)ByteVector.SPECIES_64, (MemorySegment)b, (long)i, (ByteOrder)ByteOrder.LITTLE_ENDIAN);
            Vector va32 = va8.convertShape(VectorOperators.B2I, IntVector.SPECIES_256, 0);
            Vector vb32 = vb8.convertShape(VectorOperators.B2I, IntVector.SPECIES_256, 0);
            acc = acc.add(va32.mul(vb32));
        }
        return acc.reduceLanes(VectorOperators.ADD);
    }

    private static int dotProductBody128(MemorySegment a, MemorySegment b, int limit) {
        IntVector acc = IntVector.zero((VectorSpecies)IntVector.SPECIES_128);
        for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length() >> 1) {
            ByteVector va8 = ByteVector.fromMemorySegment((VectorSpecies)ByteVector.SPECIES_64, (MemorySegment)a, (long)i, (ByteOrder)ByteOrder.LITTLE_ENDIAN);
            ByteVector vb8 = ByteVector.fromMemorySegment((VectorSpecies)ByteVector.SPECIES_64, (MemorySegment)b, (long)i, (ByteOrder)ByteOrder.LITTLE_ENDIAN);
            Vector va16 = va8.convert(VectorOperators.B2S, 0);
            Vector vb16 = vb8.convert(VectorOperators.B2S, 0);
            Vector prod16 = va16.mul(vb16);
            acc = acc.add(prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0));
        }
        return acc.reduceLanes(VectorOperators.ADD);
    }

    @Override
    public int int4DotProduct(byte[] a, boolean apacked, byte[] b, boolean bpacked) {
        assert (!(apacked && bpacked));
        int i = 0;
        int res = 0;
        if (apacked || bpacked) {
            byte[] unpacked;
            byte[] packed = apacked ? a : b;
            byte[] byArray = unpacked = apacked ? b : a;
            if (packed.length >= 32) {
                if (VECTOR_BITSIZE >= 512) {
                    res += this.dotProductBody512Int4Packed(unpacked, packed, i += ByteVector.SPECIES_256.loopBound(packed.length));
                } else if (VECTOR_BITSIZE == 256) {
                    res += this.dotProductBody256Int4Packed(unpacked, packed, i += ByteVector.SPECIES_128.loopBound(packed.length));
                } else if (PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) {
                    res += this.dotProductBody128Int4Packed(unpacked, packed, i += ByteVector.SPECIES_64.loopBound(packed.length));
                }
            }
            while (i < packed.length) {
                byte packedByte = packed[i];
                byte unpacked1 = unpacked[i];
                byte unpacked2 = unpacked[i + packed.length];
                res += (packedByte & 0xF) * unpacked2;
                res += ((packedByte & 0xFF) >> 4) * unpacked1;
                ++i;
            }
        } else {
            if (VECTOR_BITSIZE >= 512 || VECTOR_BITSIZE == 256) {
                return this.dotProduct(a, b);
            }
            if (a.length >= 32 && PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) {
                res += this.int4DotProductBody128(a, b, i += ByteVector.SPECIES_128.loopBound(a.length));
            }
            while (i < a.length) {
                res += b[i] * a[i];
                ++i;
            }
        }
        return res;
    }

    private int dotProductBody512Int4Packed(byte[] unpacked, byte[] packed, int limit) {
        int sum = 0;
        for (int i = 0; i < limit; i += 4096) {
            ShortVector acc0 = ShortVector.zero((VectorSpecies)ShortVector.SPECIES_512);
            ShortVector acc1 = ShortVector.zero((VectorSpecies)ShortVector.SPECIES_512);
            int innerLimit = Math.min(limit - i, 4096);
            for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_256.length()) {
                ByteVector vb8 = ByteVector.fromArray((VectorSpecies)ByteVector.SPECIES_256, (byte[])packed, (int)(i + j));
                ByteVector va8 = ByteVector.fromArray((VectorSpecies)ByteVector.SPECIES_256, (byte[])unpacked, (int)(i + j + packed.length));
                ByteVector prod8 = vb8.and((byte)15).mul((Vector)va8);
                Vector prod16 = prod8.convertShape(VectorOperators.ZERO_EXTEND_B2S, ShortVector.SPECIES_512, 0);
                acc0 = acc0.add(prod16);
                ByteVector vc8 = ByteVector.fromArray((VectorSpecies)ByteVector.SPECIES_256, (byte[])unpacked, (int)(i + j));
                ByteVector prod8a = vb8.lanewise(VectorOperators.LSHR, 4L).mul((Vector)vc8);
                Vector prod16a = prod8a.convertShape(VectorOperators.ZERO_EXTEND_B2S, ShortVector.SPECIES_512, 0);
                acc1 = acc1.add(prod16a);
            }
            IntVector intAcc0 = acc0.convertShape(VectorOperators.S2I, IntVector.SPECIES_512, 0).reinterpretAsInts();
            IntVector intAcc1 = acc0.convertShape(VectorOperators.S2I, IntVector.SPECIES_512, 1).reinterpretAsInts();
            IntVector intAcc2 = acc1.convertShape(VectorOperators.S2I, IntVector.SPECIES_512, 0).reinterpretAsInts();
            IntVector intAcc3 = acc1.convertShape(VectorOperators.S2I, IntVector.SPECIES_512, 1).reinterpretAsInts();
            sum += intAcc0.add((Vector)intAcc1).add((Vector)intAcc2).add((Vector)intAcc3).reduceLanes(VectorOperators.ADD);
        }
        return sum;
    }

    private int dotProductBody256Int4Packed(byte[] unpacked, byte[] packed, int limit) {
        int sum = 0;
        for (int i = 0; i < limit; i += 2048) {
            ShortVector acc0 = ShortVector.zero((VectorSpecies)ShortVector.SPECIES_256);
            ShortVector acc1 = ShortVector.zero((VectorSpecies)ShortVector.SPECIES_256);
            int innerLimit = Math.min(limit - i, 2048);
            for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) {
                ByteVector vb8 = ByteVector.fromArray((VectorSpecies)ByteVector.SPECIES_128, (byte[])packed, (int)(i + j));
                ByteVector va8 = ByteVector.fromArray((VectorSpecies)ByteVector.SPECIES_128, (byte[])unpacked, (int)(i + j + packed.length));
                ByteVector prod8 = vb8.and((byte)15).mul((Vector)va8);
                Vector prod16 = prod8.convertShape(VectorOperators.ZERO_EXTEND_B2S, ShortVector.SPECIES_256, 0);
                acc0 = acc0.add(prod16);
                ByteVector vc8 = ByteVector.fromArray((VectorSpecies)ByteVector.SPECIES_128, (byte[])unpacked, (int)(i + j));
                ByteVector prod8a = vb8.lanewise(VectorOperators.LSHR, 4L).mul((Vector)vc8);
                Vector prod16a = prod8a.convertShape(VectorOperators.ZERO_EXTEND_B2S, ShortVector.SPECIES_256, 0);
                acc1 = acc1.add(prod16a);
            }
            IntVector intAcc0 = acc0.convertShape(VectorOperators.S2I, IntVector.SPECIES_256, 0).reinterpretAsInts();
            IntVector intAcc1 = acc0.convertShape(VectorOperators.S2I, IntVector.SPECIES_256, 1).reinterpretAsInts();
            IntVector intAcc2 = acc1.convertShape(VectorOperators.S2I, IntVector.SPECIES_256, 0).reinterpretAsInts();
            IntVector intAcc3 = acc1.convertShape(VectorOperators.S2I, IntVector.SPECIES_256, 1).reinterpretAsInts();
            sum += intAcc0.add((Vector)intAcc1).add((Vector)intAcc2).add((Vector)intAcc3).reduceLanes(VectorOperators.ADD);
        }
        return sum;
    }

    private int dotProductBody128Int4Packed(byte[] unpacked, byte[] packed, int limit) {
        int sum = 0;
        for (int i = 0; i < limit; i += 1024) {
            ShortVector acc0 = ShortVector.zero((VectorSpecies)ShortVector.SPECIES_128);
            ShortVector acc1 = ShortVector.zero((VectorSpecies)ShortVector.SPECIES_128);
            int innerLimit = Math.min(limit - i, 1024);
            for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_64.length()) {
                ByteVector vb8 = ByteVector.fromArray((VectorSpecies)ByteVector.SPECIES_64, (byte[])packed, (int)(i + j));
                ByteVector va8 = ByteVector.fromArray((VectorSpecies)ByteVector.SPECIES_64, (byte[])unpacked, (int)(i + j + packed.length));
                ByteVector prod8 = vb8.and((byte)15).mul((Vector)va8);
                ShortVector prod16 = prod8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
                acc0 = acc0.add((Vector)prod16.and((short)255));
                va8 = ByteVector.fromArray((VectorSpecies)ByteVector.SPECIES_64, (byte[])unpacked, (int)(i + j));
                prod8 = vb8.lanewise(VectorOperators.LSHR, 4L).mul((Vector)va8);
                prod16 = prod8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
                acc1 = acc1.add((Vector)prod16.and((short)255));
            }
            IntVector intAcc0 = acc0.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0).reinterpretAsInts();
            IntVector intAcc1 = acc0.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1).reinterpretAsInts();
            IntVector intAcc2 = acc1.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0).reinterpretAsInts();
            IntVector intAcc3 = acc1.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1).reinterpretAsInts();
            sum += intAcc0.add((Vector)intAcc1).add((Vector)intAcc2).add((Vector)intAcc3).reduceLanes(VectorOperators.ADD);
        }
        return sum;
    }

    private int int4DotProductBody128(byte[] a, byte[] b, int limit) {
        int sum = 0;
        for (int i = 0; i < limit; i += 1024) {
            ShortVector acc0 = ShortVector.zero((VectorSpecies)ShortVector.SPECIES_128);
            ShortVector acc1 = ShortVector.zero((VectorSpecies)ShortVector.SPECIES_128);
            int innerLimit = Math.min(limit - i, 1024);
            for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) {
                ByteVector va8 = ByteVector.fromArray((VectorSpecies)ByteVector.SPECIES_64, (byte[])a, (int)(i + j));
                ByteVector vb8 = ByteVector.fromArray((VectorSpecies)ByteVector.SPECIES_64, (byte[])b, (int)(i + j));
                ByteVector prod8 = va8.mul((Vector)vb8);
                ShortVector prod16 = prod8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
                acc0 = acc0.add((Vector)prod16.and((short)255));
                va8 = ByteVector.fromArray((VectorSpecies)ByteVector.SPECIES_64, (byte[])a, (int)(i + j + 8));
                vb8 = ByteVector.fromArray((VectorSpecies)ByteVector.SPECIES_64, (byte[])b, (int)(i + j + 8));
                prod8 = va8.mul((Vector)vb8);
                prod16 = prod8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
                acc1 = acc1.add((Vector)prod16.and((short)255));
            }
            IntVector intAcc0 = acc0.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0).reinterpretAsInts();
            IntVector intAcc1 = acc0.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1).reinterpretAsInts();
            IntVector intAcc2 = acc1.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0).reinterpretAsInts();
            IntVector intAcc3 = acc1.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1).reinterpretAsInts();
            sum += intAcc0.add((Vector)intAcc1).add((Vector)intAcc2).add((Vector)intAcc3).reduceLanes(VectorOperators.ADD);
        }
        return sum;
    }

    @Override
    public float cosine(byte[] a, byte[] b) {
        return PanamaVectorUtilSupport.cosine(MemorySegment.ofArray(a), MemorySegment.ofArray(b));
    }

    public static float cosine(MemorySegment a, MemorySegment b) {
        int i = 0;
        int sum = 0;
        int norm1 = 0;
        int norm2 = 0;
        if (a.byteSize() >= 16L && PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) {
            float[] ret;
            if (VECTOR_BITSIZE >= 512) {
                ret = PanamaVectorUtilSupport.cosineBody512(a, b, i += BYTE_SPECIES.loopBound((int)a.byteSize()));
            } else if (VECTOR_BITSIZE == 256) {
                ret = PanamaVectorUtilSupport.cosineBody256(a, b, i += BYTE_SPECIES.loopBound((int)a.byteSize()));
            } else {
                i = (int)((long)i + ByteVector.SPECIES_64.loopBound(a.byteSize() - (long)ByteVector.SPECIES_64.length()));
                ret = PanamaVectorUtilSupport.cosineBody128(a, b, i);
            }
            sum = (int)((float)sum + ret[0]);
            norm1 = (int)((float)norm1 + ret[1]);
            norm2 = (int)((float)norm2 + ret[2]);
        }
        while ((long)i < a.byteSize()) {
            byte elem1 = a.get(ValueLayout.JAVA_BYTE, (long)i);
            byte elem2 = b.get(ValueLayout.JAVA_BYTE, (long)i);
            sum += elem1 * elem2;
            norm1 += elem1 * elem1;
            norm2 += elem2 * elem2;
            ++i;
        }
        return (float)((double)sum / Math.sqrt((double)norm1 * (double)norm2));
    }

    private static float[] cosineBody512(MemorySegment a, MemorySegment b, int limit) {
        IntVector accSum = IntVector.zero(INT_SPECIES);
        IntVector accNorm1 = IntVector.zero(INT_SPECIES);
        IntVector accNorm2 = IntVector.zero(INT_SPECIES);
        for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
            ByteVector va8 = ByteVector.fromMemorySegment(BYTE_SPECIES, (MemorySegment)a, (long)i, (ByteOrder)ByteOrder.LITTLE_ENDIAN);
            ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES, (MemorySegment)b, (long)i, (ByteOrder)ByteOrder.LITTLE_ENDIAN);
            Vector va16 = va8.convertShape(VectorOperators.B2S, SHORT_SPECIES, 0);
            Vector vb16 = vb8.convertShape(VectorOperators.B2S, SHORT_SPECIES, 0);
            Vector norm1_16 = va16.mul(va16);
            Vector norm2_16 = vb16.mul(vb16);
            Vector prod16 = va16.mul(vb16);
            Vector norm1_32 = norm1_16.convertShape(VectorOperators.S2I, INT_SPECIES, 0);
            Vector norm2_32 = norm2_16.convertShape(VectorOperators.S2I, INT_SPECIES, 0);
            Vector prod32 = prod16.convertShape(VectorOperators.S2I, INT_SPECIES, 0);
            accNorm1 = accNorm1.add(norm1_32);
            accNorm2 = accNorm2.add(norm2_32);
            accSum = accSum.add(prod32);
        }
        return new float[]{accSum.reduceLanes(VectorOperators.ADD), accNorm1.reduceLanes(VectorOperators.ADD), accNorm2.reduceLanes(VectorOperators.ADD)};
    }

    private static float[] cosineBody256(MemorySegment a, MemorySegment b, int limit) {
        IntVector accSum = IntVector.zero((VectorSpecies)IntVector.SPECIES_256);
        IntVector accNorm1 = IntVector.zero((VectorSpecies)IntVector.SPECIES_256);
        IntVector accNorm2 = IntVector.zero((VectorSpecies)IntVector.SPECIES_256);
        for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length()) {
            ByteVector va8 = ByteVector.fromMemorySegment((VectorSpecies)ByteVector.SPECIES_64, (MemorySegment)a, (long)i, (ByteOrder)ByteOrder.LITTLE_ENDIAN);
            ByteVector vb8 = ByteVector.fromMemorySegment((VectorSpecies)ByteVector.SPECIES_64, (MemorySegment)b, (long)i, (ByteOrder)ByteOrder.LITTLE_ENDIAN);
            Vector va32 = va8.convertShape(VectorOperators.B2I, IntVector.SPECIES_256, 0);
            Vector vb32 = vb8.convertShape(VectorOperators.B2I, IntVector.SPECIES_256, 0);
            Vector norm1_32 = va32.mul(va32);
            Vector norm2_32 = vb32.mul(vb32);
            Vector prod32 = va32.mul(vb32);
            accNorm1 = accNorm1.add(norm1_32);
            accNorm2 = accNorm2.add(norm2_32);
            accSum = accSum.add(prod32);
        }
        return new float[]{accSum.reduceLanes(VectorOperators.ADD), accNorm1.reduceLanes(VectorOperators.ADD), accNorm2.reduceLanes(VectorOperators.ADD)};
    }

    private static float[] cosineBody128(MemorySegment a, MemorySegment b, int limit) {
        IntVector accSum = IntVector.zero((VectorSpecies)IntVector.SPECIES_128);
        IntVector accNorm1 = IntVector.zero((VectorSpecies)IntVector.SPECIES_128);
        IntVector accNorm2 = IntVector.zero((VectorSpecies)IntVector.SPECIES_128);
        for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length() >> 1) {
            ByteVector va8 = ByteVector.fromMemorySegment((VectorSpecies)ByteVector.SPECIES_64, (MemorySegment)a, (long)i, (ByteOrder)ByteOrder.LITTLE_ENDIAN);
            ByteVector vb8 = ByteVector.fromMemorySegment((VectorSpecies)ByteVector.SPECIES_64, (MemorySegment)b, (long)i, (ByteOrder)ByteOrder.LITTLE_ENDIAN);
            Vector va16 = va8.convert(VectorOperators.B2S, 0);
            Vector vb16 = vb8.convert(VectorOperators.B2S, 0);
            Vector norm1_16 = va16.mul(va16);
            Vector norm2_16 = vb16.mul(vb16);
            Vector prod16 = va16.mul(vb16);
            accNorm1 = accNorm1.add(norm1_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0));
            accNorm2 = accNorm2.add(norm2_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0));
            accSum = accSum.add(prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0));
        }
        return new float[]{accSum.reduceLanes(VectorOperators.ADD), accNorm1.reduceLanes(VectorOperators.ADD), accNorm2.reduceLanes(VectorOperators.ADD)};
    }

    @Override
    public int squareDistance(byte[] a, byte[] b) {
        return PanamaVectorUtilSupport.squareDistance(MemorySegment.ofArray(a), MemorySegment.ofArray(b));
    }

    public static int squareDistance(MemorySegment a, MemorySegment b) {
        assert (a.byteSize() == b.byteSize());
        int i = 0;
        int res = 0;
        if (a.byteSize() >= 16L && PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) {
            res = VECTOR_BITSIZE >= 256 ? (res += PanamaVectorUtilSupport.squareDistanceBody256(a, b, i += BYTE_SPECIES.loopBound((int)a.byteSize()))) : (res += PanamaVectorUtilSupport.squareDistanceBody128(a, b, i += ByteVector.SPECIES_64.loopBound((int)a.byteSize())));
        }
        while ((long)i < a.byteSize()) {
            int diff = a.get(ValueLayout.JAVA_BYTE, (long)i) - b.get(ValueLayout.JAVA_BYTE, (long)i);
            res += diff * diff;
            ++i;
        }
        return res;
    }

    private static int squareDistanceBody256(MemorySegment a, MemorySegment b, int limit) {
        IntVector acc = IntVector.zero(INT_SPECIES);
        for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
            ByteVector va8 = ByteVector.fromMemorySegment(BYTE_SPECIES, (MemorySegment)a, (long)i, (ByteOrder)ByteOrder.LITTLE_ENDIAN);
            ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES, (MemorySegment)b, (long)i, (ByteOrder)ByteOrder.LITTLE_ENDIAN);
            Vector va32 = va8.convertShape(VectorOperators.B2I, INT_SPECIES, 0);
            Vector vb32 = vb8.convertShape(VectorOperators.B2I, INT_SPECIES, 0);
            Vector diff32 = va32.sub(vb32);
            acc = acc.add(diff32.mul(diff32));
        }
        return acc.reduceLanes(VectorOperators.ADD);
    }

    private static int squareDistanceBody128(MemorySegment a, MemorySegment b, int limit) {
        IntVector acc1 = IntVector.zero((VectorSpecies)IntVector.SPECIES_128);
        IntVector acc2 = IntVector.zero((VectorSpecies)IntVector.SPECIES_128);
        for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length()) {
            ByteVector va8 = ByteVector.fromMemorySegment((VectorSpecies)ByteVector.SPECIES_64, (MemorySegment)a, (long)i, (ByteOrder)ByteOrder.LITTLE_ENDIAN);
            ByteVector vb8 = ByteVector.fromMemorySegment((VectorSpecies)ByteVector.SPECIES_64, (MemorySegment)b, (long)i, (ByteOrder)ByteOrder.LITTLE_ENDIAN);
            Vector va16 = va8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0);
            Vector vb16 = vb8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0);
            Vector diff16 = va16.sub(vb16);
            Vector diff32_1 = diff16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0);
            Vector diff32_2 = diff16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1);
            acc1 = acc1.add(diff32_1.mul(diff32_1));
            acc2 = acc2.add(diff32_2.mul(diff32_2));
        }
        return acc1.add((Vector)acc2).reduceLanes(VectorOperators.ADD);
    }

    @Override
    public int findNextGEQ(int[] buffer, int target, int from, int to) {
        if (ENABLE_FIND_NEXT_GEQ_VECTOR_OPTO) {
            while (from + INT_SPECIES.length() < to) {
                if (buffer[from + INT_SPECIES.length()] >= target) {
                    IntVector vector = IntVector.fromArray(INT_SPECIES, (int[])buffer, (int)from);
                    VectorMask mask = vector.compare(VectorOperators.LT, target);
                    return from + mask.trueCount();
                }
                from += INT_SPECIES.length() + 1;
            }
        }
        for (int i = from; i < to; ++i) {
            if (buffer[i] < target) continue;
            return i;
        }
        return to;
    }

    static {
        INT_SPECIES = PanamaVectorConstants.PRERERRED_INT_SPECIES;
        VECTOR_BITSIZE = PanamaVectorConstants.PREFERRED_VECTOR_BITSIZE;
        FLOAT_SPECIES = INT_SPECIES.withLanes(Float.TYPE);
        if (VECTOR_BITSIZE >= 256) {
            BYTE_SPECIES = ByteVector.SPECIES_MAX.withShape(VectorShape.forBitSize((int)(VECTOR_BITSIZE >> 2)));
            SHORT_SPECIES = ShortVector.SPECIES_MAX.withShape(VectorShape.forBitSize((int)(VECTOR_BITSIZE >> 1)));
        } else {
            BYTE_SPECIES = null;
            SHORT_SPECIES = null;
        }
        ENABLE_FIND_NEXT_GEQ_VECTOR_OPTO = INT_SPECIES.length() >= 8;
    }
}

