Features:
test_
.test_
. # in datafuncs.py
def increment(x):
return x + 1
# in test_datafuncs.py
import datafuncs as dfn
def test_increment():
assert dfn.increment(2) != 3
Now, in your terminal, execute the following command:
$ py.test
This is the output you should expect:
============================= test session starts ==============================
platform darwin -- Python 3.6.1, pytest-3.0.7, py-1.4.33, pluggy-0.4.0
rootdir: /Users/ericmjl/github/tutorials/data-testing-tutorial, inifile:
collected 1 items
test_datafuncs.py F
=================================== FAILURES ===================================
________________________________ test_increment ________________________________
def test_increment():
> assert dfn.increment(2) != 3
E assert 3 != 3
E + where 3 = <function increment at 0x10eaf7378>(2)
E + where <function increment at 0x10eaf7378> = dfn.increment
test_datafuncs.py:3: AssertionError
=========================== 1 failed in 0.06 seconds ===========================
Let's break down the output for you, to make this simpler.
Firstly, the header.
============================= test session starts ==============================
platform darwin -- Python 3.6.1, pytest-3.0.7, py-1.4.33, pluggy-0.4.0
rootdir: /Users/ericmjl/github/tutorials/data-testing-tutorial, inifile:
collected 1 items
test_datafuncs.py F
collected 1 items
indicates how many test functions were written.test_datafuncs.py
, and is the file that contains the tests. You are allowed to have multiple files that contain tests.F
indicates that there was a test function that failed. The only two outputs you need to be concerned with right now are F
and .
(dot).Next, let's look at the FAILURES section.
=================================== FAILURES ===================================
________________________________ test_increment ________________________________
def test_increment():
> assert dfn.increment(2) != 3
E assert 3 != 3
E + where 3 = <function increment at 0x10eaf7378>(2)
E + where <function increment at 0x10eaf7378> = dfn.increment
test_datafuncs.py:3: AssertionError
=========================== 1 failed in 0.06 seconds ===========================
___ test_increment ___
header.>
(greater than) symbol.E
symbol. We use this information to figure out how a test failed.Congratulations! You wrote your first failed test! With py.test, you have a command that automatically finds tests, executes them, and reports where they fail.
Questions so far?
Now, go fix the test such that it works correctly.
def test_increment():
assert dfn.increment(2) == 3
And then re-run that test.
$ py.test
If everything passes, it should look like the following terminal output.
============================= test session starts ==============================
platform darwin -- Python 3.6.1, pytest-3.0.7, py-1.4.33, pluggy-0.4.0
rootdir: /Users/ericmjl/github/tutorials/data-testing-tutorial, inifile:
collected 1 items
test_datafuncs.py .
=========================== 1 passed in 0.02 seconds ===========================
Now, if the function changes (say, by accident), you can find out by running the test suite.
Actually, let's make that change. Make any modification to the increment()
function that causes the test_increment()
function to fail, e.g. change the return statement to return x
, or return x-1
. Then, re-run the tests using the py.test
command.
Finally, fix the function and confirm that the tests pass.
In datafuncs.py
, we are going to implement a function called min_max_scaler(x)
for your data. It should take in a numpy
array and scale all of the values to be between 0 and 1 inclusive. The min value should be 0, and the max value should be 1.
First begin by writing tests for the min-max scaler. It should check the following:
np.allclose(arr1, arr2)
function to test closeness of two floating point values.Note: This function is also implemented in the scikit-learn
library as part of their preprocessing
module. However, in case an engineering decision that you make is that you don't want to import an entire library just to use one function, you can re-implement it on your own.
Here is a possible test for the min_max_scaler(x)
function.
import numpy as np
# in test_datafuncs.py
def test_min_max_scaler():
arr = np.array([1, 2, 3]) # set up the test with necessary variables.
tfm = dfn.min_max_scaler(arr) # collect the result into a variable
assert tfm == np.array([0, 0.5, 1]) # assertion statements
assert tfm.min() == 0
assert tfm.max() == 1
Now, based on the specifications, write a minimum implementation of the min_max_scaler(x)
function. This function should take the numpy
array x
, and scale all of the values to between 0 and 1 inclusive, with the minimum value being 0 and the maximum value being 1.
import numpy as np
def min_max_scaler(x):
"""
Returns a numpy array with all of the original values scaled between 0 and 1.
Assumes the data are a numpy array.
"""
if hasattr(x, __iter__) and not isinstance(x, np.array):
x = np.array(x)
return (x - x.min()) / (x.max() - x.min())
Now, let's think of a few edge cases. Where could this function fail?
import pytest
import numpy as np
def test_min_max_scaler():
arr = np.array([1, 2, 3]) # set up the test with necessary variables.
tfm = dfn.min_max_scaler(arr) # collect the result into a variable
# Correctness tests
assert np.allclose(tfm, np.array([0, 0.5, 1])) # assertion statements
assert tfm.min() == 0
assert tfm.max() == 1
# min_max_scaler(x) should fail if an integer is passed in.
with pytest.raises(AttributeError):
dfn.min_max_scaler(2)
dfn.min_max_scaler([])
dfn.min_max_scaler([15])
Imagine we have textual data, and we want to clean it up. There are two functions we may want to write to standardize the data:
bag_of_words(text)
, which takes in the text and tokenizes the text into its set of constituent words.strip_punctuation(text)
, which strips punctuation from the text.Design the tests first, and then implement the two functions in datafuncs.py
; you may wish to write additional helper functions to manage the business logic. There's leeway in this exercise; feel free to get creative!
def test_strip_punctuation():
text = 'random. stuff; typed, in-to th`is text^line'
t = strip_punctuation(text)
assert set(t).isdisjoint(string.punctuation)
test_strip_punctuation()
def test_bag_of_words():
text = 'random stuff typed into this text line line'
text_bagged = bag_of_words(text)
assert len(text_bagged) == 7
assert ' ' not in text_bagged
import string
def strip_punctuation(text):
exclude = string.punctuation
return ''.join(s for s in text if s not in exclude)
t = "hello world! This is my pleasure, and the 2nd time I've been to PyCon!"
def bag_of_words(text):
text = strip_punctuation(text)
words = set(text.split(' '))
return words
$ py.test
============================= test session starts ==============================
platform darwin -- Python 3.6.1, pytest-3.0.7, py-1.4.33, pluggy-0.4.0
rootdir: /Users/ericmjl/github/tutorials/data-testing-tutorial, inifile:
collected 3 items
test_datafuncs.py ...
=========================== 3 passed in 0.03 seconds ===========================
In [ ]: