/*
 * Decompiled with CFR 0.152.
 */
package me.jellysquid.mods.lithium.mixin.chunk.block_counting;

import me.jellysquid.mods.lithium.common.block.BlockCountingSection;
import me.jellysquid.mods.lithium.common.block.BlockStateFlagHolder;
import me.jellysquid.mods.lithium.common.block.BlockStateFlags;
import me.jellysquid.mods.lithium.common.block.TrackedBlockStatePredicate;
import net.minecraft.network.FriendlyByteBuf;
import net.minecraft.world.level.block.state.BlockState;
import net.minecraft.world.level.chunk.LevelChunkSection;
import net.minecraft.world.level.chunk.PalettedContainer;
import org.spongepowered.asm.mixin.Mixin;
import org.spongepowered.asm.mixin.Shadow;
import org.spongepowered.asm.mixin.Unique;
import org.spongepowered.asm.mixin.injection.At;
import org.spongepowered.asm.mixin.injection.Inject;
import org.spongepowered.asm.mixin.injection.Redirect;
import org.spongepowered.asm.mixin.injection.callback.CallbackInfo;
import org.spongepowered.asm.mixin.injection.callback.CallbackInfoReturnable;
import org.spongepowered.asm.mixin.injection.callback.LocalCapture;

@Mixin(value={LevelChunkSection.class})
public abstract class ChunkSectionMixin
implements BlockCountingSection {
    @Unique
    private short[] countsByFlag = new short[BlockStateFlags.NUM_FLAGS];

    @Shadow
    public abstract void m_63018_();

    @Override
    public boolean anyMatch(TrackedBlockStatePredicate indexedBlockStatePredicate) {
        return this.countsByFlag[indexedBlockStatePredicate.getIndex()] != 0;
    }

    @Redirect(method={"calculateCounts()V"}, at=@At(value="INVOKE", target="Lnet/minecraft/world/chunk/PalettedContainer;count(Lnet/minecraft/world/chunk/PalettedContainer$Counter;)V"))
    private void initFlagCounters(PalettedContainer<BlockState> palettedContainer, PalettedContainer.CountConsumer<BlockState> consumer) {
        palettedContainer.m_63099_((state, count) -> {
            consumer.m_63144_(state, count);
            int flags = ((BlockStateFlagHolder)state).getAllFlags();
            int size = this.countsByFlag.length;
            for (int i = 0; i < size && flags != 0; flags >>>= 1, ++i) {
                if ((flags & 1) == 0) continue;
                int n = i;
                this.countsByFlag[n] = (short)(this.countsByFlag[n] + count);
            }
        });
    }

    @Inject(method={"calculateCounts()V"}, at={@At(value="HEAD")})
    private void resetFlagCounters(CallbackInfo ci) {
        this.countsByFlag = new short[BlockStateFlags.NUM_FLAGS];
    }

    @Inject(method={"setBlockState(IIILnet/minecraft/block/BlockState;Z)Lnet/minecraft/block/BlockState;"}, at={@At(value="INVOKE", target="Lnet/minecraft/block/BlockState;getFluidState()Lnet/minecraft/fluid/FluidState;", ordinal=0, shift=At.Shift.BEFORE)}, locals=LocalCapture.CAPTURE_FAILHARD)
    private void updateFlagCounters(int x, int y, int z, BlockState newState, boolean lock, CallbackInfoReturnable<BlockState> cir, BlockState oldState) {
        int i;
        int prevFlags = ((BlockStateFlagHolder)oldState).getAllFlags();
        int flags = ((BlockStateFlagHolder)newState).getAllFlags();
        int flagsXOR = prevFlags ^ flags;
        while ((i = Integer.numberOfTrailingZeros(flagsXOR)) < 32) {
            int n = i;
            this.countsByFlag[n] = (short)(this.countsByFlag[n] + (1 - ((prevFlags >>> i & 1) << 1)));
            flagsXOR &= ~(1 << i);
        }
    }

    @Inject(method={"fromPacket(Lnet/minecraft/network/PacketByteBuf;)V"}, at={@At(value="RETURN")})
    private void initCounts(FriendlyByteBuf packetByteBuf, CallbackInfo ci) {
        this.m_63018_();
    }
}

