/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.lazy;

import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.UpdateableClassifier;
import weka.core.AdditionalMeasureProducer;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.core.neighboursearch.CoverTree;
import weka.core.neighboursearch.LinearNNSearch;
import weka.core.neighboursearch.NearestNeighbourSearch;

public class IBk
extends AbstractClassifier
implements OptionHandler,
UpdateableClassifier,
WeightedInstancesHandler,
TechnicalInformationHandler,
AdditionalMeasureProducer {
    static final long serialVersionUID = -3080186098777067172L;
    protected Instances m_Train;
    protected int m_NumClasses;
    protected int m_ClassType;
    protected int m_kNN;
    protected int m_kNNUpper;
    protected boolean m_kNNValid;
    protected int m_WindowSize;
    protected int m_DistanceWeighting;
    protected boolean m_CrossValidate;
    protected boolean m_MeanSquared;
    public static final int WEIGHT_NONE = 1;
    public static final int WEIGHT_INVERSE = 2;
    public static final int WEIGHT_SIMILARITY = 4;
    public static final Tag[] TAGS_WEIGHTING = new Tag[]{new Tag(1, "No distance weighting"), new Tag(2, "Weight by 1/distance"), new Tag(4, "Weight by 1-distance")};
    protected NearestNeighbourSearch m_NNSearch = new LinearNNSearch();
    protected double m_NumAttributesUsed;

    public IBk(int k) {
        this.init();
        this.setKNN(k);
    }

    public IBk() {
        this.init();
    }

    public String globalInfo() {
        return "K-nearest neighbours classifier. Can select appropriate value of K based on cross-validation. Can also do distance weighting.\n\nFor more information, see\n\n" + this.getTechnicalInformation().toString();
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        result.setValue(TechnicalInformation.Field.AUTHOR, "D. Aha and D. Kibler");
        result.setValue(TechnicalInformation.Field.YEAR, "1991");
        result.setValue(TechnicalInformation.Field.TITLE, "Instance-based learning algorithms");
        result.setValue(TechnicalInformation.Field.JOURNAL, "Machine Learning");
        result.setValue(TechnicalInformation.Field.VOLUME, "6");
        result.setValue(TechnicalInformation.Field.PAGES, "37-66");
        return result;
    }

    public String KNNTipText() {
        return "The number of neighbours to use.";
    }

    public void setKNN(int k) {
        this.m_kNN = k;
        this.m_kNNUpper = k;
        this.m_kNNValid = false;
    }

    public int getKNN() {
        return this.m_kNN;
    }

    public String windowSizeTipText() {
        return "Gets the maximum number of instances allowed in the training pool. The addition of new instances above this value will result in old instances being removed. A value of 0 signifies no limit to the number of training instances.";
    }

    public int getWindowSize() {
        return this.m_WindowSize;
    }

    public void setWindowSize(int newWindowSize) {
        this.m_WindowSize = newWindowSize;
    }

    public String distanceWeightingTipText() {
        return "Gets the distance weighting method used.";
    }

    public SelectedTag getDistanceWeighting() {
        return new SelectedTag(this.m_DistanceWeighting, TAGS_WEIGHTING);
    }

    public void setDistanceWeighting(SelectedTag newMethod) {
        if (newMethod.getTags() == TAGS_WEIGHTING) {
            this.m_DistanceWeighting = newMethod.getSelectedTag().getID();
        }
    }

    public String meanSquaredTipText() {
        return "Whether the mean squared error is used rather than mean absolute error when doing cross-validation for regression problems.";
    }

    public boolean getMeanSquared() {
        return this.m_MeanSquared;
    }

    public void setMeanSquared(boolean newMeanSquared) {
        this.m_MeanSquared = newMeanSquared;
    }

    public String crossValidateTipText() {
        return "Whether hold-one-out cross-validation will be used to select the best k value.";
    }

    public boolean getCrossValidate() {
        return this.m_CrossValidate;
    }

    public void setCrossValidate(boolean newCrossValidate) {
        this.m_CrossValidate = newCrossValidate;
    }

    public String nearestNeighbourSearchAlgorithmTipText() {
        return "The nearest neighbour search algorithm to use (Default: weka.core.neighboursearch.LinearNNSearch).";
    }

    public NearestNeighbourSearch getNearestNeighbourSearchAlgorithm() {
        return this.m_NNSearch;
    }

    public void setNearestNeighbourSearchAlgorithm(NearestNeighbourSearch nearestNeighbourSearchAlgorithm) {
        this.m_NNSearch = nearestNeighbourSearchAlgorithm;
    }

    public int getNumTraining() {
        return this.m_Train.numInstances();
    }

    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        result.enable(Capabilities.Capability.MISSING_VALUES);
        result.enable(Capabilities.Capability.NOMINAL_CLASS);
        result.enable(Capabilities.Capability.NUMERIC_CLASS);
        result.enable(Capabilities.Capability.DATE_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        result.setMinimumNumberInstances(0);
        return result;
    }

    public void buildClassifier(Instances instances) throws Exception {
        this.getCapabilities().testWithFail(instances);
        instances = new Instances(instances);
        instances.deleteWithMissingClass();
        this.m_NumClasses = instances.numClasses();
        this.m_ClassType = instances.classAttribute().type();
        this.m_Train = new Instances(instances, 0, instances.numInstances());
        if (this.m_WindowSize > 0 && instances.numInstances() > this.m_WindowSize) {
            this.m_Train = new Instances(this.m_Train, this.m_Train.numInstances() - this.m_WindowSize, this.m_WindowSize);
        }
        this.m_NumAttributesUsed = 0.0;
        for (int i = 0; i < this.m_Train.numAttributes(); ++i) {
            if (i == this.m_Train.classIndex() || !this.m_Train.attribute(i).isNominal() && !this.m_Train.attribute(i).isNumeric()) continue;
            this.m_NumAttributesUsed += 1.0;
        }
        this.m_NNSearch.setInstances(this.m_Train);
        this.m_kNNValid = false;
    }

    public void updateClassifier(Instance instance) throws Exception {
        if (!this.m_Train.equalHeaders(instance.dataset())) {
            throw new Exception("Incompatible instance types\n" + this.m_Train.equalHeadersMsg(instance.dataset()));
        }
        if (instance.classIsMissing()) {
            return;
        }
        this.m_Train.add(instance);
        this.m_NNSearch.update(instance);
        this.m_kNNValid = false;
        if (this.m_WindowSize > 0 && this.m_Train.numInstances() > this.m_WindowSize) {
            boolean deletedInstance = false;
            while (this.m_Train.numInstances() > this.m_WindowSize) {
                this.m_Train.delete(0);
                deletedInstance = true;
            }
            if (deletedInstance) {
                this.m_NNSearch.setInstances(this.m_Train);
            }
        }
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        if (this.m_Train.numInstances() == 0) {
            throw new Exception("No training instances!");
        }
        if (this.m_WindowSize > 0 && this.m_Train.numInstances() > this.m_WindowSize) {
            this.m_kNNValid = false;
            boolean deletedInstance = false;
            while (this.m_Train.numInstances() > this.m_WindowSize) {
                this.m_Train.delete(0);
            }
            if (deletedInstance) {
                this.m_NNSearch.setInstances(this.m_Train);
            }
        }
        if (!this.m_kNNValid && this.m_CrossValidate && this.m_kNNUpper >= 1) {
            this.crossValidate();
        }
        this.m_NNSearch.addInstanceInfo(instance);
        Instances neighbours = this.m_NNSearch.kNearestNeighbours(instance, this.m_kNN);
        double[] distances = this.m_NNSearch.getDistances();
        double[] distribution = this.makeDistribution(neighbours, distances);
        return distribution;
    }

    public Enumeration listOptions() {
        Vector<Option> newVector = new Vector<Option>(8);
        newVector.addElement(new Option("\tWeight neighbours by the inverse of their distance\n\t(use when k > 1)", "I", 0, "-I"));
        newVector.addElement(new Option("\tWeight neighbours by 1 - their distance\n\t(use when k > 1)", "F", 0, "-F"));
        newVector.addElement(new Option("\tNumber of nearest neighbours (k) used in classification.\n\t(Default = 1)", "K", 1, "-K <number of neighbors>"));
        newVector.addElement(new Option("\tMinimise mean squared error rather than mean absolute\n\terror when using -X option with numeric prediction.", "E", 0, "-E"));
        newVector.addElement(new Option("\tMaximum number of training instances maintained.\n\tTraining instances are dropped FIFO. (Default = no window)", "W", 1, "-W <window size>"));
        newVector.addElement(new Option("\tSelect the number of nearest neighbours between 1\n\tand the k value specified using hold-one-out evaluation\n\ton the training data (use when k > 1)", "X", 0, "-X"));
        newVector.addElement(new Option("\tThe nearest neighbour search algorithm to use (default: weka.core.neighboursearch.LinearNNSearch).\n", "A", 0, "-A"));
        return newVector.elements();
    }

    public void setOptions(String[] options) throws Exception {
        String knnString = Utils.getOption('K', options);
        if (knnString.length() != 0) {
            this.setKNN(Integer.parseInt(knnString));
        } else {
            this.setKNN(1);
        }
        String windowString = Utils.getOption('W', options);
        if (windowString.length() != 0) {
            this.setWindowSize(Integer.parseInt(windowString));
        } else {
            this.setWindowSize(0);
        }
        if (Utils.getFlag('I', options)) {
            this.setDistanceWeighting(new SelectedTag(2, TAGS_WEIGHTING));
        } else if (Utils.getFlag('F', options)) {
            this.setDistanceWeighting(new SelectedTag(4, TAGS_WEIGHTING));
        } else {
            this.setDistanceWeighting(new SelectedTag(1, TAGS_WEIGHTING));
        }
        this.setCrossValidate(Utils.getFlag('X', options));
        this.setMeanSquared(Utils.getFlag('E', options));
        String nnSearchClass = Utils.getOption('A', options);
        if (nnSearchClass.length() != 0) {
            String[] nnSearchClassSpec = Utils.splitOptions(nnSearchClass);
            if (nnSearchClassSpec.length == 0) {
                throw new Exception("Invalid NearestNeighbourSearch algorithm specification string.");
            }
            String className = nnSearchClassSpec[0];
            nnSearchClassSpec[0] = "";
            this.setNearestNeighbourSearchAlgorithm((NearestNeighbourSearch)Utils.forName(NearestNeighbourSearch.class, className, nnSearchClassSpec));
        } else {
            this.setNearestNeighbourSearchAlgorithm(new LinearNNSearch());
        }
        Utils.checkForRemainingOptions(options);
    }

    public String[] getOptions() {
        String[] options = new String[11];
        int current = 0;
        options[current++] = "-K";
        options[current++] = "" + this.getKNN();
        options[current++] = "-W";
        options[current++] = "" + this.m_WindowSize;
        if (this.getCrossValidate()) {
            options[current++] = "-X";
        }
        if (this.getMeanSquared()) {
            options[current++] = "-E";
        }
        if (this.m_DistanceWeighting == 2) {
            options[current++] = "-I";
        } else if (this.m_DistanceWeighting == 4) {
            options[current++] = "-F";
        }
        options[current++] = "-A";
        options[current++] = this.m_NNSearch.getClass().getName() + " " + Utils.joinOptions(this.m_NNSearch.getOptions());
        while (current < options.length) {
            options[current++] = "";
        }
        return options;
    }

    public Enumeration enumerateMeasures() {
        if (this.m_CrossValidate) {
            Enumeration enm = this.m_NNSearch.enumerateMeasures();
            Vector measures = new Vector();
            while (enm.hasMoreElements()) {
                measures.add(enm.nextElement());
            }
            measures.add("measureKNN");
            return measures.elements();
        }
        return this.m_NNSearch.enumerateMeasures();
    }

    public double getMeasure(String additionalMeasureName) {
        if (additionalMeasureName.equals("measureKNN")) {
            return this.m_kNN;
        }
        return this.m_NNSearch.getMeasure(additionalMeasureName);
    }

    public String toString() {
        if (this.m_Train == null) {
            return "IBk: No model built yet.";
        }
        if (!this.m_kNNValid && this.m_CrossValidate) {
            this.crossValidate();
        }
        String result = "IB1 instance-based classifier\nusing " + this.m_kNN;
        switch (this.m_DistanceWeighting) {
            case 2: {
                result = result + " inverse-distance-weighted";
                break;
            }
            case 4: {
                result = result + " similarity-weighted";
            }
        }
        result = result + " nearest neighbour(s) for classification\n";
        if (this.m_WindowSize != 0) {
            result = result + "using a maximum of " + this.m_WindowSize + " (windowed) training instances\n";
        }
        return result;
    }

    protected void init() {
        this.setKNN(1);
        this.m_WindowSize = 0;
        this.m_DistanceWeighting = 1;
        this.m_CrossValidate = false;
        this.m_MeanSquared = false;
    }

    protected double[] makeDistribution(Instances neighbours, double[] distances) throws Exception {
        int i;
        double total = 0.0;
        double[] distribution = new double[this.m_NumClasses];
        if (this.m_ClassType == 1) {
            for (i = 0; i < this.m_NumClasses; ++i) {
                distribution[i] = 1.0 / (double)Math.max(1, this.m_Train.numInstances());
            }
            total = (double)this.m_NumClasses / (double)Math.max(1, this.m_Train.numInstances());
        }
        for (i = 0; i < neighbours.numInstances(); ++i) {
            double weight;
            Instance current = neighbours.instance(i);
            distances[i] = distances[i] * distances[i];
            distances[i] = Math.sqrt(distances[i] / this.m_NumAttributesUsed);
            switch (this.m_DistanceWeighting) {
                case 2: {
                    weight = 1.0 / (distances[i] + 0.001);
                    break;
                }
                case 4: {
                    weight = 1.0 - distances[i];
                    break;
                }
                default: {
                    weight = 1.0;
                }
            }
            weight *= current.weight();
            try {
                switch (this.m_ClassType) {
                    case 1: {
                        int n = (int)current.classValue();
                        distribution[n] = distribution[n] + weight;
                        break;
                    }
                    case 0: {
                        distribution[0] = distribution[0] + current.classValue() * weight;
                    }
                }
            }
            catch (Exception ex) {
                throw new Error("Data has no class attribute!");
            }
            total += weight;
        }
        if (total > 0.0) {
            Utils.normalize(distribution, total);
        }
        return distribution;
    }

    protected void crossValidate() {
        try {
            int i;
            if (this.m_NNSearch instanceof CoverTree) {
                throw new Exception("CoverTree doesn't support hold-one-out cross-validation. Use some other NN method.");
            }
            double[] performanceStats = new double[this.m_kNNUpper];
            double[] performanceStatsSq = new double[this.m_kNNUpper];
            for (int i2 = 0; i2 < this.m_kNNUpper; ++i2) {
                performanceStats[i2] = 0.0;
                performanceStatsSq[i2] = 0.0;
            }
            this.m_kNN = this.m_kNNUpper;
            for (i = 0; i < this.m_Train.numInstances(); ++i) {
                if (this.m_Debug && i % 50 == 0) {
                    System.err.print("Cross validating " + i + "/" + this.m_Train.numInstances() + "\r");
                }
                Instance instance = this.m_Train.instance(i);
                Instances neighbours = this.m_NNSearch.kNearestNeighbours(instance, this.m_kNN);
                double[] origDistances = this.m_NNSearch.getDistances();
                for (int j = this.m_kNNUpper - 1; j >= 0; --j) {
                    double[] convertedDistances = new double[origDistances.length];
                    System.arraycopy(origDistances, 0, convertedDistances, 0, origDistances.length);
                    double[] distribution = this.makeDistribution(neighbours, convertedDistances);
                    double thisPrediction = Utils.maxIndex(distribution);
                    if (this.m_Train.classAttribute().isNumeric()) {
                        thisPrediction = distribution[0];
                        double err = thisPrediction - instance.classValue();
                        int n = j;
                        performanceStatsSq[n] = performanceStatsSq[n] + err * err;
                        int n2 = j;
                        performanceStats[n2] = performanceStats[n2] + Math.abs(err);
                    } else if (thisPrediction != instance.classValue()) {
                        int n = j;
                        performanceStats[n] = performanceStats[n] + 1.0;
                    }
                    if (j < 1) continue;
                    neighbours = this.pruneToK(neighbours, convertedDistances, j);
                }
            }
            for (i = 0; i < this.m_kNNUpper; ++i) {
                if (this.m_Debug) {
                    System.err.print("Hold-one-out performance of " + (i + 1) + " neighbors ");
                }
                if (this.m_Train.classAttribute().isNumeric()) {
                    if (!this.m_Debug) continue;
                    if (this.m_MeanSquared) {
                        System.err.println("(RMSE) = " + Math.sqrt(performanceStatsSq[i] / (double)this.m_Train.numInstances()));
                        continue;
                    }
                    System.err.println("(MAE) = " + performanceStats[i] / (double)this.m_Train.numInstances());
                    continue;
                }
                if (!this.m_Debug) continue;
                System.err.println("(%ERR) = " + 100.0 * performanceStats[i] / (double)this.m_Train.numInstances());
            }
            double[] searchStats = performanceStats;
            if (this.m_Train.classAttribute().isNumeric() && this.m_MeanSquared) {
                searchStats = performanceStatsSq;
            }
            double bestPerformance = Double.NaN;
            int bestK = 1;
            for (int i3 = 0; i3 < this.m_kNNUpper; ++i3) {
                if (!Double.isNaN(bestPerformance) && !(bestPerformance > searchStats[i3])) continue;
                bestPerformance = searchStats[i3];
                bestK = i3 + 1;
            }
            this.m_kNN = bestK;
            if (this.m_Debug) {
                System.err.println("Selected k = " + bestK);
            }
            this.m_kNNValid = true;
        }
        catch (Exception ex) {
            throw new Error("Couldn't optimize by cross-validation: " + ex.getMessage());
        }
    }

    public Instances pruneToK(Instances neighbours, double[] distances, int k) {
        if (neighbours == null || distances == null || neighbours.numInstances() == 0) {
            return null;
        }
        if (k < 1) {
            k = 1;
        }
        int currentK = 0;
        for (int i = 0; i < neighbours.numInstances(); ++i) {
            double currentDist = distances[i];
            if (++currentK <= k || currentDist == distances[i - 1]) continue;
            neighbours = new Instances(neighbours, 0, --currentK);
            break;
        }
        return neighbours;
    }

    public String getRevision() {
        return RevisionUtils.extract("$Revision: 5928 $");
    }

    public static void main(String[] argv) {
        IBk.runClassifier(new IBk(), argv);
    }
}

