/*
 * Decompiled with CFR 0.152.
 */
package me.jellysquid.mods.lithium.mixin.chunk.block_counting;

import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import me.jellysquid.mods.lithium.common.block.BlockCountingSection;
import me.jellysquid.mods.lithium.common.block.BlockStateFlagHolder;
import me.jellysquid.mods.lithium.common.block.BlockStateFlags;
import me.jellysquid.mods.lithium.common.block.TrackedBlockStatePredicate;
import net.minecraft.network.FriendlyByteBuf;
import net.minecraft.world.level.block.state.BlockState;
import net.minecraft.world.level.chunk.LevelChunkSection;
import net.minecraft.world.level.chunk.PalettedContainer;
import org.spongepowered.asm.mixin.Final;
import org.spongepowered.asm.mixin.Mixin;
import org.spongepowered.asm.mixin.Shadow;
import org.spongepowered.asm.mixin.Unique;
import org.spongepowered.asm.mixin.injection.At;
import org.spongepowered.asm.mixin.injection.Inject;
import org.spongepowered.asm.mixin.injection.Redirect;
import org.spongepowered.asm.mixin.injection.callback.CallbackInfo;
import org.spongepowered.asm.mixin.injection.callback.CallbackInfoReturnable;
import org.spongepowered.asm.mixin.injection.callback.LocalCapture;

@Mixin(value={LevelChunkSection.class})
public abstract class ChunkSectionMixin
implements BlockCountingSection {
    @Shadow
    @Final
    private PalettedContainer<BlockState> f_62972_;
    @Unique
    private short[] countsByFlag = null;
    private CompletableFuture<short[]> countsByFlagFuture;

    @Override
    public boolean anyMatch(TrackedBlockStatePredicate trackedBlockStatePredicate, boolean fallback) {
        if (this.countsByFlag == null && !this.tryInitializeCountsByFlag()) {
            return fallback;
        }
        return this.countsByFlag[trackedBlockStatePredicate.getIndex()] != 0;
    }

    private boolean tryInitializeCountsByFlag() {
        CompletableFuture<short[]> countsByFlagFuture = this.countsByFlagFuture;
        if (countsByFlagFuture != null && countsByFlagFuture.isDone()) {
            try {
                this.countsByFlag = (short[])countsByFlagFuture.get();
                return true;
            }
            catch (InterruptedException | CancellationException | ExecutionException e) {
                this.countsByFlagFuture = null;
            }
        }
        if (this.countsByFlagFuture == null) {
            PalettedContainer<BlockState> blockStateContainer = this.f_62972_;
            this.countsByFlagFuture = CompletableFuture.supplyAsync(() -> ChunkSectionMixin.calculateLithiumCounts(blockStateContainer));
        }
        return false;
    }

    private static short[] calculateLithiumCounts(PalettedContainer<BlockState> blockStateContainer) {
        short[] countsByFlag = new short[BlockStateFlags.NUM_FLAGS];
        blockStateContainer.m_63099_((state, count) -> ChunkSectionMixin.addToFlagCount(countsByFlag, state, count));
        return countsByFlag;
    }

    @Redirect(method={"calculateCounts()V"}, at=@At(value="INVOKE", target="Lnet/minecraft/world/chunk/PalettedContainer;count(Lnet/minecraft/world/chunk/PalettedContainer$Counter;)V"))
    private void initFlagCounters(PalettedContainer<BlockState> palettedContainer, PalettedContainer.CountConsumer<BlockState> consumer) {
        palettedContainer.m_63099_((state, count) -> {
            consumer.m_63144_(state, count);
            ChunkSectionMixin.addToFlagCount(this.countsByFlag, state, count);
        });
    }

    private static void addToFlagCount(short[] countsByFlag, BlockState state, int change) {
        int i;
        int flags = ((BlockStateFlagHolder)state).getAllFlags();
        while ((i = Integer.numberOfTrailingZeros(flags)) < 32) {
            int n = i;
            countsByFlag[n] = (short)(countsByFlag[n] + change);
            flags &= ~(1 << i);
        }
    }

    @Inject(method={"calculateCounts()V"}, at={@At(value="HEAD")})
    private void createFlagCounters(CallbackInfo ci) {
        this.countsByFlag = new short[BlockStateFlags.NUM_FLAGS];
    }

    @Inject(method={"setBlockState(IIILnet/minecraft/block/BlockState;Z)Lnet/minecraft/block/BlockState;"}, at={@At(value="HEAD")})
    private void joinFuture(int x, int y, int z, BlockState state, boolean lock, CallbackInfoReturnable<BlockState> cir) {
        if (this.countsByFlagFuture != null) {
            this.countsByFlag = this.countsByFlagFuture.join();
            this.countsByFlagFuture = null;
        }
    }

    @Inject(method={"fromPacket"}, at={@At(value="HEAD")})
    private void resetData(FriendlyByteBuf buf, CallbackInfo ci) {
        this.countsByFlag = null;
        this.countsByFlagFuture = null;
    }

    @Inject(method={"setBlockState(IIILnet/minecraft/block/BlockState;Z)Lnet/minecraft/block/BlockState;"}, at={@At(value="INVOKE", target="Lnet/minecraft/block/BlockState;getFluidState()Lnet/minecraft/fluid/FluidState;", ordinal=0, shift=At.Shift.BEFORE)}, locals=LocalCapture.CAPTURE_FAILHARD)
    private void updateFlagCounters(int x, int y, int z, BlockState newState, boolean lock, CallbackInfoReturnable<BlockState> cir, BlockState oldState) {
        int i;
        short[] countsByFlag = this.countsByFlag;
        if (countsByFlag == null) {
            return;
        }
        int prevFlags = ((BlockStateFlagHolder)oldState).getAllFlags();
        int flags = ((BlockStateFlagHolder)newState).getAllFlags();
        int flagsXOR = prevFlags ^ flags;
        while ((i = Integer.numberOfTrailingZeros(flagsXOR)) < 32) {
            int n = i;
            countsByFlag[n] = (short)(countsByFlag[n] + (1 - ((prevFlags >>> i & 1) << 1)));
            flagsXOR &= ~(1 << i);
        }
    }
}

