/*
 * Decompiled with CFR 0.152.
 */
package dr.math.distributions.gp;

import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.loggers.NumberColumn;
import dr.inference.model.DesignMatrix;
import dr.inference.model.Model;
import dr.inference.model.ModelListener;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.inference.model.VariableListener;
import dr.math.MathUtils;
import dr.math.distributions.gp.AdditiveGaussianProcessDistribution;
import dr.xml.Reportable;
import java.util.ArrayList;
import java.util.List;
import org.ejml.alg.dense.decomposition.chol.CholeskyDecompositionCommon_D64;
import org.ejml.data.D1Matrix64F;
import org.ejml.data.DenseMatrix64F;
import org.ejml.factory.LinearSolverFactory;
import org.ejml.interfaces.linsol.LinearSolver;
import org.ejml.ops.CommonOps;

public class GaussianProcessPrediction
implements Reportable,
Loggable,
VariableListener,
ModelListener {
    private final AdditiveGaussianProcessDistribution gp;
    private final Parameter realizedValues;
    private final List<DesignMatrix> predictiveDesigns;
    private final int realizedDim;
    private final int predictiveDim;
    private final Parameter orderVariance;
    private final double[] prediction;
    private final double[] mean;
    private final DenseMatrix64F variance;
    private final DenseMatrix64F crossGramian;
    private DenseMatrix64F realizedPrecision;
    private final LinearSolver<DenseMatrix64F> solver;
    private final List<AdditiveGaussianProcessDistribution.BasisDimension> predictiveBases;
    private final List<AdditiveGaussianProcessDistribution.BasisDimension> crossBases;
    private final DenseMatrix64F crossRealized;
    private boolean predictionKnown;
    private boolean meanKnown;
    private boolean varianceKnown;
    private boolean crossRealizedKnown;
    private boolean crossGramianKnown;
    private LogColumn[] columns;

    public GaussianProcessPrediction(AdditiveGaussianProcessDistribution additiveGaussianProcessDistribution, Parameter parameter, List<DesignMatrix> list) {
        this.gp = additiveGaussianProcessDistribution;
        this.realizedValues = parameter;
        this.predictiveDesigns = list;
        this.realizedDim = additiveGaussianProcessDistribution.getDimension();
        this.predictiveDim = list.get(0).getRowDimension();
        this.orderVariance = additiveGaussianProcessDistribution.getOrderVariance();
        this.crossGramian = new DenseMatrix64F(this.predictiveDim, this.realizedDim);
        this.realizedPrecision = new DenseMatrix64F(this.realizedDim, this.realizedDim);
        this.crossRealized = new DenseMatrix64F(this.predictiveDim, this.realizedDim);
        this.mean = new double[this.predictiveDim];
        this.variance = new DenseMatrix64F(this.predictiveDim, this.predictiveDim);
        this.prediction = new double[this.predictiveDim];
        this.solver = LinearSolverFactory.symmPosDef(this.realizedDim);
        List<AdditiveGaussianProcessDistribution.BasisDimension> list2 = additiveGaussianProcessDistribution.getBases();
        this.predictiveBases = GaussianProcessPrediction.makePredictiveBases(list2, list);
        this.crossBases = GaussianProcessPrediction.makeCrossBases(list2, list);
        additiveGaussianProcessDistribution.addModelListener(this);
        parameter.addVariableListener(this);
        for (DesignMatrix designMatrix : list) {
            designMatrix.addVariableListener(this);
        }
        this.predictionKnown = false;
        this.meanKnown = false;
        this.varianceKnown = false;
        this.crossRealizedKnown = false;
        this.crossGramianKnown = false;
    }

    private static List<AdditiveGaussianProcessDistribution.BasisDimension> makeCrossBases(List<AdditiveGaussianProcessDistribution.BasisDimension> list, List<DesignMatrix> list2) {
        ArrayList<AdditiveGaussianProcessDistribution.BasisDimension> arrayList = new ArrayList<AdditiveGaussianProcessDistribution.BasisDimension>();
        for (int i = 0; i < list.size(); ++i) {
            AdditiveGaussianProcessDistribution.BasisDimension basisDimension = list.get(i);
            AdditiveGaussianProcessDistribution.BasisDimension basisDimension2 = new AdditiveGaussianProcessDistribution.BasisDimension(basisDimension.getKernel(), list2.get(i), basisDimension.getDesignMatrix1());
            arrayList.add(basisDimension2);
        }
        return arrayList;
    }

    private static List<AdditiveGaussianProcessDistribution.BasisDimension> makePredictiveBases(List<AdditiveGaussianProcessDistribution.BasisDimension> list, List<DesignMatrix> list2) {
        ArrayList<AdditiveGaussianProcessDistribution.BasisDimension> arrayList = new ArrayList<AdditiveGaussianProcessDistribution.BasisDimension>();
        for (int i = 0; i < list.size(); ++i) {
            AdditiveGaussianProcessDistribution.BasisDimension basisDimension = list.get(i);
            AdditiveGaussianProcessDistribution.BasisDimension basisDimension2 = new AdditiveGaussianProcessDistribution.BasisDimension(basisDimension.getKernel(), list2.get(i), list2.get(i));
            arrayList.add(basisDimension2);
        }
        return arrayList;
    }

    private void computePredictions() {
        int n;
        this.computeMean();
        this.computeVariance();
        if (!this.solver.setA(this.variance)) {
            throw new RuntimeException("Unable to decompose matrix");
        }
        DenseMatrix64F denseMatrix64F = ((CholeskyDecompositionCommon_D64)this.solver.getDecomposition()).getT();
        double[] dArray = new double[this.predictiveDim];
        for (n = 0; n < this.predictiveDim; ++n) {
            dArray[n] = MathUtils.nextGaussian();
        }
        for (n = 0; n < this.predictiveDim; ++n) {
            this.prediction[n] = 0.0;
            for (int i = 0; i < this.predictiveDim; ++i) {
                int n2 = n;
                this.prediction[n2] = this.prediction[n2] + denseMatrix64F.get(n, i) * dArray[i];
            }
            int n3 = n;
            this.prediction[n3] = this.prediction[n3] + this.mean[n];
        }
    }

    private double getPrediction(int n) {
        if (!this.predictionKnown) {
            this.computePredictions();
            this.predictionKnown = true;
        }
        if (n == this.predictiveDim - 1) {
            this.predictionKnown = false;
        }
        return this.prediction[n];
    }

    private void computeCrossGramian() {
        if (!this.crossGramianKnown) {
            AdditiveGaussianProcessDistribution.computeAdditiveGramian(this.crossGramian, this.crossBases, this.orderVariance);
            this.crossGramianKnown = true;
        }
    }

    private void computeCrossRealized() {
        if (!this.crossRealizedKnown) {
            this.computeCrossGramian();
            this.realizedPrecision = this.gp.getPrecisionAsMatrix();
            CommonOps.mult(this.crossGramian, this.realizedPrecision, this.crossRealized);
            this.crossRealizedKnown = true;
        }
    }

    private void computeMean() {
        if (!this.meanKnown) {
            this.computeCrossRealized();
            for (int i = 0; i < this.predictiveDim; ++i) {
                this.mean[i] = 0.0;
                for (int j = 0; j < this.realizedDim; ++j) {
                    int n = i;
                    this.mean[n] = this.mean[n] + this.crossRealized.get(i, j) * this.realizedValues.getParameterValue(j);
                }
            }
            this.meanKnown = true;
        }
    }

    private void computeVariance() {
        if (!this.varianceKnown) {
            this.computeCrossRealized();
            DenseMatrix64F denseMatrix64F = new DenseMatrix64F(this.predictiveDim, this.predictiveDim);
            AdditiveGaussianProcessDistribution.computeAdditiveGramian(denseMatrix64F, this.predictiveBases, this.orderVariance);
            DenseMatrix64F denseMatrix64F2 = new DenseMatrix64F(this.predictiveDim, this.predictiveDim);
            for (int i = 0; i < this.crossRealized.numRows; ++i) {
                for (int j = 0; j < this.crossGramian.numRows; ++j) {
                    double d = 0.0;
                    for (int k = 0; k < this.crossRealized.numCols; ++k) {
                        d += this.crossRealized.get(i, k) * this.crossGramian.get(j, k);
                    }
                    denseMatrix64F2.set(i, j, d);
                }
            }
            CommonOps.subtract((D1Matrix64F)denseMatrix64F, (D1Matrix64F)denseMatrix64F2, (D1Matrix64F)this.variance);
        }
    }

    private double[] getMean() {
        this.computeMean();
        return this.mean;
    }

    private double[] getVariance() {
        this.computeVariance();
        return this.variance.getData();
    }

    private LogColumn[] createColumns() {
        LogColumn[] logColumnArray = new LogColumn[this.predictiveDim];
        for (int i = 0; i < this.predictiveDim; ++i) {
            String string = "prediction" + (i + 1);
            final int n = i;
            logColumnArray[i] = new NumberColumn(string){

                @Override
                public double getDoubleValue() {
                    return GaussianProcessPrediction.this.getPrediction(n);
                }
            };
        }
        return logColumnArray;
    }

    @Override
    public LogColumn[] getColumns() {
        if (this.columns == null) {
            this.columns = this.createColumns();
        }
        return this.columns;
    }

    @Override
    public String getReport() {
        StringBuilder stringBuilder = new StringBuilder();
        stringBuilder.append("mean:");
        for (double d : this.getMean()) {
            stringBuilder.append(" ").append(d);
        }
        stringBuilder.append("\n");
        stringBuilder.append("variance:");
        for (double d : this.getVariance()) {
            stringBuilder.append(" ").append(d);
        }
        stringBuilder.append("\n");
        stringBuilder.append("prediction:");
        for (int i = 0; i < this.predictiveDim; ++i) {
            stringBuilder.append(" ").append(this.getPrediction(i));
        }
        return stringBuilder.toString();
    }

    @Override
    public void modelChangedEvent(Model model, Object object, int n) {
        if (model != this.gp) {
            throw new IllegalArgumentException("Unknown model");
        }
        this.predictionKnown = false;
        this.meanKnown = false;
        this.varianceKnown = false;
        this.crossRealizedKnown = false;
        this.crossGramianKnown = false;
    }

    @Override
    public void modelRestored(Model model) {
        this.predictionKnown = false;
    }

    @Override
    public void variableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        if (variable != this.realizedValues) {
            if (variable instanceof DesignMatrix && this.predictiveDesigns.contains((DesignMatrix)variable)) {
                throw new IllegalArgumentException("Not yet implemented");
            }
            throw new IllegalArgumentException("Unknown variable");
        }
        this.predictionKnown = false;
        this.meanKnown = false;
    }
}

