In [2]:
import cs109style
cs109style.customize_mpl()
cs109style.customize_css()
# special IPython command to prepare the notebook for matplotlib
%matplotlib inline
from collections import defaultdict
import pandas as pd
import matplotlib.pyplot as plt
import requests
from pattern import web
In this example we will fetch data about countries and their population from Wikipedia.
http://en.wikipedia.org/wiki/List_of_countries_by_past_and_future_population has several tables for individual countries, subcontinents as well as different years. We will combine the data for all countries and all years in a single panda dataframe and visualize the change in population for different countries.
To give you some starting points for your homework, we will also show the different sub-steps that can be taken to reach the presented solution.
In [1]:
url = 'http://en.wikipedia.org/wiki/List_of_countries_by_past_and_future_population'
website_html = requests.get(url).text
#print website_html
In [12]:
def get_population_html_tables(html):
"""Parse html and return html tables of wikipedia population data."""
dom = web.Element(html)
### 0. step: look at html source!
#### 1. step: get all tables
# tbls = dom('table')
#### 2. step: get all tables we care about
tbls = dom.by_class('wikitable sortable')
return tbls
tables = get_population_html_tables(website_html)
print "table length: %d" %len(tables)
for t in tables:
print t.attributes
In [22]:
def table_type(tbl):
### Extract the table type
return tbl('th')[1].content
table_by_type = {}
for tbl in tables:
typ = table_type(tbl)
if typ not in table_by_type:
table_by_type[typ] = list() # equivalent to []
table_by_type[typ].append(tbl)
# Equivalent code below
# group the tables by type
tables_by_type = defaultdict(list)
# defaultdicts have a default value that is inserted when a new key is accessed
for tbl in tables:
tables_by_type[table_type(tbl)].append(tbl)
print tables_by_type
In [34]:
def get_countries_population(tables):
"""Extract population data for countries from all tables and store it in dictionary."""
result = defaultdict(dict)
# 1. step: try to extract data for a single table
for tbl in tables:
headers = tbl('tr')
first_header = headers[0]
th_s = first_header('th')
years = [int(val.content) for val in th_s if val.content.isnumeric()]
year_indices = [idx for idx, val in enumerate(th_s) if val.content.isnumeric()]
print years
print year_indices
# 2. step: iterate over all tables, extract headings and actual data and combine data into single dict
rows = tbl('tr')[1:]
for row in rows:
tds = row('td')
country_name = tds[1]('a')[0].content
population_by_year = [int(tds[colidx].content.replace(',', ''))
for colidx in year_indices]
subdict = dict(zip(years, population_by_year))
result[country_name].update(subdict)
return result
result = get_countries_population(tables_by_type[u'Country or territory'])
print len(result), "Countries extracted"
print result[u'Canada']
In [33]:
# create dataframe
df = pd.DataFrame.from_dict(result, orient='index')
# sort based on year
df.sort(axis=1,inplace=True)
print df
In [35]:
subtable = df.iloc[0:2, 0:2]
print "subtable"
print subtable
print ""
column = df[1955]
print "column"
print column
print ""
row = df.ix[0] #row 0
print "row"
print row
print ""
rows = df.ix[:2] #rows 0,1
print "rows"
print rows
print ""
element = df.ix[0,1955] #element
print "element"
print element
print ""
# max along column
print "max"
print df[1950].max()
print ""
# axes
print "axes"
print df.axes
print ""
row = df.ix[0]
print "row info"
print row.name
print row.index
print ""
countries = df.index
print "countries"
print countries
print ""
print "Austria"
print df.ix['Austria']
In [36]:
plotCountries = ['Austria', 'Germany', 'United States', 'France']
for country in plotCountries:
row = df.ix[country]
plt.plot(row.index, row, label=row.name )
plt.ylim(ymin=0) # start y axis at 0
plt.xticks(rotation=70)
plt.legend(loc='best')
plt.xlabel("Year")
plt.ylabel("# people (million)")
plt.title("Population of countries")
Out[36]:
In [37]:
def plot_populous(df, year):
# sort table depending on data value in year column
df_by_year = df.sort(year, ascending=False)
plt.figure()
for i in range(5):
row = df_by_year.ix[i]
plt.plot(row.index, row, label=row.name )
plt.ylim(ymin=0)
plt.xticks(rotation=70)
plt.legend(loc='best')
plt.xlabel("Year")
plt.ylabel("# people (million)")
plt.title("Most populous countries in %d" % year)
plot_populous(df, 2010)
plot_populous(df, 2050)
In [ ]: