Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

Overview

In this notebook, we explore the problem of fairness generalization. That is, given that we've trained a model to satisfy the fairness constraints on training data, will fairness also hold on testing. We present the approach proposed by [CotterEtAl2018]. Constrained optimization can be viewed as a two-player game where one player minimizes the training error and the second player minimizes the constraint violation. The main idea of [CotterEtAl2018] is that if each player optimizes over the same dataset, then the fairness generalization error will have both the error due to the model complexity and the constraints and a better approach is for each player to use separate data. For example, as we will show in this notebook, we can allow the first player to use the first half of the training dataset and the second player to use the second half of the training dataset. Since the two players see different datasets, the generalization guarantees for fairness will be decoupled from the model complexity and thus can be substantially improved, especially when the model is complex. We show that this is indeed the case on the [Communities and Crime dataset].


In [1]:
import math
import random
import numpy as np
import pandas as pd
import warnings
from six.moves import xrange
import tensorflow.compat.v1 as tf
import tensorflow_constrained_optimization as tfco
import matplotlib.pyplot as plt

tf.disable_eager_execution()

warnings.filterwarnings('ignore')
%matplotlib inline

Reading and processing dataset.

We load and download the [Communities and Crime dataset] and do some pre-processing. We impute the missing values for each feature by the average of that feature over the training set. We then construct eight protected groups, two for each race based on whether the percentage of that race in a community was above or below median. We also convert the continuous label to a binary one.

The fairness goal is to make sure that the rate at which we falsely predict a neighborhood is violent for any of the protected groups is no higher than that of the overall dataset.


In [2]:
column_names = "state,county,community,communityname,fold,population,householdsize,racePctblack,racePctWhite,racePctAsian,racePctHisp,agePct12t21,agePct12t29,agePct16t24,agePct65up,numbUrban,pctUrban,medIncome,pctWWage,pctWFarmSelf,pctWInvInc,pctWSocSec,pctWPubAsst,pctWRetire,medFamInc,perCapInc,whitePerCap,blackPerCap,indianPerCap,AsianPerCap,OtherPerCap,HispPerCap,NumUnderPov,PctPopUnderPov,PctLess9thGrade,PctNotHSGrad,PctBSorMore,PctUnemployed,PctEmploy,PctEmplManu,PctEmplProfServ,PctOccupManu,PctOccupMgmtProf,MalePctDivorce,MalePctNevMarr,FemalePctDiv,TotalPctDiv,PersPerFam,PctFam2Par,PctKids2Par,PctYoungKids2Par,PctTeen2Par,PctWorkMomYoungKids,PctWorkMom,NumIlleg,PctIlleg,NumImmig,PctImmigRecent,PctImmigRec5,PctImmigRec8,PctImmigRec10,PctRecentImmig,PctRecImmig5,PctRecImmig8,PctRecImmig10,PctSpeakEnglOnly,PctNotSpeakEnglWell,PctLargHouseFam,PctLargHouseOccup,PersPerOccupHous,PersPerOwnOccHous,PersPerRentOccHous,PctPersOwnOccup,PctPersDenseHous,PctHousLess3BR,MedNumBR,HousVacant,PctHousOccup,PctHousOwnOcc,PctVacantBoarded,PctVacMore6Mos,MedYrHousBuilt,PctHousNoPhone,PctWOFullPlumb,OwnOccLowQuart,OwnOccMedVal,OwnOccHiQuart,RentLowQ,RentMedian,RentHighQ,MedRent,MedRentPctHousInc,MedOwnCostPctInc,MedOwnCostPctIncNoMtg,NumInShelters,NumStreet,PctForeignBorn,PctBornSameState,PctSameHouse85,PctSameCity85,PctSameState85,LemasSwornFT,LemasSwFTPerPop,LemasSwFTFieldOps,LemasSwFTFieldPerPop,LemasTotalReq,LemasTotReqPerPop,PolicReqPerOffic,PolicPerPop,RacialMatchCommPol,PctPolicWhite,PctPolicBlack,PctPolicHisp,PctPolicAsian,PctPolicMinor,OfficAssgnDrugUnits,NumKindsDrugsSeiz,PolicAveOTWorked,LandArea,PopDens,PctUsePubTrans,PolicCars,PolicOperBudg,LemasPctPolicOnPatr,LemasGangUnitDeploy,LemasPctOfficDrugUn,PolicBudgPerPop,ViolentCrimesPerPop".split(",")
df = pd.read_csv("http://archive.ics.uci.edu/ml/machine-learning-databases/communities/communities.data",  header=None, names=column_names, na_values=['?'])

LABEL_COLUMN = 'label'

CATEGORICAL_COLUMNS = [
    'racePctblack_cat', 'racePctAsian_cat', 'racePctWhite_cat',
    'racePctHisp_cat'
]


PROTECTED_COLUMNS = ['racePctblack_cat_low', 'racePctblack_cat_high',
                     'racePctAsian_cat_low', 'racePctAsian_cat_high',
                     'racePctWhite_cat_low', 'racePctWhite_cat_high',
                     'racePctHisp_cat_low', 'racePctHisp_cat_high']


CONTINUOUS_LABEL_COLUMN = 'ViolentCrimesPerPop'

BINARY_LABEL_COLUMN = 'label'

EXCLUDED_COLUMNS = [
    'state', 'county', 'community', 'communityname', 'ViolentCrimesPerPop'
]

def _train_test_split(df, test_frac=0.33, seed=42):
    np.random.seed(seed)
    perm = np.random.permutation(df.index)
    n = len(df)

    test_end = int(test_frac * n)
    test_df = df.iloc[perm[:test_end]]
    train_df = df.iloc[perm[test_end:]]

    return train_df, test_df

def add_binary_label(input_df):
    quantile = input_df[CONTINUOUS_LABEL_COLUMN].quantile(0.7)
    input_df[BINARY_LABEL_COLUMN] = input_df.apply(
      lambda row: 1 if row[CONTINUOUS_LABEL_COLUMN] > quantile else 0, axis=1)


def add_protected_categories(input_df):
    # Bucketize race percentages
    input_df['racePctblack_cat'] = pd.qcut(
      input_df['racePctblack'], 2, labels=['low', 'high'])
    input_df['racePctAsian_cat'] = pd.qcut(
      input_df['racePctAsian'], 2, labels=['low', 'high'])
    input_df['racePctWhite_cat'] = pd.qcut(
      input_df['racePctWhite'], 2, labels=['low', 'high'])
    input_df['racePctHisp_cat'] = pd.qcut(
      input_df['racePctHisp'], 2, labels=['low', 'high'])

add_binary_label(df)
add_protected_categories(df)
df = pd.get_dummies(df, columns=CATEGORICAL_COLUMNS)

FEATURE_NAMES = [
      name for name in df.keys()
      if name not in [LABEL_COLUMN] + EXCLUDED_COLUMNS
  ]


train_df, test_df = _train_test_split(df)

for column in FEATURE_NAMES:
    train_mean = train_df[column].mean()
    train_df[column].fillna(train_mean, inplace=True)
    test_df[column].fillna(train_mean, inplace=True)

np.random.seed(12345)
train_df['SPLIT_0'] = np.random.randint(0, 2, train_df.shape[0])
train_df['SPLIT_1'] = train_df['SPLIT_0'].apply(lambda row: 1 - row)

Model.

We use a 1-hidden layer Neural Network with ReLU activations and 10 hidden units. We show that even with a relatively simple model, we still can see substantial improvements in fairness generalization performance when using the approach of dataset splitting.

In the following code, we initialize the placeholders and model. In build_train_op, we set up the constrained optimization problem. We create a rate context for the entire dataset, and compute the overall false positive rate as the positive prediction rate on the negatively labeled subset. We then construct a constraint for each of the protected groups based on the difference between the false positive rates of the protected group and that of the overall dataset.

For the non-split approach, we use the typical rate_context, but for the split dataset approach, we use split_rate_context which allows us to conveniently input two separate datasets and then later in optimization, it will know which examples to use for the objective and which examples to use for the constraint.

We then construct a minimization problem using RateMinimizationProblem and use the ProxyLagrangianOptimizerV1 as the solver. build_train_op initializes a training operation which will later be used to actually train the model.


In [3]:
def _construct_model(model_name, input_tensor, hidden_units=10):
    with tf.variable_scope('model_name', reuse=True):
        hidden = tf.layers.dense(
            inputs=input_tensor,
            units=hidden_units,
            activation=tf.nn.relu,
            reuse=tf.AUTO_REUSE,
            name=model_name + "_hidden")
        output = tf.layers.dense(
            inputs=hidden,
            units=1,
            activation=None,
            reuse=tf.AUTO_REUSE,
            name=model_name + "_outputs")
        return output

class Model(object):
    def __init__(self,
                 model_name,
                feature_names,
                hidden_units=10,
                gen_split=False,
                fpr_max_diff=0):
        tf.random.set_random_seed(123)
        self.feature_names = feature_names
        self.fpr_max_diff = fpr_max_diff
        num_features = len(self.feature_names)
        self.gen_split = gen_split
        if self.gen_split:
            self.features_split_0 = tf.placeholder(
                tf.float32, shape=(None, num_features), name='split_0_features_placeholder')
            self.features_split_1 = tf.placeholder(
                tf.float32, shape=(None, num_features), name='split_1_features_placeholder')
            self.split_0_labels = tf.placeholder(
                tf.float32, shape=(None, 1), name='split_0_labels_placeholder')
            self.split_1_labels = tf.placeholder(
                tf.float32, shape=(None, 1), name='split_1_labels_placeholder')
            self.split_0_predictions = _construct_model(
                model_name, self.features_split_0, hidden_units=hidden_units)
            self.split_1_predictions = _construct_model(
                model_name, self.features_split_1, hidden_units=hidden_units)
            self.protected_split_0 = [tf.placeholder(tf.float32, shape=(None, 1), name=attribute+"_placeholder0") for attribute in PROTECTED_COLUMNS]
            self.protected_split_1 = [tf.placeholder(tf.float32, shape=(None, 1), name=attribute+"_placeholder1") for attribute in PROTECTED_COLUMNS]


        self.features_placeholder = tf.placeholder(
            tf.float32, shape=(None, num_features), name='features_placeholder')
        self.labels_placeholder = tf.placeholder(
            tf.float32, shape=(None, 1), name='labels_placeholder')
        self.predictions_tensor = _construct_model(
            model_name, self.features_placeholder, hidden_units=hidden_units)
        self.protected_placeholders = [tf.placeholder(tf.float32, shape=(None, 1), name=attribute+"_placeholder") for attribute in PROTECTED_COLUMNS]

    def build_train_op(self,
                       learning_rate,
                       unconstrained=False):
        if self.gen_split:
            ctx = tfco.split_rate_context(self.split_0_predictions, self.split_1_predictions, self.split_0_labels, self.split_1_labels)
            negative_slice = ctx.subset(self.split_0_labels <= 0, self.split_1_labels <= 0) 
        else:
            ctx = tfco.rate_context(self.predictions_tensor, self.labels_placeholder)
            negative_slice = ctx.subset(self.labels_placeholder <= 0)
            
        overall_fpr = tfco.positive_prediction_rate(negative_slice)
        constraints = []
        if not unconstrained:
            for i in range(len(PROTECTED_COLUMNS)):
                if self.gen_split:
                    slice_fpr = tfco.positive_prediction_rate(
                        ctx.subset(
                            (self.split_0_labels < 0) & (self.protected_split_0[i] > 0),
                            (self.split_1_labels < 0) & (self.protected_split_1[i] > 0)))

                else:
                    slice_fpr = tfco.positive_prediction_rate(
                        ctx.subset((self.protected_placeholders[i] > 0) & (self.labels_placeholder < 0)))
                constraints.append(slice_fpr <= overall_fpr + self.fpr_max_diff)
          
        error = tfco.error_rate(ctx)
        mp = tfco.RateMinimizationProblem(error, constraints)
        opt = tfco.ProxyLagrangianOptimizerV1(tf.train.AdamOptimizer(learning_rate))
        self.train_op = opt.minimize(mp)
        return self.train_op
  
    def feed_dict_helper(self, dataframe, train=False):
        feed_dict = {}
        if self.gen_split and train:
            feed_dict[self.features_split_0] = dataframe[dataframe['SPLIT_0'] > 0][self.feature_names]
            feed_dict[self.features_split_1] = dataframe[dataframe['SPLIT_1'] > 0][self.feature_names]
            feed_dict[self.split_0_labels] = dataframe[dataframe['SPLIT_0'] > 0][[LABEL_COLUMN]]
            feed_dict[self.split_1_labels] = dataframe[dataframe['SPLIT_1'] > 0][[LABEL_COLUMN]]
            for i, protected_attribute in enumerate(PROTECTED_COLUMNS):
                feed_dict[self.protected_split_0[i]] = dataframe[dataframe['SPLIT_0'] > 0][[protected_attribute]]
                feed_dict[self.protected_split_1[i]] = dataframe[dataframe['SPLIT_1'] > 0][[protected_attribute]]

        elif self.gen_split and not train:
            feed_dict[self.features_split_0] = dataframe[self.feature_names]
            feed_dict[self.features_split_1] = dataframe[self.feature_names]
            feed_dict[self.split_0_labels] = dataframe[[LABEL_COLUMN]]
            feed_dict[self.split_1_labels] = dataframe[[LABEL_COLUMN]]
            for i, protected_attribute in enumerate(PROTECTED_COLUMNS):
                feed_dict[self.protected_split_0[i]] = dataframe[[protected_attribute]]
                feed_dict[self.protected_split_1[i]] = dataframe[[protected_attribute]]

        feed_dict[self.features_placeholder] = dataframe[self.feature_names]
        feed_dict[self.labels_placeholder] = dataframe[[LABEL_COLUMN]]
        for i, protected_attribute in enumerate(PROTECTED_COLUMNS):
            feed_dict[self.protected_placeholders[i]] = dataframe[[protected_attribute]]
            
        return feed_dict

Training.

Below is the function which performs the training of our constrained optimization problem. Each call to the function does one epoch through the dataset and then yields the training and testing predictions.


In [4]:
def training_generator(model,
                       train_df,
                       test_df,
                       minibatch_size,
                       num_iterations_per_loop=1,
                       num_loops=1):
    random.seed(31337)
    num_rows = train_df.shape[0]
    minibatch_size = min(minibatch_size, num_rows)
    permutation = list(range(train_df.shape[0]))
    random.shuffle(permutation)

    session = tf.Session()
    session.run((tf.global_variables_initializer(),
               tf.local_variables_initializer()))

    minibatch_start_index = 0
    for n in xrange(num_loops):
        for _ in xrange(num_iterations_per_loop):
            minibatch_indices = []
            while len(minibatch_indices) < minibatch_size:
                minibatch_end_index = (
                minibatch_start_index + minibatch_size - len(minibatch_indices))
                if minibatch_end_index >= num_rows:
                    minibatch_indices += range(minibatch_start_index, num_rows)
                    minibatch_start_index = 0
                else:
                    minibatch_indices += range(minibatch_start_index, minibatch_end_index)
                    minibatch_start_index = minibatch_end_index

            session.run(
                  model.train_op,
                  feed_dict=model.feed_dict_helper(
                      train_df.iloc[[permutation[ii] for ii in minibatch_indices]], train=True))
            
        train_predictions = session.run(
            model.predictions_tensor,
            feed_dict=model.feed_dict_helper(train_df))
        test_predictions = session.run(
            model.predictions_tensor,
            feed_dict=model.feed_dict_helper(test_df))

        yield (train_predictions, test_predictions)

Computing accuracy and fairness metrics.


In [5]:
def error_rate(predictions, labels):
    signed_labels = (
      (labels > 0).astype(np.float32) - (labels <= 0).astype(np.float32))
    numerator = (np.multiply(signed_labels, predictions) <= 0).sum()
    denominator = predictions.shape[0]
    return float(numerator) / float(denominator)

def positive_prediction_rate(predictions, subset):
    numerator = np.multiply((predictions > 0).astype(np.float32),
                          (subset > 0).astype(np.float32)).sum()
    denominator = (subset > 0).sum()
    return float(numerator) / float(denominator)

def fpr(df):
    fp = sum((df['predictions'] >= 0.0) & (df[LABEL_COLUMN] < 0.5))
    ln = sum(df[LABEL_COLUMN] < 0.5)
    return float(fp) / float(ln)

def _get_error_rate_and_constraints(df, fpr_max_diff):
    error_rate_local = error_rate(df[['predictions']], df[[LABEL_COLUMN]])
    overall_fpr = fpr(df)
    return error_rate_local, overall_fpr, [fpr(df[df[protected_attribute] > 0.5]) - (overall_fpr + fpr_max_diff) for protected_attribute in PROTECTED_COLUMNS]

def _get_exp_error_rate_constraints(cand_dist, error_rates_vector,
                                    overall_fpr_vector, constraints_matrix):
    expected_error_rate = np.dot(cand_dist, error_rates_vector)
    expected_overall_fpr = np.dot(cand_dist, overall_fpr_vector)
    expected_constraints = np.matmul(cand_dist, constraints_matrix)
    return expected_error_rate, expected_overall_fpr, expected_constraints


def get_iterate_metrics(cand_dist, best_cand_index, error_rate_vector,
                        overall_fpr_vector, constraints_matrix):
    metrics = {}
    exp_error_rate, exp_overall_fpr, exp_constraints = _get_exp_error_rate_constraints(
      cand_dist, error_rate_vector, overall_fpr_vector, constraints_matrix)
    metrics['m_stochastic_error_rate'] = exp_error_rate
    metrics['m_stochastic_overall_fpr'] = exp_overall_fpr
    metrics['m_stochastic_max_constraint_violation'] = max(exp_constraints)
    for i, constraint in enumerate(exp_constraints):
        metrics['m_stochastic_constraint_violation_%d' % i] = constraint
    metrics['best_error_rate'] = error_rate_vector[best_cand_index]
    metrics['last_error_rate'] = error_rate_vector[-1]
    metrics['t_stochastic_error_rate'] = sum(error_rate_vector) / len(
      error_rate_vector)
    metrics['best_overall_fpr'] = overall_fpr_vector[best_cand_index]
    metrics['last_overall_fpr'] = overall_fpr_vector[-1]
    metrics['t_stochastic_overall_fpr'] = sum(overall_fpr_vector) / len(
      overall_fpr_vector)
    avg_constraints = []
    best_constraints = []
    last_constraints = []
    for constraint_iterates in np.transpose(constraints_matrix):
        avg_constraint = sum(constraint_iterates) / len(constraint_iterates)
        avg_constraints.append(avg_constraint)
        best_constraints.append(constraint_iterates[best_cand_index])
        last_constraints.append(constraint_iterates[-1])
    metrics['best_max_constraint_violation'] = max(best_constraints)
    for i, constraint in enumerate(best_constraints):
        metrics['best_constraint_violation_%d' % i] = constraint
    metrics['last_max_constraint_violation'] = max(last_constraints)
    for i, constraint in enumerate(last_constraints):
        metrics['last_constraint_violation_%d' % i] = constraint
    metrics['t_stochastic_max_constraint_violation'] = max(avg_constraints)
    for i, constraint in enumerate(avg_constraints):
        metrics['t_stochastic_constraint_violation_%d' % i] = constraint
    return metrics

def training_helper(model,
                    train_df,
                    test_df,
                    minibatch_size,
                    num_iterations_per_loop=1,
                    num_loops=1):
    train_objective_vector = []
    train_constraints_loss_matrix = []
    train_error_rate_vector = []
    train_overall_fpr_vector = []
    train_constraints_matrix = []
    test_error_rate_vector = []
    test_overall_fpr_vector = []
    test_constraints_matrix = []
    for train, test in training_generator(
        model, train_df, test_df, minibatch_size, num_iterations_per_loop, num_loops):
        train_df['predictions'] = train
        test_df['predictions'] = test

        if model.gen_split:
            train_error_rate_split0, train_overall_fpr0, train_constraints_split0 = _get_error_rate_and_constraints(train_df[train_df['SPLIT_0'] > 0], model.fpr_max_diff)
            train_error_rate_split1, train_overall_fpr1, train_constraints_split1 = _get_error_rate_and_constraints(train_df[train_df['SPLIT_1'] > 0], model.fpr_max_diff)
            train_error_rate_vector.append(train_error_rate_split0)
            train_constraints_matrix.append(train_constraints_split1)
            train_constraints_loss_matrix.append(train_constraints_split1)
            train_overall_fpr_vector.append((train_overall_fpr0 + train_overall_fpr1) / 2)
        else:
            train_error_rate, train_overall_fpr, train_constraints = _get_error_rate_and_constraints(train_df, model.fpr_max_diff)
            train_error_rate_vector.append(train_error_rate)
            train_overall_fpr_vector.append(train_overall_fpr)
            train_constraints_matrix.append(train_constraints)

        test_error_rate, test_overall_fpr, test_constraints = _get_error_rate_and_constraints(
            test_df, model.fpr_max_diff)
        test_error_rate_vector.append(test_error_rate)
        test_overall_fpr_vector.append(test_overall_fpr)
        test_constraints_matrix.append(test_constraints)

    cand_dist = tfco.find_best_candidate_distribution(
      train_error_rate_vector, train_constraints_matrix, epsilon=0.001)
    best_cand_index = tfco.find_best_candidate_index(
      train_error_rate_vector, train_constraints_matrix)
    train_metrics = get_iterate_metrics(
      cand_dist, best_cand_index, train_error_rate_vector,
      train_overall_fpr_vector, train_constraints_matrix)
    test_metrics = get_iterate_metrics(
      cand_dist, best_cand_index, test_error_rate_vector,
      test_overall_fpr_vector, test_constraints_matrix)

    return (train_metrics, test_metrics)

Baseline without constraints.

We now declare the model, build the training op, and then perform the training. We use a neural network with 10 hidden units as the underlying classifier, and train using the ADAM optimizer with learning rate 0.01, with minibatch size of 100 over 500 epochs. We first train without fairness constraints to show the baseline performance. We see that without training fair fairness, we obtain a high fairness violation.


In [6]:
model = Model("baseline_unconstrained", FEATURE_NAMES, hidden_units=10, gen_split=False)
model.build_train_op(0.01, unconstrained=True)

results = training_helper(
      model,
      train_df,
      test_df,
      100,
      num_iterations_per_loop=14,
      num_loops=100)

In [7]:
print("Train Error", results[0]["last_error_rate"])
print("Train Violation", results[0]["last_max_constraint_violation"])
print()
print("Test Error", results[1]["last_error_rate"])
print("Test Violation", results[1]["last_max_constraint_violation"])


Train Error 0.10479041916167664
Train Violation 0.08842700774812448

Test Error 0.135258358662614
Test Violation 0.1786933268832809

Baseline with constraints single training dataset.

We now show train with the constraints and show the performance of the m-stochastic classifier as used in [CotterEtAl2018].

We see a fairly large difference between training fairness violation and testing fairness violation.


In [8]:
model = Model("single_dataset", FEATURE_NAMES, hidden_units=10, gen_split=False)
model.build_train_op(0.01, unconstrained=False)

results = training_helper(
      model,
      train_df,
      test_df,
      100,
      num_iterations_per_loop=14,
      num_loops=100)

In [9]:
print("Train Error", results[0]["m_stochastic_error_rate"])
print("Train Violation", results[0]["m_stochastic_max_constraint_violation"])
print()
print("Test Error", results[1]["m_stochastic_error_rate"])
print("Test Violation", results[1]["m_stochastic_max_constraint_violation"])


Train Error 0.2252994011976048
Train Violation 0.03470052884024105

Test Error 0.23100303951367782
Test Violation 0.025235830041119085

Putting it together: training with constraints using dataset splitting.

We now show what happens when we split the training dataset into two halves, one for minimizing the loss and the other for enforcing the fairness constraints.

We see a substantial improvement in generalization (that is, the difference between training fairness violation and testing fairness violation).


In [10]:
model = Model("split_dataset", FEATURE_NAMES, hidden_units=10, gen_split=True)
model.build_train_op(0.01, unconstrained=False)

results = training_helper(
      model,
      train_df,
      test_df,
      100,
      num_iterations_per_loop=14,
      num_loops=100)

In [11]:
print("Train Error", results[0]["m_stochastic_error_rate"])
print("Train Violation", results[0]["m_stochastic_max_constraint_violation"])
print()
print("Test Error", results[1]["m_stochastic_error_rate"])
print("Test Violation", results[1]["m_stochastic_max_constraint_violation"])


Train Error 0.24011713030746706
Train Violation 0.010979308226803331

Test Error 0.22796352583586627
Test Violation 0.004205971673519847