/*
 * Decompiled with CFR 0.152.
 */
package dmLab.experiment.classification;

import dmLab.array.Array;
import dmLab.array.FArray;
import dmLab.array.functions.DiscFunctions;
import dmLab.array.loader.File2Array;
import dmLab.array.saver.Array2File;
import dmLab.classifier.Classifier;
import dmLab.classifier.Prediction;
import dmLab.classifier.adx.ADXClassifier;
import dmLab.classifier.bayesNet.BayesNetClassifier;
import dmLab.classifier.ensemble.EnsembleClassifier;
import dmLab.classifier.hyperPipes.HyperPipesClassifier;
import dmLab.classifier.j48.J48Classifier;
import dmLab.classifier.knn.KNNClassifier;
import dmLab.classifier.logistic.LogisticClassifier;
import dmLab.classifier.nb.NBClassifier;
import dmLab.classifier.randomForest.RandomForestClassifier;
import dmLab.classifier.ripper.RipperClassifier;
import dmLab.classifier.rnd.RNDClassifier;
import dmLab.classifier.sliq.SliqClassifier;
import dmLab.classifier.svm.SVMClassifier;
import dmLab.experiment.classification.ClassificationParams;
import dmLab.utils.ArrayUtils;
import dmLab.utils.cmatrix.AccuracyMeasure;
import dmLab.utils.cmatrix.ConfusionMatrix;
import java.io.IOException;
import java.util.Random;

public final class ClassificationBody {
    public ClassificationParams classParams;
    public Classifier classifier;
    public ConfusionMatrix resultConfMatrix;
    public FArray inArray;
    private FArray trainArray;
    private FArray testArray;
    private File2Array file2Container;
    private Array2File array2file;
    public float learningTime;
    public float testingTime;
    public float experimentTime;
    public double[] accArray;
    public double[] wAccArray;
    private Prediction[] predictions;
    private DiscFunctions selectFunctions$49f14b5d;

    /*
     * WARNING - void declaration
     */
    public ClassificationBody(Random random) {
        void var1_1;
        this.selectFunctions$49f14b5d = new DiscFunctions((Random)var1_1);
        this.file2Container = new File2Array();
        this.array2file = new Array2File();
        this.cleanStats();
    }

    public final void cleanStats() {
        this.learningTime = 0.0f;
        this.testingTime = 0.0f;
        this.experimentTime = 0.0f;
    }

    public final float run() {
        Float f;
        ClassificationBody classificationBody = null;
        classificationBody = this;
        if (classificationBody.classifier == null && !classificationBody.createClassifier()) {
            f = null;
        } else if (!classificationBody.loadArrays()) {
            f = null;
        } else if (classificationBody.classParams == null || !classificationBody.classifier.params.check(classificationBody.inArray)) {
            f = null;
        } else {
            if (classificationBody.classParams.validationType == 1 || classificationBody.classParams.validationType == 3) {
                classificationBody.multipleTrainTest(classificationBody.inArray);
            } else if (classificationBody.classParams.validationType == 2) {
                classificationBody.multipleCV(classificationBody.inArray);
            }
            f = Float.valueOf(classificationBody.resultConfMatrix.calcMeasure(AccuracyMeasure.WACC));
        }
        return f.floatValue();
    }

    /*
     * WARNING - void declaration
     */
    public final boolean setParameters(ClassificationParams classParams) {
        void var1_1;
        this.classParams = classParams;
        return var1_1.check(null);
    }

    /*
     * WARNING - void declaration
     */
    public final boolean loadParameters(String paramsFileName) {
        this.classParams = new ClassificationParams();
        if (!this.classParams.load("", paramsFileName)) {
            void var1_1;
            System.err.println("Error loading configuration file. File: " + (String)var1_1);
            return false;
        }
        if (this.classParams.verbose) {
            System.out.println(this.classParams.toString());
        }
        return this.classParams.check(null);
    }

    public final boolean createClassifier() {
        if (this.classParams.classifier == Classifier.ENSEMBLE) {
            this.classifier = new EnsembleClassifier();
        } else if (this.classParams.classifier == Classifier.RND) {
            this.classifier = new RNDClassifier();
        } else if (this.classParams.classifier == Classifier.J48) {
            this.classifier = new J48Classifier();
        } else if (this.classParams.classifier == Classifier.ADX) {
            this.classifier = new ADXClassifier();
        } else if (this.classParams.classifier == Classifier.SLIQ) {
            this.classifier = new SliqClassifier();
        } else if (this.classParams.classifier == Classifier.RF) {
            this.classifier = new RandomForestClassifier();
        } else if (this.classParams.classifier == Classifier.NB) {
            this.classifier = new NBClassifier();
        } else if (this.classParams.classifier == Classifier.KNN) {
            this.classifier = new KNNClassifier();
        } else if (this.classParams.classifier == Classifier.RIPPER) {
            this.classifier = new RipperClassifier();
        } else if (this.classParams.classifier == Classifier.SVM) {
            this.classifier = new SVMClassifier();
        } else if (this.classParams.classifier == Classifier.BNET) {
            this.classifier = new BayesNetClassifier();
        } else if (this.classParams.classifier == Classifier.HP) {
            this.classifier = new HyperPipesClassifier();
        } else if (this.classParams.classifier == Classifier.LOGISTIC) {
            this.classifier = new LogisticClassifier();
        } else {
            System.err.println("Error creating the classifier.");
            return false;
        }
        if (!this.classifier.params.load(this.classParams.classifierCfgPATH, this.classifier.label)) {
            return false;
        }
        this.classifier.init();
        this.classifier.setTempPath(this.classParams.resFilesPATH);
        if (this.classParams.verbose) {
            System.out.println(this.classifier.params.toString());
        }
        this.cleanStats();
        return true;
    }

    public final boolean loadArrays() {
        if (this.inArray != null) {
            return true;
        }
        this.inArray = new FArray();
        if (this.classParams.verbose) {
            System.out.println("Loading Input Table...");
        }
        if (!this.file2Container.load(this.inArray, String.valueOf(this.classParams.inputFilesPATH) + this.classParams.inputFileName)) {
            return false;
        }
        if (!this.inArray.checkDecisionValues()) {
            return false;
        }
        if (this.classParams.debug) {
            System.out.println(" ### DEBUG ### ");
            System.out.println(this.inArray.toString());
        }
        if (this.classParams.verbose) {
            System.out.println("Input table has been loaded.");
        }
        if (this.classParams.validationType == 3) {
            this.trainArray = this.inArray;
            if (this.classParams.verbose) {
                System.out.println("Loading Testing Table...");
            }
            this.testArray = new FArray();
            this.testArray.dictionary = this.trainArray.dictionary.clone();
            this.testArray.setDecValues(this.trainArray.getDecValues());
            this.testArray.setDecAttrIdx(this.trainArray.getDecAttrIdx());
            this.file2Container.load(this.testArray, String.valueOf(this.classParams.inputFilesPATH) + this.classParams.testFileName);
            if (this.classParams.verbose) {
                System.out.println("Testing table has been loaded.");
            }
        }
        return true;
    }

    /*
     * WARNING - void declaration
     */
    private boolean split(FArray inputArray, int[] splitMask) {
        void var1_1;
        void var2_2;
        if (this.classParams.verbose) {
            System.out.println("Splitting Input Table...");
        }
        if (splitMask == null) {
            if (this.classParams.splitType == 1) {
                splitMask = this.selectFunctions$49f14b5d.getSplitMaskRandom(inputArray, this.classParams.splitRatio);
            } else if (this.classParams.splitType == 2) {
                splitMask = this.selectFunctions$49f14b5d.getSplitMaskUniform(inputArray, this.classParams.splitRatio);
            } else {
                System.err.println("classParams.splitType does not equal to SPLIT_RANDOM or SPLIT_UNIFORM.");
                return false;
            }
        }
        Array[] trainTestArrays = DiscFunctions.split(inputArray, (int[])var2_2);
        this.trainArray = (FArray)trainTestArrays[0];
        this.testArray = (FArray)var1_1[1];
        if (this.classParams.verbose) {
            System.out.println("Input table has been splitted.");
        }
        return true;
    }

    /*
     * WARNING - void declaration
     */
    private boolean savePredictionArray(FArray array, String fileName) {
        void var2_2;
        void var3_3;
        FArray predictionArray = array.clone();
        String[] decValues = array.getDecValuesStr();
        int[] scoreIndex = new int[decValues.length];
        boolean saveScores = this.predictions[0].hasScores();
        if (saveScores) {
            int i = 0;
            while (i < decValues.length) {
                String scoreAttrName = "score_" + decValues[i];
                DiscFunctions.addAttribute(predictionArray, scoreAttrName);
                scoreIndex[i] = predictionArray.getColIndex(scoreAttrName);
                predictionArray.attributes[scoreIndex[i]].type = (short)2;
                ++i;
            }
        }
        DiscFunctions.addAttribute(predictionArray, "prediction");
        int predictionIndex = predictionArray.getColIndex("prediction");
        int rows = predictionArray.rowsNumber();
        int j = 0;
        while (j < rows) {
            if (saveScores) {
                int i = 0;
                while (i < decValues.length) {
                    predictionArray.writeValue(scoreIndex[i], j, (float)this.predictions[j].getScore(i));
                    ++i;
                }
            }
            predictionArray.writeValueStr(predictionIndex, j, this.predictions[j].getLabel());
            ++j;
        }
        this.array2file.setFormat(0);
        this.array2file.saveFile(predictionArray, String.valueOf(fileName) + "_pred");
        this.array2file.setFormat(2);
        this.array2file.saveFile((Array)var3_3, String.valueOf(var2_2) + "_pred");
        return true;
    }

    /*
     * WARNING - void declaration
     */
    public final ConfusionMatrix multipleCV(FArray inputArray) {
        void var2_2;
        if (this.classParams.verbose) {
            System.out.println("Running multCV...");
        }
        ConfusionMatrix cMatrix = new ConfusionMatrix(inputArray.getColNames(true)[inputArray.getDecAttrIdx()], inputArray.getDecValuesStr());
        FArray inputArrayOriginal = inputArray.clone();
        int repetitions = this.classParams.repetitions;
        this.accArray = new double[repetitions];
        this.wAccArray = new double[repetitions];
        int i = 0;
        while (i < this.classParams.repetitions) {
            double start = System.currentTimeMillis();
            String label = "_rep" + Integer.toString(i + 1);
            ConfusionMatrix singleMatrix = this.singleCV(inputArray, label);
            this.experimentTime = (float)((double)this.experimentTime + ((double)System.currentTimeMillis() - start) / 1000.0);
            cMatrix.add(singleMatrix);
            this.accArray[i] = singleMatrix.calcMeasure(AccuracyMeasure.ACC);
            this.wAccArray[i] = singleMatrix.calcMeasure(AccuracyMeasure.WACC);
            String experimentLabel = String.valueOf(this.classParams.label) + Integer.toString(i + 1);
            if (this.classParams.savePredictionResult) {
                this.savePredictionArray(inputArrayOriginal, String.valueOf(this.classParams.resFilesPATH) + "//" + experimentLabel);
            }
            if (this.classParams.verbose) {
                System.out.println();
                System.out.println("##### CV " + Integer.toString(i + 1) + " RESULT #####");
                System.out.println(singleMatrix.toString());
                System.out.println(singleMatrix.statsToString(4));
            }
            System.gc();
            ++i;
        }
        this.resultConfMatrix = cMatrix;
        this.normalizeResults();
        return var2_2;
    }

    /*
     * WARNING - void declaration
     */
    public final ConfusionMatrix singleCV(FArray inputArray, String label) {
        void var3_3;
        if (this.classParams.verbose) {
            System.out.println("Running single CV...");
        }
        ConfusionMatrix cMatrix = new ConfusionMatrix(inputArray.getColNames(true)[inputArray.getDecAttrIdx()], inputArray.getDecValuesStr());
        int cvFolds = this.classParams.folds;
        int rows = inputArray.rowsNumber();
        int[] cvTable = new int[rows];
        int[] splitMask = new int[rows];
        int n = cvFolds;
        int[] nArray = cvTable;
        ArrayUtils arrayUtils = this.selectFunctions$49f14b5d.arrayUtils;
        int n2 = 0;
        while (n2 < nArray.length) {
            nArray[n2] = (int)(arrayUtils.random.nextFloat() * (float)n);
            ++n2;
        }
        this.predictions = new Prediction[rows];
        int i = 0;
        while (i < cvFolds) {
            int j = 0;
            while (j < rows) {
                splitMask[j] = cvTable[j] == i ? 0 : 1;
                ++j;
            }
            this.split(inputArray, splitMask);
            this.singleTrainTest(this.trainArray, this.testArray);
            this.learningTime = (float)((double)this.learningTime + this.classifier.getLearningTime());
            this.testingTime = (float)((double)this.testingTime + this.classifier.getTestingTime());
            cMatrix.add(this.classifier.getConfusionMatrix());
            Prediction[] singlePrediction = this.classifier.getPredictions();
            int k = 0;
            int j2 = 0;
            while (j2 < rows) {
                if (cvTable[j2] == i) {
                    this.predictions[j2] = singlePrediction[k++];
                }
                ++j2;
            }
            String experimentLabel = String.valueOf(this.classParams.label) + label + "_fold" + Integer.toString(i + 1);
            if (this.classParams.saveClassifier) {
                try {
                    this.classifier.saveDefinition(this.classParams.resFilesPATH, experimentLabel);
                }
                catch (IOException e) {
                    System.err.println("Error saving classifier.");
                    e.printStackTrace();
                }
            }
            ++i;
        }
        this.resultConfMatrix = cMatrix;
        return var3_3;
    }

    /*
     * WARNING - void declaration
     */
    private ConfusionMatrix multipleTrainTest(FArray inputArray) {
        void var3_3;
        int repetitions = 0;
        if (this.classParams.validationType == 3) {
            if (this.classParams.verbose) {
                System.out.println("MultTrainTest based on testing set...");
            }
            repetitions = 1;
        } else if (this.classParams.validationType == 1) {
            if (this.classParams.verbose) {
                System.out.println("MultTrainTest based on splitting of input set...");
            }
            repetitions = this.classParams.repetitions;
        }
        this.accArray = new double[repetitions];
        this.wAccArray = new double[repetitions];
        ConfusionMatrix cMatrix = new ConfusionMatrix(inputArray.getColNames(true)[inputArray.getDecAttrIdx()], inputArray.getDecValuesStr());
        int i = 0;
        while (i < repetitions) {
            if (this.classParams.validationType == 1) {
                this.split(inputArray, null);
            }
            if (this.classParams.debug) {
                System.out.println(" ### DEBUG ### ");
                System.out.println(this.trainArray.toString());
                System.out.println(this.testArray.toString());
            }
            double start = System.currentTimeMillis();
            this.singleTrainTest(this.trainArray, this.testArray);
            this.experimentTime = (float)((double)this.experimentTime + ((double)System.currentTimeMillis() - start) / 1000.0);
            this.learningTime = (float)((double)this.learningTime + this.classifier.getLearningTime());
            this.testingTime = (float)((double)this.testingTime + this.classifier.getTestingTime());
            ConfusionMatrix singleMatrix = this.classifier.getConfusionMatrix();
            cMatrix.add(singleMatrix);
            this.accArray[i] = singleMatrix.calcMeasure(AccuracyMeasure.ACC);
            this.wAccArray[i] = singleMatrix.calcMeasure(AccuracyMeasure.WACC);
            String experimentLabel = String.valueOf(this.classParams.label) + Integer.toString(i + 1);
            if (this.classParams.saveClassifier) {
                try {
                    this.classifier.saveDefinition(this.classParams.resFilesPATH, experimentLabel);
                }
                catch (IOException e) {
                    System.err.println("Error saving classifier.");
                    e.printStackTrace();
                }
            }
            this.predictions = this.classifier.getPredictions();
            if (this.classParams.savePredictionResult) {
                this.savePredictionArray(this.testArray, String.valueOf(this.classParams.resFilesPATH) + "//" + experimentLabel);
            }
            if (this.classParams.verbose) {
                System.out.println("\n##### SPLIT " + Integer.toString(i + 1) + " RESULT #####");
                System.out.println(singleMatrix.toString());
                System.out.println(singleMatrix.statsToString(4));
            }
            ++i;
        }
        this.resultConfMatrix = cMatrix;
        this.normalizeResults();
        return var3_3;
    }

    /*
     * WARNING - void declaration
     */
    public final ConfusionMatrix singleTrainTest(FArray trainArray, FArray testArray) {
        void var1_1;
        ConfusionMatrix cMatrix;
        void var2_2;
        this.classifier.train(trainArray);
        if (this.classParams.debug) {
            System.out.println("### DEBUG ### ");
            System.out.println(this.classifier.toString());
        }
        this.classifier.test((FArray)var2_2);
        this.resultConfMatrix = cMatrix = this.classifier.getConfusionMatrix();
        return var1_1;
    }

    /*
     * WARNING - void declaration
     */
    private void normalizeResults() {
        void var1_1;
        float repetitions = this.classParams.repetitions;
        if (this.classParams.validationType == 2) {
            repetitions *= (float)this.classParams.folds;
        }
        this.experimentTime /= repetitions;
        this.learningTime /= repetitions;
        this.testingTime /= var1_1;
    }

    /*
     * WARNING - void declaration
     */
    public final String toStringCMatrix() {
        void var1_1;
        if (this.resultConfMatrix == null) {
            return "";
        }
        StringBuffer tmp = new StringBuffer();
        tmp.append(this.resultConfMatrix.toString(true, true, false, "\t")).append("\n");
        tmp.append(this.resultConfMatrix.statsToString(4)).append("\n");
        return var1_1.toString();
    }
}

