/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.evaluation;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.ToDoubleFunction;
import java.util.logging.Logger;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.classification.evaluation.ConfusionMatrix;
import org.tribuo.math.la.DenseMatrix;

public final class LabelConfusionMatrix
implements ConfusionMatrix<Label> {
    private static final Logger logger = Logger.getLogger(LabelConfusionMatrix.class.getName());
    private final ImmutableOutputInfo<Label> domain;
    private final int total;
    private final Map<Label, Double> occurrences;
    private final Set<Label> observed;
    private final DenseMatrix cm;
    private List<Label> labelOrder;

    public LabelConfusionMatrix(Model<Label> model, List<Prediction<Label>> predictions) {
        this((ImmutableOutputInfo<Label>)model.getOutputIDInfo(), predictions);
    }

    public LabelConfusionMatrix(ImmutableOutputInfo<Label> domain, List<Prediction<Label>> predictions) {
        this.domain = domain;
        this.total = predictions.size();
        this.cm = new DenseMatrix(domain.size(), domain.size());
        this.occurrences = new HashMap<Label, Double>();
        this.observed = new HashSet<Label>();
        this.tabulate(predictions);
    }

    private void tabulate(List<Prediction<Label>> predictions) {
        predictions.forEach(prediction -> {
            Label y = (Label)prediction.getExample().getOutput();
            Label p = (Label)prediction.getOutput();
            if (y.getLabel().equals("LABEL##UNKNOWN")) {
                throw new IllegalArgumentException("Prediction with unknown ground truth. Unable to evaluate.");
            }
            this.occurrences.merge(y, 1.0, Double::sum);
            this.observed.add(y);
            this.observed.add(p);
            int iy = this.getIDOrThrow(y);
            int ip = this.getIDOrThrow(p);
            this.cm.add(ip, iy, 1.0);
        });
    }

    @Override
    public ImmutableOutputInfo<Label> getDomain() {
        return this.domain;
    }

    @Override
    public double support() {
        return this.total;
    }

    @Override
    public double support(Label label) {
        return this.occurrences.getOrDefault(label, 0.0);
    }

    @Override
    public double tp(Label cls) {
        return this.compute(cls, i -> this.cm.get(i.intValue(), i.intValue()));
    }

    @Override
    public double fp(Label cls) {
        return this.compute(cls, i -> this.cm.rowSum(i.intValue()) - this.cm.get(i.intValue(), i.intValue()));
    }

    @Override
    public double fn(Label cls) {
        return this.compute(cls, i -> this.cm.columnSum(i.intValue()) - this.cm.get(i.intValue(), i.intValue()));
    }

    @Override
    public double tn(Label cls) {
        int n = this.getDomain().size();
        int i = this.getDomain().getID((Output)cls);
        double total = 0.0;
        for (int j = 0; j < n; ++j) {
            if (j == i) continue;
            for (int k = 0; k < n; ++k) {
                if (k == i) continue;
                total += this.cm.get(j, k);
            }
        }
        return total;
    }

    @Override
    public double confusion(Label predicted, Label trueClass) {
        int i = this.getDomain().getID((Output)predicted);
        int j = this.getDomain().getID((Output)trueClass);
        return this.cm.get(i, j);
    }

    private double compute(Label cls, ToDoubleFunction<Integer> getter) {
        int i = this.getDomain().getID((Output)cls);
        if (i < 0) {
            logger.fine("Unknown Label " + cls);
            return 0.0;
        }
        return getter.applyAsDouble(i);
    }

    private int getIDOrThrow(Label key) {
        int id = this.domain.getID((Output)key);
        if (id < 0) {
            throw new IllegalArgumentException("Unknown label: " + key);
        }
        return id;
    }

    public void setLabelOrder(List<Label> labelOrder) {
        this.labelOrder = labelOrder;
    }

    public String toString() {
        if (this.labelOrder == null) {
            this.labelOrder = new ArrayList<Label>(this.domain.getDomain());
        }
        this.labelOrder.retainAll(this.observed);
        int maxLen = Integer.MIN_VALUE;
        for (Label label : this.labelOrder) {
            maxLen = Math.max(label.getLabel().length(), maxLen);
            maxLen = Math.max(String.format(" %,d", (int)this.occurrences.getOrDefault(label, 0.0).doubleValue()).length(), maxLen);
        }
        StringBuilder sb = new StringBuilder();
        String trueLabelFormat = String.format("%%-%ds", maxLen + 2);
        String predictedLabelFormat = String.format("%%%ds", maxLen + 2);
        String countFormat = String.format("%%,%dd", maxLen + 2);
        sb.append(String.format(trueLabelFormat, ""));
        for (Label predictedLabel : this.labelOrder) {
            sb.append(String.format(predictedLabelFormat, predictedLabel.getLabel()));
        }
        sb.append('\n');
        for (Label trueLabel : this.labelOrder) {
            sb.append(String.format(trueLabelFormat, trueLabel.getLabel()));
            for (Label predictedLabel : this.labelOrder) {
                int confusion = (int)this.confusion(predictedLabel, trueLabel);
                sb.append(String.format(countFormat, confusion));
            }
            sb.append('\n');
        }
        return sb.toString();
    }

    public String toHTML() {
        if (this.labelOrder == null) {
            this.labelOrder = new ArrayList<Label>(this.domain.getDomain());
        }
        LinkedHashSet<Label> labelsToPrint = new LinkedHashSet<Label>(this.labelOrder);
        labelsToPrint.retainAll(this.observed);
        StringBuilder sb = new StringBuilder();
        sb.append("<table>\n");
        sb.append(String.format("<tr><th>True Label</th><th style=\"text-align:center\" colspan=\"%d\">Predicted Labels</th></tr>%n", this.occurrences.size() + 1));
        sb.append("<tr><th></th>");
        for (Label predictedLabel : labelsToPrint) {
            sb.append("<th style=\"text-align:right\">").append(predictedLabel).append("</th>");
        }
        sb.append("<th style=\"text-align:right\">Total</th>");
        sb.append("</tr>\n");
        for (Label trueLabel : labelsToPrint) {
            sb.append("<tr><th>").append(trueLabel).append("</th>");
            double count = this.occurrences.getOrDefault(trueLabel, 0.0);
            for (Label predictedLabel : labelsToPrint) {
                double tlmc = this.confusion(predictedLabel, trueLabel);
                double percent = tlmc / count * 100.0;
                sb.append("<td style=\"text-align:right\">").append(String.format("%,d (%.1f%%)", (int)tlmc, percent)).append("</td>");
            }
            sb.append("<td style=\"text-align:right\">").append(count).append("</td>");
            sb.append("</tr>\n");
        }
        sb.append("</table>");
        return sb.toString();
    }
}

