/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.preorder;

import beagle.Beagle;
import dr.evolution.alignment.PatternList;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evolution.tree.TreeTraitProvider;
import dr.evomodel.siteratemodel.SiteRateModel;
import dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.EvolutionaryProcessDelegate;
import dr.evomodel.treedatalikelihood.ProcessOnTreeDelegate;
import dr.evomodel.treedatalikelihood.preorder.ProcessSimulationDelegate;
import dr.inference.model.Model;
import dr.math.matrixAlgebra.WrappedVector;
import java.util.List;

public abstract class AbstractDiscreteTraitDelegate
extends ProcessSimulationDelegate.AbstractDelegate {
    private static final String GRADIENT_TRAIT_NAME = "Gradient";
    private static final String HESSIAN_TRAIT_NAME = "Hessian";
    private static final boolean DEBUG_TRANSPOSE = false;
    private static final boolean USE_CACHE = true;
    protected final BeagleDataLikelihoodDelegate likelihoodDelegate;
    protected final Beagle beagle;
    protected EvolutionaryProcessDelegate evolutionaryProcessDelegate;
    protected final SiteRateModel siteRateModel;
    protected final PatternList patternList;
    protected final int patternCount;
    protected final int stateCount;
    protected final int categoryCount;
    private int preOrderPartialOffset;
    protected final double[] gradient;
    private boolean substitutionProcessKnown;
    private static final boolean COUNT_TOTAL_OPERATIONS = true;
    private long simulateCount = 0L;
    private long getTraitCount = 0L;
    private long updatePrePartialCount = 0L;

    public AbstractDiscreteTraitDelegate(String string, Tree tree, BeagleDataLikelihoodDelegate beagleDataLikelihoodDelegate) {
        super(string, tree);
        this.likelihoodDelegate = beagleDataLikelihoodDelegate;
        this.beagle = beagleDataLikelihoodDelegate.getBeagleInstance();
        assert (this.likelihoodDelegate.isUsePreOrder());
        this.evolutionaryProcessDelegate = beagleDataLikelihoodDelegate.getEvolutionaryProcessDelegate();
        this.siteRateModel = beagleDataLikelihoodDelegate.getSiteRateModel();
        this.patternCount = beagleDataLikelihoodDelegate.getPatternList().getPatternCount();
        this.stateCount = beagleDataLikelihoodDelegate.getPatternList().getDataType().getStateCount();
        this.categoryCount = this.siteRateModel.getCategoryCount();
        this.preOrderPartialOffset = beagleDataLikelihoodDelegate.getPartialBufferCount();
        this.patternList = beagleDataLikelihoodDelegate.getPatternList();
        this.gradient = new double[tree.getNodeCount() - 1];
        beagleDataLikelihoodDelegate.addModelListener(this);
        beagleDataLikelihoodDelegate.addModelRestoreListener(this);
        this.substitutionProcessKnown = false;
    }

    private void printMatrix(double[] dArray) {
        for (int i = 0; i < this.siteRateModel.getCategoryCount(); ++i) {
            System.err.println("\nRate = " + i);
            for (int j = 0; j < this.stateCount; ++j) {
                double[] dArray2 = new double[this.stateCount];
                System.arraycopy(dArray, i * this.stateCount * this.stateCount + j * this.stateCount, dArray2, 0, this.stateCount);
                System.err.println(new WrappedVector.Raw(dArray2));
            }
        }
    }

    private void debugMatrixTranspose(int[] nArray) {
        double[] dArray = new double[this.stateCount * this.stateCount * this.siteRateModel.getCategoryCount()];
        int n = nArray[4];
        this.beagle.getTransitionMatrix(n, dArray);
        this.printMatrix(dArray);
        int n2 = 1;
        this.beagle.transposeTransitionMatrices(new int[]{n}, new int[]{n2}, 1);
        this.beagle.getTransitionMatrix(n2, dArray);
        this.printMatrix(dArray);
    }

    @Override
    public void simulate(int[] nArray, int n, int n2) {
        this.simulateRoot(n2);
        this.beagle.updatePrePartials(nArray, n, -1);
        this.getNodeDerivatives(this.tree, this.gradient, null);
        ++this.simulateCount;
        this.updatePrePartialCount += (long)n;
    }

    @Override
    public void setupStatistics() {
        throw new RuntimeException("Not used (?) with BEAGLE");
    }

    @Override
    protected void simulateRoot(int n) {
        double[] dArray = this.evolutionaryProcessDelegate.getRootStateFrequencies();
        double[] dArray2 = new double[this.stateCount * this.patternCount * this.categoryCount];
        for (int i = 0; i < this.patternCount * this.categoryCount; ++i) {
            System.arraycopy(dArray, 0, dArray2, i * this.stateCount, this.stateCount);
        }
        this.beagle.setPartials(this.getPreOrderPartialIndex(n), dArray2);
    }

    @Override
    protected void simulateNode(int n, int n2, int n3, int n4, int n5) {
        throw new RuntimeException("Not used with BEAGLE");
    }

    protected String getGradientTraitName() {
        return GRADIENT_TRAIT_NAME;
    }

    protected String getHessianTraitName() {
        return HESSIAN_TRAIT_NAME;
    }

    @Override
    protected void constructTraits(TreeTraitProvider.Helper helper) {
        helper.addTrait(new TreeTrait.DA(){

            @Override
            public String getTraitName() {
                return AbstractDiscreteTraitDelegate.this.getGradientTraitName();
            }

            @Override
            public TreeTrait.Intent getIntent() {
                return TreeTrait.Intent.BRANCH;
            }

            @Override
            public double[] getTrait(Tree tree, NodeRef nodeRef) {
                return AbstractDiscreteTraitDelegate.this.getGradient(nodeRef);
            }
        });
        helper.addTrait(new TreeTrait.DA(){

            @Override
            public String getTraitName() {
                return AbstractDiscreteTraitDelegate.this.getHessianTraitName();
            }

            @Override
            public TreeTrait.Intent getIntent() {
                return TreeTrait.Intent.BRANCH;
            }

            @Override
            public double[] getTrait(Tree tree, NodeRef nodeRef) {
                return AbstractDiscreteTraitDelegate.this.getHessian(tree, nodeRef);
            }
        });
    }

    private double[] getHessian(Tree tree, NodeRef nodeRef) {
        this.simulationProcess.cacheSimulatedTraits(nodeRef);
        double[] dArray = new double[tree.getNodeCount() - 1];
        this.getNodeDerivatives(tree, null, dArray);
        return dArray;
    }

    private double[] getGradient(NodeRef nodeRef) {
        ++this.getTraitCount;
        this.simulationProcess.cacheSimulatedTraits(nodeRef);
        return (double[])this.gradient.clone();
    }

    protected abstract void cacheDifferentialMassMatrix(Tree var1, boolean var2);

    private void getNodeDerivatives(Tree tree, double[] dArray, double[] dArray2) {
        boolean bl;
        int[] nArray = new int[tree.getNodeCount() - 1];
        int[] nArray2 = new int[tree.getNodeCount() - 1];
        int[] nArray3 = new int[tree.getNodeCount() - 1];
        int[] nArray4 = new int[tree.getNodeCount() - 1];
        boolean bl2 = bl = !this.substitutionProcessKnown || dArray2 != null;
        if (bl) {
            this.cacheDifferentialMassMatrix(tree, dArray2 != null);
            this.substitutionProcessKnown = true;
        }
        int n = 0;
        for (int i = 0; i < tree.getNodeCount(); ++i) {
            if (tree.isRoot(tree.getNode(i))) continue;
            nArray[n] = this.getPostOrderPartialIndex(i);
            nArray2[n] = this.getPreOrderPartialIndex(i);
            nArray3[n] = this.getFirstDerivativeMatrixBufferIndex(i);
            nArray4[n] = this.getSecondDerivativeMatrixBufferIndex(i);
            ++n;
        }
        double[] dArray3 = dArray2 != null ? new double[dArray2.length] : null;
        this.beagle.calculateEdgeDifferentials(nArray, nArray2, nArray3, new int[]{0}, tree.getNodeCount() - 1, null, dArray, dArray3);
        if (dArray2 != null) {
            this.beagle.calculateEdgeDifferentials(nArray, nArray2, nArray4, new int[]{0}, tree.getNodeCount() - 1, null, dArray2, null);
            for (int i = 0; i < dArray2.length; ++i) {
                int n2 = i;
                dArray2[n2] = dArray2[n2] - dArray3[i];
            }
        }
    }

    protected int getFirstDerivativeMatrixBufferIndex(int n) {
        return this.evolutionaryProcessDelegate.getInfinitesimalMatrixBufferIndex(n);
    }

    protected int getSecondDerivativeMatrixBufferIndex(int n) {
        return this.evolutionaryProcessDelegate.getInfinitesimalSquaredMatrixBufferIndex(n);
    }

    @Override
    public void modelChangedEvent(Model model, Object object, int n) {
        this.substitutionProcessKnown = false;
    }

    @Override
    public void modelRestored(Model model) {
        this.substitutionProcessKnown = false;
    }

    @Override
    public int vectorizeNodeOperations(List<ProcessOnTreeDelegate.NodeOperation> list, int[] nArray) {
        int n = 0;
        for (ProcessOnTreeDelegate.NodeOperation nodeOperation : list) {
            nArray[n++] = this.getPreOrderPartialIndex(nodeOperation.getLeftChild());
            nArray[n++] = -1;
            nArray[n++] = -1;
            nArray[n++] = this.getPreOrderPartialIndex(nodeOperation.getNodeNumber());
            nArray[n++] = this.evolutionaryProcessDelegate.getMatrixIndex(nodeOperation.getLeftChild());
            nArray[n++] = this.getPostOrderPartialIndex(nodeOperation.getRightChild());
            nArray[n++] = this.evolutionaryProcessDelegate.getMatrixIndex(nodeOperation.getRightChild());
        }
        return list.size();
    }

    @Override
    public int getSingleOperationSize() {
        return 7;
    }

    private int getPostOrderPartialIndex(int n) {
        return this.likelihoodDelegate.getPartialBufferIndex(n);
    }

    private int getPreOrderPartialIndex(int n) {
        return this.preOrderPartialOffset + n;
    }

    public String toString() {
        return "\tsimulateCount = " + this.simulateCount + "\n\tgetTraitCount = " + this.getTraitCount + "\n\tupPrePartialCount = " + this.updatePrePartialCount + "\n";
    }
}

