Merging and Spliting Iterator

The chain() function takes several iterators as arguments and returns a single iterator that produces the contents of all of the inputs as though they came from a single iterator.


In [2]:
from itertools import *

for i in chain([1, 2, 3], ['a', 'b', 'c']):
    print(i, end=' ')
print()


1 2 3 a b c 

If the iterables to be combined are not all known in advance, or need to be evaluated lazily, chain.from_iterable() can be used to construct the chain instead.


In [3]:
from itertools import *


def make_iterables_to_chain():
    yield [1, 2, 3]
    yield ['a', 'b', 'c']


for i in chain.from_iterable(make_iterables_to_chain()):
    print(i, end=' ')
print()


1 2 3 a b c 

The built-in function zip() returns an iterator that combines the elements of several iterators into tuples.


In [4]:
for i in zip([1, 2, 3], ['a', 'b', 'c']):
    print(i)


(1, 'a')
(2, 'b')
(3, 'c')

zip() stops when the first input iterator is exhausted. To process all of the inputs, even if the iterators produce different numbers of values, use zip_longest(). By default, zip_longest() substitutes None for any missing values. Use the fillvalue argument to use a different substitute value.


In [5]:
from itertools import *

r1 = range(3)
r2 = range(2)

print('zip stops early:')
print(list(zip(r1, r2)))

r1 = range(3)
r2 = range(2)

print('\nzip_longest processes all of the values:')
print(list(zip_longest(r1, r2)))


zip stops early:
[(0, 0), (1, 1)]

zip_longest processes all of the values:
[(0, 0), (1, 1), (2, None)]

The islice() function returns an iterator which returns selected items from the input iterator, by index.


In [7]:
from itertools import *

print('Stop at 5:')
for i in islice(range(100), 5):
    print(i, end=' ')
print('\n')

print('Start at 5, Stop at 10:')
for i in islice(range(100), 5, 10):
    print(i, end=' ')
print('\n')

print('By tens to 100:')
for i in islice(range(100), 0, 100, 10):
    print(i, end=' ')
print('\n')


Stop at 5:
0 1 2 3 4 

Start at 5, Stop at 10:
5 6 7 8 9 

By tens to 100:
0 10 20 30 40 50 60 70 80 90 

The tee() function returns several independent iterators (defaults to 2) based on a single original input.

tee() has semantics similar to the Unix tee utility, which repeats the values it reads from its input and writes them to a named file and standard output. The iterators returned by tee() can be used to feed the same set of data into multiple algorithms to be processed in parallel.


In [8]:
from itertools import *

r = islice(count(), 5)
i1, i2 = tee(r)

print('i1:', list(i1))
print('i2:', list(i2))


i1: [0, 1, 2, 3, 4]
i2: [0, 1, 2, 3, 4]

The new iterators created by tee() share their input, so the original iterator should not be used after the new ones are created.


In [12]:
from itertools import *

r = islice(count(), 5)
i1, i2 = tee(r)

print('r:', end=' ')
for i in r:
    print(i, end=' ')
    if i > 1:
        break
print()

print('i1:', list(i1))
print('i2:', list(i2))


r: 0 1 2 
i1: [3, 4]
i2: [3, 4]

Converting Inputs

The built-in map() function returns an iterator that calls a function on the values in the input iterators, and returns the results. It stops when any input iterator is exhausted.


In [14]:
def times_two(x):
    return 2 * x


def multiply(x, y):
    return (x, y, x * y)


print('Doubles:')
for i in map(times_two, range(5)):
    print(i)

print('\nMultiples:')
r1 = range(5)
r2 = range(5, 10)
for i in map(multiply, r1, r2):
    print('{:d} * {:d} = {:d}'.format(*i))

print('\nStopping:')
r1 = range(5)
r2 = range(2)
for i in map(multiply, r1, r2):
    print(i)


Doubles:
0
2
4
6
8

Multiples:
0 * 5 = 0
1 * 6 = 6
2 * 7 = 14
3 * 8 = 24
4 * 9 = 36

Stopping:
(0, 0, 0)
(1, 1, 1)

The starmap() function is similar to map(), but instead of constructing a tuple from multiple iterators, it splits up the items in a single iterator as arguments to the mapping function using the * syntax.


In [15]:
from itertools import *

values = [(0, 5), (1, 6), (2, 7), (3, 8), (4, 9)]

for i in starmap(lambda x, y: (x, y, x * y), values):
    print('{} * {} = {}'.format(*i))


0 * 5 = 0
1 * 6 = 6
2 * 7 = 14
3 * 8 = 24
4 * 9 = 36

Producing new values

The count() function returns an iterator that produces consecutive integers, indefinitely. The first number can be passed as an argument (the default is zero). There is no upper bound argument (see the built-in range() for more control over the result set).


In [16]:
from itertools import *

for i in zip(count(1), ['a', 'b', 'c']):
    print(i)


(1, 'a')
(2, 'b')
(3, 'c')

The start and step arguments to count() can be any numerical values that can be added together.


In [17]:
import fractions
from itertools import *

start = fractions.Fraction(1, 3)
step = fractions.Fraction(1, 3)

for i in zip(count(start, step), ['a', 'b', 'c']):
    print('{}: {}'.format(*i))


1/3: a
2/3: b
1: c

The cycle() function returns an iterator that repeats the contents of the arguments it is given indefinitely. Since it has to remember the entire contents of the input iterator, it may consume quite a bit of memory if the iterator is long.


In [18]:
from itertools import *

for i in zip(range(7), cycle(['a', 'b', 'c'])):
    print(i)


(0, 'a')
(1, 'b')
(2, 'c')
(3, 'a')
(4, 'b')
(5, 'c')
(6, 'a')

The repeat() function returns an iterator that produces the same value each time it is accessed.


In [19]:
from itertools import *

for i in repeat('over-and-over', 5):
    print(i)


over-and-over
over-and-over
over-and-over
over-and-over
over-and-over

Filtering

The dropwhile() function returns an iterator that produces elements of the input iterator after a condition becomes false for the first time.

dropwhile() does not filter every item of the input; after the condition is false the first time, all of the remaining items in the input are returned.


In [20]:
from itertools import *


def should_drop(x):
    print('Testing:', x)
    return x < 1


for i in dropwhile(should_drop, [-1, 0, 1, 2, -2]):
    print('Yielding:', i)


Testing: -1
Testing: 0
Testing: 1
Yielding: 1
Yielding: 2
Yielding: -2

The opposite of dropwhile() is takewhile(). It returns an iterator that returns items from the input iterator as long as the test function returns true.


In [23]:
from itertools import *


def should_take(x):
#     print('Testing:', x)
    return x < 2


for i in takewhile(should_take, [-1, 0, 1, 2, -2]):
    print('Yielding:', i)


Yielding: -1
Yielding: 0
Yielding: 1

The built-in function filter() returns an iterator that includes only items for which the test function returns true.

filter() is different from dropwhile() and takewhile() in that every item is tested before it is returned.


In [24]:
from itertools import *


def check_item(x):
    print('Testing:', x)
    return x < 1


for i in filter(check_item, [-1, 0, 1, 2, -2]):
    print('Yielding:', i)


Testing: -1
Yielding: -1
Testing: 0
Yielding: 0
Testing: 1
Testing: 2
Testing: -2
Yielding: -2

filterfalse() returns an iterator that includes only items where the test function returns false.

compress() offers another way to filter the contents of an iterable. Instead of calling a function, it uses the values in another iterable to indicate when to accept a value and when to ignore it.

The first argument is the data iterable to process and the second is a selector iterable producing Boolean values indicating which elements to take from the data input (a true value causes the value to be produced, a false value causes it to be ignored).


In [25]:
from itertools import *

every_third = cycle([False, False, True])
data = range(1, 10)

for i in compress(data, every_third):
    print(i, end=' ')
print()


3 6 9 

Grouping Data

The groupby() function returns an iterator that produces sets of values organized by a common key. This example illustrates grouping related values based on an attribute.

The input sequence needs to be sorted on the key value in order for the groupings to work out as expected.


In [26]:
import functools
from itertools import *
import operator
import pprint


@functools.total_ordering
class Point:

    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __repr__(self):
        return '({}, {})'.format(self.x, self.y)

    def __eq__(self, other):
        return (self.x, self.y) == (other.x, other.y)

    def __gt__(self, other):
        return (self.x, self.y) > (other.x, other.y)


# Create a dataset of Point instances
data = list(map(Point,
                cycle(islice(count(), 3)),
                islice(count(), 7)))
print('Data:')
pprint.pprint(data, width=35)
print()

# Try to group the unsorted data based on X values
print('Grouped, unsorted:')
for k, g in groupby(data, operator.attrgetter('x')):
    print(k, list(g))
print()

# Sort the data
data.sort()
print('Sorted:')
pprint.pprint(data, width=35)
print()

# Group the sorted data based on X values
print('Grouped, sorted:')
for k, g in groupby(data, operator.attrgetter('x')):
    print(k, list(g))
print()


Data:
[(0, 0),
 (1, 1),
 (2, 2),
 (0, 3),
 (1, 4),
 (2, 5),
 (0, 6)]

Grouped, unsorted:
0 [(0, 0)]
1 [(1, 1)]
2 [(2, 2)]
0 [(0, 3)]
1 [(1, 4)]
2 [(2, 5)]
0 [(0, 6)]

Sorted:
[(0, 0),
 (0, 3),
 (0, 6),
 (1, 1),
 (1, 4),
 (2, 2),
 (2, 5)]

Grouped, sorted:
0 [(0, 0), (0, 3), (0, 6)]
1 [(1, 1), (1, 4)]
2 [(2, 2), (2, 5)]

Combining Inputs

The accumulate() function processes the input iterable, passing the nth and n+1st item to a function and producing the return value instead of either input. The default function used to combine the two values adds them, so accumulate() can be used to produce the cumulative sum of a series of numerical inputs.


In [27]:
from itertools import *

print(list(accumulate(range(5))))
print(list(accumulate('abcde')))


[0, 1, 3, 6, 10]
['a', 'ab', 'abc', 'abcd', 'abcde']

It is possible to combine accumulate() with any other function that takes two input values to achieve different results.


In [28]:
from itertools import *


def f(a, b):
    print(a, b)
    return b + a + b


print(list(accumulate('abcde', f)))


a b
bab c
cbabc d
dcbabcd e
['a', 'bab', 'cbabc', 'dcbabcd', 'edcbabcde']

Nested for loops that iterate over multiple sequences can often be replaced with product(), which produces a single iterable whose values are the Cartesian product of the set of input values.


In [30]:
from itertools import *
import pprint

FACE_CARDS = ('J', 'Q', 'K', 'A')
SUITS = ('H', 'D', 'C', 'S')

DECK = list(
    product(
        chain(range(2, 11), FACE_CARDS),
        SUITS,
    )
)

for card in DECK:
    print('{:>2}{}'.format(*card), end=' ')
    if card[1] == SUITS[-1]:
        print()


 2H  2D  2C  2S 
 3H  3D  3C  3S 
 4H  4D  4C  4S 
 5H  5D  5C  5S 
 6H  6D  6C  6S 
 7H  7D  7C  7S 
 8H  8D  8C  8S 
 9H  9D  9C  9S 
10H 10D 10C 10S 
 JH  JD  JC  JS 
 QH  QD  QC  QS 
 KH  KD  KC  KS 
 AH  AD  AC  AS 

The values produced by product() are tuples, with the members taken from each of the iterables passed in as arguments in the order they are passed. The first tuple returned includes the first value from each iterable. The last iterable passed to product() is processed first, followed by the next to last, and so on. The result is that the return values are in order based on the first iterable, then the next iterable, etc.

The permutations() function produces items from the input iterable combined in the possible permutations of the given length. It defaults to producing the full set of all permutations.


In [31]:
from itertools import *


def show(iterable):
    first = None
    for i, item in enumerate(iterable, 1):
        if first != item[0]:
            if first is not None:
                print()
            first = item[0]
        print(''.join(item), end=' ')
    print()


print('All permutations:\n')
show(permutations('abcd'))

print('\nPairs:\n')
show(permutations('abcd', r=2))


All permutations:

abcd abdc acbd acdb adbc adcb 
bacd badc bcad bcda bdac bdca 
cabd cadb cbad cbda cdab cdba 
dabc dacb dbac dbca dcab dcba 

Pairs:

ab ac ad 
ba bc bd 
ca cb cd 
da db dc 

To limit the values to unique combinations rather than permutations, use combinations(). As long as the members of the input are unique, the output will not include any repeated values.


In [33]:
from itertools import *


def show(iterable):
    first = None
    for i, item in enumerate(iterable, 1):
        if first != item[0]:
            if first is not None:
                print()
            first = item[0]
        print(''.join(item), end=' ')
    print()


print('Unique pairs:\n')
show(combinations('abcd', r=2))


Unique pairs:

ab ac ad 
bc bd 
cd 

While combinations() does not repeat individual input elements, sometimes it is useful to consider combinations that do include repeated elements. For those cases, use combinations_with_replacement().


In [34]:
from itertools import *


def show(iterable):
    first = None
    for i, item in enumerate(iterable, 1):
        if first != item[0]:
            if first is not None:
                print()
            first = item[0]
        print(''.join(item), end=' ')
    print()


print('Unique pairs:\n')
show(combinations_with_replacement('abcd', r=2))


Unique pairs:

aa ab ac ad 
bb bc bd 
cc cd 
dd 

In [ ]: