/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.dtree;

import com.oracle.labs.mlrg.olcut.config.Config;
import org.tribuo.Dataset;
import org.tribuo.Trainer;
import org.tribuo.classification.Label;
import org.tribuo.classification.dtree.impl.ClassifierTrainingNode;
import org.tribuo.classification.dtree.impurity.GiniIndex;
import org.tribuo.classification.dtree.impurity.LabelImpurity;
import org.tribuo.common.tree.AbstractCARTTrainer;
import org.tribuo.common.tree.AbstractTrainingNode;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;

public class CARTClassificationTrainer
extends AbstractCARTTrainer<Label> {
    @Config(description="The impurity measure used to determine split quality.")
    private LabelImpurity impurity = new GiniIndex();

    public CARTClassificationTrainer(int maxDepth, float minChildWeight, float minImpurityDecrease, float fractionFeaturesInSplit, boolean useRandomSplitPoints, LabelImpurity impurity, long seed) {
        super(maxDepth, minChildWeight, minImpurityDecrease, fractionFeaturesInSplit, useRandomSplitPoints, seed);
        this.impurity = impurity;
        this.postConfig();
    }

    public CARTClassificationTrainer() {
        this(Integer.MAX_VALUE);
    }

    public CARTClassificationTrainer(int maxDepth) {
        this(maxDepth, 5.0f, 0.0f, 1.0f, false, new GiniIndex(), 12345L);
    }

    public CARTClassificationTrainer(int maxDepth, float fractionFeaturesInSplit, long seed) {
        this(maxDepth, 5.0f, 0.0f, fractionFeaturesInSplit, false, new GiniIndex(), seed);
    }

    public CARTClassificationTrainer(int maxDepth, float fractionFeaturesInSplit, boolean useRandomSplitPoints, long seed) {
        this(maxDepth, 5.0f, 0.0f, fractionFeaturesInSplit, useRandomSplitPoints, new GiniIndex(), seed);
    }

    public CARTClassificationTrainer(int maxDepth, float minChildWeight, float minImpurityDecrease, float fractionFeaturesInSplit, LabelImpurity impurity, long seed) {
        this(maxDepth, minChildWeight, minImpurityDecrease, fractionFeaturesInSplit, false, impurity, seed);
    }

    protected AbstractTrainingNode<Label> mkTrainingNode(Dataset<Label> examples, AbstractTrainingNode.LeafDeterminer leafDeterminer) {
        return new ClassifierTrainingNode(this.impurity, examples, leafDeterminer);
    }

    public String toString() {
        StringBuilder buffer = new StringBuilder();
        buffer.append("CARTClassificationTrainer(maxDepth=");
        buffer.append(this.maxDepth);
        buffer.append(",minChildWeight=");
        buffer.append(this.minChildWeight);
        buffer.append(",minImpurityDecrease=");
        buffer.append(this.minImpurityDecrease);
        buffer.append(",fractionFeaturesInSplit=");
        buffer.append(this.fractionFeaturesInSplit);
        buffer.append(",useRandomSplitPoints=");
        buffer.append(this.useRandomSplitPoints);
        buffer.append(",impurity=");
        buffer.append(this.impurity.toString());
        buffer.append(",seed=");
        buffer.append(this.seed);
        buffer.append(")");
        return buffer.toString();
    }

    public TrainerProvenance getProvenance() {
        return new TrainerProvenanceImpl((Trainer)this);
    }
}

