/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.discrete;

import dr.evomodel.substmodel.GlmSubstitutionModel;
import dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.discrete.AbstractLogAdditiveSubstitutionModelGradient;
import dr.inference.distribution.GeneralizedLinearModel;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Parameter;
import dr.util.Author;
import dr.util.Citation;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

public abstract class AbstractGlmSubstitutionModelGradient
extends AbstractLogAdditiveSubstitutionModelGradient {
    protected final GeneralizedLinearModel glm;
    private final ParameterMap parameterMap;
    private static final Citation CITATION = new Citation(new Author[]{new Author("AF", "Magee"), new Author("AJ", "Holbrook"), new Author("JE", "Pekar"), new Author("IW", "Caviedes-Solis"), new Author("FA", "Matsen"), new Author("G", "Baele"), new Author("JO", "Wertheim"), new Author("X", "Ji"), new Author("P", "Lemey"), new Author("MA", "Suchard")}, "Random-effects substitution models for phylogenetics via scalable gradient approximations", "", Citation.Status.IN_PREPARATION);

    public AbstractGlmSubstitutionModelGradient(String string, TreeDataLikelihood treeDataLikelihood, BeagleDataLikelihoodDelegate beagleDataLikelihoodDelegate, GlmSubstitutionModel glmSubstitutionModel, AbstractLogAdditiveSubstitutionModelGradient.ApproximationMode approximationMode) {
        super(string, treeDataLikelihood, beagleDataLikelihoodDelegate, glmSubstitutionModel, approximationMode);
        this.glm = glmSubstitutionModel.getGeneralizedLinearModel();
        this.parameterMap = this.makeParameterMap(this.glm);
    }

    String getType() {
        return "fixed";
    }

    ParameterMap makeParameterMap(final GeneralizedLinearModel generalizedLinearModel) {
        final ArrayList<Integer> arrayList = new ArrayList<Integer>();
        final ArrayList<Integer> arrayList2 = new ArrayList<Integer>();
        CompoundParameter compoundParameter = new CompoundParameter("fixedEffects");
        boolean bl = generalizedLinearModel.getNumberOfFixedEffects() > 1;
        for (int i = 0; i < generalizedLinearModel.getNumberOfFixedEffects(); ++i) {
            Parameter parameter = generalizedLinearModel.getFixedEffect(i);
            if (bl) {
                compoundParameter.addParameter(parameter);
            }
            for (int j = 0; j < parameter.getDimension(); ++j) {
                arrayList.add(i);
                arrayList2.add(j);
            }
            if (generalizedLinearModel.getFixedEffectIndicator(i) == null) continue;
            throw new IllegalArgumentException("GLM fixed effects gradients do not currently work with indicator variables");
        }
        final CompoundParameter compoundParameter2 = bl ? compoundParameter : generalizedLinearModel.getFixedEffect(0);
        return new ParameterMap(){

            @Override
            public double[] getCovariateColumn(int n) {
                return generalizedLinearModel.getDesignMatrix((Integer)arrayList.get(n)).getColumnValues((Integer)arrayList2.get(n));
            }

            @Override
            public Parameter getParameter() {
                return compoundParameter2;
            }
        };
    }

    @Override
    public Parameter getParameter() {
        return this.parameterMap.getParameter();
    }

    @Override
    protected double preProcessNormalization(double[] dArray, double[] dArray2, boolean bl) {
        return 0.0;
    }

    @Override
    double processSingleGradientDimension(int n, double[] dArray, double[] dArray2, double[] dArray3, boolean bl, double d) {
        double[] dArray4 = this.parameterMap.getCovariateColumn(n);
        return this.calculateCovariateDifferential(dArray2, dArray, dArray4, dArray3, bl);
    }

    private double calculateCovariateDifferential(double[] dArray, double[] dArray2, double[] dArray3, double[] dArray4, boolean bl) {
        double d;
        double d2;
        int n;
        int n2;
        double d3 = 0.0;
        double d4 = 0.0;
        int n3 = 0;
        for (n2 = 0; n2 < this.stateCount; ++n2) {
            for (n = n2 + 1; n < this.stateCount; ++n) {
                if ((d2 = (d = dArray3[n3++]) * dArray[this.index(n2, n)]) == 0.0) continue;
                d4 += dArray2[this.index(n2, n)] * d2;
                d4 -= dArray2[this.index(n2, n2)] * d2;
                d4 += this.correction(n2, n, dArray2) * d2;
                d4 -= this.correction(n2, n2, dArray2) * d2;
                d3 += d2 * dArray4[n2];
            }
        }
        for (n2 = 0; n2 < this.stateCount; ++n2) {
            for (n = n2 + 1; n < this.stateCount; ++n) {
                if ((d2 = (d = dArray3[n3++]) * dArray[this.index(n, n2)]) == 0.0) continue;
                d4 += dArray2[this.index(n, n2)] * d2;
                d4 -= dArray2[this.index(n, n)] * d2;
                d4 += this.correction(n, n2, dArray2) * d2;
                d4 -= this.correction(n, n2, dArray2) * d2;
                d3 += d2 * dArray4[n];
            }
        }
        if (bl) {
            for (n2 = 0; n2 < this.stateCount; ++n2) {
                for (n = 0; n < this.stateCount; ++n) {
                    d4 -= dArray2[this.index(n2, n)] * dArray[this.index(n2, n)] * d3;
                }
            }
        }
        return d4;
    }

    int index(int n, int n2) {
        return n * this.stateCount + n2;
    }

    @Override
    public LogColumn[] getColumns() {
        return Loggable.getColumnsFromReport(this, "gradient report");
    }

    @Override
    public Citation.Category getCategory() {
        return Citation.Category.FRAMEWORK;
    }

    @Override
    public String getDescription() {
        return "Using linear-time differential calculations for all substitution generator elements";
    }

    @Override
    public List<Citation> getCitations() {
        return Collections.singletonList(CITATION);
    }

    static interface ParameterMap {
        public double[] getCovariateColumn(int var1);

        public Parameter getParameter();
    }
}

