/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.spark;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysds.common.Opcodes;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.spark.AppendGSPInstruction;
import org.apache.sysds.runtime.instructions.spark.FrameAppendMSPInstruction;
import org.apache.sysds.runtime.instructions.spark.FrameAppendRSPInstruction;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.functions.MapInputSignature;
import org.apache.sysds.runtime.instructions.spark.functions.MapJoinSignature;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.lineage.LineageTraceable;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.UtilFunctions;
import scala.Tuple2;

public class BuiltinNarySPInstruction
extends SPInstruction
implements LineageTraceable {
    public CPOperand[] inputs;
    public CPOperand output;

    protected BuiltinNarySPInstruction(CPOperand[] in, CPOperand out, String opcode, String istr) {
        super(SPInstruction.SPType.BuiltinNary, opcode, istr);
        this.inputs = in;
        this.output = out;
    }

    public static BuiltinNarySPInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        CPOperand output = new CPOperand(parts[parts.length - 1]);
        CPOperand[] inputs = null;
        inputs = new CPOperand[parts.length - 2];
        for (int i = 1; i < parts.length - 1; ++i) {
            inputs[i - 1] = new CPOperand(parts[i]);
        }
        return new BuiltinNarySPInstruction(inputs, output, opcode, str);
    }

    /*
     * Enabled aggressive block sorting
     */
    @Override
    public void processInstruction(ExecutionContext ec) {
        DataCharacteristics dcout;
        JavaPairRDD out;
        SparkExecutionContext sec;
        block17: {
            CPOperand input;
            int n;
            int broadcasted2;
            CPOperand[] cPOperandArray;
            JavaPairRDD in;
            List<ScalarObject> scalars;
            block18: {
                int i;
                boolean[] broadcasted2;
                FrameObject fo;
                JavaPairRDD<Long, FrameBlock> outFrame;
                boolean cbind;
                block16: {
                    int n2;
                    CPOperand[] cPOperandArray2;
                    MatrixCharacteristics off;
                    block15: {
                        block13: {
                            block14: {
                                sec = (SparkExecutionContext)ec;
                                out = null;
                                dcout = null;
                                boolean inputIsMatrix = this.inputs[0].isMatrix();
                                if (!this.getOpcode().equals(Opcodes.CBIND.toString()) && !this.getOpcode().equals(Opcodes.RBIND.toString())) break block13;
                                cbind = this.getOpcode().equals(Opcodes.CBIND.toString());
                                dcout = BuiltinNarySPInstruction.computeAppendOutputDataCharacteristics(sec, this.inputs, cbind);
                                if (!inputIsMatrix) break block14;
                                off = new MatrixCharacteristics(0L, 0L, dcout.getBlocksize(), 0L);
                                cPOperandArray2 = this.inputs;
                                n2 = cPOperandArray2.length;
                                break block15;
                            }
                            outFrame = sec.getFrameBinaryBlockRDDHandleForVariable(this.inputs[0].getName());
                            dcout = new MatrixCharacteristics(sec.getDataCharacteristics(this.inputs[0].getName()));
                            fo = new FrameObject(sec.getFrameObject(this.inputs[0].getName()));
                            broadcasted2 = new boolean[this.inputs.length];
                            broadcasted2[0] = false;
                            break block16;
                        }
                        if (!ArrayUtils.contains((Object[])new String[]{Opcodes.NMIN.toString(), Opcodes.NMAX.toString(), Opcodes.NP.toString(), Opcodes.NM.toString()}, (Object)this.getOpcode())) break block17;
                        dcout = BuiltinNarySPInstruction.computeMinMaxOutputDataCharacteristics(sec, this.inputs);
                        scalars = sec.getScalarInputs(this.inputs);
                        in = null;
                        cPOperandArray = this.inputs;
                        broadcasted2 = cPOperandArray.length;
                        break block18;
                    }
                    for (n = 0; n < n2; ++n) {
                        input = cPOperandArray2[n];
                        DataCharacteristics mcIn = sec.getDataCharacteristics(input.getName());
                        JavaPairRDD in2 = sec.getBinaryMatrixBlockRDDHandleForVariable(input.getName()).flatMapToPair((PairFlatMapFunction)new AppendGSPInstruction.ShiftMatrix(off, mcIn, cbind)).mapToPair((PairFunction)new PadBlocksFunction(dcout));
                        out = out != null ? out.union(in2) : in2;
                        BuiltinNarySPInstruction.updateAppendDataCharacteristics(mcIn, off, cbind);
                    }
                    int numPartOut = SparkUtils.getNumPreferredPartitions(dcout);
                    out = RDDAggregateUtils.mergeByKey(out, numPartOut, false);
                    break block17;
                }
                for (i = 1; i < this.inputs.length; ++i) {
                    DataCharacteristics dcIn = sec.getDataCharacteristics(this.inputs[i].getName());
                    int blk_size = dcout.getBlocksize() <= 0 ? 1000 : dcout.getBlocksize();
                    boolean bl = broadcasted2[i] = BinaryOp.FORCED_APPEND_METHOD == BinaryOp.AppendMethod.MR_MAPPEND || BinaryOp.FORCED_APPEND_METHOD == null && cbind && dcIn.getCols() <= (long)blk_size && OptimizerUtils.checkSparkBroadcastMemoryBudget(dcout.getCols(), dcIn.getCols(), blk_size, dcIn.getNonZeros());
                    if (broadcasted2[i]) {
                        outFrame = FrameAppendMSPInstruction.appendFrameMSP(outFrame, sec.getBroadcastForFrameVariable(this.inputs[i].getName()));
                    } else {
                        if (BinaryOp.FORCED_APPEND_METHOD != null && BinaryOp.FORCED_APPEND_METHOD != BinaryOp.AppendMethod.MR_RAPPEND) {
                            throw new DMLRuntimeException("Forced append type [" + BinaryOp.FORCED_APPEND_METHOD + "] is not supported for frames");
                        }
                        JavaPairRDD<Long, FrameBlock> in2 = sec.getFrameBinaryBlockRDDHandleForVariable(this.inputs[i].getName());
                        outFrame = FrameAppendRSPInstruction.appendFrameRSP(outFrame, in2, dcout.getRows(), cbind);
                    }
                    BuiltinNarySPInstruction.updateAppendDataCharacteristics(dcIn, dcout, cbind);
                    if (!cbind) continue;
                    fo.setSchema(fo.mergeSchemas(sec.getFrameObject(this.inputs[i].getName())));
                }
                sec.getDataCharacteristics(this.output.getName()).set(dcout);
                sec.setRDDHandleForVariable(this.output.getName(), outFrame);
                sec.getFrameObject(this.output.getName()).setSchema(fo.getSchema());
                i = 0;
                while (true) {
                    if (i >= this.inputs.length) {
                        return;
                    }
                    if (broadcasted2[i]) {
                        sec.addLineageBroadcast(this.output.getName(), this.inputs[i].getName());
                    } else {
                        sec.addLineageRDD(this.output.getName(), this.inputs[i].getName());
                    }
                    ++i;
                }
            }
            for (n = 0; n < broadcasted2; ++n) {
                input = cPOperandArray[n];
                if (!input.getDataType().isMatrix()) continue;
                JavaPairRDD<MatrixIndexes, MatrixBlock> tmp = sec.getBinaryMatrixBlockRDDHandleForVariable(input.getName());
                in = in == null ? tmp.mapValues((Function)new MapInputSignature()) : in.join(tmp).mapValues((Function)new MapJoinSignature());
            }
            out = in.mapValues((Function)new MinMaxAddMultFunction(this.getOpcode(), scalars));
        }
        sec.getDataCharacteristics(this.output.getName()).set(dcout);
        sec.setRDDHandleForVariable(this.output.getName(), out);
        CPOperand[] cPOperandArray = this.inputs;
        int n = cPOperandArray.length;
        int n3 = 0;
        while (n3 < n) {
            CPOperand input = cPOperandArray[n3];
            if (!input.isScalar()) {
                sec.addLineageRDD(this.output.getName(), input.getName());
            }
            ++n3;
        }
        return;
    }

    private static DataCharacteristics computeAppendOutputDataCharacteristics(SparkExecutionContext sec, CPOperand[] inputs, boolean cbind) {
        DataCharacteristics mcIn1 = sec.getDataCharacteristics(inputs[0].getName());
        MatrixCharacteristics mcOut = new MatrixCharacteristics(0L, 0L, mcIn1.getBlocksize(), 0L);
        for (CPOperand input : inputs) {
            DataCharacteristics mcIn = sec.getDataCharacteristics(input.getName());
            BuiltinNarySPInstruction.updateAppendDataCharacteristics(mcIn, mcOut, cbind);
        }
        return mcOut;
    }

    private static void updateAppendDataCharacteristics(DataCharacteristics in, DataCharacteristics out, boolean cbind) {
        out.setDimension(cbind ? Math.max(out.getRows(), in.getRows()) : out.getRows() + in.getRows(), cbind ? out.getCols() + in.getCols() : Math.max(out.getCols(), in.getCols()));
        out.setNonZeros(out.getNonZeros() != -1L && in.dimsKnown(true) ? out.getNonZeros() + in.getNonZeros() : -1L);
    }

    private static DataCharacteristics computeMinMaxOutputDataCharacteristics(SparkExecutionContext sec, CPOperand[] inputs) {
        MatrixCharacteristics mcOut = new MatrixCharacteristics();
        for (CPOperand input : inputs) {
            if (!input.getDataType().isMatrix()) continue;
            DataCharacteristics mcIn = sec.getDataCharacteristics(input.getName());
            ((DataCharacteristics)mcOut).setRows(Math.max(((DataCharacteristics)mcOut).getRows(), mcIn.getRows()));
            ((DataCharacteristics)mcOut).setCols(Math.max(((DataCharacteristics)mcOut).getCols(), mcIn.getCols()));
            mcOut.setBlocksize(mcIn.getBlocksize());
        }
        return mcOut;
    }

    @Override
    public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
        return Pair.of((Object)this.output.getName(), (Object)new LineageItem(this.getOpcode(), LineageItemUtils.getLineage(ec, this.inputs)));
    }

    private static class MinMaxAddMultFunction
    implements Function<MatrixBlock[], MatrixBlock> {
        private static final long serialVersionUID = -4227447915387484397L;
        private final SimpleOperator _op;
        private final ScalarObject[] _scalars;

        public MinMaxAddMultFunction(String opcode, List<ScalarObject> scalars) {
            this._scalars = scalars.toArray(new ScalarObject[0]);
            this._op = new SimpleOperator(opcode.equals(Opcodes.NP.toString()) ? Plus.getPlusFnObject() : (opcode.equals(Opcodes.NM.toString()) ? Multiply.getMultiplyFnObject() : Builtin.getBuiltinFnObject(opcode.substring(1))));
        }

        public MatrixBlock call(MatrixBlock[] v1) throws Exception {
            return MatrixBlock.naryOperations(this._op, v1, this._scalars, new MatrixBlock());
        }
    }

    public static class PadBlocksFunction
    implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 1291358959908299855L;
        private final DataCharacteristics _mcOut;

        public PadBlocksFunction(DataCharacteristics mcOut) {
            this._mcOut = mcOut;
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception {
            MatrixIndexes ix = (MatrixIndexes)arg0._1();
            MatrixBlock mb = (MatrixBlock)arg0._2();
            int brlen = UtilFunctions.computeBlockSize(this._mcOut.getRows(), ix.getRowIndex(), this._mcOut.getBlocksize());
            int bclen = UtilFunctions.computeBlockSize(this._mcOut.getCols(), ix.getColumnIndex(), this._mcOut.getBlocksize());
            if (brlen == mb.getNumRows() && bclen == mb.getNumColumns()) {
                return arg0;
            }
            if (brlen > mb.getNumRows()) {
                mb = mb.append(new MatrixBlock(brlen - mb.getNumRows(), bclen, true), new MatrixBlock(), false);
            } else if (bclen > mb.getNumColumns()) {
                mb = mb.append(new MatrixBlock(brlen, bclen - mb.getNumColumns(), true), new MatrixBlock(), true);
            }
            return new Tuple2((Object)ix, (Object)mb);
        }
    }

    private static class AlignBlkTask
    implements PairFlatMapFunction<Tuple2<Long, FrameBlock>, Long, FrameBlock> {
        private static final long serialVersionUID = 1333460067852261573L;
        long max_rows;

        public AlignBlkTask(long rows) {
            this.max_rows = rows;
        }

        public Iterator<Tuple2<Long, FrameBlock>> call(Tuple2<Long, FrameBlock> longFrameBlockTuple2) throws Exception {
            Long index = (Long)longFrameBlockTuple2._1;
            FrameBlock fb = (FrameBlock)longFrameBlockTuple2._2;
            ArrayList<Tuple2> list = new ArrayList<Tuple2>();
            if (this.max_rows > 1000L) {
                throw new NotImplementedException("Other Alignment strategies need to be implemented");
            }
            FrameBlock fbout = new FrameBlock(fb.getSchema());
            fbout.ensureAllocatedColumns((int)this.max_rows);
            fbout = fbout.leftIndexingOperations(fb, index.intValue() - 1, index.intValue() + fb.getNumRows() - 2, 0, fb.getNumColumns() - 1, null);
            list.add(new Tuple2((Object)1L, (Object)fbout));
            return list.iterator();
        }
    }
}

