/*
 * Decompiled with CFR 0.152.
 */
package org.sinytra.adapter.patch.transformer.dynamic;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Multimap;
import com.mojang.logging.LogUtils;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Set;
import java.util.function.Consumer;
import org.jetbrains.annotations.Nullable;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.AbstractInsnNode;
import org.objectweb.asm.tree.ClassNode;
import org.objectweb.asm.tree.FrameNode;
import org.objectweb.asm.tree.InsnList;
import org.objectweb.asm.tree.LabelNode;
import org.objectweb.asm.tree.LineNumberNode;
import org.objectweb.asm.tree.LocalVariableNode;
import org.objectweb.asm.tree.MethodInsnNode;
import org.objectweb.asm.tree.MethodNode;
import org.sinytra.adapter.patch.PatchInstance;
import org.sinytra.adapter.patch.analysis.InstructionMatcher;
import org.sinytra.adapter.patch.analysis.LocalVarAnalyzer;
import org.sinytra.adapter.patch.analysis.LocalVariableLookup;
import org.sinytra.adapter.patch.analysis.MethodCallAnalyzer;
import org.sinytra.adapter.patch.api.MethodContext;
import org.sinytra.adapter.patch.api.MethodTransform;
import org.sinytra.adapter.patch.api.Patch;
import org.sinytra.adapter.patch.api.PatchContext;
import org.sinytra.adapter.patch.selector.AnnotationHandle;
import org.sinytra.adapter.patch.selector.AnnotationValueHandle;
import org.sinytra.adapter.patch.util.AdapterUtil;
import org.sinytra.adapter.patch.util.GeneratedVariables;
import org.sinytra.adapter.patch.util.SingleValueHandle;
import org.slf4j.Logger;

public class DynamicInjectorOrdinalPatch
implements MethodTransform {
    private static final Logger LOGGER = LogUtils.getLogger();
    private static final Map<String, OffsetUpdateHandler> OFFSET_HANDLERS = Map.of("INVOKE", InvokeOffsetHandler.INSTANCE, "RETURN", ReturnOffsetHandler.INSTANCE);

    @Override
    public Collection<String> getAcceptedAnnotations() {
        return Set.of("Lorg/spongepowered/asm/mixin/injection/Inject;", "Lorg/spongepowered/asm/mixin/injection/ModifyVariable;");
    }

    @Override
    public Patch.Result apply(ClassNode classNode, MethodNode methodNode, MethodContext methodContext, PatchContext context) {
        Type returnType = Type.getReturnType((String)methodNode.desc);
        List<HandlerInstance<?, ?>> offsetHandlers = DynamicInjectorOrdinalPatch.getOffsetHandlers(methodContext, returnType);
        if (offsetHandlers.isEmpty()) {
            return Patch.Result.PASS;
        }
        MethodContext.TargetPair cleanTarget = methodContext.findCleanInjectionTarget();
        if (cleanTarget == null) {
            return Patch.Result.PASS;
        }
        MethodContext.TargetPair dirtyTarget = methodContext.findDirtyInjectionTarget();
        if (dirtyTarget == null) {
            return Patch.Result.PASS;
        }
        boolean applied = false;
        for (HandlerInstance<?, ?> instance : offsetHandlers) {
            applied |= instance.apply(methodContext, classNode, methodNode, cleanTarget, dirtyTarget);
        }
        return applied ? Patch.Result.APPLY : Patch.Result.PASS;
    }

    private static List<HandlerInstance<?, ?>> getOffsetHandlers(MethodContext methodContext, Type returnType) {
        LocalVariableLookup cleanTable;
        ArrayList handlers = new ArrayList();
        AnnotationHandle annotation = methodContext.injectionPointAnnotationOrThrow();
        annotation.getValue("ordinal").ifPresent(atOrdinal -> {
            String target = annotation.getValue("target").map(AnnotationValueHandle::get).orElse(null);
            annotation.getValue("value").map(AnnotationValueHandle::get).map(OFFSET_HANDLERS::get).filter(handler -> !handler.requiresTarget() || target != null).ifPresent(h -> handlers.add(new HandlerInstance<OffsetUpdateHandler.Context, Integer>((UpdateHandler<OffsetUpdateHandler.Context, Integer>)h, new OffsetUpdateHandler.Context(target, (Integer)atOrdinal.get()), atOrdinal::set)));
        });
        if (methodContext.methodAnnotation().matchesDesc("Lorg/spongepowered/asm/mixin/injection/ModifyVariable;") && (cleanTable = methodContext.cleanLocalsTable()) != null) {
            methodContext.methodAnnotation().getValue("ordinal").flatMap(ordinal -> cleanTable.getByTypedOrdinal(returnType, (Integer)ordinal.get()).flatMap(lvn -> cleanTable.getTypedOrdinal((LocalVariableNode)lvn).map(o -> new LocalVar((LocalVariableNode)lvn, (int)o, true))).map(local -> new HandlerInstance<LocalVar, LocalVar>(ModifyVariableOffsetHandler.INSTANCE, (LocalVar)local, var -> ordinal.set(var.ordinal())))).or(() -> methodContext.methodAnnotation().getValue("index").flatMap(index -> Optional.ofNullable(cleanTable.getByIndexOrNull((Integer)index.get())).flatMap(lvn -> cleanTable.getTypedOrdinal((LocalVariableNode)lvn).map(o -> new LocalVar((LocalVariableNode)lvn, (int)o, false))).map(local -> new HandlerInstance<LocalVar, LocalVar>(ModifyVariableOffsetHandler.INSTANCE, (LocalVar)local, var -> index.set(var.lvn().index))))).ifPresent(handlers::add);
        }
        return handlers;
    }

    private record HandlerInstance<T, U>(UpdateHandler<T, U> handler, T context, Consumer<U> applicator) {
        public boolean apply(MethodContext methodContext, ClassNode classNode, MethodNode methodNode, MethodContext.TargetPair cleanTarget, MethodContext.TargetPair dirtyTarget) {
            Optional<U> updatedValue = this.handler.apply(methodContext, classNode, methodNode, cleanTarget, dirtyTarget, this.context);
            if (updatedValue.isPresent()) {
                U value = updatedValue.get();
                LOGGER.info(PatchInstance.MIXINPATCH, "Updating injection point ordinal of {}.{} from {} to {}", new Object[]{classNode.name, methodNode.name, this.context, value});
                this.applicator.accept(value);
                return true;
            }
            return false;
        }
    }

    private static class ModifyVariableOffsetHandler
    implements UpdateHandler<LocalVar, LocalVar> {
        private static final ModifyVariableOffsetHandler INSTANCE = new ModifyVariableOffsetHandler();

        private ModifyVariableOffsetHandler() {
        }

        @Override
        public Optional<LocalVar> apply(MethodContext methodContext, ClassNode classNode, MethodNode methodNode, MethodContext.TargetPair cleanTarget, MethodContext.TargetPair dirtyTarget, LocalVar local) {
            Type[] args = Type.getArgumentTypes((String)methodNode.desc);
            if (args.length < 1) {
                return Optional.empty();
            }
            Type targetType = args[0];
            if (targetType != Type.BOOLEAN_TYPE && targetType != Type.INT_TYPE && targetType != Type.FLOAT_TYPE) {
                return Optional.empty();
            }
            if (methodContext.methodAnnotation().getValue("slice").isPresent() && local.relative()) {
                return Optional.empty();
            }
            return ModifyVariableOffsetHandler.tryFindUpdatedIndex(targetType, cleanTarget, dirtyTarget, local).or(() -> ModifyVariableOffsetHandler.tryFindSyntheticVariableIndex(methodContext, methodNode, cleanTarget, dirtyTarget, local));
        }

        private static Optional<LocalVar> tryFindSyntheticVariableIndex(MethodContext methodContext, MethodNode methodNode, MethodContext.TargetPair cleanTarget, MethodContext.TargetPair dirtyTarget, LocalVar local) {
            List<LocalVariableNode> available;
            int ordinal = local.ordinal();
            Type variableType = Type.getReturnType((String)methodNode.desc);
            LocalVariableLookup cleanTable = new LocalVariableLookup(cleanTarget.methodNode());
            LocalVariableLookup dirtyTable = new LocalVariableLookup(dirtyTarget.methodNode());
            if (cleanTable.getForType(variableType).size() == dirtyTable.getForType(variableType).size() && (available = dirtyTable.getForType(variableType)).size() > ordinal) {
                int variableIndex = available.get((int)ordinal).index;
                List<AbstractInsnNode> cleanInsns = methodContext.findInjectionTargetInsns(cleanTarget);
                List<AbstractInsnNode> dirtyInsns = methodContext.findInjectionTargetInsns(dirtyTarget);
                if (cleanInsns.size() == 1 && dirtyInsns.size() == 1) {
                    for (AbstractInsnNode insn = cleanInsns.get(0); insn != null && !(insn instanceof LabelNode); insn = insn.getNext()) {
                        int dirtyIndex;
                        SingleValueHandle<Integer> handle = AdapterUtil.handleLocalVarInsnValue(insn);
                        if (handle == null || handle.get() != variableIndex) continue;
                        List<SingleValueHandle<Integer>> dirtyVars = ModifyVariableOffsetHandler.getUsedVariablesInLabel(dirtyInsns.get(0), insn.getOpcode());
                        if (dirtyVars.size() != 1 || (dirtyIndex = dirtyVars.get(0).get().intValue()) == variableIndex) break;
                        methodContext.methodAnnotation().getValue("argsOnly").ifPresent(h -> h.set(false));
                        LocalVariableNode lvn = dirtyTable.getByIndex(dirtyIndex);
                        return dirtyTable.getTypedOrdinal(lvn).map(o -> new LocalVar(lvn, (int)o));
                    }
                }
            }
            return Optional.empty();
        }

        private static List<SingleValueHandle<Integer>> getUsedVariablesInLabel(AbstractInsnNode start, int opcode) {
            ArrayList<SingleValueHandle<Integer>> list = new ArrayList<SingleValueHandle<Integer>>();
            for (AbstractInsnNode insn = start; insn != null && !(insn instanceof LabelNode); insn = insn.getNext()) {
                SingleValueHandle<Integer> handle;
                if (insn.getOpcode() != opcode || (handle = AdapterUtil.handleLocalVarInsnValue(insn)) == null) continue;
                list.add(handle);
            }
            return list;
        }

        private static Optional<LocalVar> tryFindUpdatedIndex(Type targetType, MethodContext.TargetPair cleanTarget, MethodContext.TargetPair dirtyTarget, LocalVar local) {
            int ordinal = local.ordinal();
            List<LocalVariableNode> cleanLocals = cleanTarget.methodNode().localVariables.stream().filter(l -> Type.getType((String)l.desc) == targetType).sorted(Comparator.comparingInt(l -> l.index)).toList();
            if (cleanLocals.size() <= ordinal) {
                return Optional.empty();
            }
            LocalVariableNode cleanLocal = cleanLocals.get(ordinal);
            if (!GeneratedVariables.isGeneratedVariableName(cleanLocal.name, Type.getType((String)cleanLocal.desc))) {
                return Optional.empty();
            }
            LocalVariableLookup dirtyVarLookup = new LocalVariableLookup(dirtyTarget.methodNode());
            List<LocalVariableNode> dirtyLocals = dirtyVarLookup.getForType(targetType);
            if (cleanLocals.size() != dirtyLocals.size() || dirtyLocals.size() <= ordinal) {
                return ModifyVariableOffsetHandler.findReplacementLocal(cleanTarget.methodNode(), dirtyTarget.methodNode(), cleanLocal).flatMap(var -> dirtyVarLookup.getTypedOrdinal((LocalVariableNode)var).map(o -> new LocalVar((LocalVariableNode)var, (int)o)));
            }
            LocalVariableNode dirtyLocal = dirtyLocals.get(ordinal);
            if (!local.relative() && dirtyLocal.index == local.lvn().index) {
                return Optional.empty();
            }
            OptionalInt dirtyNameOrdinal = GeneratedVariables.getGeneratedVariableOrdinal(dirtyLocal.name, Type.getType((String)dirtyLocal.desc));
            if (dirtyNameOrdinal.isEmpty() || local.relative() && ordinal == dirtyNameOrdinal.getAsInt()) {
                return Optional.empty();
            }
            if (cleanLocal.index != dirtyLocal.index && !local.relative()) {
                return Optional.of(new LocalVar(dirtyLocal, dirtyLocals.indexOf(dirtyLocal)));
            }
            return Optional.empty();
        }

        private static Optional<LocalVariableNode> findReplacementLocal(MethodNode cleanMethod, MethodNode dirtyMethod, LocalVariableNode desired) {
            InsnList desiredInitializerInsns = LocalVarAnalyzer.findInitializerInsns(cleanMethod, desired.index);
            List<LocalVariableNode> matches = dirtyMethod.localVariables.stream().filter(lvn -> desired.desc.equals(lvn.desc)).filter(lvn -> {
                InsnList insns = LocalVarAnalyzer.findInitializerInsns(dirtyMethod, lvn.index);
                return InstructionMatcher.test(desiredInitializerInsns, insns);
            }).toList();
            return matches.size() == 1 ? Optional.of(matches.get(0)) : Optional.empty();
        }
    }

    private static interface UpdateHandler<T, U> {
        public Optional<U> apply(MethodContext var1, ClassNode var2, MethodNode var3, MethodContext.TargetPair var4, MethodContext.TargetPair var5, T var6);
    }

    private record LocalVar(LocalVariableNode lvn, int ordinal, boolean relative) {
        public LocalVar(LocalVariableNode lvn, int ordinal) {
            this(lvn, ordinal, false);
        }
    }

    private static interface OffsetUpdateHandler
    extends UpdateHandler<Context, Integer> {
        default public boolean requiresTarget() {
            return false;
        }

        public record Context(@Nullable String target, int ordinal) {
        }
    }

    private static class InvokeOffsetHandler
    implements OffsetUpdateHandler {
        public static final InvokeOffsetHandler INSTANCE = new InvokeOffsetHandler();

        private InvokeOffsetHandler() {
        }

        @Override
        public boolean requiresTarget() {
            return true;
        }

        @Override
        public Optional<Integer> apply(MethodContext methodContext, ClassNode classNode, MethodNode methodNode, MethodContext.TargetPair cleanTarget, MethodContext.TargetPair dirtyTarget, OffsetUpdateHandler.Context context) {
            String target = context.target();
            int ordinal = context.ordinal();
            Multimap<String, MethodInsnNode> cleanCallsMap = MethodCallAnalyzer.getMethodCalls(cleanTarget.methodNode(), new ArrayList<String>());
            Multimap<String, MethodInsnNode> dirtyCallsMap = MethodCallAnalyzer.getMethodCalls(dirtyTarget.methodNode(), new ArrayList<String>());
            PatchContext patchContext = methodContext.patchContext();
            String cleanValue = patchContext.remap(target);
            Collection cleanCalls = cleanCallsMap.get((Object)cleanValue);
            String dirtyValue = patchContext.remap(target);
            Collection dirtyCalls = dirtyCallsMap.get((Object)dirtyValue);
            if (cleanCalls.size() != dirtyCalls.size()) {
                int insnRange = 5;
                List<InstructionMatcher> cleanMatchers = cleanCalls.stream().map(i -> MethodCallAnalyzer.findSurroundingInstructions(i, insnRange)).toList();
                List<InstructionMatcher> dirtyMatchers = dirtyCalls.stream().map(i -> MethodCallAnalyzer.findSurroundingInstructions(i, insnRange)).toList();
                if (ordinal >= 0 && ordinal < cleanMatchers.size()) {
                    InstructionMatcher original = cleanMatchers.get(ordinal);
                    List<InstructionMatcher> matches = dirtyMatchers.stream().filter(original::test).toList();
                    if (matches.size() == 1) {
                        return Optional.of(dirtyMatchers.indexOf(matches.get(0)));
                    }
                }
            }
            return Optional.empty();
        }
    }

    private static class ReturnOffsetHandler
    implements OffsetUpdateHandler {
        public static final OffsetUpdateHandler INSTANCE = new ReturnOffsetHandler();
        private static final Set<Integer> RETURN_OPCODES = Set.of(Integer.valueOf(177), Integer.valueOf(176), Integer.valueOf(172), Integer.valueOf(174), Integer.valueOf(175), Integer.valueOf(173));

        private ReturnOffsetHandler() {
        }

        @Override
        public Optional<Integer> apply(MethodContext methodContext, ClassNode classNode, MethodNode methodNode, MethodContext.TargetPair cleanTarget, MethodContext.TargetPair dirtyTarget, OffsetUpdateHandler.Context context) {
            int ordinal = context.ordinal();
            List<AbstractInsnNode> cleanReturnInsns = ReturnOffsetHandler.findReturnInsns(cleanTarget.methodNode());
            List<AbstractInsnNode> dirtyReturnInsns = ReturnOffsetHandler.findReturnInsns(dirtyTarget.methodNode());
            if (ordinal < cleanReturnInsns.size() && cleanReturnInsns.size() != dirtyReturnInsns.size()) {
                AbstractInsnNode cleanInsn = cleanReturnInsns.get(ordinal);
                InstructionMatcher original = new InstructionMatcher(cleanInsn, ReturnOffsetHandler.findReturnPrecedingInsns(cleanInsn), List.of());
                List<InstructionMatcher> dirtyMatchers = dirtyReturnInsns.stream().map(i -> new InstructionMatcher((AbstractInsnNode)i, ReturnOffsetHandler.findReturnPrecedingInsns(i), List.of())).toList();
                List<InstructionMatcher> matches = dirtyMatchers.stream().filter(m -> original.test((InstructionMatcher)m, 1)).toList();
                if (matches.size() == 1) {
                    return Optional.of(dirtyMatchers.indexOf(matches.get(0)));
                }
            }
            return Optional.empty();
        }

        private static List<AbstractInsnNode> findReturnPrecedingInsns(AbstractInsnNode insn) {
            ArrayList<AbstractInsnNode> insns = new ArrayList<AbstractInsnNode>();
            int maxSize = 6;
            for (AbstractInsnNode prev = insn.getPrevious(); prev != null && insns.size() < maxSize && !RETURN_OPCODES.contains(prev.getOpcode()); prev = prev.getPrevious()) {
                if (prev instanceof FrameNode || prev instanceof LineNumberNode || prev instanceof LabelNode) continue;
                insns.add(0, prev);
            }
            return insns;
        }

        private static List<AbstractInsnNode> findReturnInsns(MethodNode methodNode) {
            ImmutableList.Builder insns = ImmutableList.builder();
            for (AbstractInsnNode insn : methodNode.instructions) {
                if (!RETURN_OPCODES.contains(insn.getOpcode())) continue;
                insns.add((Object)insn);
            }
            return insns.build();
        }
    }
}

