package edu.stanford.nlp.scoref;

import edu.stanford.nlp.scoref.PairwiseModel;
import edu.stanford.nlp.scoref.SimpleLinearClassifier;
import edu.stanford.nlp.stats.Counter;
import java.util.Map;

/* loaded from: input_file:edu/stanford/nlp/scoref/MaxMarginMentionRanker.class */
public class MaxMarginMentionRanker extends PairwiseModel {
    private final SimpleLinearClassifier.Loss[] losses;
    private final SimpleLinearClassifier.Loss loss;
    public final double[] costs;
    public final boolean multiplicativeCost;

    /* loaded from: input_file:edu/stanford/nlp/scoref/MaxMarginMentionRanker$Builder.class */
    public static class Builder extends PairwiseModel.Builder {
        private double[] costs;
        private boolean multiplicativeCost;

        public Builder(String str, MetaFeatureExtractor metaFeatureExtractor) {
            super(str, metaFeatureExtractor);
            this.costs = new double[]{1.2d, 1.2d, 0.5d, 1.0d};
            this.multiplicativeCost = true;
        }

        public Builder setCosts(double d, double d2, double d3, double d4) {
            this.costs = new double[]{d, d2, d3, d4};
            return this;
        }

        public Builder multiplicativeCost(boolean z) {
            this.multiplicativeCost = z;
            return this;
        }

        @Override // edu.stanford.nlp.scoref.PairwiseModel.Builder
        public MaxMarginMentionRanker build() {
            return new MaxMarginMentionRanker(this);
        }
    }

    /* loaded from: input_file:edu/stanford/nlp/scoref/MaxMarginMentionRanker$ErrorType.class */
    public enum ErrorType {
        FN(0),
        FN_PRON(1),
        FL(2),
        WL(3);

        public final int id;

        ErrorType(int i) {
            this.id = i;
        }
    }

    public static Builder newBuilder(String str, MetaFeatureExtractor metaFeatureExtractor) {
        return new Builder(str, metaFeatureExtractor);
    }

    public MaxMarginMentionRanker(Builder builder) {
        super(builder);
        this.losses = new SimpleLinearClassifier.Loss[ErrorType.values().length];
        this.costs = builder.costs;
        this.multiplicativeCost = builder.multiplicativeCost;
        if (this.multiplicativeCost) {
            for (ErrorType errorType : ErrorType.values()) {
                this.losses[errorType.id] = SimpleLinearClassifier.maxMargin(builder.costs[errorType.id]);
            }
        }
        this.loss = SimpleLinearClassifier.maxMargin(1.0d);
    }

    public void learn(Example example, Example example2, Map<Integer, CompressedFeatureVector> map, Compressor<String> compressor, ErrorType errorType) {
        Counter<String> features = this.meta.getFeatures(example, map, compressor);
        Counter<String> features2 = this.meta.getFeatures(example2, map, compressor);
        for (Map.Entry<String, Double> entry : features.entrySet()) {
            features2.decrementCount(entry.getKey(), entry.getValue().doubleValue());
        }
        if (this.multiplicativeCost) {
            this.classifier.learn(features2, 1.0d, this.costs[errorType.id], this.loss);
        } else {
            this.classifier.learn(features2, 1.0d, 1.0d, this.losses[errorType.id]);
        }
    }
}
