/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.trees.j48;

import java.util.Random;
import weka.classifiers.trees.j48.ClassifierSplitModel;
import weka.classifiers.trees.j48.ClassifierTree;
import weka.classifiers.trees.j48.ModelSelection;
import weka.classifiers.trees.j48.NoSplit;
import weka.core.Capabilities;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.Utils;

public class PruneableClassifierTree
extends ClassifierTree {
    static final long serialVersionUID = -555775736857600201L;
    private boolean pruneTheTree = false;
    private int numSets = 3;
    private boolean m_cleanup = true;
    private int m_seed = 1;

    public PruneableClassifierTree(ModelSelection toSelectLocModel, boolean pruneTree, int num, boolean cleanup, int seed) throws Exception {
        super(toSelectLocModel);
        this.pruneTheTree = pruneTree;
        this.numSets = num;
        this.m_cleanup = cleanup;
        this.m_seed = seed;
    }

    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        result.enable(Capabilities.Capability.MISSING_VALUES);
        result.enable(Capabilities.Capability.NOMINAL_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        result.setMinimumNumberInstances(0);
        return result;
    }

    public void buildClassifier(Instances data) throws Exception {
        this.getCapabilities().testWithFail(data);
        data = new Instances(data);
        data.deleteWithMissingClass();
        Random random = new Random(this.m_seed);
        data.stratify(this.numSets);
        this.buildTree(data.trainCV(this.numSets, this.numSets - 1, random), data.testCV(this.numSets, this.numSets - 1), false);
        if (this.pruneTheTree) {
            this.prune();
        }
        if (this.m_cleanup) {
            this.cleanup(new Instances(data, 0));
        }
    }

    public void prune() throws Exception {
        if (!this.m_isLeaf) {
            for (int i = 0; i < this.m_sons.length; ++i) {
                this.son(i).prune();
            }
            if (Utils.smOrEq(this.errorsForLeaf(), this.errorsForTree())) {
                this.m_sons = null;
                this.m_isLeaf = true;
                this.m_localModel = new NoSplit(this.localModel().distribution());
            }
        }
    }

    protected ClassifierTree getNewTree(Instances train, Instances test) throws Exception {
        PruneableClassifierTree newTree = new PruneableClassifierTree(this.m_toSelectModel, this.pruneTheTree, this.numSets, this.m_cleanup, this.m_seed);
        newTree.buildTree(train, test, false);
        return newTree;
    }

    private double errorsForTree() throws Exception {
        double errors = 0.0;
        if (this.m_isLeaf) {
            return this.errorsForLeaf();
        }
        for (int i = 0; i < this.m_sons.length; ++i) {
            if (Utils.eq(this.localModel().distribution().perBag(i), 0.0)) {
                errors += this.m_test.perBag(i) - this.m_test.perClassPerBag(i, this.localModel().distribution().maxClass());
                continue;
            }
            errors += this.son(i).errorsForTree();
        }
        return errors;
    }

    private double errorsForLeaf() throws Exception {
        return this.m_test.total() - this.m_test.perClass(this.localModel().distribution().maxClass());
    }

    private ClassifierSplitModel localModel() {
        return this.m_localModel;
    }

    private PruneableClassifierTree son(int index) {
        return (PruneableClassifierTree)this.m_sons[index];
    }

    public String getRevision() {
        return RevisionUtils.extract("$Revision: 5533 $");
    }
}

