/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.controlprogram.paramserv;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.utils.Statistics;

public abstract class ParamServer {
    protected static final Log LOG = LogFactory.getLog((String)ParamServer.class.getName());
    protected static final boolean ACCRUE_BSP_GRADIENTS = true;
    protected Map<Integer, BlockingQueue<ListObject>> _modelMap;
    private ListObject _model;
    protected ExecutionContext _ec;
    private Statement.PSUpdateType _updateType;
    private Statement.PSFrequency _freq;
    private FunctionCallCPInstruction _inst;
    private String _outputName;
    private boolean[] _finishedStates;
    private ListObject _accGradients = null;
    private boolean _validationPossible;
    private FunctionCallCPInstruction _valInst;
    private String _lossOutput;
    private String _accuracyOutput;
    private int _syncCounter = 0;
    private int _epochCounter = 0;
    private int _numBatchesPerEpoch;
    private int _numWorkers;

    protected ParamServer() {
    }

    protected ParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, Statement.PSFrequency freq, ExecutionContext ec, int workerNum, String valFunc, int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject valLabels) {
        this._modelMap = new HashMap<Integer, BlockingQueue<ListObject>>(workerNum);
        IntStream.range(0, workerNum).forEach(i -> this._modelMap.put(i, new ArrayBlockingQueue(1)));
        this._model = model;
        this._ec = ec;
        this._updateType = updateType;
        this._freq = freq;
        this._finishedStates = new boolean[workerNum];
        this.setupAggFunc(this._ec, aggFunc);
        if (valFunc != null && numBatchesPerEpoch > 0 && valFeatures != null && valLabels != null) {
            this.setupValFunc(this._ec, valFunc, valFeatures, valLabels);
        }
        this._numBatchesPerEpoch = numBatchesPerEpoch;
        this._numWorkers = workerNum;
        this.broadcastModel(true);
    }

    protected void setupAggFunc(ExecutionContext ec, String aggFunc) {
        String[] cfn = DMLProgram.splitFunctionKey(aggFunc);
        String ns = cfn[0];
        String fname = cfn[1];
        boolean opt = !ec.getProgram().containsFunctionProgramBlock(ns, fname, false);
        FunctionProgramBlock func = ec.getProgram().getFunctionProgramBlock(ns, fname, opt);
        ArrayList<DataIdentifier> inputs = func.getInputParams();
        ArrayList<DataIdentifier> outputs = func.getOutputParams();
        if (outputs.size() != 1) {
            throw new DMLRuntimeException(String.format("The output of the '%s' function should provide one list containing the updated model.", aggFunc));
        }
        if (outputs.get(0).getDataType() != Types.DataType.LIST) {
            throw new DMLRuntimeException(String.format("The output of the '%s' function should be type of list.", aggFunc));
        }
        this._outputName = outputs.get(0).getName();
        CPOperand[] boundInputs = (CPOperand[])inputs.stream().map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())).toArray(CPOperand[]::new);
        ArrayList outputNames = outputs.stream().map(DataIdentifier::getName).collect(Collectors.toCollection(ArrayList::new));
        this._inst = new FunctionCallCPInstruction(ns, fname, opt, boundInputs, func.getInputParamNames(), outputNames, "aggregate function");
    }

    protected void setupValFunc(ExecutionContext ec, String valFunc, MatrixObject valFeatures, MatrixObject valLabels) {
        String[] cfn = DMLProgram.splitFunctionKey(valFunc);
        String ns = cfn[0];
        String fname = cfn[1];
        boolean opt = !ec.getProgram().containsFunctionProgramBlock(ns, fname, false);
        FunctionProgramBlock func = ec.getProgram().getFunctionProgramBlock(ns, fname, opt);
        ArrayList<DataIdentifier> inputs = func.getInputParams();
        ArrayList<DataIdentifier> outputs = func.getOutputParams();
        if (outputs.size() != 2) {
            throw new DMLRuntimeException(String.format("The output of the '%s' function should provide the loss and the accuracy in that order", valFunc));
        }
        if (outputs.get(0).getDataType() != Types.DataType.SCALAR || outputs.get(1).getDataType() != Types.DataType.SCALAR) {
            throw new DMLRuntimeException(String.format("The outputs of the '%s' function should both be scalars", valFunc));
        }
        this._lossOutput = outputs.get(0).getName();
        this._accuracyOutput = outputs.get(1).getName();
        CPOperand[] boundInputs = (CPOperand[])inputs.stream().map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())).toArray(CPOperand[]::new);
        ArrayList outputNames = outputs.stream().map(DataIdentifier::getName).collect(Collectors.toCollection(ArrayList::new));
        this._valInst = new FunctionCallCPInstruction(ns, fname, opt, boundInputs, func.getInputParamNames(), outputNames, "validate function");
        this._ec.setVariable("val_features", valFeatures);
        this._ec.setVariable("val_labels", valLabels);
        this._validationPossible = true;
    }

    public abstract void push(int var1, ListObject var2);

    public abstract ListObject pull(int var1);

    public ListObject getResult() {
        return this._model;
    }

    protected synchronized void updateGlobalModel(int workerID, ListObject gradients) {
        try {
            if (LOG.isDebugEnabled()) {
                LOG.debug((Object)String.format("Successfully pulled the gradients [size:%d kb] of worker_%d.", gradients.getDataSize() / 1024L, workerID));
            }
            switch (this._updateType) {
                case BSP: {
                    this.setFinishedState(workerID);
                    this._accGradients = ParamservUtils.accrueGradients(this._accGradients, gradients, true);
                    if (this.allFinished()) {
                        this.updateGlobalModel(this._accGradients);
                        this._accGradients = null;
                        if (this._numBatchesPerEpoch != -1 && (this._freq == Statement.PSFrequency.EPOCH || this._freq == Statement.PSFrequency.BATCH && ++this._syncCounter % this._numBatchesPerEpoch == 0)) {
                            if (LOG.isInfoEnabled()) {
                                LOG.info((Object)("[+] PARAMSERV: completed EPOCH " + this._epochCounter));
                            }
                            this.time_epoch();
                            if (this._validationPossible) {
                                this.validate();
                            }
                            ++this._epochCounter;
                            this._syncCounter = 0;
                        }
                        this.resetFinishedStates();
                        this.broadcastModel(true);
                        if (LOG.isDebugEnabled()) {
                            LOG.debug((Object)"Global parameter is broadcasted successfully.");
                        }
                    }
                    break;
                }
                case ASP: {
                    this.updateGlobalModel(gradients);
                    if (this._numBatchesPerEpoch != -1 && (this._freq == Statement.PSFrequency.EPOCH && (float)(++this._syncCounter) % (float)this._numWorkers == 0.0f || this._freq == Statement.PSFrequency.BATCH && (float)(++this._syncCounter) / (float)this._numWorkers % (float)this._numBatchesPerEpoch == 0.0f)) {
                        if (LOG.isInfoEnabled()) {
                            LOG.info((Object)("[+] PARAMSERV: completed PSEUDO EPOCH (ASP) " + this._epochCounter));
                        }
                        this.time_epoch();
                        if (this._validationPossible) {
                            this.validate();
                        }
                        ++this._epochCounter;
                        this._syncCounter = 0;
                    }
                    this.broadcastModel(workerID);
                    break;
                }
                default: {
                    throw new DMLRuntimeException("Unsupported update: " + this._updateType.name());
                }
            }
        }
        catch (Exception e) {
            throw new DMLRuntimeException("Aggregation or validation service failed: ", e);
        }
    }

    private void updateGlobalModel(ListObject gradients) {
        Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null;
        this._model = this.updateLocalModel(this._ec, gradients, this._model);
        if (DMLScript.STATISTICS && tAgg != null) {
            Statistics.accPSAggregationTime((long)tAgg.stop());
        }
    }

    protected ListObject updateLocalModel(ExecutionContext ec, ListObject gradients, ListObject model) {
        ec.setVariable("gradients", gradients);
        ec.setVariable("model", model);
        this._inst.processInstruction(ec);
        ListObject newModel = ec.getListObject(this._outputName);
        ParamservUtils.cleanupListObject(ec, "model", newModel.getStatus());
        ParamservUtils.cleanupListObject(ec, "gradients");
        return newModel;
    }

    private boolean allFinished() {
        return !ArrayUtils.contains((boolean[])this._finishedStates, (boolean)false);
    }

    private void resetFinishedStates() {
        Arrays.fill(this._finishedStates, false);
    }

    private void setFinishedState(int workerID) {
        this._finishedStates[workerID] = true;
    }

    private void broadcastModel(boolean par) {
        IntStream stream = IntStream.range(0, this._modelMap.size());
        (par ? stream.parallel() : stream).forEach(workerID -> {
            try {
                this.broadcastModel(workerID);
            }
            catch (InterruptedException e) {
                throw new DMLRuntimeException("Paramserv func: some error occurred when broadcasting model", e);
            }
        });
    }

    private void broadcastModel(int workerID) throws InterruptedException {
        Timing tBroad = DMLScript.STATISTICS ? new Timing(true) : null;
        this._modelMap.get(workerID).put(ParamservUtils.copyList(this._model, false));
        if (DMLScript.STATISTICS && tBroad != null) {
            Statistics.accPSModelBroadcastTime((long)tBroad.stop());
        }
    }

    private void time_epoch() {
        if (DMLScript.STATISTICS) {
            Statistics.accPSExecutionTime((long)Statistics.getPSExecutionTimer().stop());
            double current_total_execution_time = Statistics.getPSExecutionTime();
            double current_total_validation_time = Statistics.getPSValidationTime();
            double time_to_epoch = current_total_execution_time - current_total_validation_time;
            if (LOG.isInfoEnabled()) {
                if (this._validationPossible) {
                    LOG.info((Object)("[+] PARAMSERV: epoch timer (excl. validation): " + time_to_epoch / 1000.0 + " secs."));
                } else {
                    LOG.info((Object)("[+] PARAMSERV: epoch timer: " + time_to_epoch / 1000.0 + " secs."));
                }
            }
        }
    }

    private void validate() {
        Timing tValidate = DMLScript.STATISTICS ? new Timing(true) : null;
        this._ec.setVariable("model", this._model);
        this._valInst.processInstruction(this._ec);
        double loss = ((DoubleObject)this._ec.getVariable(this._lossOutput)).getDoubleValue();
        double accuracy = ((DoubleObject)this._ec.getVariable(this._accuracyOutput)).getDoubleValue();
        ParamservUtils.cleanupListObject(this._ec, "model");
        if (LOG.isInfoEnabled()) {
            LOG.info((Object)("[+] PARAMSERV: validation-loss: " + loss + " validation-accuracy: " + accuracy));
        }
        if (tValidate != null) {
            Statistics.accPSValidationTime((long)tValidate.stop());
        }
    }

    public FunctionCallCPInstruction getAggInst() {
        return this._inst;
    }
}

