/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.operators;

import dr.evomodel.continuous.AbstractMultivariateTraitLikelihood;
import dr.inference.distribution.MultivariateDistributionLikelihood;
import dr.inference.distribution.WishartGammalDistributionModel;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.model.Parameter;
import dr.inference.operators.GibbsOperator;
import dr.inference.operators.MCMCOperator;
import dr.inference.operators.SimpleMCMCOperator;
import dr.math.distributions.GammaDistribution;
import dr.math.distributions.WishartDistribution;
import dr.math.distributions.WishartStatistics;
import dr.math.distributions.WishartSufficientStatistics;
import dr.math.interfaces.ConjugateWishartStatisticsProvider;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.Matrix;
import dr.math.matrixAlgebra.SymmetricMatrix;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;

public class CorrelationMatrixGibbsOperator
extends SimpleMCMCOperator
implements GibbsOperator {
    private static final String CORRELATION_OPERATOR = "correlationGibbsOperator";
    public static final String TREE_MODEL = "treeModel";
    public static final String DISTRIBUTION = "distribution";
    public static final String PRIOR = "prior";
    private final ConjugateWishartStatisticsProvider conjugateWishartProvider;
    private final MatrixParameterInterface inverseCorrelation;
    private Statistics priorStatistics;
    private Statistics workingStatistics;
    private double priorDf;
    private SymmetricMatrix priorInverseScaleMatrix;
    private final int dim;
    private double numberObservations;
    private double pathWeight = 1.0;
    private boolean wishartIsModel = false;
    private WishartGammalDistributionModel priorModel = null;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser(){
        private XMLSyntaxRule[] rules = new XMLSyntaxRule[]{AttributeRule.newDoubleRule("weight"), new ElementRule(AbstractMultivariateTraitLikelihood.class, true), new ElementRule(ConjugateWishartStatisticsProvider.class, true), new ElementRule(MultivariateDistributionLikelihood.class, 1, 2), new ElementRule(MatrixParameterInterface.class, true)};

        @Override
        public String getParserName() {
            return CorrelationMatrixGibbsOperator.CORRELATION_OPERATOR;
        }

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            double d = xMLObject.getDoubleAttribute("weight");
            ConjugateWishartStatisticsProvider conjugateWishartStatisticsProvider = (ConjugateWishartStatisticsProvider)xMLObject.getChild(ConjugateWishartStatisticsProvider.class);
            MatrixParameterInterface matrixParameterInterface = conjugateWishartStatisticsProvider.getPrecisionParameter();
            MultivariateDistributionLikelihood multivariateDistributionLikelihood = (MultivariateDistributionLikelihood)xMLObject.getChild(MultivariateDistributionLikelihood.class);
            if (!(multivariateDistributionLikelihood.getDistribution() instanceof WishartStatistics)) {
                throw new XMLParseException("Only a Wishart distribution is conjugate for Gibbs sampling");
            }
            if (matrixParameterInterface.getColumnDimension() != matrixParameterInterface.getRowDimension()) {
                throw new XMLParseException("The variance matrix is not square or of wrong dimension");
            }
            return new CorrelationMatrixGibbsOperator(conjugateWishartStatisticsProvider, matrixParameterInterface, (WishartStatistics)((Object)multivariateDistributionLikelihood.getDistribution()), null, d);
        }

        @Override
        public String getParserDescription() {
            return "This element returns a multivariate normal random walk operator on a given parameter.";
        }

        @Override
        public Class getReturnType() {
            return MCMCOperator.class;
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }
    };
    private static final boolean DEBUG = false;

    private Statistics setupStatistics(WishartStatistics wishartStatistics) {
        double[][] dArray = wishartStatistics.getScaleMatrix();
        double[][] dArray2 = null;
        if (dArray != null) {
            dArray2 = new SymmetricMatrix(dArray).inverse().toComponents();
        }
        return new Statistics(wishartStatistics.getDF(), dArray2);
    }

    private void setupWishartStatistics(WishartStatistics wishartStatistics) {
        this.priorDf = wishartStatistics.getDF();
        this.priorInverseScaleMatrix = null;
        double[][] dArray = wishartStatistics.getScaleMatrix();
        if (dArray != null) {
            this.priorInverseScaleMatrix = new SymmetricMatrix(dArray).inverse();
        }
    }

    private void normalizeToGetInverseCorrelation(double[][] dArray) {
        int n;
        double[][] dArray2 = new SymmetricMatrix(dArray).inverse().toComponents();
        double[] dArray3 = new double[this.dim];
        for (n = 0; n < this.dim; ++n) {
            dArray3[n] = Math.sqrt(dArray2[n][n]);
        }
        for (n = 0; n < this.dim; ++n) {
            for (int i = 0; i < this.dim; ++i) {
                dArray[n][i] = dArray3[n] * dArray3[i] * dArray[n][i];
            }
        }
    }

    public CorrelationMatrixGibbsOperator(ConjugateWishartStatisticsProvider conjugateWishartStatisticsProvider, MatrixParameterInterface matrixParameterInterface, WishartStatistics wishartStatistics, WishartStatistics wishartStatistics2, double d) {
        this.conjugateWishartProvider = conjugateWishartStatisticsProvider;
        this.inverseCorrelation = matrixParameterInterface != null ? matrixParameterInterface : this.conjugateWishartProvider.getPrecisionParameter();
        this.setupWishartStatistics(wishartStatistics);
        this.priorStatistics = this.setupStatistics(wishartStatistics);
        if (wishartStatistics instanceof WishartGammalDistributionModel) {
            this.wishartIsModel = true;
            this.priorModel = (WishartGammalDistributionModel)wishartStatistics;
        }
        if (wishartStatistics2 != null) {
            this.workingStatistics = this.setupStatistics(wishartStatistics2);
        }
        this.setWeight(d);
        this.dim = this.inverseCorrelation.getRowDimension();
    }

    @Override
    public void setPathParameter(double d) {
        if (d < 0.0 || d > 1.0) {
            throw new IllegalArgumentException("Illegal path weight of " + d);
        }
        this.pathWeight = d;
    }

    public int getStepCount() {
        return 1;
    }

    private double[] getDiagRescaleMatrix() {
        double[] dArray = new double[this.dim];
        for (int i = 0; i < this.dim; ++i) {
            double d = GammaDistribution.nextGamma((double)(this.dim + 1) / 2.0, 1.0);
            dArray[i] = Math.sqrt(this.inverseCorrelation.getParameterValue(i, i) / (2.0 * d));
        }
        return dArray;
    }

    private void rescaleOuterProduct(double[] dArray) {
        double[] dArray2 = this.getDiagRescaleMatrix();
        for (int i = 0; i < this.dim; ++i) {
            for (int j = 0; j < this.dim; ++j) {
                dArray[j * this.dim + i] = dArray2[i] * dArray2[j] * dArray[j * this.dim + i];
            }
        }
    }

    private void incrementOuterProductWithRescale(double[][] dArray, ConjugateWishartStatisticsProvider conjugateWishartStatisticsProvider) {
        WishartSufficientStatistics wishartSufficientStatistics = conjugateWishartStatisticsProvider.getWishartStatistics();
        double[] dArray2 = wishartSufficientStatistics.getScaleMatrix();
        this.rescaleOuterProduct(dArray2);
        double d = wishartSufficientStatistics.getDf();
        int n = dArray.length;
        for (int i = 0; i < n; ++i) {
            System.arraycopy(dArray2, i * n, dArray[i], 0, n);
        }
        this.numberObservations = d;
    }

    private double[][] getOperationScaleMatrixAndSetObservationCount2() {
        double[][] dArray = new double[this.dim][this.dim];
        Matrix matrix = null;
        this.numberObservations = 0.0;
        this.incrementOuterProductWithRescale(dArray, this.conjugateWishartProvider);
        try {
            SymmetricMatrix symmetricMatrix = new SymmetricMatrix(dArray);
            if (this.priorInverseScaleMatrix != null) {
                symmetricMatrix = this.priorInverseScaleMatrix.add(symmetricMatrix);
            }
            matrix = symmetricMatrix.inverse();
        }
        catch (IllegalDimension illegalDimension) {
            illegalDimension.printStackTrace();
        }
        assert (matrix != null);
        return matrix.toComponents();
    }

    @Override
    public double doOperation() {
        if (this.wishartIsModel) {
            this.setupWishartStatistics(this.priorModel);
            this.priorStatistics = this.setupStatistics(this.priorModel);
        }
        double[][] dArray = this.getOperationScaleMatrixAndSetObservationCount2();
        double d = this.numberObservations;
        double d2 = this.priorDf + d * this.pathWeight;
        double[][] dArray2 = WishartDistribution.nextWishart(d2, dArray);
        this.normalizeToGetInverseCorrelation(dArray2);
        for (int i = 0; i < this.dim; ++i) {
            Parameter parameter = this.inverseCorrelation.getParameter(i);
            for (int j = 0; j < this.dim; ++j) {
                parameter.setParameterValueQuietly(j, dArray2[j][i]);
            }
        }
        this.inverseCorrelation.fireParameterChangedEvent();
        return 0.0;
    }

    public String getPerformanceSuggestion() {
        return null;
    }

    @Override
    public String getOperatorName() {
        return CORRELATION_OPERATOR;
    }

    private class Statistics {
        final double degreesOfFreedom;
        final double[][] rateMatrix;

        Statistics(double d, double[][] dArray) {
            this.degreesOfFreedom = d;
            this.rateMatrix = dArray;
        }
    }
}

