/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.dtree.impl;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.io.NotSerializableException;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.SplittableRandom;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.classification.Label;
import org.tribuo.classification.dtree.impl.InvertedFeature;
import org.tribuo.classification.dtree.impl.TreeFeature;
import org.tribuo.classification.dtree.impurity.LabelImpurity;
import org.tribuo.common.tree.AbstractTrainingNode;
import org.tribuo.common.tree.LeafNode;
import org.tribuo.common.tree.Node;
import org.tribuo.common.tree.impl.IntArrayContainer;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.util.Util;

public class ClassifierTrainingNode
extends AbstractTrainingNode<Label> {
    private static final long serialVersionUID = 1L;
    private static final Logger logger = Logger.getLogger(ClassifierTrainingNode.class.getName());
    private static final ThreadLocal<IntArrayContainer> mergeBufferOne = ThreadLocal.withInitial(() -> new IntArrayContainer(16));
    private static final ThreadLocal<IntArrayContainer> mergeBufferTwo = ThreadLocal.withInitial(() -> new IntArrayContainer(16));
    private static final ThreadLocal<IntArrayContainer> mergeBufferThree = ThreadLocal.withInitial(() -> new IntArrayContainer(16));
    private transient ArrayList<TreeFeature> data;
    private final ImmutableOutputInfo<Label> labelIDMap;
    private final ImmutableFeatureMap featureIDMap;
    private final LabelImpurity impurity;
    private final float[] weightedLabelCounts;
    private final float weightSum;

    public ClassifierTrainingNode(LabelImpurity impurity, Dataset<Label> examples, AbstractTrainingNode.LeafDeterminer leafDeterminer) {
        this(impurity, ClassifierTrainingNode.invertData(examples), examples.size(), 0, examples.getFeatureIDMap(), (ImmutableOutputInfo<Label>)examples.getOutputIDInfo(), leafDeterminer);
    }

    private ClassifierTrainingNode(LabelImpurity impurity, ArrayList<TreeFeature> data, int numExamples, int depth, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> labelIDMap, AbstractTrainingNode.LeafDeterminer leafDeterminer) {
        super(depth, numExamples, leafDeterminer);
        this.data = data;
        this.featureIDMap = featureIDMap;
        this.labelIDMap = labelIDMap;
        this.impurity = impurity;
        this.weightedLabelCounts = data.get(0).getWeightedLabelCounts();
        this.weightSum = Util.sum((float[])this.weightedLabelCounts);
        this.impurityScore = impurity.impurity(this.weightedLabelCounts);
    }

    private ClassifierTrainingNode(LabelImpurity impurity, ArrayList<TreeFeature> data, int numExamples, int depth, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> labelIDMap, AbstractTrainingNode.LeafDeterminer leafDeterminer, float[] weightedLabelCounts, float weightSum, double impurityScore) {
        super(depth, numExamples, leafDeterminer);
        this.data = data;
        this.featureIDMap = featureIDMap;
        this.labelIDMap = labelIDMap;
        this.impurity = impurity;
        this.weightedLabelCounts = weightedLabelCounts;
        this.weightSum = weightSum;
        this.impurityScore = impurityScore;
    }

    public float getWeightSum() {
        return this.weightSum;
    }

    public double getImpurity() {
        return this.impurityScore;
    }

    public List<AbstractTrainingNode<Label>> buildTree(int[] featureIDs, SplittableRandom rng, boolean useRandomSplitPoints) {
        if (useRandomSplitPoints) {
            return this.buildRandomTree(featureIDs, rng);
        }
        return this.buildGreedyTree(featureIDs);
    }

    private List<AbstractTrainingNode<Label>> buildGreedyTree(int[] featureIDs) {
        int bestID = -1;
        double bestSplitValue = 0.0;
        double bestScore = this.getImpurity();
        float[] lessThanCountsOfBest = new float[this.weightedLabelCounts.length];
        float[] greaterThanCountsOfBest = new float[this.weightedLabelCounts.length];
        float[] lessThanCounts = new float[this.weightedLabelCounts.length];
        float[] greaterThanCounts = new float[this.weightedLabelCounts.length];
        for (int i = 0; i < featureIDs.length; ++i) {
            List<InvertedFeature> feature = this.data.get(featureIDs[i]).getFeature();
            Arrays.fill(lessThanCounts, 0.0f);
            System.arraycopy(this.weightedLabelCounts, 0, greaterThanCounts, 0, this.weightedLabelCounts.length);
            for (int j = 0; j < feature.size() - 1; ++j) {
                double score;
                InvertedFeature f = feature.get(j);
                float[] featureCounts = f.getWeightedLabelCounts();
                Util.inPlaceAdd((float[])lessThanCounts, (float[])featureCounts);
                Util.inPlaceSubtract((float[])greaterThanCounts, (float[])featureCounts);
                double lessThanScore = this.impurity.impurityWeighted(lessThanCounts);
                double greaterThanScore = this.impurity.impurityWeighted(greaterThanCounts);
                if (!(lessThanScore > 1.0E-10) || !(greaterThanScore > 1.0E-10) || !((score = (lessThanScore + greaterThanScore) / (double)this.weightSum) < bestScore)) continue;
                bestID = i;
                bestScore = score;
                System.arraycopy(lessThanCounts, 0, lessThanCountsOfBest, 0, lessThanCounts.length);
                System.arraycopy(greaterThanCounts, 0, greaterThanCountsOfBest, 0, greaterThanCounts.length);
                bestSplitValue = (f.value + feature.get((int)(j + 1)).value) / 2.0;
            }
        }
        double impurityDecrease = (double)this.weightSum * (this.getImpurity() - bestScore);
        List<AbstractTrainingNode<Label>> output = bestID != -1 && impurityDecrease >= (double)this.leafDeterminer.getScaledMinImpurityDecrease() ? this.splitAtBest(featureIDs, bestID, bestSplitValue, lessThanCountsOfBest, greaterThanCountsOfBest) : Collections.emptyList();
        this.data = null;
        return output;
    }

    public List<AbstractTrainingNode<Label>> buildRandomTree(int[] featureIDs, SplittableRandom rng) {
        int bestID = -1;
        double bestSplitValue = 0.0;
        double bestScore = this.getImpurity();
        float[] lessThanCountsOfBest = new float[this.weightedLabelCounts.length];
        float[] greaterThanCountsOfBest = new float[this.weightedLabelCounts.length];
        float[] lessThanCounts = new float[this.weightedLabelCounts.length];
        float[] greaterThanCounts = new float[this.weightedLabelCounts.length];
        for (int i = 0; i < featureIDs.length; ++i) {
            double score;
            List<InvertedFeature> feature = this.data.get(featureIDs[i]).getFeature();
            if (feature.size() == 1) continue;
            Arrays.fill(lessThanCounts, 0.0f);
            System.arraycopy(this.weightedLabelCounts, 0, greaterThanCounts, 0, this.weightedLabelCounts.length);
            int splitIdx = rng.nextInt(feature.size() - 1);
            for (int j = 0; j < splitIdx + 1; ++j) {
                InvertedFeature vf = feature.get(j);
                float[] countsBelowOrEqual = vf.getWeightedLabelCounts();
                Util.inPlaceAdd((float[])lessThanCounts, (float[])countsBelowOrEqual);
                Util.inPlaceSubtract((float[])greaterThanCounts, (float[])countsBelowOrEqual);
            }
            double lessThanScore = this.impurity.impurityWeighted(lessThanCounts);
            double greaterThanScore = this.impurity.impurityWeighted(greaterThanCounts);
            if (!(lessThanScore > 1.0E-10) || !(greaterThanScore > 1.0E-10) || !((score = (lessThanScore + greaterThanScore) / (double)this.weightSum) < bestScore)) continue;
            bestID = i;
            bestScore = score;
            System.arraycopy(lessThanCounts, 0, lessThanCountsOfBest, 0, lessThanCounts.length);
            System.arraycopy(greaterThanCounts, 0, greaterThanCountsOfBest, 0, greaterThanCounts.length);
            bestSplitValue = (feature.get((int)splitIdx).value + feature.get((int)(splitIdx + 1)).value) / 2.0;
        }
        double impurityDecrease = (double)this.weightSum * (this.getImpurity() - bestScore);
        List<AbstractTrainingNode<Label>> output = bestID != -1 && impurityDecrease >= (double)this.leafDeterminer.getScaledMinImpurityDecrease() ? this.splitAtBest(featureIDs, bestID, bestSplitValue, lessThanCountsOfBest, greaterThanCountsOfBest) : Collections.emptyList();
        this.data = null;
        return output;
    }

    private List<AbstractTrainingNode<Label>> splitAtBest(int[] featureIDs, int bestID, double bestSplitValue, float[] lessThanCounts, float[] greaterThanCounts) {
        ClassifierTrainingNode tmpNode;
        this.splitID = featureIDs[bestID];
        this.split = true;
        this.splitValue = bestSplitValue;
        float lessThanWeightSum = Util.sum((float[])lessThanCounts);
        double lessThanImpurityScore = this.impurity.impurity(lessThanCounts);
        float greaterThanWeightSum = Util.sum((float[])greaterThanCounts);
        double greaterThanImpurityScore = this.impurity.impurity(greaterThanCounts);
        boolean shouldMakeLessThanLeaf = this.shouldMakeLeaf(lessThanImpurityScore, lessThanWeightSum);
        boolean shouldMakeGreaterThanLeaf = this.shouldMakeLeaf(greaterThanImpurityScore, greaterThanWeightSum);
        if (shouldMakeLessThanLeaf && shouldMakeGreaterThanLeaf) {
            this.lessThanOrEqual = this.createLeaf(lessThanImpurityScore, lessThanCounts);
            this.greaterThan = this.createLeaf(greaterThanImpurityScore, greaterThanCounts);
            return Collections.emptyList();
        }
        IntArrayContainer lessThanIndices = mergeBufferOne.get();
        lessThanIndices.size = 0;
        IntArrayContainer buffer = mergeBufferTwo.get();
        buffer.size = 0;
        for (InvertedFeature f : this.data.get(this.splitID)) {
            if (!(f.value < this.splitValue)) break;
            int[] indices = f.indices();
            IntArrayContainer.merge((IntArrayContainer)lessThanIndices, (int[])indices, (IntArrayContainer)buffer);
            IntArrayContainer tmp = lessThanIndices;
            lessThanIndices = buffer;
            buffer = tmp;
        }
        IntArrayContainer secondBuffer = mergeBufferThree.get();
        secondBuffer.grow(lessThanIndices.size);
        ArrayList<TreeFeature> lessThanData = new ArrayList<TreeFeature>(this.data.size());
        ArrayList<TreeFeature> greaterThanData = new ArrayList<TreeFeature>(this.data.size());
        for (TreeFeature feature : this.data) {
            Pair<TreeFeature, TreeFeature> split = feature.split(lessThanIndices, buffer, secondBuffer);
            lessThanData.add((TreeFeature)split.getA());
            greaterThanData.add((TreeFeature)split.getB());
        }
        ArrayList<AbstractTrainingNode<Label>> output = new ArrayList<AbstractTrainingNode<Label>>(2);
        if (shouldMakeLessThanLeaf) {
            this.lessThanOrEqual = this.createLeaf(lessThanImpurityScore, lessThanCounts);
        } else {
            tmpNode = new ClassifierTrainingNode(this.impurity, lessThanData, lessThanIndices.size, this.depth + 1, this.featureIDMap, this.labelIDMap, this.leafDeterminer, lessThanCounts, lessThanWeightSum, lessThanImpurityScore);
            this.lessThanOrEqual = tmpNode;
            output.add(tmpNode);
        }
        if (shouldMakeGreaterThanLeaf) {
            this.greaterThan = this.createLeaf(greaterThanImpurityScore, greaterThanCounts);
        } else {
            tmpNode = new ClassifierTrainingNode(this.impurity, greaterThanData, this.numExamples - lessThanIndices.size, this.depth + 1, this.featureIDMap, this.labelIDMap, this.leafDeterminer, greaterThanCounts, greaterThanWeightSum, greaterThanImpurityScore);
            this.greaterThan = tmpNode;
            output.add(tmpNode);
        }
        return output;
    }

    private LeafNode<Label> createLeaf(double impurityScore, float[] weightedCounts) {
        double[] normedCounts = Util.normalizeToDistribution((float[])weightedCounts);
        double maxScore = Double.NEGATIVE_INFINITY;
        Label maxLabel = null;
        LinkedHashMap<String, Label> counts = new LinkedHashMap<String, Label>();
        for (int i = 0; i < weightedCounts.length; ++i) {
            double curCount = normedCounts[i];
            String name = ((Label)this.labelIDMap.getOutput(i)).getLabel();
            Label label = new Label(name, curCount);
            counts.put(name, label);
            if (!(curCount > maxScore)) continue;
            maxScore = curCount;
            maxLabel = label;
        }
        return new LeafNode(impurityScore, maxLabel, counts, true);
    }

    public Node<Label> convertTree() {
        if (this.split) {
            return this.createSplitNode();
        }
        return this.createLeaf(this.getImpurity(), this.weightedLabelCounts);
    }

    private static ArrayList<TreeFeature> invertData(Dataset<Label> examples) {
        int i;
        ImmutableFeatureMap featureInfos = examples.getFeatureIDMap();
        ImmutableOutputInfo labelInfo = examples.getOutputIDInfo();
        int numLabels = labelInfo.size();
        int numFeatures = featureInfos.size();
        int numExamples = examples.size();
        int[] labels = new int[numExamples];
        float[] weights = new float[numExamples];
        int k = 0;
        for (Example e : examples) {
            weights[k] = e.getWeight();
            labels[k] = labelInfo.getID(e.getOutput());
            ++k;
        }
        logger.fine("Building initial List<TreeFeature> for " + numFeatures + " features and " + numLabels + " classes");
        ArrayList<TreeFeature> data = new ArrayList<TreeFeature>(featureInfos.size());
        for (i = 0; i < featureInfos.size(); ++i) {
            data.add(new TreeFeature(i, numLabels, labels, weights));
        }
        for (i = 0; i < examples.size(); ++i) {
            Example e = examples.getExample(i);
            SparseVector vec = SparseVector.createSparseVector((Example)e, (ImmutableFeatureMap)featureInfos, (boolean)false);
            int lastID = 0;
            for (VectorTuple f : vec) {
                int curID = f.index;
                for (int j = lastID; j < curID; ++j) {
                    data.get(j).observeValue(0.0, i);
                }
                data.get(curID).observeValue(f.value, i);
                if (lastID > curID) {
                    logger.severe("Example = " + e.toString());
                    throw new IllegalStateException("Features aren't ordered. At id " + i + ", lastID = " + lastID + ", curID = " + curID);
                }
                if (lastID - 1 == curID) {
                    logger.severe("Example = " + e.toString());
                    throw new IllegalStateException("Features are repeated. At id " + i + ", lastID = " + lastID + ", curID = " + curID);
                }
                lastID = curID + 1;
            }
            for (int j = lastID; j < numFeatures; ++j) {
                data.get(j).observeValue(0.0, i);
            }
            if (i % 1000 != 0) continue;
            logger.fine("Processed example " + i);
        }
        logger.fine("Sorting features");
        data.forEach(TreeFeature::sort);
        logger.fine("Fixing InvertedFeature sizes");
        data.forEach(TreeFeature::fixSize);
        logger.fine("Built initial List<TreeFeature>");
        return data;
    }

    private void writeObject(ObjectOutputStream stream) throws IOException {
        throw new NotSerializableException("ClassifierTrainingNode is a runtime class only, and should not be serialized.");
    }
}

