Neutral Class Design Pattern

This notebook demonstrates on a synthetic dataset that creating a separate Neutral class can be helpful. Then, carries this to a real-world problem.

On synthetic dataset

Patients with a history of jaundice will be assumed to be at risk of liver damage and prescribed ibuprofen while patients with a history of stomach ulcers will be prescribed acetaminophen. The remaining patients will be arbitrarily assigned to either category.


In [39]:
import numpy as np
import pandas as pd

def create_synthetic_dataset(N, shuffle):
    # random array
    prescription = np.full(N, fill_value='acetominophen', dtype='U20')
    prescription[:N//2] = 'ibuprofen'
    np.random.shuffle(prescription)
    
    # neutral class
    p_neutral = np.full(N, fill_value='Neutral', dtype='U20')

    # 10% is patients with history of liver disease
    jaundice = np.zeros(N, dtype=bool)
    jaundice[0:N//10] = True
    prescription[0:N//10] = 'ibuprofen'
    p_neutral[0:N//10] = 'ibuprofen'

    # 10% is patients with history of stomach problems
    ulcers = np.zeros(N, dtype=bool)
    ulcers[(9*N)//10:] = True
    prescription[(9*N)//10:] = 'acetominophen'
    p_neutral[(9*N)//10:] = 'acetominophen'
    
    df = pd.DataFrame.from_dict({
        'jaundice': jaundice,
        'ulcers': ulcers,
        'prescription': prescription,
        'prescription_with_neutral': p_neutral
    })
    
    if shuffle:
        return df.sample(frac=1).reset_index(drop=True)
    else:
        return df

create_synthetic_dataset(10, False)


Out[39]:
jaundice prescription prescription_with_neutral ulcers
0 True ibuprofen ibuprofen False
1 False acetominophen Neutral False
2 False acetominophen Neutral False
3 False ibuprofen Neutral False
4 False ibuprofen Neutral False
5 False ibuprofen Neutral False
6 False acetominophen Neutral False
7 False acetominophen Neutral False
8 False ibuprofen Neutral False
9 False acetominophen acetominophen True

In [40]:
df = create_synthetic_dataset(1000, shuffle=True)

from sklearn import linear_model
for label in ['prescription', 'prescription_with_neutral']:
    ntrain = 8*len(df)//10 # 80% of data for training
    lm = linear_model.LogisticRegression()
    lm = lm.fit(df.loc[:ntrain-1, ['jaundice', 'ulcers']], df[label][:ntrain])
    acc = lm.score(df.loc[ntrain:, ['jaundice', 'ulcers']], df[label][ntrain:])
    print('label={} accuracy={}'.format(label, acc))


label=prescription accuracy=0.555
label=prescription_with_neutral accuracy=1.0

On the Natality data

Let's do this on real data. A baby with an Apgar score of 10 is healthy and one with an Apgar score of <= 7 requires some medical attention. What about babies with scores of 8-9? They are neither perfectly healthy, nor do they need serious medical intervention. Let's see how the model does with a 2-class model and with a 3-class model that includes a Neutral class.

First, without the Neutral class


In [ ]:
%%bigquery
CREATE OR REPLACE MODEL mlpatterns.neutral_2classes
OPTIONS(model_type='logistic_reg', input_label_cols=['health']) AS

SELECT 
  IF(apgar_1min >= 9, 'Healthy', 'NeedsAttention') AS health,
  plurality,
  mother_age,
  gestation_weeks,
  ever_born
FROM `bigquery-public-data.samples.natality`
WHERE apgar_1min <= 10

In [41]:
%%bigquery
SELECT * FROM ML.EVALUATE(MODEL mlpatterns.neutral_2classes)


Out[41]:
precision recall accuracy f1_score log_loss roc_auc
0 0.565628 0.997893 0.565213 0.722007 0.690348 0.52722

With 3 classes (including a neutral class)


In [ ]:
%%bigquery
CREATE OR REPLACE MODEL mlpatterns.neutral_3classes
OPTIONS(model_type='logistic_reg', input_label_cols=['health']) AS

SELECT 
  IF(apgar_1min = 10, 'Healthy', IF(apgar_1min >= 8, 'Neutral', 'NeedsAttention')) AS health,
  plurality,
  mother_age,
  gestation_weeks,
  ever_born
FROM `bigquery-public-data.samples.natality`
WHERE apgar_1min <= 10

In [38]:
%%bigquery
SELECT * FROM ML.EVALUATE(MODEL mlpatterns.neutral_3classes)


Out[38]:
precision recall accuracy f1_score log_loss roc_auc
0 0.46499 0.333789 0.794872 0.296302 1.840975 0.553596

Copyright 2020 Google Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License