/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.trees.lmt;

import weka.classifiers.trees.j48.ClassifierSplitModel;
import weka.classifiers.trees.j48.Distribution;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.Utils;

public class ResidualSplit
extends ClassifierSplitModel {
    private static final long serialVersionUID = -5055883734183713525L;
    protected Attribute m_attribute;
    protected int m_attIndex;
    protected int m_numInstances;
    protected int m_numClasses;
    protected Instances m_data;
    protected double[][] m_dataZs;
    protected double[][] m_dataWs;
    protected double m_splitPoint;

    public ResidualSplit(int attIndex) {
        this.m_attIndex = attIndex;
    }

    public void buildClassifier(Instances data, double[][] dataZs, double[][] dataWs) throws Exception {
        this.m_numClasses = data.numClasses();
        this.m_numInstances = data.numInstances();
        if (this.m_numInstances == 0) {
            throw new Exception("Can't build split on 0 instances");
        }
        this.m_data = data;
        this.m_dataZs = dataZs;
        this.m_dataWs = dataWs;
        this.m_attribute = data.attribute(this.m_attIndex);
        if (this.m_attribute.isNominal()) {
            this.m_splitPoint = 0.0;
            this.m_numSubsets = this.m_attribute.numValues();
        } else {
            this.getSplitPoint();
            this.m_numSubsets = 2;
        }
        this.m_distribution = new Distribution(data, this);
    }

    protected boolean getSplitPoint() throws Exception {
        double[] splitPoints = new double[this.m_numInstances];
        int numSplitPoints = 0;
        Instances sortedData = new Instances(this.m_data);
        sortedData.sort(sortedData.attribute(this.m_attIndex));
        double last = sortedData.instance(0).value(this.m_attIndex);
        for (int i = 0; i < this.m_numInstances - 1; ++i) {
            double current = sortedData.instance(i + 1).value(this.m_attIndex);
            if (!Utils.eq(current, last)) {
                splitPoints[numSplitPoints++] = (last + current) / 2.0;
            }
            last = current;
        }
        double[] entropyGain = new double[numSplitPoints];
        for (int i = 0; i < numSplitPoints; ++i) {
            this.m_splitPoint = splitPoints[i];
            entropyGain[i] = this.entropyGain();
        }
        int bestSplit = -1;
        double bestGain = -1.7976931348623157E308;
        for (int i = 0; i < numSplitPoints; ++i) {
            if (!(entropyGain[i] > bestGain)) continue;
            bestGain = entropyGain[i];
            bestSplit = i;
        }
        if (bestSplit < 0) {
            return false;
        }
        this.m_splitPoint = splitPoints[bestSplit];
        return true;
    }

    public double entropyGain() throws Exception {
        int i;
        int numSubsets = this.m_attribute.isNominal() ? this.m_attribute.numValues() : 2;
        double[][][] splitDataZs = new double[numSubsets][][];
        double[][][] splitDataWs = new double[numSubsets][][];
        int[] subsetSize = new int[numSubsets];
        for (i = 0; i < this.m_numInstances; ++i) {
            int subset = this.whichSubset(this.m_data.instance(i));
            if (subset < 0) {
                throw new Exception("ResidualSplit: no support for splits on missing values");
            }
            int n = subset;
            subsetSize[n] = subsetSize[n] + 1;
        }
        for (i = 0; i < numSubsets; ++i) {
            splitDataZs[i] = new double[subsetSize[i]][];
            splitDataWs[i] = new double[subsetSize[i]][];
        }
        int[] subsetCount = new int[numSubsets];
        for (int i2 = 0; i2 < this.m_numInstances; ++i2) {
            int subset = this.whichSubset(this.m_data.instance(i2));
            splitDataZs[subset][subsetCount[subset]] = this.m_dataZs[i2];
            splitDataWs[subset][subsetCount[subset]] = this.m_dataWs[i2];
            int n = subset;
            subsetCount[n] = subsetCount[n] + 1;
        }
        double entropyOrig = this.entropy(this.m_dataZs, this.m_dataWs);
        double entropySplit = 0.0;
        for (int i3 = 0; i3 < numSubsets; ++i3) {
            entropySplit += this.entropy(splitDataZs[i3], splitDataWs[i3]);
        }
        return entropyOrig - entropySplit;
    }

    protected double entropy(double[][] dataZs, double[][] dataWs) {
        double entropy = 0.0;
        int numInstances = dataZs.length;
        for (int j = 0; j < this.m_numClasses; ++j) {
            int i;
            double m = 0.0;
            double sum = 0.0;
            for (i = 0; i < numInstances; ++i) {
                m += dataZs[i][j] * dataWs[i][j];
                sum += dataWs[i][j];
            }
            m /= sum;
            for (i = 0; i < numInstances; ++i) {
                entropy += dataWs[i][j] * Math.pow(dataZs[i][j] - m, 2.0);
            }
        }
        return entropy;
    }

    public boolean checkModel(int minNumInstances) {
        int count = 0;
        for (int i = 0; i < this.m_distribution.numBags(); ++i) {
            if (!(this.m_distribution.perBag(i) >= (double)minNumInstances)) continue;
            ++count;
        }
        return count >= 2;
    }

    public final String leftSide(Instances data) {
        return data.attribute(this.m_attIndex).name();
    }

    public final String rightSide(int index, Instances data) {
        StringBuffer text = new StringBuffer();
        if (data.attribute(this.m_attIndex).isNominal()) {
            text.append(" = " + data.attribute(this.m_attIndex).value(index));
        } else if (index == 0) {
            text.append(" <= " + Utils.doubleToString(this.m_splitPoint, 6));
        } else {
            text.append(" > " + Utils.doubleToString(this.m_splitPoint, 6));
        }
        return text.toString();
    }

    public final int whichSubset(Instance instance) throws Exception {
        if (instance.isMissing(this.m_attIndex)) {
            return -1;
        }
        if (instance.attribute(this.m_attIndex).isNominal()) {
            return (int)instance.value(this.m_attIndex);
        }
        if (Utils.smOrEq(instance.value(this.m_attIndex), this.m_splitPoint)) {
            return 0;
        }
        return 1;
    }

    public void buildClassifier(Instances data) {
    }

    public final double[] weights(Instance instance) {
        return null;
    }

    public final String sourceExpression(int index, Instances data) {
        return "";
    }

    public String getRevision() {
        return RevisionUtils.extract("$Revision: 8034 $");
    }
}

