/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.operators;

import dr.evolution.tree.MutableTreeModel;
import dr.evolution.tree.NodeRef;
import dr.evomodel.branchratemodel.ArbitraryBranchRates;
import dr.evomodel.continuous.AbstractMultivariateTraitLikelihood;
import dr.evomodel.continuous.IntegratedMultivariateTraitLikelihood;
import dr.inference.distribution.DistributionLikelihood;
import dr.inference.distribution.GammaDistributionModel;
import dr.inference.model.MatrixParameter;
import dr.inference.operators.GibbsOperator;
import dr.inference.operators.MCMCOperator;
import dr.inference.operators.SimpleMCMCOperator;
import dr.math.distributions.GammaDistribution;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.logging.Logger;

public class TraitRateGibbsOperator
extends SimpleMCMCOperator
implements GibbsOperator {
    private static final String GIBBS_OPERATOR = "traitRateGibbsOperator";
    private final MutableTreeModel treeModel;
    private final MatrixParameter precisionMatrixParameter;
    private final AbstractMultivariateTraitLikelihood traitModel;
    private final GammaDistributionModel ratePriorModel;
    private final GammaDistribution ratePrior;
    private final ArbitraryBranchRates branchRateModel;
    private final int dim;
    private final String traitName;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser(){
        private final XMLSyntaxRule[] rules = new XMLSyntaxRule[]{AttributeRule.newDoubleRule("weight"), new ElementRule(AbstractMultivariateTraitLikelihood.class), new ElementRule(ArbitraryBranchRates.class), new ElementRule(DistributionLikelihood.class)};

        @Override
        public String getParserName() {
            return TraitRateGibbsOperator.GIBBS_OPERATOR;
        }

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            double d = xMLObject.getDoubleAttribute("weight");
            AbstractMultivariateTraitLikelihood abstractMultivariateTraitLikelihood = (AbstractMultivariateTraitLikelihood)xMLObject.getChild(AbstractMultivariateTraitLikelihood.class);
            ArbitraryBranchRates arbitraryBranchRates = (ArbitraryBranchRates)xMLObject.getChild(ArbitraryBranchRates.class);
            DistributionLikelihood distributionLikelihood = (DistributionLikelihood)xMLObject.getChild(DistributionLikelihood.class);
            GammaDistributionModel gammaDistributionModel = null;
            GammaDistribution gammaDistribution = null;
            if (distributionLikelihood.getDistribution() instanceof GammaDistributionModel) {
                gammaDistributionModel = (GammaDistributionModel)distributionLikelihood.getDistribution();
            } else if (distributionLikelihood.getDistribution() instanceof GammaDistribution) {
                gammaDistribution = (GammaDistribution)distributionLikelihood.getDistribution();
            } else {
                throw new XMLParseException("Currently only works with a GammaDistributionModel or GammaDistribution");
            }
            boolean bl = arbitraryBranchRates.usingReciprocal();
            if (!bl) {
                throw new XMLParseException("Gibbs sampling of rates only works with reciprocal rates under an ArbitraryBranchRates model");
            }
            TraitRateGibbsOperator traitRateGibbsOperator = new TraitRateGibbsOperator(abstractMultivariateTraitLikelihood, arbitraryBranchRates, gammaDistributionModel, gammaDistribution);
            traitRateGibbsOperator.setWeight(d);
            return traitRateGibbsOperator;
        }

        @Override
        public String getParserDescription() {
            return "This element returns a multivariate Gibbs operator on traits for possible all nodes.";
        }

        @Override
        public Class getReturnType() {
            return MCMCOperator.class;
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }
    };

    public TraitRateGibbsOperator(AbstractMultivariateTraitLikelihood abstractMultivariateTraitLikelihood, ArbitraryBranchRates arbitraryBranchRates, GammaDistributionModel gammaDistributionModel, GammaDistribution gammaDistribution) {
        boolean bl;
        this.traitModel = abstractMultivariateTraitLikelihood;
        this.treeModel = abstractMultivariateTraitLikelihood.getTreeModel();
        this.precisionMatrixParameter = (MatrixParameter)abstractMultivariateTraitLikelihood.getDiffusionModel().getPrecisionParameter();
        this.traitName = abstractMultivariateTraitLikelihood.getTraitName();
        this.branchRateModel = arbitraryBranchRates;
        this.ratePriorModel = gammaDistributionModel;
        this.ratePrior = gammaDistribution;
        this.dim = this.treeModel.getMultivariateNodeTrait(this.treeModel.getRoot(), this.traitName).length;
        boolean bl2 = gammaDistributionModel == null;
        boolean bl3 = bl = gammaDistribution == null;
        if (abstractMultivariateTraitLikelihood instanceof IntegratedMultivariateTraitLikelihood) {
            throw new RuntimeException("Only implemented for a SampledMultivariateTraitLikelihood");
        }
        if (bl && bl2 || !bl && !bl2) {
            throw new RuntimeException("Can only provide one prior density in TraitRateGibbsOperation");
        }
        boolean bl4 = arbitraryBranchRates.usingReciprocal();
        if (!bl4) {
            throw new RuntimeException("ArbitraryBranchRates in TraitRateGibbsOperator must use reciprocal rates");
        }
        Logger.getLogger("dr.evomodel").info("Using Gibbs operator and trait rates");
    }

    public int getStepCount() {
        return 1;
    }

    private void sampleRateForNode(NodeRef nodeRef, double[][] dArray, double d, double d2) {
        NodeRef nodeRef2 = this.treeModel.getParent(nodeRef);
        double[] dArray2 = this.treeModel.getMultivariateNodeTrait(nodeRef, this.traitName);
        double[] dArray3 = this.treeModel.getMultivariateNodeTrait(nodeRef2, this.traitName);
        double d3 = this.branchRateModel.getBranchRate(this.treeModel, nodeRef) / this.traitModel.getRescaledBranchLengthForPrecision(nodeRef);
        for (int i = 0; i < this.dim; ++i) {
            int n = i;
            dArray2[n] = dArray2[n] - dArray3[i];
        }
        double d4 = 0.0;
        for (int i = 0; i < this.dim; ++i) {
            for (int j = 0; j < this.dim; ++j) {
                d4 += dArray2[i] * dArray[i][j] * dArray2[j];
            }
        }
        double d5 = d + 0.5 * (double)this.dim;
        double d6 = d2 + 0.5 * d4 * d3;
        double d7 = GammaDistribution.nextGamma(d5, 1.0 / d6);
        this.branchRateModel.setBranchRate(this.treeModel, nodeRef, 1.0 / d7);
    }

    @Override
    public double doOperation() {
        double d;
        double d2;
        double[][] dArray = this.precisionMatrixParameter.getParameterAsMatrix();
        if (this.ratePriorModel != null) {
            d2 = this.ratePriorModel.getShape();
            d = 1.0 / this.ratePriorModel.getScale();
        } else {
            d2 = this.ratePrior.getShape();
            d = 1.0 / this.ratePrior.getScale();
        }
        for (int i = 0; i < this.treeModel.getNodeCount(); ++i) {
            NodeRef nodeRef = this.treeModel.getNode(i);
            if (nodeRef == this.treeModel.getRoot()) continue;
            this.sampleRateForNode(nodeRef, dArray, d2, d);
        }
        return 0.0;
    }

    public String getPerformanceSuggestion() {
        return null;
    }

    @Override
    public String getOperatorName() {
        return GIBBS_OPERATOR;
    }
}

