/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.lops.compile.linearization;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.sysds.common.Opcodes;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.lops.CSVReBlock;
import org.apache.sysds.lops.CentralMoment;
import org.apache.sysds.lops.Checkpoint;
import org.apache.sysds.lops.CoVariance;
import org.apache.sysds.lops.GroupedAggregate;
import org.apache.sysds.lops.GroupedAggregateM;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.MMTSJ;
import org.apache.sysds.lops.MMZip;
import org.apache.sysds.lops.MapMultChain;
import org.apache.sysds.lops.OperatorOrderingUtils;
import org.apache.sysds.lops.ParameterizedBuiltin;
import org.apache.sysds.lops.PickByCount;
import org.apache.sysds.lops.ReBlock;
import org.apache.sysds.lops.SpoofFused;
import org.apache.sysds.lops.UAggOuterChain;
import org.apache.sysds.lops.UnaryCP;
import org.apache.sysds.lops.compile.linearization.IDagLinearizer;
import org.apache.sysds.lops.compile.linearization.LinearizerDepthFirst;

public class LinearizerMaxParallelism
extends IDagLinearizer {
    @Override
    public List<Lop> linearize(List<Lop> v) {
        ArrayList<Lop> operatorList;
        List<Lop> roots;
        List<Lop> v2 = v;
        boolean hasSpark = v.stream().anyMatch(LinearizerMaxParallelism::isDistributedOp);
        boolean hasGPU = v.stream().anyMatch(LinearizerMaxParallelism::isGPUOp);
        if (!hasSpark && !hasGPU) {
            return new LinearizerDepthFirst().linearize(v);
        }
        if (hasSpark) {
            HashMap sparkOpCount = new HashMap();
            roots = v.stream().filter(OperatorOrderingUtils::isLopRoot).collect(Collectors.toList());
            HashSet sparkRoots = new HashSet();
            roots.forEach(r -> OperatorOrderingUtils.collectSparkRoots(r, sparkOpCount, sparkRoots));
            sparkRoots.forEach(sr -> sr.setAsynchronous(true));
            operatorList = new ArrayList<Lop>();
            sparkRoots.forEach(r -> LinearizerMaxParallelism.depthFirst(r, operatorList, sparkOpCount, false));
            roots.forEach(r -> LinearizerMaxParallelism.depthFirst(r, operatorList, sparkOpCount, false));
            roots.forEach(Lop::resetVisitStatus);
            v2 = operatorList;
        }
        if (hasGPU) {
            HashMap gpuOpCount = new HashMap();
            roots = v2.stream().filter(OperatorOrderingUtils::isLopRoot).collect(Collectors.toList());
            HashSet gpuRoots = new HashSet();
            roots.forEach(r -> OperatorOrderingUtils.collectGPURoots(r, gpuOpCount, gpuRoots));
            gpuRoots.forEach(sr -> sr.setAsynchronous(true));
            operatorList = new ArrayList();
            gpuRoots.forEach(r -> LinearizerMaxParallelism.depthFirst(r, operatorList, gpuOpCount, false));
            roots.forEach(r -> LinearizerMaxParallelism.depthFirst(r, operatorList, gpuOpCount, false));
            roots.forEach(Lop::resetVisitStatus);
            v2 = operatorList;
        }
        return v2;
    }

    private static void depthFirst(Lop root, ArrayList<Lop> opList, Map<Long, Integer> sparkOpCount, boolean sparkFirst) {
        if (root.isVisited()) {
            return;
        }
        if (root.getInputs().isEmpty()) {
            opList.add(root);
            root.setVisited();
            return;
        }
        Lop[] sortedInputs = root.getInputs().toArray(new Lop[0]);
        if (sparkFirst) {
            Arrays.sort(sortedInputs, (l1, l2) -> (Integer)sparkOpCount.get(l2.getID()) - (Integer)sparkOpCount.get(l1.getID()));
        } else {
            Arrays.sort(sortedInputs, Comparator.comparingInt(l -> (Integer)sparkOpCount.get(l.getID())));
        }
        for (Lop input : sortedInputs) {
            LinearizerMaxParallelism.depthFirst(input, opList, sparkOpCount, sparkFirst);
        }
        opList.add(root);
        root.setVisited();
    }

    private static boolean isDistributedOp(Lop lop) {
        return lop.isExecSpark() || lop instanceof UnaryCP && (((UnaryCP)lop).getOpCode().equalsIgnoreCase(Opcodes.PREFETCH.toString()) || ((UnaryCP)lop).getOpCode().equalsIgnoreCase(Opcodes.BROADCAST.toString()));
    }

    private static boolean isGPUOp(Lop lop) {
        return lop.isExecGPU() || lop instanceof UnaryCP && (((UnaryCP)lop).getOpCode().equalsIgnoreCase(Opcodes.PREFETCH.toString()) || ((UnaryCP)lop).getOpCode().equalsIgnoreCase(Opcodes.BROADCAST.toString()));
    }

    private static List<Lop> addAsyncEagerCheckpointLop(List<Lop> nodes) {
        ArrayList<Lop> nodesWithCheckpoint = new ArrayList<Lop>();
        for (Lop l : nodes) {
            if (LinearizerMaxParallelism.isCheckpointNeeded(l)) {
                ArrayList<Lop> oldInputs = new ArrayList<Lop>(l.getInputs());
                for (Lop in : oldInputs) {
                    if (in.getExecType() != Types.ExecType.SPARK) continue;
                    Checkpoint checkpoint = new Checkpoint(in, in.getDataType(), in.getValueType(), Checkpoint.getDefaultStorageLevelString(), true);
                    checkpoint.addOutput(l);
                    l.replaceInput(in, checkpoint);
                    in.removeOutput(l);
                    nodesWithCheckpoint.add(checkpoint);
                }
            }
            nodesWithCheckpoint.add(l);
        }
        return nodesWithCheckpoint;
    }

    private static boolean isCheckpointNeeded(Lop lop) {
        boolean actionOP = lop.getExecType() == Types.ExecType.SPARK && (lop.getAggType() == AggBinaryOp.SparkAggType.SINGLE_BLOCK || lop.getDataType() == Types.DataType.SCALAR || lop instanceof MapMultChain || lop instanceof PickByCount || lop instanceof MMZip || lop instanceof CentralMoment || lop instanceof CoVariance || lop instanceof MMTSJ) && !(lop instanceof Checkpoint) && !(lop instanceof ReBlock) && !(lop instanceof CSVReBlock) && !(lop instanceof UAggOuterChain) && !(lop instanceof ParameterizedBuiltin) && !(lop instanceof SpoofFused);
        boolean hasParameterizedOut = lop.getOutputs().stream().anyMatch(out -> out instanceof ParameterizedBuiltin || out instanceof GroupedAggregate || out instanceof GroupedAggregateM);
        return actionOP && !hasParameterizedOut;
    }
}

