/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.mi;

import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.core.Capabilities;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.MultiInstanceCapabilitiesHandler;
import weka.core.Optimization;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Normalize;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
import weka.filters.unsupervised.attribute.Standardize;

public class MIDD
extends Classifier
implements OptionHandler,
MultiInstanceCapabilitiesHandler,
TechnicalInformationHandler {
    static final long serialVersionUID = 4263507733600536168L;
    protected int m_ClassIndex;
    protected double[] m_Par;
    protected int m_NumClasses;
    protected int[] m_Classes;
    protected double[][][] m_Data;
    protected Instances m_Attributes;
    protected Filter m_Filter = null;
    protected int m_filterType = 1;
    public static final int FILTER_NORMALIZE = 0;
    public static final int FILTER_STANDARDIZE = 1;
    public static final int FILTER_NONE = 2;
    public static final Tag[] TAGS_FILTER = new Tag[]{new Tag(0, "Normalize training data"), new Tag(1, "Standardize training data"), new Tag(2, "No normalization/standardization")};
    protected ReplaceMissingValues m_Missing = new ReplaceMissingValues();

    public String globalInfo() {
        return "Re-implement the Diverse Density algorithm, changes the testing procedure.\n\n" + this.getTechnicalInformation().toString();
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.PHDTHESIS);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "Oded Maron");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "1998");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "Learning from ambiguity");
        technicalInformation.setValue(TechnicalInformation.Field.SCHOOL, "Massachusetts Institute of Technology");
        TechnicalInformation technicalInformation2 = technicalInformation.add(TechnicalInformation.Type.ARTICLE);
        technicalInformation2.setValue(TechnicalInformation.Field.AUTHOR, "O. Maron and T. Lozano-Perez");
        technicalInformation2.setValue(TechnicalInformation.Field.YEAR, "1998");
        technicalInformation2.setValue(TechnicalInformation.Field.TITLE, "A Framework for Multiple Instance Learning");
        technicalInformation2.setValue(TechnicalInformation.Field.JOURNAL, "Neural Information Processing Systems");
        technicalInformation2.setValue(TechnicalInformation.Field.VOLUME, "10");
        return technicalInformation;
    }

    public Enumeration listOptions() {
        Vector<Option> vector = new Vector<Option>();
        vector.addElement(new Option("\tTurn on debugging output.", "D", 0, "-D"));
        vector.addElement(new Option("\tWhether to 0=normalize/1=standardize/2=neither.\n\t(default 1=standardize)", "N", 1, "-N <num>"));
        return vector.elements();
    }

    public void setOptions(String[] stringArray) throws Exception {
        this.setDebug(Utils.getFlag('D', stringArray));
        String string = Utils.getOption('N', stringArray);
        if (string.length() != 0) {
            this.setFilterType(new SelectedTag(Integer.parseInt(string), TAGS_FILTER));
        } else {
            this.setFilterType(new SelectedTag(1, TAGS_FILTER));
        }
    }

    public String[] getOptions() {
        Vector<String> vector = new Vector<String>();
        if (this.getDebug()) {
            vector.add("-D");
        }
        vector.add("-N");
        vector.add("" + this.m_filterType);
        return vector.toArray(new String[vector.size()]);
    }

    public String filterTypeTipText() {
        return "The filter type for transforming the training data.";
    }

    public SelectedTag getFilterType() {
        return new SelectedTag(this.m_filterType, TAGS_FILTER);
    }

    public void setFilterType(SelectedTag selectedTag) {
        if (selectedTag.getTags() == TAGS_FILTER) {
            this.m_filterType = selectedTag.getSelectedTag().getID();
        }
    }

    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.RELATIONAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.MISSING_VALUES);
        capabilities.enable(Capabilities.Capability.BINARY_CLASS);
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        capabilities.enable(Capabilities.Capability.ONLY_MULTIINSTANCE);
        return capabilities;
    }

    public Capabilities getMultiInstanceCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.MISSING_VALUES);
        capabilities.disableAllClasses();
        capabilities.enable(Capabilities.Capability.NO_CLASS);
        return capabilities;
    }

    public void buildClassifier(Instances instances) throws Exception {
        int n;
        int n2;
        int n3;
        this.getCapabilities().testWithFail(instances);
        instances = new Instances(instances);
        instances.deleteWithMissingClass();
        this.m_ClassIndex = instances.classIndex();
        this.m_NumClasses = instances.numClasses();
        int n4 = instances.attribute(1).relation().numAttributes();
        int n5 = instances.numInstances();
        FastVector fastVector = new FastVector();
        int n6 = 0;
        int[] nArray = new int[n5];
        Instances instances2 = new Instances(instances.attribute(1).relation(), 0);
        this.m_Data = new double[n5][n4][];
        this.m_Classes = new int[n5];
        this.m_Attributes = instances2.stringFreeStructure();
        if (this.m_Debug) {
            System.out.println("Extracting data...");
        }
        for (n3 = 0; n3 < n5; ++n3) {
            Instance instance = instances.instance(n3);
            this.m_Classes[n3] = (int)instance.classValue();
            Instances instances3 = instance.relationalValue(1);
            for (n2 = 0; n2 < instances3.numInstances(); ++n2) {
                Instance instance2 = instances3.instance(n2);
                instances2.add(instance2);
            }
            nArray[n3] = n2 = instances3.numInstances();
            if (this.m_Classes[n3] != 1) continue;
            if (n2 > n6) {
                n6 = n2;
                fastVector = new FastVector(1);
                fastVector.addElement(new Integer(n3));
                continue;
            }
            if (n2 != n6) continue;
            fastVector.addElement(new Integer(n3));
        }
        this.m_Filter = this.m_filterType == 1 ? new Standardize() : (this.m_filterType == 0 ? new Normalize() : null);
        if (this.m_Filter != null) {
            this.m_Filter.setInputFormat(instances2);
            instances2 = Filter.useFilter(instances2, this.m_Filter);
        }
        this.m_Missing.setInputFormat(instances2);
        instances2 = Filter.useFilter(instances2, this.m_Missing);
        n3 = 0;
        int n7 = 0;
        for (int i = 0; i < n5; ++i) {
            for (n2 = 0; n2 < instances2.numAttributes(); ++n2) {
                this.m_Data[i][n2] = new double[nArray[i]];
                n3 = n7;
                for (int j = 0; j < nArray[i]; ++j) {
                    this.m_Data[i][n2][j] = instances2.instance(n3).value(n2);
                    ++n3;
                }
            }
            n7 = n3;
        }
        if (this.m_Debug) {
            System.out.println("\nIteration History...");
        }
        double[] dArray = new double[n4 * 2];
        double[] dArray2 = new double[dArray.length];
        double[][] dArray3 = new double[2][dArray.length];
        double d = Double.MAX_VALUE;
        for (n = 0; n < dArray.length; ++n) {
            dArray3[0][n] = Double.NaN;
            dArray3[1][n] = Double.NaN;
        }
        for (n = 0; n < fastVector.size(); ++n) {
            int n8 = (Integer)fastVector.elementAt(n);
            for (int i = 0; i < this.m_Data[n8][0].length; ++i) {
                for (int j = 0; j < n4; ++j) {
                    dArray[2 * j] = this.m_Data[n8][j][i];
                    dArray[2 * j + 1] = 1.0;
                }
                OptEng optEng = new OptEng();
                dArray2 = optEng.findArgmin(dArray, dArray3);
                while (dArray2 == null) {
                    dArray2 = optEng.getVarbValues();
                    if (this.m_Debug) {
                        System.out.println("200 iterations finished, not enough!");
                    }
                    dArray2 = optEng.findArgmin(dArray2, dArray3);
                }
                double d2 = optEng.getMinFunction();
                if (d2 < d) {
                    d = d2;
                    this.m_Par = dArray2;
                    dArray2 = new double[dArray.length];
                    if (this.m_Debug) {
                        System.out.println("!!!!!!!!!!!!!!!!Smaller NLL found: " + d2);
                    }
                }
                if (!this.m_Debug) continue;
                System.out.println(n8 + ":  -------------<Converged>--------------");
            }
        }
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        int n;
        Instances instances = instance.relationalValue(1);
        if (this.m_Filter != null) {
            instances = Filter.useFilter(instances, this.m_Filter);
        }
        instances = Filter.useFilter(instances, this.m_Missing);
        int n2 = instances.numInstances();
        int n3 = instances.numAttributes();
        double[][] dArray = new double[n2][n3];
        for (int i = 0; i < n2; ++i) {
            for (n = 0; n < n3; ++n) {
                dArray[i][n] = instances.instance(i).value(n);
            }
        }
        double[] dArray2 = new double[2];
        dArray2[0] = 0.0;
        for (n = 0; n < n2; ++n) {
            double d = 0.0;
            for (int i = 0; i < n3; ++i) {
                d += (this.m_Par[i * 2] - dArray[n][i]) * (this.m_Par[i * 2] - dArray[n][i]) * this.m_Par[i * 2 + 1] * this.m_Par[i * 2 + 1];
            }
            d = Math.exp(-d);
            dArray2[0] = dArray2[0] + Math.log(1.0 - d);
        }
        dArray2[0] = Math.exp(dArray2[0]);
        dArray2[1] = 1.0 - dArray2[0];
        return dArray2;
    }

    public String toString() {
        String string = "Diverse Density";
        if (this.m_Par == null) {
            return string + ": No model built yet.";
        }
        string = string + "\nCoefficients...\nVariable       Point       Scale\n";
        int n = 0;
        int n2 = 0;
        while (n < this.m_Par.length / 2) {
            string = string + this.m_Attributes.attribute(n2).name();
            string = string + " " + Utils.doubleToString(this.m_Par[n * 2], 12, 4);
            string = string + " " + Utils.doubleToString(this.m_Par[n * 2 + 1], 12, 4) + "\n";
            ++n;
            ++n2;
        }
        return string;
    }

    public static void main(String[] stringArray) {
        MIDD.runClassifier(new MIDD(), stringArray);
    }

    private class OptEng
    extends Optimization {
        private OptEng() {
        }

        protected double objectiveFunction(double[] dArray) {
            double d = 0.0;
            for (int i = 0; i < MIDD.this.m_Classes.length; ++i) {
                int n = MIDD.this.m_Data[i][0].length;
                double d2 = 0.0;
                for (int j = 0; j < n; ++j) {
                    double d3 = 0.0;
                    for (int k = 0; k < MIDD.this.m_Data[i].length; ++k) {
                        d3 += (MIDD.this.m_Data[i][k][j] - dArray[k * 2]) * (MIDD.this.m_Data[i][k][j] - dArray[k * 2]) * dArray[k * 2 + 1] * dArray[k * 2 + 1];
                    }
                    d3 = Math.exp(-d3);
                    d3 = 1.0 - d3;
                    if (MIDD.this.m_Classes[i] == 1) {
                        d2 += Math.log(d3);
                        continue;
                    }
                    if (d3 <= m_Zero) {
                        d3 = m_Zero;
                    }
                    d -= Math.log(d3);
                }
                if (MIDD.this.m_Classes[i] != 1) continue;
                if ((d2 = 1.0 - Math.exp(d2)) <= m_Zero) {
                    d2 = m_Zero;
                }
                d -= Math.log(d2);
            }
            return d;
        }

        protected double[] evaluateGradient(double[] dArray) {
            double[] dArray2 = new double[dArray.length];
            for (int i = 0; i < MIDD.this.m_Classes.length; ++i) {
                int n;
                int n2 = MIDD.this.m_Data[i][0].length;
                double d = 0.0;
                double[] dArray3 = new double[dArray.length];
                for (n = 0; n < n2; ++n) {
                    int n3;
                    double d2 = 0.0;
                    for (n3 = 0; n3 < MIDD.this.m_Data[i].length; ++n3) {
                        d2 += (MIDD.this.m_Data[i][n3][n] - dArray[n3 * 2]) * (MIDD.this.m_Data[i][n3][n] - dArray[n3 * 2]) * dArray[n3 * 2 + 1] * dArray[n3 * 2 + 1];
                    }
                    d2 = Math.exp(-d2);
                    d2 = 1.0 - d2;
                    if (MIDD.this.m_Classes[i] == 1) {
                        d += Math.log(d2);
                    }
                    if (d2 <= m_Zero) {
                        d2 = m_Zero;
                    }
                    for (n3 = 0; n3 < MIDD.this.m_Data[i].length; ++n3) {
                        int n4 = 2 * n3;
                        dArray3[n4] = dArray3[n4] + (1.0 - d2) * 2.0 * (dArray[2 * n3] - MIDD.this.m_Data[i][n3][n]) * dArray[n3 * 2 + 1] * dArray[n3 * 2 + 1] / d2;
                        int n5 = 2 * n3 + 1;
                        dArray3[n5] = dArray3[n5] + 2.0 * (1.0 - d2) * (dArray[2 * n3] - MIDD.this.m_Data[i][n3][n]) * (dArray[2 * n3] - MIDD.this.m_Data[i][n3][n]) * dArray[n3 * 2 + 1] / d2;
                    }
                }
                if ((d = 1.0 - Math.exp(d)) <= m_Zero) {
                    d = m_Zero;
                }
                for (n = 0; n < MIDD.this.m_Data[i].length; ++n) {
                    if (MIDD.this.m_Classes[i] == 1) {
                        int n6 = 2 * n;
                        dArray2[n6] = dArray2[n6] + dArray3[2 * n] * (1.0 - d) / d;
                        int n7 = 2 * n + 1;
                        dArray2[n7] = dArray2[n7] + dArray3[2 * n + 1] * (1.0 - d) / d;
                        continue;
                    }
                    int n8 = 2 * n;
                    dArray2[n8] = dArray2[n8] - dArray3[2 * n];
                    int n9 = 2 * n + 1;
                    dArray2[n9] = dArray2[n9] - dArray3[2 * n + 1];
                }
            }
            return dArray2;
        }
    }
}

