In [1]:
import cs109style
cs109style.customize_mpl()
cs109style.customize_css()
# special IPython command to prepare the notebook for matplotlib
%matplotlib inline
from collections import defaultdict
import pandas as pd
import matplotlib.pyplot as plt
import requests
from pattern import web
In this example we will fetch data about countries and their population from Wikipedia.
http://en.wikipedia.org/wiki/List_of_countries_by_past_and_future_population has several tables for individual countries, subcontinents as well as different years. We will combine the data for all countries and all years in a single panda dataframe and visualize the change in population for different countries.
To give you some starting points for your homework, we will also show the different sub-steps that can be taken to reach the presented solution.
In [6]:
url = 'http://en.wikipedia.org/wiki/List_of_countries_by_past_and_future_population'
website_html = requests.get(url).text
#print website_html
In [3]:
def get_population_html_tables(html):
"""Parse html and return html tables of wikipedia population data."""
dom = web.Element(html)
# 0. step: look at html source!
# 1. step: get all tables
# tbls = [t for t in dom.by_tag('table')]
# 2. step: get all wikitable sortable tables (the ones with data)
tbls = [t for t in dom.by_tag('table') if t.attributes['class'] == "wikitable sortable"]
return tbls
tables = get_population_html_tables(website_html)
print "table length: %d" %len(tables)
for t in tables:
print t.attributes
In [ ]:
def table_type(tbl):
headers = [th.content for th in tbl.by_tag('th')]
return headers[1]
# group the tables by type
tables_by_type = defaultdict(list) # defaultdicts have a default value that is inserted when a new key is accessed
for tbl in tables:
tables_by_type[table_type(tbl)].append(tbl)
print tables_by_type
In [4]:
def get_countries_population(tables):
"""Extract population data for countries from all tables and store it in dictionary."""
result = defaultdict(dict)
# 1. step: try to extract data for a single table
# 2. step: iterate over all tables, extract headings and actual data and combine data into single dict
for tbl in tables:
# extract column headers
# each table looks a little different, therefore extract columns that store data (i.e., table header is a year)
tbl_headers = [ th.content for th in tbl.by_tag('th')]
column_idx_years = [(idx, int(header)) for idx, header in enumerate(tbl_headers) if header.isnumeric()]
column_idx, column_years = zip(*column_idx_years)
# extract data from table
# get table rows - but skip the ones that have no td element
tbl_rows = [ row for row in tbl.by_tag('tr') if row.by_tag('td') ]
#print len(trs)
#print trs[0]
for row in tbl_rows:
#datarow = [td.content for td in tr.by_tag('td')]
#print datarow
# get country name - 2nd td, a href, convert unicode to string
countryname = (row.by_tag('td')[1].by_tag('a')[0].content).encode('ascii','ignore')
#print type(countryname)
#print countryname
# get country data - create a dictionary {1955: 10000, 1960: 14000,...}
# extract data from the columns in column_idx; strip commas from numers; scale number to millions
countrydata = {column_years[i]:int(row.by_tag('td')[idx].content.replace(',', ''))/1000.0 for i,idx in enumerate(column_idx) }
#print datarow
# append to dictionary
result[countryname].update(countrydata)
return result
result = get_countries_population(tables_by_type['Country or territory'])
print result
In [8]:
# create dataframe
df = pd.DataFrame.from_dict(result, orient='index')
# sort based on year
df.sort(axis=1,inplace=True)
print df
In [9]:
subtable = df.iloc[0:2, 0:2]
print "subtable"
print subtable
print ""
column = df[1955]
print "column"
print column
print ""
row = df.ix[0] #row 0
print "row"
print row
print ""
rows = df.ix[:2] #rows 0,1
print "rows"
print rows
print ""
element = df.ix[0,1955] #element
print "element"
print element
print ""
# max along column
print "max"
print df[1950].max()
print ""
# axes
print "axes"
print df.axes
print ""
row = df.ix[0]
print "row info"
print row.name
print row.index
print ""
countries = df.index
print "countries"
print countries
print ""
print "Austria"
print df.ix['Austria']
In [10]:
plotCountries = ['Austria', 'Germany', 'United States', 'France']
for country in plotCountries:
row = df.ix[country]
plt.plot(row.index, row, label=row.name )
plt.ylim(ymin=0) # start y axis at 0
plt.xticks(rotation=70)
plt.legend(loc='best')
plt.xlabel("Year")
plt.ylabel("# people (million)")
plt.title("Population of countries")
Out[10]:
In [11]:
def plot_populous(df, year):
# sort table depending on data value in year column
df_by_year = df.sort(year, ascending=False)
plt.figure()
for i in range(5):
row = df_by_year.ix[i]
plt.plot(row.index, row, label=row.name )
plt.ylim(ymin=0)
plt.xticks(rotation=70)
plt.legend(loc='best')
plt.xlabel("Year")
plt.ylabel("# people (million)")
plt.title("Most populous countries in %d" % year)
plot_populous(df, 2010)
plot_populous(df, 2050)
In [ ]: