/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.spark;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.sysds.common.Opcodes;
import org.apache.sysds.common.Types;
import org.apache.sysds.lops.BinaryM;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.data.TensorBlock;
import org.apache.sysds.runtime.data.TensorIndexes;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.spark.BinaryFrameFrameSPInstruction;
import org.apache.sysds.runtime.instructions.spark.BinaryFrameMatrixSPInstruction;
import org.apache.sysds.runtime.instructions.spark.BinaryMatrixBVectorSPInstruction;
import org.apache.sysds.runtime.instructions.spark.BinaryMatrixMatrixSPInstruction;
import org.apache.sysds.runtime.instructions.spark.BinaryMatrixScalarSPInstruction;
import org.apache.sysds.runtime.instructions.spark.BinaryTensorTensorBroadcastSPInstruction;
import org.apache.sysds.runtime.instructions.spark.BinaryTensorTensorSPInstruction;
import org.apache.sysds.runtime.instructions.spark.ComputationSPInstruction;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
import org.apache.sysds.runtime.instructions.spark.functions.MatrixMatrixBinaryOpFunction;
import org.apache.sysds.runtime.instructions.spark.functions.MatrixScalarUnaryFunction;
import org.apache.sysds.runtime.instructions.spark.functions.MatrixVectorBinaryOpPartitionFunction;
import org.apache.sysds.runtime.instructions.spark.functions.OuterVectorBinaryOpFunction;
import org.apache.sysds.runtime.instructions.spark.functions.ReblockTensorFunction;
import org.apache.sysds.runtime.instructions.spark.functions.ReplicateTensorFunction;
import org.apache.sysds.runtime.instructions.spark.functions.ReplicateVectorFunction;
import org.apache.sysds.runtime.instructions.spark.functions.TensorTensorBinaryOpFunction;
import org.apache.sysds.runtime.instructions.spark.functions.TensorTensorBinaryOpPartitionFunction;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataUtils;

public abstract class BinarySPInstruction
extends ComputationSPInstruction {
    protected BinarySPInstruction(SPInstruction.SPType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
        super(type, op, in1, in2, out, opcode, istr);
    }

    public static BinarySPInstruction parseInstruction(String str) {
        CPOperand in1 = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
        CPOperand in2 = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
        CPOperand out = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
        String opcode = null;
        boolean isBroadcast = false;
        BinaryM.VectorType vtype = null;
        if (str.startsWith("SPARK\u00b0map")) {
            String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
            InstructionUtils.checkNumFields(parts, 5);
            opcode = parts[0];
            in1.split(parts[1]);
            in2.split(parts[2]);
            out.split(parts[3]);
            vtype = BinaryM.VectorType.valueOf(parts[5]);
            isBroadcast = true;
        } else {
            opcode = BinarySPInstruction.parseBinaryInstruction(str, in1, in2, out);
        }
        Types.DataType dt1 = in1.getDataType();
        Types.DataType dt2 = in2.getDataType();
        Operator operator = InstructionUtils.parseExtendedBinaryOrBuiltinOperator(opcode, in1, in2);
        if (dt1 == Types.DataType.MATRIX || dt2 == Types.DataType.MATRIX) {
            if (dt1 == Types.DataType.MATRIX && dt2 == Types.DataType.MATRIX) {
                if (isBroadcast) {
                    return new BinaryMatrixBVectorSPInstruction(operator, in1, in2, out, vtype, opcode, str);
                }
                return new BinaryMatrixMatrixSPInstruction(operator, in1, in2, out, opcode, str);
            }
            if (dt1 == Types.DataType.FRAME && dt2 == Types.DataType.MATRIX) {
                return new BinaryFrameMatrixSPInstruction(operator, in1, in2, out, opcode, str);
            }
            return new BinaryMatrixScalarSPInstruction(operator, in1, in2, out, opcode, str);
        }
        if (dt1 == Types.DataType.TENSOR || dt2 == Types.DataType.TENSOR) {
            if (dt1 == Types.DataType.TENSOR && dt2 == Types.DataType.TENSOR) {
                if (isBroadcast) {
                    return new BinaryTensorTensorBroadcastSPInstruction(operator, in1, in2, out, opcode, str);
                }
                return new BinaryTensorTensorSPInstruction(operator, in1, in2, out, opcode, str);
            }
            throw new DMLRuntimeException("Tensor binary operation not yet implemented for tensor-scalar, or tensor-matrix");
        }
        if (dt1 == Types.DataType.FRAME || dt2 == Types.DataType.FRAME) {
            if (dt1 == Types.DataType.FRAME && dt2 == Types.DataType.FRAME) {
                return new BinaryFrameFrameSPInstruction(operator, in1, in2, out, opcode, str);
            }
            if (dt1 == Types.DataType.FRAME && dt2 == Types.DataType.SCALAR && opcode.equalsIgnoreCase(Opcodes.PLUS.toString())) {
                return new BinaryMatrixScalarSPInstruction(operator, in1, in2, out, opcode, str);
            }
        }
        return null;
    }

    protected static String parseBinaryInstruction(String instr, CPOperand in1, CPOperand in2, CPOperand out) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(instr);
        InstructionUtils.checkNumFields(parts, 3);
        String opcode = parts[0];
        in1.split(parts[1]);
        in2.split(parts[2]);
        out.split(parts[3]);
        return opcode;
    }

    protected static String parseBinaryInstruction(String instr, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(instr);
        InstructionUtils.checkNumFields(parts, 4);
        String opcode = parts[0];
        in1.split(parts[1]);
        in2.split(parts[2]);
        in3.split(parts[3]);
        out.split(parts[4]);
        return opcode;
    }

    protected void processMatrixMatrixBinaryInstruction(ExecutionContext ec) {
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        this.checkMatrixMatrixBinaryCharacteristics(sec);
        this.updateBinaryOutputDataCharacteristics(sec);
        JavaPairRDD in1 = sec.getBinaryMatrixBlockRDDHandleForVariable(this.input1.getName());
        JavaPairRDD in2 = sec.getBinaryMatrixBlockRDDHandleForVariable(this.input2.getName());
        DataCharacteristics mc1 = sec.getDataCharacteristics(this.input1.getName());
        DataCharacteristics mc2 = sec.getDataCharacteristics(this.input2.getName());
        DataCharacteristics mcOut = sec.getDataCharacteristics(this.output.getName());
        BinaryOperator bop = (BinaryOperator)this._optr;
        boolean rowvector = mc2.getRows() == 1L && mc1.getRows() > 1L;
        long numRepLeft = this.getNumReplicas(mc1, mc2, true);
        long numRepRight = this.getNumReplicas(mc1, mc2, false);
        if (numRepLeft > 1L) {
            in1 = in1.flatMapToPair((PairFlatMapFunction)new ReplicateVectorFunction(false, numRepLeft));
        }
        if (numRepRight > 1L) {
            in2 = in2.flatMapToPair((PairFlatMapFunction)new ReplicateVectorFunction(rowvector, numRepRight));
        }
        int numPrefPart = SparkUtils.isHashPartitioned(in1) ? in1.getNumPartitions() : (SparkUtils.isHashPartitioned(in2) ? in2.getNumPartitions() : Math.min(in1.getNumPartitions() + in2.getNumPartitions(), 2 * SparkUtils.getNumPreferredPartitions(mcOut)));
        JavaPairRDD out = in1.join(in2, numPrefPart).mapValues((Function)new MatrixMatrixBinaryOpFunction(bop));
        sec.setRDDHandleForVariable(this.output.getName(), out);
        sec.addLineageRDD(this.output.getName(), this.input1.getName());
        sec.addLineageRDD(this.output.getName(), this.input2.getName());
    }

    protected void processTensorTensorBinaryInstruction(ExecutionContext ec) {
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        this.checkTensorTensorBinaryCharacteristics(sec);
        this.updateBinaryTensorOutputDataCharacteristics(sec);
        JavaPairRDD<TensorIndexes, TensorBlock> in1 = sec.getBinaryTensorBlockRDDHandleForVariable(this.input1.getName());
        JavaPairRDD in2 = sec.getBinaryTensorBlockRDDHandleForVariable(this.input2.getName());
        DataCharacteristics tc1 = sec.getDataCharacteristics(this.input1.getName());
        DataCharacteristics tc2 = sec.getDataCharacteristics(this.input2.getName());
        DataCharacteristics dcOut = sec.getDataCharacteristics(this.output.getName());
        BinaryOperator bop = (BinaryOperator)this._optr;
        if (tc2.getNumDims() < tc1.getNumDims()) {
            in2 = in2.flatMapToPair((PairFlatMapFunction)new ReblockTensorFunction(tc1.getNumDims(), tc1.getBlocksize()));
        }
        for (int i = 0; i < tc1.getNumDims(); ++i) {
            long numReps = this.getNumDimReplicas(tc1, tc2, i);
            if (numReps <= 1L) continue;
            in2 = in2.flatMapToPair((PairFlatMapFunction)new ReplicateTensorFunction(i, numReps));
        }
        int numPrefPart = SparkUtils.isHashPartitioned(in1) ? in1.getNumPartitions() : (SparkUtils.isHashPartitioned(in2) ? in2.getNumPartitions() : Math.min(in1.getNumPartitions() + in2.getNumPartitions(), 2 * SparkUtils.getNumPreferredPartitions(dcOut)));
        JavaPairRDD out = in1.join(in2, numPrefPart).mapValues((Function)new TensorTensorBinaryOpFunction(bop));
        sec.setRDDHandleForVariable(this.output.getName(), out);
        sec.addLineageRDD(this.output.getName(), this.input1.getName());
        sec.addLineageRDD(this.output.getName(), this.input2.getName());
    }

    protected void processMatrixBVectorBinaryInstruction(ExecutionContext ec, BinaryM.VectorType vtype) {
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        this.checkMatrixMatrixBinaryCharacteristics(sec);
        String rddVar = this.input1.getName();
        String bcastVar = this.input2.getName();
        JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryMatrixBlockRDDHandleForVariable(rddVar);
        PartitionedBroadcast<MatrixBlock> in2 = sec.getBroadcastForVariable(bcastVar);
        DataCharacteristics mc1 = sec.getDataCharacteristics(rddVar);
        DataCharacteristics mc2 = sec.getDataCharacteristics(bcastVar);
        BinaryOperator bop = (BinaryOperator)this._optr;
        boolean isOuter = mc1.getRows() > 1L && mc1.getCols() == 1L && mc2.getRows() == 1L && mc2.getCols() > 1L;
        JavaPairRDD out = null;
        out = isOuter ? in1.flatMapToPair((PairFlatMapFunction)new OuterVectorBinaryOpFunction(bop, in2)) : in1.mapPartitionsToPair((PairFlatMapFunction)new MatrixVectorBinaryOpPartitionFunction(bop, in2, vtype), true);
        this.updateBinaryOutputDataCharacteristics(sec);
        sec.setRDDHandleForVariable(this.output.getName(), out);
        sec.addLineageRDD(this.output.getName(), rddVar);
        sec.addLineageBroadcast(this.output.getName(), bcastVar);
    }

    protected void processTensorTensorBroadcastBinaryInstruction(ExecutionContext ec) {
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        this.checkTensorTensorBinaryCharacteristics(sec);
        String rddVar = this.input1.getName();
        String bcastVar = this.input2.getName();
        JavaPairRDD<TensorIndexes, TensorBlock> in1 = sec.getBinaryTensorBlockRDDHandleForVariable(rddVar);
        DataCharacteristics dc1 = sec.getDataCharacteristics(rddVar);
        DataCharacteristics dc2 = sec.getDataCharacteristics(bcastVar).setBlocksize(dc1.getBlocksize());
        PartitionedBroadcast<TensorBlock> in2 = sec.getBroadcastForTensorVariable(bcastVar);
        BinaryOperator bop = (BinaryOperator)this._optr;
        boolean[] replicateDim = new boolean[dc2.getNumDims()];
        for (int i = 0; i < replicateDim.length; ++i) {
            replicateDim[i] = dc2.getDim(i) == 1L;
        }
        JavaPairRDD out = in1.mapPartitionsToPair((PairFlatMapFunction)new TensorTensorBinaryOpPartitionFunction(bop, in2, replicateDim), true);
        this.updateBinaryTensorOutputDataCharacteristics(sec);
        sec.setRDDHandleForVariable(this.output.getName(), out);
        sec.addLineageRDD(this.output.getName(), rddVar);
        sec.addLineageBroadcast(this.output.getName(), bcastVar);
    }

    protected void processMatrixScalarBinaryInstruction(ExecutionContext ec) {
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        String rddVar = this.input1.getDataType() == Types.DataType.MATRIX ? this.input1.getName() : this.input2.getName();
        JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryMatrixBlockRDDHandleForVariable(rddVar);
        CPOperand scalar = this.input1.getDataType() == Types.DataType.MATRIX ? this.input2 : this.input1;
        ScalarObject constant = ec.getScalarInput(scalar);
        ScalarOperator sc_op = (ScalarOperator)this._optr;
        sc_op = sc_op.setConstant(constant.getDoubleValue());
        JavaPairRDD out = in1.mapValues((Function)new MatrixScalarUnaryFunction(sc_op));
        this.updateUnaryOutputDataCharacteristics(sec, rddVar, this.output.getName());
        sec.setRDDHandleForVariable(this.output.getName(), out);
        sec.addLineageRDD(this.output.getName(), rddVar);
    }

    protected DataCharacteristics updateBinaryMMOutputDataCharacteristics(SparkExecutionContext sec, boolean checkCommonDim) {
        DataCharacteristics mc1 = sec.getDataCharacteristics(this.input1.getName());
        DataCharacteristics mc2 = sec.getDataCharacteristics(this.input2.getName());
        DataCharacteristics mcOut = sec.getDataCharacteristics(this.output.getName());
        if (!mcOut.dimsKnown()) {
            if (!mc1.dimsKnown() || !mc2.dimsKnown()) {
                throw new DMLRuntimeException("The output dimensions are not specified and cannot be inferred from inputs.");
            }
            if (mc1.getBlocksize() != mc2.getBlocksize()) {
                throw new DMLRuntimeException("Incompatible block sizes for BinarySPInstruction.");
            }
            if (checkCommonDim && mc1.getCols() != mc2.getRows()) {
                throw new DMLRuntimeException("Incompatible dimensions for BinarySPInstruction");
            }
            mcOut.set(mc1.getRows(), mc2.getCols(), mc1.getBlocksize(), mc1.getBlocksize());
        }
        return mcOut;
    }

    protected void updateBinaryAppendOutputDataCharacteristics(SparkExecutionContext sec, boolean cbind) {
        DataCharacteristics mc1 = sec.getDataCharacteristics(this.input1.getName());
        DataCharacteristics mc2 = sec.getDataCharacteristics(this.input2.getName());
        DataCharacteristics mcOut = sec.getDataCharacteristics(this.output.getName());
        MetaDataUtils.updateAppendDataCharacteristics(mc1, mc2, mcOut, cbind);
        if (!mcOut.nnzKnown() && mc1.nnzKnown() && mc2.nnzKnown()) {
            mcOut.setNonZeros(mc1.getNonZeros() + mc2.getNonZeros());
        }
    }

    protected long getNumReplicas(DataCharacteristics mc1, DataCharacteristics mc2, boolean left) {
        if (left) {
            if (mc1.getCols() == 1L) {
                return mc2.getNumColBlocks();
            }
        } else {
            if (mc2.getRows() == 1L && mc1.getRows() > 1L) {
                return mc1.getNumRowBlocks();
            }
            if (mc2.getCols() == 1L && mc1.getCols() > 1L) {
                return mc2.getNumColBlocks();
            }
        }
        return 1L;
    }

    protected long getNumDimReplicas(DataCharacteristics dc1, DataCharacteristics dc2, int dim) {
        if (dim >= dc2.getNumDims() || dc2.getDim(dim) == 1L && dc2.getDim(dim) > 1L) {
            return dc1.getNumBlocks(dim);
        }
        return 1L;
    }

    protected void checkMatrixMatrixBinaryCharacteristics(SparkExecutionContext sec) {
        DataCharacteristics mc1 = sec.getDataCharacteristics(this.input1.getName());
        DataCharacteristics mc2 = sec.getDataCharacteristics(this.input2.getName());
        if (!mc1.dimsKnown() || !mc2.dimsKnown()) {
            throw new DMLRuntimeException("Unknown dimensions matrix-matrix binary operations: [" + mc1.getRows() + "x" + mc1.getCols() + " vs " + mc2.getRows() + "x" + mc2.getCols() + "]");
        }
        if (!(mc1.getRows() == mc2.getRows() && mc1.getCols() == mc2.getCols() || mc1.getRows() == mc2.getRows() && mc2.getCols() == 1L || mc1.getCols() == mc2.getCols() && mc2.getRows() == 1L || mc1.getCols() == 1L && mc2.getRows() == 1L)) {
            throw new DMLRuntimeException("Dimensions mismatch matrix-matrix binary operations: [" + mc1.getRows() + "x" + mc1.getCols() + " vs " + mc2.getRows() + "x" + mc2.getCols() + "]");
        }
        if (mc1.getBlocksize() != mc2.getBlocksize()) {
            throw new DMLRuntimeException("Blocksize mismatch matrix-matrix binary operations: [" + mc1.getBlocksize() + "x" + mc1.getBlocksize() + " vs " + mc2.getBlocksize() + "x" + mc2.getBlocksize() + "]");
        }
    }

    protected void checkTensorTensorBinaryCharacteristics(SparkExecutionContext sec) {
        boolean dimensionMismatch;
        DataCharacteristics mc1 = sec.getDataCharacteristics(this.input1.getName());
        DataCharacteristics mc2 = sec.getDataCharacteristics(this.input2.getName());
        if (!mc1.dimsKnown() || !mc2.dimsKnown()) {
            throw new DMLRuntimeException("Unknown dimensions tensor-tensor binary operations");
        }
        boolean bl = dimensionMismatch = mc1.getNumDims() < mc2.getNumDims();
        if (!dimensionMismatch) {
            for (int i = 0; i < mc2.getNumDims(); ++i) {
                if (mc1.getDim(i) == mc2.getDim(i) || mc2.getDim(i) == 1L) continue;
                dimensionMismatch = true;
                break;
            }
        }
        if (dimensionMismatch) {
            throw new DMLRuntimeException("Dimensions mismatch tensor-tensor binary operations");
        }
    }

    protected void checkBinaryAppendInputCharacteristics(SparkExecutionContext sec, boolean cbind, boolean checkSingleBlk, boolean checkAligned) {
        DataCharacteristics mc1 = sec.getDataCharacteristics(this.input1.getName());
        DataCharacteristics mc2 = sec.getDataCharacteristics(this.input2.getName());
        if (!mc1.dimsKnown() || !mc2.dimsKnown()) {
            throw new DMLRuntimeException("The dimensions unknown for inputs");
        }
        if (cbind && mc1.getRows() != mc2.getRows()) {
            throw new DMLRuntimeException("The number of rows of inputs should match for append-cbind instruction");
        }
        if (!cbind && mc1.getCols() != mc2.getCols()) {
            throw new DMLRuntimeException("The number of columns of inputs should match for append-rbind instruction");
        }
        if (mc1.getBlocksize() != mc2.getBlocksize()) {
            throw new DMLRuntimeException("The block sizes do not match for input matrices");
        }
        if (checkSingleBlk && mc1.getCols() + mc2.getCols() > (long)mc1.getBlocksize()) {
            throw new DMLRuntimeException("Output must have at most one column block");
        }
        if (checkAligned && (cbind ? mc1.getCols() : mc1.getRows()) % (long)mc1.getBlocksize() != 0L) {
            throw new DMLRuntimeException("Input matrices are not aligned to blocksize boundaries. Wrong append selected");
        }
    }
}

