/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.spark.utils;

import java.io.Serializable;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import org.apache.spark.HashPartitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.storage.StorageLevel;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.Checkpoint;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.data.BasicTensorBlock;
import org.apache.sysds.runtime.data.IndexedTensorBlock;
import org.apache.sysds.runtime.data.TensorBlock;
import org.apache.sysds.runtime.data.TensorIndexes;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.instructions.spark.functions.CopyBinaryCellFunction;
import org.apache.sysds.runtime.instructions.spark.functions.CopyMatrixBlockFunction;
import org.apache.sysds.runtime.instructions.spark.functions.CopyMatrixBlockPairFunction;
import org.apache.sysds.runtime.instructions.spark.functions.CopyTensorBlockFunction;
import org.apache.sysds.runtime.instructions.spark.functions.CopyTensorBlockPairFunction;
import org.apache.sysds.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction;
import org.apache.sysds.runtime.instructions.spark.functions.RecomputeNnzFunction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixCell;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.matrix.data.Pair;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.stats.InfrastructureAnalyzer;
import scala.Tuple2;

public class SparkUtils {
    public static final StorageLevel DEFAULT_TMP = Checkpoint.DEFAULT_STORAGE_LEVEL;

    public static IndexedMatrixValue toIndexedMatrixBlock(Tuple2<MatrixIndexes, MatrixBlock> in) {
        return new IndexedMatrixValue((MatrixIndexes)in._1(), (MatrixValue)in._2());
    }

    public static IndexedMatrixValue toIndexedMatrixBlock(MatrixIndexes ix, MatrixBlock mb) {
        return new IndexedMatrixValue(ix, mb);
    }

    public static IndexedTensorBlock toIndexedTensorBlock(Tuple2<TensorIndexes, TensorBlock> in) {
        return new IndexedTensorBlock((TensorIndexes)in._1(), (TensorBlock)in._2());
    }

    public static IndexedTensorBlock toIndexedTensorBlock(TensorIndexes ix, TensorBlock mb) {
        return new IndexedTensorBlock(ix, mb);
    }

    public static Tuple2<MatrixIndexes, MatrixBlock> fromIndexedMatrixBlock(IndexedMatrixValue in) {
        return new Tuple2((Object)in.getIndexes(), (Object)((MatrixBlock)in.getValue()));
    }

    public static List<Tuple2<MatrixIndexes, MatrixBlock>> fromIndexedMatrixBlock(List<IndexedMatrixValue> in) {
        return in.stream().map(imv -> SparkUtils.fromIndexedMatrixBlock(imv)).collect(Collectors.toList());
    }

    public static Pair<MatrixIndexes, MatrixBlock> fromIndexedMatrixBlockToPair(IndexedMatrixValue in) {
        return new Pair<MatrixIndexes, MatrixBlock>(in.getIndexes(), (MatrixBlock)in.getValue());
    }

    public static List<Pair<MatrixIndexes, MatrixBlock>> fromIndexedMatrixBlockToPair(List<IndexedMatrixValue> in) {
        return in.stream().map(imv -> SparkUtils.fromIndexedMatrixBlockToPair(imv)).collect(Collectors.toList());
    }

    public static Tuple2<Long, FrameBlock> fromIndexedFrameBlock(Pair<Long, FrameBlock> in) {
        return new Tuple2((Object)in.getKey(), (Object)in.getValue());
    }

    public static List<Tuple2<Long, FrameBlock>> fromIndexedFrameBlock(List<Pair<Long, FrameBlock>> in) {
        return in.stream().map(ifv -> SparkUtils.fromIndexedFrameBlock(ifv)).collect(Collectors.toList());
    }

    public static List<Pair<Long, Long>> toIndexedLong(List<Tuple2<Long, Long>> in) {
        return in.stream().map(e -> new Pair<Long, Long>((Long)e._1(), (Long)e._2())).collect(Collectors.toList());
    }

    public static Pair<Long, FrameBlock> toIndexedFrameBlock(Tuple2<Long, FrameBlock> in) {
        return new Pair<Long, FrameBlock>((Long)in._1(), (FrameBlock)in._2());
    }

    public static boolean isHashPartitioned(JavaPairRDD<?, ?> in) {
        return !in.rdd().partitioner().isEmpty() && in.rdd().partitioner().get() instanceof HashPartitioner;
    }

    public static int getNumPreferredPartitions(DataCharacteristics dc, JavaPairRDD<?, ?> in) {
        if (!dc.dimsKnown(true) && in != null) {
            return in.getNumPartitions();
        }
        return SparkUtils.getNumPreferredPartitions(dc);
    }

    public static int getNumPreferredPartitions(DataCharacteristics dc) {
        return SparkUtils.getNumPreferredPartitions(dc, !dc.isNoEmptyBlocks());
    }

    public static int getNumPreferredPartitions(DataCharacteristics dc, boolean outputEmptyBlocks) {
        if (!dc.dimsKnown()) {
            return SparkExecutionContext.getDefaultParallelism(true);
        }
        double hdfsBlockSize = InfrastructureAnalyzer.getHDFSBlockSize();
        double matrixPSize = OptimizerUtils.estimatePartitionedSizeExactSparsity(dc, outputEmptyBlocks);
        return (int)Math.max(Math.ceil(matrixPSize / hdfsBlockSize), 1.0);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> copyBinaryBlockMatrix(JavaPairRDD<MatrixIndexes, MatrixBlock> in) {
        return SparkUtils.copyBinaryBlockMatrix(in, true);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> copyBinaryBlockMatrix(JavaPairRDD<MatrixIndexes, MatrixBlock> in, boolean deep) {
        if (!deep) {
            return in.mapValues(new CopyMatrixBlockFunction(false));
        }
        return in.mapPartitionsToPair(new CopyMatrixBlockPairFunction(deep), true);
    }

    public static JavaPairRDD<TensorIndexes, BasicTensorBlock> copyBinaryBlockTensor(JavaPairRDD<TensorIndexes, BasicTensorBlock> in) {
        return SparkUtils.copyBinaryBlockTensor(in, true);
    }

    public static JavaPairRDD<TensorIndexes, BasicTensorBlock> copyBinaryBlockTensor(JavaPairRDD<TensorIndexes, BasicTensorBlock> in, boolean deep) {
        if (!deep) {
            return in.mapValues(new CopyTensorBlockFunction(false));
        }
        return in.mapPartitionsToPair(new CopyTensorBlockPairFunction(deep), true);
    }

    public static void checkSparsity(String varname, ExecutionContext ec) {
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        sec.getBinaryMatrixBlockRDDHandleForVariable(varname).foreach(new CheckSparsityFunction());
    }

    public static String getStartLineFromSparkDebugInfo(String line) {
        String withoutPrefix = line.substring(4, line.length());
        return withoutPrefix.split(":")[0];
    }

    public static String getPrefixFromSparkDebugInfo(String line) {
        String[] lines = line.split("\\||\\+-");
        Object retVal = lines[0];
        for (int i = 1; i < lines.length - 1; ++i) {
            retVal = (String)retVal + "|" + lines[i];
        }
        String twoSpaces = "  ";
        if (line.contains("+-")) {
            return (String)retVal + "+- ";
        }
        return (String)retVal + "|" + twoSpaces;
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> getEmptyBlockRDD(JavaSparkContext sc, DataCharacteristics mc) {
        long size = mc.getNumBlocks() * OptimizerUtils.estimateSizeEmptyBlock(Math.min(Math.max(mc.getRows(), 1L), (long)mc.getBlocksize()), Math.min(Math.max(mc.getCols(), 1L), (long)mc.getBlocksize()));
        int par = (int)Math.min(4.0 * Math.max((double)SparkExecutionContext.getDefaultParallelism(true), Math.ceil(size / InfrastructureAnalyzer.getHDFSBlockSize())), (double)mc.getNumBlocks());
        long pNumBlocks = (long)Math.ceil((double)mc.getNumBlocks() / (double)par);
        List offsets = LongStream.iterate(0L, n -> n + pNumBlocks).limit(par).boxed().collect(Collectors.toList());
        return sc.parallelize(offsets, par).flatMapToPair(new GenerateEmptyBlocks(mc, pNumBlocks));
    }

    public static JavaPairRDD<MatrixIndexes, MatrixCell> cacheBinaryCellRDD(JavaPairRDD<MatrixIndexes, MatrixCell> input) {
        return !input.getStorageLevel().equals((Object)DEFAULT_TMP) ? input.mapToPair(new CopyBinaryCellFunction()).persist(DEFAULT_TMP) : input;
    }

    public static DataCharacteristics computeDataCharacteristics(JavaPairRDD<MatrixIndexes, MatrixCell> input) {
        DataCharacteristics ret = input.map(new AnalyzeCellDataCharacteristics()).reduce(new AggregateDataCharacteristics());
        return ret;
    }

    public static long getNonZeros(MatrixObject mo) {
        return SparkUtils.getNonZeros(mo.getRDDHandle().getRDD());
    }

    public static long getNonZeros(JavaPairRDD<MatrixIndexes, MatrixBlock> input) {
        return (Long)input.filter(new FilterNonEmptyBlocksFunction()).values().mapPartitions(new RecomputeNnzFunction()).reduce((Function2 & Serializable)(a, b) -> a + b);
    }

    public static void postprocessUltraSparseOutput(MatrixObject mo, DataCharacteristics mcOut) {
        long memUB = OptimizerUtils.estimateSizeExactSparsity(mcOut.getRows(), mcOut.getCols(), mcOut.getNonZerosBound());
        if (!OptimizerUtils.exceedsCachingThreshold(mcOut.getCols(), memUB) && memUB < OptimizerUtils.estimateSizeExactSparsity(mcOut)) {
            mo.acquireReadAndRelease();
        }
    }

    private static class GenerateEmptyBlocks
    implements PairFlatMapFunction<Long, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 630129586089106855L;
        private final DataCharacteristics _mc;
        private final long _pNumBlocks;

        public GenerateEmptyBlocks(DataCharacteristics mc, long pNumBlocks) {
            this._mc = mc;
            this._pNumBlocks = pNumBlocks;
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Long arg0) throws Exception {
            long ncblks = this._mc.getNumColBlocks();
            long nblocksU = Math.min(arg0 + this._pNumBlocks, this._mc.getNumBlocks());
            return LongStream.range(arg0, nblocksU).mapToObj(i -> {
                long rix = 1L + i / ncblks;
                long cix = 1L + i % ncblks;
                int lrlen = UtilFunctions.computeBlockSize(this._mc.getRows(), rix, this._mc.getBlocksize());
                int lclen = UtilFunctions.computeBlockSize(this._mc.getCols(), cix, this._mc.getBlocksize());
                return new Tuple2((Object)new MatrixIndexes(rix, cix), (Object)new MatrixBlock(lrlen, lclen, true));
            }).iterator();
        }
    }

    private static class AggregateDataCharacteristics
    implements Function2<DataCharacteristics, DataCharacteristics, DataCharacteristics> {
        private static final long serialVersionUID = 4263886749699779994L;

        private AggregateDataCharacteristics() {
        }

        public DataCharacteristics call(DataCharacteristics arg0, DataCharacteristics arg1) throws Exception {
            return new MatrixCharacteristics(Math.max(arg0.getRows(), arg1.getRows()), Math.max(arg0.getCols(), arg1.getCols()), arg0.getBlocksize(), arg0.getNonZeros() + arg1.getNonZeros());
        }
    }

    private static class AnalyzeCellDataCharacteristics
    implements Function<Tuple2<MatrixIndexes, MatrixCell>, DataCharacteristics> {
        private static final long serialVersionUID = 8899395272683723008L;

        private AnalyzeCellDataCharacteristics() {
        }

        public DataCharacteristics call(Tuple2<MatrixIndexes, MatrixCell> arg0) throws Exception {
            long rix = ((MatrixIndexes)arg0._1()).getRowIndex();
            long cix = ((MatrixIndexes)arg0._1()).getColumnIndex();
            long nnz = ((MatrixCell)arg0._2()).getValue() != 0.0 ? 1L : 0L;
            return new MatrixCharacteristics(rix, cix, 0, nnz);
        }
    }

    private static class CheckSparsityFunction
    implements VoidFunction<Tuple2<MatrixIndexes, MatrixBlock>> {
        private static final long serialVersionUID = 4150132775681848807L;

        private CheckSparsityFunction() {
        }

        public void call(Tuple2<MatrixIndexes, MatrixBlock> arg) throws Exception {
            ((MatrixBlock)arg._2).checkNonZeros();
            ((MatrixBlock)arg._2).checkSparseRows();
        }
    }
}

