In [1]:
import cs109style
cs109style.customize_mpl()
cs109style.customize_css()
# special IPython command to prepare the notebook for matplotlib
%matplotlib inline
from collections import defaultdict
import pandas as pd
import matplotlib.pyplot as plt
import requests
from pattern import web
In this example we will fetch data about countries and their population from Wikipedia.
http://en.wikipedia.org/wiki/List_of_countries_by_past_and_future_population has several tables for individual countries, subcontinents as well as different years. We will combine the data for all countries and all years in a single panda dataframe and visualize the change in population for different countries.
To give you some starting points for your homework, we will also show the different sub-steps that can be taken to reach the presented solution.
In [2]:
url = 'http://en.wikipedia.org/wiki/List_of_countries_by_past_and_future_population'
website_html = requests.get(url).text
#print website_html
In [3]:
def get_population_html_tables(html):
"""Parse html and return html tables of wikipedia population data."""
dom = web.Element(html)
### 0. step: look at html source!
#### 1. step: get all tables
#tbls = dom('table') # <-- This gets us all the tables
# Only get the tables we need
# But we need only the tables with class "sortable wikitable"
tbls = dom.by_class('sortable wikitable')
#### 2. step: get all tables we care about
return tbls
tables = get_population_html_tables(website_html)
print "table length: %d" %len(tables)
for t in tables:
print t.attributes
In [4]:
def table_type(tbl):
### Extract the table type
return tbl('th')[0].content # The second header type and content (what we need)
#for th in tbl('th'):
#print th.content
#return 'table'
# Method 1
# group the tables by type
#table_by_type = {}
#for tbl in tables:
# typ = table_type(tbl)
# if typ not in table_by_type:
# table_by_type[typ] = list()
# tables_by_type[typ].append(tbl)
# Method 2
# Equivalent code below
# tables_by_type = defaultdict(list) # To group in lists in dictionaries
# defaultdicts have a default value that is inserted when a new key is accessed
# So, in the loop below, if it has not encountered the table type
# It inserts an empty list
# This is why the "defaultdict" is used !
# group the tables by type
tables_by_type = defaultdict(list)
# defaultdicts have a default value that is inserted when a new key is accessed
for tbl in tables:
tables_by_type[table_type(tbl)].append(tbl)
# Shows country or territory has 3 tables
print tables_by_type
In [17]:
def get_countries_population(tables):
"""Extract population data for countries from all tables and store it in dictionary."""
result = defaultdict(dict)
for tbl in tables:
# 1. step: try to extract data for a single table
tbl = tables[0]
# From the fist table (tbl) get the headers
# I just need the rownames, and the years columns
# From the Wikipedia URL
# I do not need the % growth columns
headers = tbl('tr')
first_header = headers[0]
th_s = first_header('th')
# Extracting the columns that has years only
years = [int(val.content) for val in th_s if val.content.isnumeric()]
# Enumerate produces a pair - the index and the value
# Enumerate can be applied to lists, and other iterable objects
year_indices = [idx for idx, val in enumerate(th_s) if val.content.isnumeric()]
#table_headers = tbl('tr')[0]('th')
print years
print year_indices
#years = [int(th.content)]
# 2. step: iterate over all tables, extract headings and actual data and combine data into single dict
rows = tbl('tr')[1:]
for row in rows:
tds = row('td')
country_name = tds[0]('a')[0].content
population_by_year = [int(tds[colidx].content.replace(',','')) for
colidx in year_indices]
# zip creates a sequences of tuples
subdict = dict(zip(years, population_by_year))
result[country_name].update(subdict)
return result
result = get_countries_population(tables_by_type['Country or territory'])
print len(result)
print result[u'Canada']
In [ ]:
# When you get an error
# You can just insert a cell below the traceback with %debug
# And it will take you to the debugger
#%debug
In [20]:
# create dataframe
df = pd.DataFrame.from_dict(result, orient='index')
# sort based on year
df.sort(axis=1,inplace=True)
print df
In [21]:
subtable = df.iloc[0:2, 0:2]
print "subtable"
print subtable
print ""
column = df[1955]
print "column"
print column
print ""
row = df.ix[0] #row 0
print "row"
print row
print ""
rows = df.ix[:2] #rows 0,1
print "rows"
print rows
print ""
element = df.ix[0,1955] #element
print "element"
print element
print ""
# max along column
print "max"
print df[1950].max()
print ""
# axes
print "axes"
print df.axes
print ""
row = df.ix[0]
print "row info"
print row.name
print row.index
print ""
countries = df.index
print "countries"
print countries
print ""
print "Austria"
print df.ix['Austria']
In [22]:
plotCountries = ['Austria', 'Germany', 'United States', 'France']
for country in plotCountries:
row = df.ix[country]
plt.plot(row.index, row, label=row.name )
plt.ylim(ymin=0) # start y axis at 0
plt.xticks(rotation=70)
plt.legend(loc='best')
plt.xlabel("Year")
plt.ylabel("# people (million)")
plt.title("Population of countries")
Out[22]:
In [23]:
def plot_populous(df, year):
# sort table depending on data value in year column
df_by_year = df.sort(year, ascending=False)
plt.figure()
for i in range(5):
row = df_by_year.ix[i]
plt.plot(row.index, row, label=row.name )
plt.ylim(ymin=0)
plt.xticks(rotation=70)
plt.legend(loc='best')
plt.xlabel("Year")
plt.ylabel("# people (million)")
plt.title("Most populous countries in %d" % year)
plot_populous(df, 2010)
plot_populous(df, 2050)
In [ ]:
%debug
In [23]:
In [ ]: