/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.xgboost;

import java.io.EOFException;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.dmg.pmml.Visitable;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.xgboost.BinomialLogisticRegression;
import org.jpmml.xgboost.Dart;
import org.jpmml.xgboost.FeatureMap;
import org.jpmml.xgboost.GBTree;
import org.jpmml.xgboost.GeneralizedLinearRegression;
import org.jpmml.xgboost.LinearRegression;
import org.jpmml.xgboost.Loadable;
import org.jpmml.xgboost.LogisticRegression;
import org.jpmml.xgboost.MultinomialLogisticRegression;
import org.jpmml.xgboost.ObjFunction;
import org.jpmml.xgboost.PoissonRegression;
import org.jpmml.xgboost.XGBoostDataInput;
import org.jpmml.xgboost.XGBoostEncoder;
import org.jpmml.xgboost.visitors.TreeModelCompactor;

public class Learner
implements Loadable {
    private float base_score;
    private int num_features;
    private int num_class;
    private int contain_extra_attrs;
    private int contain_eval_metrics;
    private ObjFunction obj;
    private GBTree gbtree;
    private Map<String, String> attributes = null;
    private String[] metrics = null;

    @Override
    public void load(XGBoostDataInput input) throws IOException {
        String name_gbm;
        String name_obj;
        this.base_score = input.readFloat();
        this.num_features = input.readInt();
        this.num_class = input.readInt();
        this.contain_extra_attrs = input.readInt();
        this.contain_eval_metrics = input.readInt();
        input.readReserved(29);
        switch (name_obj = input.readString()) {
            case "reg:linear": {
                this.obj = new LinearRegression();
                break;
            }
            case "reg:logistic": {
                this.obj = new LogisticRegression();
                break;
            }
            case "reg:gamma": 
            case "reg:tweedie": {
                this.obj = new GeneralizedLinearRegression();
                break;
            }
            case "count:poisson": {
                this.obj = new PoissonRegression();
                break;
            }
            case "binary:logistic": {
                this.obj = new BinomialLogisticRegression();
                break;
            }
            case "multi:softmax": 
            case "multi:softprob": {
                this.obj = new MultinomialLogisticRegression(this.num_class);
                break;
            }
            default: {
                throw new IllegalArgumentException(name_obj);
            }
        }
        switch (name_gbm = input.readString()) {
            case "gbtree": {
                this.gbtree = new GBTree();
                break;
            }
            case "dart": {
                this.gbtree = new Dart();
                break;
            }
            default: {
                throw new IllegalArgumentException(name_gbm);
            }
        }
        this.gbtree.load(input);
        if (this.contain_extra_attrs != 0) {
            this.attributes = input.readStringMap();
        }
        if (this.obj instanceof PoissonRegression) {
            try {
                String max_delta_step = input.readString();
            }
            catch (EOFException eOFException) {
                // empty catch block
            }
        }
        if (this.contain_eval_metrics != 0) {
            this.metrics = input.readStringVector();
        }
    }

    public PMML encodePMML(FieldName targetField, List<String> targetCategories, FeatureMap featureMap, Map<String, ?> options) {
        XGBoostEncoder encoder = new XGBoostEncoder();
        if (targetField == null) {
            targetField = FieldName.create((String)"_target");
        }
        Label label = this.obj.encodeLabel(targetField, targetCategories, (PMMLEncoder)encoder);
        List<Feature> features = featureMap.encodeFeatures((PMMLEncoder)encoder);
        Schema schema = new Schema(label, features);
        MiningModel miningModel = this.encodeMiningModel(options, schema);
        PMML pmml = encoder.encodePMML((Model)miningModel);
        return pmml;
    }

    public MiningModel encodeMiningModel(Map<String, ?> options, Schema schema) {
        Boolean compact = (Boolean)options.get("compact");
        Integer ntreeLimit = (Integer)options.get("ntree_limit");
        MiningModel miningModel = this.gbtree.encodeMiningModel(this.obj, this.base_score, ntreeLimit, schema).setAlgorithmName("XGBoost (" + this.gbtree.getAlgorithmName() + ")");
        if (Boolean.TRUE.equals(compact)) {
            TreeModelCompactor visitor = new TreeModelCompactor();
            visitor.applyTo((Visitable)miningModel);
        }
        return miningModel;
    }

    public int num_features() {
        return this.num_features;
    }

    public int num_class() {
        return this.num_class;
    }

    public ObjFunction obj() {
        return this.obj;
    }
}

