Functions

First class functions

Functions behave like any other object, such as an int or a list

  • use functions as arguments to other functions
  • store functions as dictionary values
  • return a function from another function

This leads to many powerful ways to use functions.


In [1]:
def square(x):
    """Square of x."""
    return x*x

def cube(x):
    """Cube of x."""
    return x*x*x

def root(x):
    """Square root of x."""
    return x**.5

In [2]:
# create a dictionary of functions
funcs = {
    'square': square,
    'cube': cube,
    'root': root,
}

In [3]:
x = 2

print square(x)
print cube(x)
print root(x)


4
8
1.41421356237

In [4]:
# print function name and output, sorted by function name
for func_name in sorted(funcs):
    print func_name, funcs[func_name](x)


cube 8
root 1.41421356237
square 4

Functions can be passed in as arguments


In [5]:
def derivative(x, f, h=0.01):
    """ Calculate the derivative of any continuous, differentiable function """
    
    return (f(x+h) - f(x-h))/(2*h)
$$ f(x) = 3x^2 + 5x + 3$$

In [6]:
def some_func(x):    
    return 3*x**2 + 5*x + 3

In [7]:
derivative(2, some_func) # passing in function f


Out[7]:
16.999999999999815

Functions can also be returned by functions


In [8]:
import time

def sum_squares(n):
    """ Sum of the squares from 1 to n """
    
    s = sum([x*x for x in range(n)])
    return s

def timer(f,n):
    """ time how long it takes to evaluate function """
    
    start = time.clock()
    result = f(n)   
    elapsed = time.clock() - start
    return result, elapsed

In [9]:
n = 1000000
timer(sum_squares, n)


Out[9]:
(333332833333500000, 0.13836700000000002)

Higher order functions

  • A function that uses another function as an input argument or returns a function
  • The most familiar are `map` and `filter`.
  • Custom functions are HOF

In [10]:
# The map function applies a function to each member of a collection
# map(aFunction, aSequence)

map(square, range(10))


Out[10]:
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

In [11]:
# The filter function applies a predicate to each member of a collection,
# retaining only those members where the predicate is True

def is_even(x):
    return x % 2 == 0

filter(is_even, range(10))


Out[11]:
[0, 2, 4, 6, 8]

In [12]:
# It is common to combine map and filter

map(square, filter(is_even, range(10)))


Out[12]:
[0, 4, 16, 36, 64]

In [13]:
# The reduce function reduces a collection using a binary operator to combine items two at a time

def my_add(x, y):
    return x + y

# another implementation of the sum function - like a running total
reduce(my_add, [1,2,3,4,5])


Out[13]:
15

In [14]:
def custom_sum(xs, func):
    """Returns the sum of xs after a user specified transform."""
    
    return sum(map(func, xs))

xs = range(10)
print custom_sum(xs, square)
print custom_sum(xs, cube)
print custom_sum(xs, root)


285
2025
19.306000526

++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

EXERCISE TIME!

1) Using map, write python program to calculate the length of each element in a list: ['Donald','Ted','Hilary','Joe','Bernie'].


In [17]:
## SOLUTION - 1
map(len,['Donald','Ted','Hilary','Joe','Bernie'])


Out[17]:
[6, 3, 6, 3, 6]

2) Using reduce and map, write a python program to find the largest element in the list of integers, floats or strings (that are numbers).
For example: [2, '3', 4.0, 2, -1, '10', 9, -4.3, 8, 7, 11, 3]. Should return 11.

Hint:

  1. How can you compare `'10'` to `8`? Should the types be the same?
  2. You can use `reduce` to find the maximum in a similar way as we used it to find the `sum` in the example above.

In [18]:
## SOLUTION - 2

reduce(max, (map (int, [2, '3', 4.0, 2, -1, '10', 9, -4.3, 8, 7, 11, 3])))


Out[18]:
11

++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

Anonymous functions

  • When using functional style, there is often the need to create small specific functions that perform a limited task as input to a HOF such as map or filter.
  • In such cases, these functions are often written as anonymous or **lambda** functions.

If you find it hard to understand what a lambda function is doing, it should probably be rewritten as a regular function.


In [19]:
# Using standard functions
n = 10

def square(x):
    return x*x

square(n)


Out[19]:
100

In [20]:
map(square, range(n))


Out[20]:
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

In [21]:
# Using lambda function

sqr = lambda x: x*x

sqr(n)


Out[21]:
100

In [22]:
map(sqr, range(n))


Out[22]:
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

In [23]:
# what does this function do?

s1 = reduce(lambda x, y: x+y, map(lambda x: x**2, range(1,10)))
print(s1)


285

In [24]:
# functional expressions and lambdas are cool
# but can be difficult to read when over-used
# Here is a more comprehensible version

s2 = sum(x**2 for x in range(1, 10))
print(s2)


285

++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

EXERCISE TIME!

Rewrite the following as a list comprehension, i.e. one liner without using map or filter


In [25]:
ans = map(lambda x: x*x, filter(lambda x: x%2 == 0, range(10)))
print ans


[0, 4, 16, 36, 64]

++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++


In [26]:
## SOLUTION

ans = [x*x for x in range(10) if x%2 == 0]
print ans


[0, 4, 16, 36, 64]

Recursion

  • A recursive function is one that calls itself
  • Extremely useful examples of the divide-and-conquer paradigm in algorithm development
  • However, they can be computationally inefficient and their use in Python is quite rare in practice

Recursive functions generally have:

  • a set of base cases
    • the answer is obvious
    • can be returned immediately
  • a set of recursive cases
    • which are split into smaller pieces
    • each of which is given to the same function called recursively

Examples

Factorial:

$$ n! = n\times(n-1)\times(n-2)\times...\times2\times1$$

For example, $$4! = 4\times3\times2\times1 = 24 $$


In [27]:
def fact(n):
    """Returns the factorial of n."""
    
    # base case
    if n==0:
        return 1
    
    # recursive case
    else:
        return n * fact(n-1)

In [28]:
[(n,fact(n)) for n in range(1,10)]


Out[28]:
[(1, 1),
 (2, 2),
 (3, 6),
 (4, 24),
 (5, 120),
 (6, 720),
 (7, 5040),
 (8, 40320),
 (9, 362880)]

Fibonacci sequence:

$$F_n = F_{n-1} + F_{n-2},\!\,$$

Output is:

$$1, 1, 2, 3, 5, 8, 13, 21, ...$$

In [29]:
def fib1(n):
    """Fib with recursion."""

    # base case
    if n==0 or n==1:
        return 1
    # recursive caae
    else:
        return fib1(n-1) + fib1(n-2)

In [30]:
[(i,fib1(i)) for i in range(20)]


Out[30]:
[(0, 1),
 (1, 1),
 (2, 2),
 (3, 3),
 (4, 5),
 (5, 8),
 (6, 13),
 (7, 21),
 (8, 34),
 (9, 55),
 (10, 89),
 (11, 144),
 (12, 233),
 (13, 377),
 (14, 610),
 (15, 987),
 (16, 1597),
 (17, 2584),
 (18, 4181),
 (19, 6765)]

In [31]:
# In Python, a more efficient version that does not use recursion is

def fib2(n):
    """Fib without recursion."""
    a, b = 0, 1
    for i in range(1, n+1):
        a, b = b, a+b
    return b

In [32]:
[(i,fib2(i)) for i in range(20)]


Out[32]:
[(0, 1),
 (1, 1),
 (2, 2),
 (3, 3),
 (4, 5),
 (5, 8),
 (6, 13),
 (7, 21),
 (8, 34),
 (9, 55),
 (10, 89),
 (11, 144),
 (12, 233),
 (13, 377),
 (14, 610),
 (15, 987),
 (16, 1597),
 (17, 2584),
 (18, 4181),
 (19, 6765)]

In [33]:
# Note that the recursive version is much slower than the non-recursive version

%timeit fib1(20)
%timeit fib2(20)


100 loops, best of 3: 2.87 ms per loop
1000000 loops, best of 3: 1.7 µs per loop

This is because it makes many duplicate function calls. For example:

fib(5) -> fib(4), fib(3)
fib(4) -> fib(3), fib(2)
fib(3) -> fib(2), fib(1)
fib(2) -> fib(1), fib(0)
fib(1) -> 1
fib(0) -> 1


In [35]:
# Recursion is used to show off the divide-and-conquer paradigm

def quick_sort(xs):
    """ Classic quick sort """

    # base case
    if xs == []:
        return xs
    # recursive case
    else:
        pivot = xs[0] # choose starting pivot to be on the left
        less_than = [x for x in xs[1:] if x <= pivot]
        more_than = [x for x in xs[1:] if x > pivot]
        
        return quick_sort(less_than) + [pivot] + quick_sort(more_than)

In [36]:
xs = [11,3,1,4,1,5,9,2,6,5,3,5,9,0,10,4,3,7,4,5,8,-1]
print quick_sort(xs)


[-1, 0, 1, 1, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 5, 6, 7, 8, 9, 9, 10, 11]

++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

EXERCISE TIME!

Euclid's algorithm for finding the greatest common divisor of two numbers is

gcd(a, 0) = a
gcd(a, b) = gcd(b, a modulo b)
  1. What is the greatest common divisor of `17384` and `1928`? Write the `gcd(a,b)` function.
  2. Write a function to calculate the least common multiple, `lcm(a,b)`
  3. What is the least common multiple of `17384` and `1928`? Hint: Google it!

++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++


In [2]:
from __future__ import division

In [1]:
## SOLUTION

def gcd(a,b):
    if b == 0:
        return a
    else:
        return gcd(b, a % b)

print gcd(17384,1928)

def lcm(a,b):
    return a*b/gcd(a,b)

print lcm(17384,1928)


8
4189544

Iterators

  • Iterators represent streams of values.
  • Will produce the next value when you call `next()` on it
  • Because only one value is consumed at a time, they use very little memory.
  • Use of iterators is very helpful for working with data sets too large to fit into RAM.

In [252]:
# Iterators can be created from sequences with the built-in function iter()

xs = [1,2,3]
x_iter = iter(xs)

print x_iter.next() # python "remembers" where the pointer is
print x_iter.next()
print x_iter.next()
print x_iter.next()


1
2
3
---------------------------------------------------------------------------
StopIteration                             Traceback (most recent call last)
<ipython-input-252-35c80bc0fe4b> in <module>()
      7 print x_iter.next()
      8 print x_iter.next()
----> 9 print x_iter.next()

StopIteration: 

In [253]:
# Most commonly, iterators are used (automatically) within a for loop
# which terminates when it encouters a StopIteration exception

x_iter = iter(xs)
for x in x_iter:
    print x


1
2
3

++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

EXERCISE TIME!

Starting with range(1, 20), make a list of the squares of each odd number in the following ways

  • Using map and filter. This can be done in one line if you use lambda (twice!). Remember - you need to create functions to pass the list elements to. One function checks if its even, the other squares it.
  • Using a list comprehension
  • With a for loop

The answer should be [1, 9, 25, 49, 81, 121, 169, 225, 289, 361]

++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++


In [43]:
## SOLUTION

# using a for loop
for_list = []
for i in range(1,20):
    if i % 2 != 0:
        for_list.append(i**2)
print for_list
      
# using list comprehension
print [i**2 for i in range(1,20) if i % 2 != 0]

# using map and filter
def sqr(x):
    return x**2

def is_even(x):
    return True if x%2!=0 else False

print map(sqr, filter(is_even, range(1,20)))

# using lambda
print map(lambda a: a**2, (filter(lambda x: x % 2 !=0, range(1,20))))


[1, 9, 25, 49, 81, 121, 169, 225, 289, 361]
[1, 9, 25, 49, 81, 121, 169, 225, 289, 361]
[1, 9, 25, 49, 81, 121, 169, 225, 289, 361]
[1, 9, 25, 49, 81, 121, 169, 225, 289, 361]

Review Problems

Q1. Rewrite the factorial function so that it does not use recursion. Hint: consider using reduce since you're multiplying consecutive numbers.

Here is the original code:

def fact(n):
    """Returns the factorial of n."""
    # base case
    if n==0:
        return 1
    # recursive case
    else:
        return n * fact(n-1)

for i in range(1,11):
    print fact(i),

In [44]:
## SOLUTION
def fact1(n):
    """Returns the factorial of n."""
    return reduce(lambda x, y: x*y, range(1, n+1))

for i in range(1,11):
    print fact1(i),


1 2 6 24 120 720 5040 40320 362880 3628800

Q2. Write a function, normalize(x) that takes in a vector x and outputs a normalized vector x_norm in the following way:

$$ X^{normed}_{i} = \frac{X_{i} - \mu}{\sigma} $$

where,

$$ \mu = \frac{1}{n}\sum_{i=1}^{n}X_{i} $$$$ \sigma^2 = \frac{1}{n}\sum_{i=1}^{n}(X_{i} - \mu)^2 $$

Each $X_i$ is a single data point from the input list. For example, an input list x = [1,2,3,4] should output x_norm = [-1.3416407865,-0.4472135955,0.4472135955,1.3416407865].

Note that the sum of the new list should be 0 and the standard deviation should be 1.0 - this is why it's called normalizing. It's also called standardizing


In [47]:
## SOLUTION
def normalize(x):
    
    mean_x = sum(x)/len(x)
    
    std_ = 0
    for i in x:
        std_ += (i-mean_x)**2

    std_ = std_/len(x)

    x_prime = []
    for i in x:
        x_prime.append((i - mean_x)/(std_**0.5))
    
    return x_prime
    

x = [1,2,3,4.]
normalize(x)


Out[47]:
[-1.3416407864998738,
 -0.4472135954999579,
 0.4472135954999579,
 1.3416407864998738]

Q3. Matrix Multiplication and list comprehension

Rewrite the matrix multiplication code to use list comprehension instead of nested for loops


In [48]:
## SOLUTION

def dot_product(A, B):
    
    rows = len(A)
    shared = len(B)
    cols = len(B[0])
    
    return [[sum(A[i][k]*B[k][j] for k in range(shared)) for j in range(cols)] for i in range(rows)]

dot_product(A,B)


---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-48-fe2e9bd333a4> in <module>()
      9     return [[sum(A[i][k]*B[k][j] for k in range(shared)) for j in range(cols)] for i in range(rows)]
     10 
---> 11 dot_product(A,B)

NameError: name 'A' is not defined

Q4.

Write a program to merge two dictionaries together. Assume the dictionaries have the same keys. For example:

a = {
    "key1" : 1,
    "key3" : "snafu",
    "key2" : 5,
    "key5" : 7,
    "key4" : 0,
    }

b = {
    "key2" : 6,
    "key1" : 8,
    "key4" : "bar",
    "key3" : 9,
    "key5" : "foo"
    }

becomes

c = {
     'key1': (1, 8),
     'key2': (5, 6),
     'key3': ('snafu', 9),
     'key4': (0, 'bar'),
     'key5': (7, 'foo')
}

Hint: See the collections module for a helper function.


In [49]:
## SOLUTION

from collections import OrderedDict

a = {
    "key1" : 1,
    "key3" : "snafu",
    "key2" : 5,
    "key5" : 7,
    "key4" : 0,
    }

b = {
    "key2" : 6,
    "key1" : 8,
    "key4" : "bar",
    "key3" : 9,
    "key5" : "foo"
    }

a = OrderedDict(sorted(a.items()))
b = OrderedDict(sorted(b.items()))

# combine graph_attr and unique_attr. Order is the same because of OrderedDict
c = {}
for e, (k, v) in enumerate(a.items()):
    c[k] = zip(a.values(), b.values())[e]
    
c


Out[49]:
{'key1': (1, 8),
 'key2': (5, 6),
 'key3': ('snafu', 9),
 'key4': (0, 'bar'),
 'key5': (7, 'foo')}

Q5. Pascal's triangle

  1. Write a function `pascal(c,r)` which takes in a column `c` and row `r` (both start indexing at 0) and returns the value of Pascal's triangle.
  2. Print the first 10 iterations of the triangle. Here are the first 6:
     1
    1 1
   1 2 1
  1 3 3 1
 1 4 6 4 1 
1 5 10 10 5 1 
...

Do this recursively.


In [50]:
## SOLUTION

def pascal(c,r):
    
    assert c <= r, "Bad parameters - column cannot be greater than row"
    
    if c == 1 or c == r:
        return 1
    else:
        return pascal(c-1, r-1) + pascal(c, r-1)


depth = 10

for row in range(1, depth+1):
    for col in range(1, row+1):
         print pascal(col, row),
    print


1
1 1
1 2 1
1 3 3 1
1 4 6 4 1
1 5 10 10 5 1
1 6 15 20 15 6 1
1 7 21 35 35 21 7 1
1 8 28 56 70 56 28 8 1
1 9 36 84 126 126 84 36 9 1

Bonus (Hard!) Questions

Q1. The four adjacent digits in the 1000-digit number that have the greatest product are 9 × 9 × 8 × 9 = 5832.

73167176531330624919225119674426574742355349194934
96983520312774506326239578318016984801869478851843
85861560789112949495459501737958331952853208805511
12540698747158523863050715693290963295227443043557
66896648950445244523161731856403098711121722383113
62229893423380308135336276614282806444486645238749
30358907296290491560440772390713810515859307960866
70172427121883998797908792274921901699720888093776
65727333001053367881220235421809751254540594752243
52584907711670556013604839586446706324415722155397
53697817977846174064955149290862569321978468622482
83972241375657056057490261407972968652414535100474
82166370484403199890008895243450658541227588666881
16427171479924442928230863465674813919123162824586
17866458359124566529476545682848912883142607690042
24219022671055626321111109370544217506941658960408
07198403850962455444362981230987879927244284909188
84580156166097919133875499200524063689912560717606
05886116467109405077541002256983155200055935729725
71636269561882670428252483600823257530420752963450

Write a program to find the thirteen adjacent digits in the 1000-digit number that have the greatest product. What is the value of this product? (Euler problem #8)

The answer shoud be 23514624000.


In [ ]:
BIG_NUM = 7316717653133062491922511967442657474235534919493496983520312774506326239578318016984801869478851843858615607891129494954595017379583319528532088055111254069874715852386305071569329096329522744304355766896648950445244523161731856403098711121722383113622298934233803081353362766142828064444866452387493035890729629049156044077239071381051585930796086670172427121883998797908792274921901699720888093776657273330010533678812202354218097512545405947522435258490771167055601360483958644670632441572215539753697817977846174064955149290862569321978468622482839722413756570560574902614079729686524145351004748216637048440319989000889524345065854122758866688116427171479924442928230863465674813919123162824586178664583591245665294765456828489128831426076900422421902267105562632111110937054421750694165896040807198403850962455444362981230987879927244284909188845801561660979191338754992005240636899125607176060588611646710940507754100225698315520005593572972571636269561882670428252483600823257530420752963450
max_num = 0  # this holds the largest number for the current iteration of the loop
NUM_ADJACENT = 13

def product_consecutive(ind):
    ''' Calculate the product of consecutive digits'''
    prod = 1
    for digit in str(BIG_NUM)[ind:ind + NUM_ADJACENT]: # loop over every digit in the 13 consecutive digits
        prod *= int(digit)
    return prod
        

for ind in range(len(str(BIG_NUM))): # loop over the range of digits in big_num
    current_product = product_consecutive(ind)
    if current_product > max_num: 
        max_num = current_product

print(max_num)

Q2.

A string of consecutive successes is known as a success run. Write a function that returns the counts for runs of length $k$ for each $k$ observed in a dictionary. Hint: check out the itertools library.

For example: if the trials were [0, 1, 0, 1, 1, 0, 0, 0, 0, 1], the function should return

{1: 2, 2: 1})

In [40]:
## SOLUTION

from itertools import groupby

def calc_runs_dict(trials):
    """
    Returns the counts for runs of length k for each k observed in a dictionary
    """

    runs = {}

    count_consecutive = lambda x: [(k, sum(1 for i in g)) for k, g in groupby(x)] # function to group consecutive values
    grouped_counts = count_consecutive(trials)
    
    run_values = [i[1] for i in grouped_counts if i[0] == 1] # filter out 0's
    update_dictionary = lambda x: runs.update({x:runs.get(x,0)+1})
    
    list(map(update_dictionary, run_values)) # update runs dictionary for new consecutive 1's
    
    return runs

trials = [0, 1, 0, 1, 1, 0, 0, 0, 0, 1]    
calc_runs_dict(trials)


Out[40]:
{1: 2, 2: 1}

In [ ]: