Website of data story : http://ada.gregunz.io/

To view the notebook plus use this nbviewer link

Summary

Most of the code is implemented in side functions to keep this notebook as clean as possible.

Summary of what each file does :

Final_notebook.ipynb -- (jupyter notebook)

  • Show our final progress

clean_data.py -- (python code functions)

  • Filter data we don't use
  • Clean the Gdelt data

fetch_gdelt_data.py -- (python code functions)

  • Download/Save/Load Gdelt data given dates

fetch_location.py -- (python code functions)

  • Find all Actions country in a dataframe using longitude and latitude

fetch_source_country.py -- (python code functions)

  • Find the country of the sources (newspapers/websites)

high_level_fetch.py -- (python code function)

  • Provide some functions to fetch/clean/load data by chunk (for example per month)

load_data.py -- (python code function)

  • Contain the function to download/clean/load the entire GDELT 2.0 dataset. If the final cleaned file is present on the disk it will obviously not do all the work and just load the file. This should be the only function used to load the data.

Note

All the Gdelt Events (2.0) data represent 200k files, 100GB uncompressed

After cleaning and only keeping the informations we need, we are left with only 1 file of 6.2GB uncompressed


In [28]:
import pandas as pd
import datetime
import numpy as np

Data Pipeline

This pipeline sums up all the data processing done before visualition. This includes data acquisition, data cleaning and data augmentation.

  1. We download the data from the Gdelt website. There are a file every 15 minutes (96 per day). Some powerful functions in fetch_gdelt_data.py makes the download easy. By providing a date (or range of dates) we can automatically download every files and save them.
  2. We load all the files into a single dataframes, this is directly done in fetch_gdelt_data.py after the download. (Note: if a file is already downloaded, it will load it automatically from storage)
  3. In our project we only need a few columns, hence we keep only the ones we need for our project ('EventCode', 'SOURCEURL', 'ActionGeo_CountryCode', 'ActionGeo_Lat', 'ActionGeo_Long', 'IsRootEvent', 'QuadClass', 'GoldsteinScale', 'AvgTone', 'NumMentions', 'NumSources', 'NumArticles', 'ActionGeo_Type', 'Day'), please refer to the Gdelt Codebook for the details about each field. This is done in the clean_data.py file.
  4. When some values are missing or invalid (e.g. not geographic position), we remove the row (done in the clean_data.py file as well).
  5. 'ActionGeo_CountryCode' do NOT use the ISO 3166-1 country codes. Hence, to be more consistent we find the correct country using the latitude and longitude and construct a mapping of their country codes to the ISO ones. This is done in [fetch_locations.py].
  6. For each event we have a source URL of the article from which the event comes from. Unfortunately we don't have the country from which the article comes from. For this reason we use this database and make use of the top level domain (tld) names which represent a country to determine it. All this is done in fetch_source_country.py

Important notice

What is below until the section "Post milestone 2 work" is what have been done before Milestone 3 and hence does not reflect 100% what is present in the final work, but we keep it for completness.

Data Fetching


In [2]:
# might be better to import the code and not a file (to show what we've done)

from fetch_gdelt_data import *
from clean_data import clean_df

Below, we specify the date interval of the data to load in a DataFrame for us to use and to download, if we do not already have the data locally.


In [3]:
start = datetime.date(2015, 3, 1)
end = datetime.date(2015, 4, 1)

To load and download the data, a simple function call is needed. We can specify whether we want the translingual version or the english only one.


In [4]:
test_df = fetch_df(start, end, translingual=True, v='2')

Data Cleaning and Selection

We will only keep the informations about the event type and location, the source URL and number of mentions, and the Goldstein scale and average tone of the Event. We drop every event with missing entries and add a column containing the ISO 3166-1 alpha-3 convention where the event happens.


In [5]:
selected_df = clean_df(test_df)

Data visualization

To show how we can visualize the data, we plan to use folium and plotly later on.


In [6]:
import json
import branca
import folium
from folium.plugins import HeatMap
from fetch_location import get_country_from_point, get_mapping

from plotly.offline import download_plotlyjs, init_notebook_mode, iplot
from plotly.graph_objs import *
init_notebook_mode()


We load the geojson that will be used to aggregate the data by country and display it in a choropleth.


In [7]:
world_geo_path = '../data/locations/countries.geo.json'
world_json_data = json.load(open(world_geo_path, encoding="UTF-8"))

The data contains a code for each country. From this we can aggregate the event easily together. However,the corresponding country name is not always the same as it is sometimes contains details on the city/state level.

For the reason above, it is not easy to know which country code corresponds to which polygon of the geojson. The easiest solution we found, was to test using the longitude and latitude in which polygon the event happened and create a mapping country_code -> polygon_name

Choropleth

We compute below the different metrics times their "importance", this could be done differently.


In [8]:
selected_df.loc[:,'pondered_GoldsteinScale'] = selected_df.loc[:,'GoldsteinScale'] * selected_df.loc[:,'NumMentions']
selected_df.loc[:,'pondered_AvgTone'] = selected_df.loc[:,'AvgTone'] * selected_df.loc[:,'NumMentions']

Once we have the mapping, we can give a score to each country based on a chosen metric (average tone in the news toward the event, Goldstein scale, etc...), and map the index of each country (country_code) to the name of the polygon that reprensent it.


In [9]:
chosen_metric = 'pondered_GoldsteinScale'

In [10]:
scores = selected_df.groupby('Country_Code')[chosen_metric].agg('mean')

rate_min = min(scores)
rate_max = max(scores)

# color scale from min rate to max rate
color_scale = branca.colormap.linear.RdYlGn.scale(rate_min, rate_max)
color_scale = color_scale.to_step(n=8)

def style_function(country):
    if country['id'] in scores.index.values: 
        # country is in the dataframe
        score = scores.loc[country['id']].mean()
        return {
            'fillOpacity': 0.8,
            'weight': 0.5,
            'color': 'black',
            'fillColor': color_scale(score)
        }
    else:
        # country is not in the dataframe, hence we put its color as black
        return {
            'fillOpacity': 0.2,
            'weight': 0.2,
            'color': 'black',
            'fillColor': 'black'
        }
def highlight_function(i):
    return {
                'weight': 2,
                'fillOpacity': .2
            }
world_map = folium.Map([25, 0], tiles='', zoom_start=2)
g = folium.GeoJson(world_json_data, style_function=style_function, highlight_function=highlight_function).add_to(world_map)

color_scale.caption = ' '.join(chosen_metric.split('_'))
color_scale.add_to(world_map)

del style_function
del highlight_function
world_map


Out[10]:

Here we plotted the pondered Goldstein scale for each country on the specified interval of time. We mutiplied the Goldstein scale for each event by the number of time its news source was mentionned during the first 15 minutes to give more weight to more important events.

This can be interpreted by saying that reddish countries have had event that shook its stability and that greener event have had more event that were more benifical toward the stability.

Heatmap

The function below will be usefull later on when we will need to select some event conditionally on one of their feature.


In [11]:
def select_events(df, feature, selector):
    '''Example of use : select_events(selected_df, 'EventCode', lambda x: x[:2] == '08')'''
    return df[df[feature].apply(selector)]

We could use the function select_events created earlier to fine tune the events we want to plot, and show them on a heatmap.

Using the feature EventCode, we will show every event that involve a threat toward someone.


In [24]:
selector = lambda x: x[:3] == '130'
selected_events = select_events(selected_df, 'EventCode', selector)

In [25]:
positions = selected_events.groupby(['ActionGeo_Lat', 'ActionGeo_Long'])['Day'].agg(['count'])

In [26]:
points = np.zeros((len(positions.values), 3))
for i in range(len(positions.values)):
    points[i][0] = positions.index[i][0]
    points[i][1] = positions.index[i][1]
    points[i][2] = positions.values[i]

In [27]:
m_h = folium.Map([25, 0], tiles='stamentoner', zoom_start=2)

HeatMap(points.tolist()).add_to(m_h)
m_h


Out[27]:

Here we can see that there is a higher density of events involving threats in Europe compared to North-America.

However, Heatmaps can be a bit ineffective when we are looking at the map with a small scall. In the future, we will rather use Bubble maps.

Sankey Diagram

We saw during the data exploration that most of the sources come with an URL. From this we tried to get in which country the news came from. Given a country X, the goal was to answer the following questions: "Which countries write about the events in X ?", "Where do the news writen in X happen ?".

To answer those question, we will use the Sankey Diagram.

Here, we will focus on Switzerland. We keep the data where the source of the event comes from Switzerland and group by the country it happens(right_t). We also keep the data where the event happens in Switzerland and group by the country the news comes from (left_t).


In [16]:
# Groupby and then aggregate for the left part and the right part of the sackey diagram
left_t = selected_df[selected_df['Country_Name'] == 'Switzerland'][['Country_Source', 'Day']].groupby('Country_Source').agg(['count'])
right_t = selected_df[selected_df['Country_Source'] == 'Switzerland'][['Country_Name', 'Day']].groupby('Country_Name').agg(['count'])

From those groupby, we only keep the top ten countries in both groups.


In [17]:
left = left_t['Day']['count'].sort_values(ascending=False)[:10].reset_index()
right = right_t['Day']['count'].sort_values(ascending=False)[:10].reset_index()

We then merge the two groups together so that we can plots them in a Sankey Diagram.


In [18]:
# Merge the two part in an horrible manner so that plotly can be used
data = right.reset_index()
data['Source'] = 'Switzerland '
data['Target'] = data['Country_Name'] + '  '
data['Value'] = data['count']
data['Label'] = data['Country_Name']

data2 = left.reset_index()
data2['Source'] = data2['Country_Source'] 
data2['Target'] = 'Switzerland '
data2['Value'] = data2['count']
data2['Label'] = data2['Country_Source']

l = np.concatenate([data['Source'].values, data2['Source'].values, data2['Target'].values, data['Target'].values], axis=0)

d = dict([(y,x) for x,y in enumerate(sorted(set(l)))])
data3 = pd.concat([data, data2])[['Source', 'Target', 'Value', 'Label']]
data3['Target'] = data3['Target'].map(d)
data3['Source'] = data3['Source'].map(d)

In [19]:
trace1 = {
  "domain": {
    "x": [0, 1], 
    "y": [0, 1]
  }, 
  "link": {
    "label": data3['Label'].values,
    "source": data3['Source'].values, 
    "target": data3['Target'].values, 
    "value": data3['Value'].values
  }, 
  "node": {"label": list(sorted(set(l)))}, 
  "type": "sankey"
}

layout =  dict(
    title = "Sankey Diagram",
    height = 1000,
    width = 1000,
    font = dict(
      size = 10
    ),    
)


fig = dict(data=[trace1], layout=layout)
iplot(fig)


Here we have, on the left, the countries where the news talk about event happening in Switzerland. NOENTRY corresponds to websites we do not know the location yet. We are still scrapping the web for more data and we should be done in a week or two. The first country on the left (if we do not count NOENTRY) is Switzerland, which is not surprising.

On the right side, we have the countries where the event our newspapers/websites write about happen. Here we also have Switzerland as the most frequent. The following ones will most likely be countries with conflicts and neighboring countries.

On a larger scale of time, we could see more trends, select only a part of the data to show interests depending on the event type for example and so on.

Post Milestone 2 work

More data cleaning

Although we had lots and powerful functions to fetch and clean the data, we were still unable to fetch, clean and store every data GDELT2.0 is providing us in an efficient manner.

To solve this problem, we mostly put all the pieces of code we had together. Also we added a cleaning step consisting of converting countries name to their ISO 3166-1 alpha-3 code. Finally we updated to a version 2 our code to get the mapping from source to country (the country in which the news was written). We now achieve an accuracy of 77%, that is 77% of the news get a source country assigned, the 23% left are mostly international website or website not considered an online newspaper, the unassigned news are discarded for this analysis.

Now to load the entirety of the data, a single command is needed. The original data files size is of more than 100 GB, our final data file is of 6.2 GB.

WARNING: The first time the command is exectued it takes A LOT of times to download and clean all the data from GDELT, at the end the function will create the above mentionned file of 6.2GB and the loading will take a few seconds as long as the file is present in the data folder.


In [ ]:
from load_data import load_data
df = load_data()

Final visualization

First let's note we let go the idea of using folium and are only using polt.ly now.

Also we created a website hosted on GitHub pages to write the data story and display everything in a nice way. Thankfully plot.ly allows us to easily export the visualization to javascript to include it in the website.


In [170]:
import viz
import datetime

import numpy as np
import pandas as pd
import pycountry as pyc

from sklearn import cluster
from sklearn import metrics
from sklearn import linear_model
from tqdm import tqdm_notebook as tqdm
from imp import reload
from matplotlib import pyplot as plt
from viz import world_map_figure

import plotly
import plotly.plotly as py
from plotly.graph_objs import Choropleth, Bar, Scatter
from plotly.offline import init_notebook_mode, iplot
from IPython.display import display, HTML

In [2]:
init_notebook_mode(connected=True)


MOST FREQUENT COUNTRIES


In [3]:
most_common_countries = pd.read_csv('../data/viz/most_common_countries.csv', header=None, names=['Country', 'Count'])

In [4]:
show_top = 10

In [5]:
most_common_countries['Percentage'] = most_common_countries['Count'] / most_common_countries['Count'].sum()

In [433]:
most_common_countries_data_plot = [Bar(
    x=most_common_countries['Country'][:show_top],
    y=most_common_countries['Percentage'][:show_top]
)]

most_common_countries_fig = {
    'data': most_common_countries_data_plot,
    'layout': {
        'title': 'Most Common Countries',
        'paper_bgcolor': 'rgba(0, 0, 0, 0)',
        'plot_bgcolor': 'rgba(0, 0, 0, 0)',
        'xaxis': {
            'title': 'Country',            
        },
        'yaxis': {
            'title': 'Proportion of total countries',           
        }
    }
}
    
iplot(most_common_countries_fig)



In [524]:
#print(plotly.offline.plot(most_common_countries_fig, include_plotlyjs=False, output_type='div'))

MOST FREQUENT UNKNOWN COUNTRIES


In [8]:
most_unknown_websites = pd.read_csv('../data/viz/most_unknown_websites.csv', header=None, names=['URL', 'Count'])

In [9]:
show_top = 10

In [10]:
most_unknown_websites['Percentage'] = most_unknown_websites['Count'] / most_unknown_websites['Count'].sum()

In [502]:
most_unknown_websites_data_plot = [Bar(
    x=most_unknown_websites['URL'][:show_top],
    y=most_unknown_websites['Percentage'][:show_top],
    )]

most_unknown_websites_fig = {
    'data': most_unknown_websites_data_plot,
    'layout': {
        'title': 'Most Frequent Unknown Websites',
        'paper_bgcolor': 'rgba(0, 0, 0, 0)',
        'plot_bgcolor': 'rgba(0, 0, 0, 0)',
        'xaxis': {
            'title': 'Website',            
        },
        'yaxis': {
            'title': 'Proportion of total websites',           
        }
    }
}

iplot(most_unknown_websites_fig)



In [525]:
#print(plotly.offline.plot(most_unknown_websites_fig, include_plotlyjs=False, output_type='div'))

Load all data


In [13]:
def select_events(df, feature, selector):
    '''Example of use : select_events(selected_df, 'EventCode', lambda x: x[:2] == '08')'''
    return df[df[feature].apply(selector)]

In [14]:
all_cca = [c.alpha_3 for c in pyc.countries]
all_cca_set = set(all_cca)

In [15]:
start_date = datetime.datetime(2015, 3, 1)
end_date = datetime.datetime(2017, 12, 1)

n_months = (end_date - start_date).days * 12 // 365

dates = []
for i in range(n_months):
    index = start_date.month - 1 + i
    month = index % 12 + 1
    year = start_date.year + index // 12
    date = "{}_{:02d}".format(year, month)
    dates.append(date)
    
dates_set = set(dates)

In [16]:
df = pd.read_csv('../data/final_data.csv', encoding='utf-8')

In [17]:
df = select_events(df, 'Target_CountryCode', lambda x: x in all_cca_set)

In [18]:
df = select_events(df, 'Source_CountryCode', lambda x: x in all_cca_set)

In [19]:
df['Year_Month'] = df['Day'].apply(str).apply(lambda x: x[:4] + '_' + x[4:6])

In [20]:
df = select_events(df, 'Year_Month', lambda x: x in dates_set)

WORLD MAP


In [22]:
reload(viz)


Out[22]:
<module 'viz' from 'C:\\Users\\Greg\\Programming\\Python\\ada2017\\project\\src\\viz.py'>

In [275]:
colorscale_perso = [[0.0, 'rgb(165,0,38)'], [0.1111111111111111, 'rgb(215,48,39)'], [0.2222222222222222, 'rgb(244,109,67)'], [0.3333333333333333, 'rgb(253,174,97)'], [0.4444444444444444, 'rgb(254,224,144)'], [0.5555555555555556, 'rgb(224,243,248)'], [0.6666666666666666, 'rgb(171,217,233)'], [0.7777777777777778, 'rgb(116,173,209)'], [0.8888888888888888, 'rgb(69,117,180)'], [1.0, 'rgb(49,54,149)']]

colorscale_perso1 = [[0.0, '0066CC'], [1, 'FFFFFF']]#, [0.2, 'F9DBBD'], [0.3, 'rgb(253,174,97)'], [0.4, 'rgb(254,224,144)'], [0.5, 'rgb(224,243,248)'], [0.6, 'rgb(171,217,233)'], [0.7, 'rgb(116,173,209)'], [0.8, 'rgb(69,117,180)'], [0.9, 'rgb(49,54,149)'], [1.0, 'rgb(49,54,149)']]

default_colorscale = [[0,'"rgb(5, 10, 172)"'],[0.35,"rgb(40, 60, 190)"],[0.5,"rgb(70, 100, 245)"],\
            [0.6,"rgb(90, 120, 245)"],[0.7,"rgb(106, 137, 247)"],[1,"rgb(220, 220, 220)"]],

AVG TONE of Target_CountryCode per month


In [24]:
# Pivot on countries and average on AvgTone for each month

df_tone_target = pd.pivot_table(df, values='AvgTone', index=['Target_CountryCode'], columns=['Year_Month'], aggfunc=np.median)

In [277]:
zmin_tone_target = df_tone_target.min().max()
zmax_tone_target = df_tone_target.max().min()

zmin_tone_target, zmax_tone_target


Out[277]:
(-3.8285918479571897, 1.1235955056179798)

In [290]:
figure_tone_target = world_map_figure(title='Average Tone Evolution',
                                      title_colorscale='Median<br>Average Tone',
                                      frames_title=dates,
                                      df=df_tone_target.dropna().reset_index(),
                                      locations_col='Target_CountryCode',
                                      txt_fn=lambda code: pyc.countries.get(alpha_3=code).name,
                                      zmin=zmin_tone_target,
                                      zmax=zmax_tone_target, 
                                      colorscale=colorscale_perso1)

iplot(figure_tone_target, validate=False)



In [526]:
#print(plotly.offline.plot(figure_tone_target, include_plotlyjs=False, output_type='div'))

GOLDSTEIN of Target_CountryCode per month


In [283]:
df['GoldsteinScalePondered'] = df['GoldsteinScale'] * df['NumMentions']

In [358]:
# Pivot on countries and average on AvgTone for each month

df_gs_target = pd.pivot_table(df, values='GoldsteinScale', index=['Target_CountryCode'], columns=['Year_Month'], aggfunc=np.mean)

In [359]:
zmin_gs_target = df_gs_target.min().median()
zmax_gs_target = df_gs_target.max().median()

zmin_gs_target, zmax_gs_target


Out[359]:
(-2.001569300557773, 2.7349056603773563)

In [361]:
figure_gs_target = world_map_figure(title='Goldstein Scale Evolution',
                                    title_colorscale='Median Pondered <br> Goldstein Scale',
                                    frames_title=dates,
                                    df=df_gs_target.dropna().reset_index(),
                                    locations_col='Target_CountryCode',
                                    txt_fn=lambda code: pyc.countries.get(alpha_3=code).name,
                                    zmin=zmin_gs_target,
                                    zmax=zmax_gs_target, 
                                    colorscale=colorscale_perso1)

iplot(figure_gs_target, validate=False)



In [527]:
#print(plotly.offline.plot(figure_gs_target, include_plotlyjs=False, output_type='div'))

In [35]:
%telepyth 'REALLY DONE 2'


Out[35]:
'REALLY DONE 2'

In [292]:
def tone_focus_on(df, code):
    df_target = select_events(df, 'Target_CountryCode', lambda x: x == code)
    return pd.pivot_table(df_target, values='AvgTone', index=['Source_CountryCode'], columns=['Year_Month'], aggfunc=np.mean).dropna()

In [293]:
df_tone_usa = tone_focus_on(df, 'USA')

In [297]:
zmin_tone_usa = df_tone_usa.min().max()
zmax_tone_usa = df_tone_usa.max().min()

zmin_tone_usa, zmax_tone_usa


Out[297]:
(-3.4250309954824449, 0.46150679794867328)

In [298]:
reload(viz)


Out[298]:
<module 'viz' from 'C:\\Users\\Greg\\Programming\\Python\\ada2017\\project\\src\\viz.py'>

In [299]:
figure_tone_us = world_map_figure(title='AvgTone toward/against USA - Evolution',
                                    title_colorscale='Median<br>AvgTone',
                                    frames_title=dates,
                                    df=df_tone_usa.reset_index(),
                                    locations_col='Source_CountryCode',
                                    txt_fn=lambda code: pyc.countries.get(alpha_3=code).name,
                                    zmin=zmin_tone_usa,
                                    zmax=zmax_tone_usa, 
                                    colorscale=colorscale_perso1)

iplot(figure_tone_us, validate=False)



In [528]:
#print(plotly.offline.plot(figure_tone_us, include_plotlyjs=False, output_type='div'))

In [337]:
def approx(y, degree=6):
    X = np.arange(len(y))
    X = X.reshape(X.shape + (1,))
    X = np.concatenate([X ** i for i in range(degree)], axis=1)
    #X -= X.mean()
    #X /= X.std()
    smoothing_model = linear_model.Lasso()
    smoothing_model.fit(X, y)
    return smoothing_model.predict(X), smoothing_model.coef_

In [323]:
def colors(c):
    if c == 'France':
        return 'rgb(22, 96, 167)'
    elif c == 'Switzerland':
        return 'rgb(205, 12, 24)'
    else:
        return 'rgb(0, 0, 0)'

In [324]:
def build_trace(x, y, name, polynomial_approx, mode='lines'):    
    if polynomial_approx:
        y, _ = approx(y)

    return Scatter(
        x = x,
        y = y,
        mode = mode,
        name = name + (' (approx)' if polynomial_approx else ''),
        line = {
            'color': colors(name)
        }
    )

In [493]:
def traces_to_fig(title, traces, xaxis='', yaxis=''):
    return {
        'data': traces,
        'layout': {
            'title': title,
            'paper_bgcolor': 'rgba(0, 0, 0, 0)',
            'plot_bgcolor': 'rgba(0, 0, 0, 0)',
            'xaxis': {
                'title': xaxis,            
            },
            'yaxis': {
                'title': yaxis,           
            }
        },
        
    }

In [500]:
def trends_to_fig(df, countries, title, xaxis, yaxis, centered=False, polynomial_approx=False):  
    y = df.copy()
    
    if centered:
        y -= y.mean()
    
    traces = [build_trace([d.replace('_', '/') for d in dates], y.loc[c], pyc.countries.get(alpha_3=c).name, poly) for c in some_countries for poly in set([False, polynomial_approx])]

    fig = traces_to_fig(title, traces, xaxis, yaxis)

    return fig

In [501]:
some_countries = ['CHE', 'FRA', 'MEX']
title = 'Trends in the Average Tone used to relate the events happening in the USA(common trend removed) (with approximation)'
xaxis = 'Date'
yaxis = 'Average Tone'
fig_trends = trends_to_fig(df_tone_usa, some_countries, title, xaxis, yaxis, centered=True, polynomial_approx=True)
iplot(fig_trends)


C:\Users\Greg\Anaconda3\envs\py36\lib\site-packages\sklearn\linear_model\coordinate_descent.py:491: ConvergenceWarning:

Objective did not converge. You might want to increase the number of iterations. Fitting data with very small alpha may cause precision problems.


In [529]:
#print(plotly.offline.plot(fig_trends, include_plotlyjs=False, output_type='div'))

In [80]:
def df_to_weights(df, countries):
    dict_ = {}
    
    bias = df.mean()
    bias -= bias.mean()
    bias /= bias.std()
    
    for code, y in zip(df.index, df.values):
        if code in countries:
            y -= y.mean()
            y /= y.std()
            y -= bias
            _, weights = approx(y, degree=4)
            dict_[code] = weights
    return dict_

In [400]:
def many_df_to_weights(df_list):
    all_weights = {}
    all_countries = set()
    
    for df in df_list:
        if len(all_countries) == 0:
            all_countries = set(df.index)
        else:
            all_countries = all_countries & set(df.index)

    for df in df_list:
        new_weights = df_to_weights(df, all_countries)
        for code in new_weights:
            if code in all_countries:
                w = new_weights[code]
                if code in all_weights:
                    all_weights[code] = np.append(all_weights[code], w)
                else:
                    all_weights[code] = w
                    
    codes = np.array(list(all_weights.keys()))
    weights = np.array(list(all_weights.values()))
    
    return codes, weights

In [401]:
def many_df_to_labels(df_list):
    codes, weights = many_df_to_weights(df_list)
    
    clusters = cluster.SpectralClustering(n_clusters=2)
    labels = clusters.fit(weights).labels_
    
    return codes, labels

In [403]:
all_clusters_df = None
for c in ['USA', 'CHN', 'FRA']:
    codes, cluster_idx = many_df_to_labels([tone_focus_on(df, c)])
    clusters_df = pd.DataFrame(cluster_idx, index=codes, columns=[c])
    if all_clusters_df is None:
        all_clusters_df = clusters_df
    else:
        all_clusters_df = pd.concat([all_clusters_df, clusters_df], axis=1, join='inner')

In [404]:
all_clusters_df.shape


Out[404]:
(110, 3)

In [422]:
reload(viz)


Out[422]:
<module 'viz' from 'C:\\Users\\Greg\\Programming\\Python\\ada2017\\project\\src\\viz.py'>

In [522]:
cs = [
        [0, 'rgb(255, 0, 0)'],
        [0.1, 'rgb(255, 0, 0)'],

        [0.1, 'rgb(0, 0, 255)'],
        [1.0, 'rgb(0, 0, 255)']
    ]

map_clusters = world_map_figure(title='Spectral Clustering',
                                title_colorscale='colobar',
                                frames_title=['CHN', 'FRA', 'USA'],
                                df=all_clusters_df.reset_index(),
                                locations_col='index',
                                txt_fn=lambda code: pyc.countries.get(alpha_3=code).name,
                                zmin=0,
                                zmax=1, 
                                colorscale=cs,
                                showscale=False,
                               )

iplot(map_clusters, validate=False)



In [530]:
#print(plotly.offline.plot(map_clusters, include_plotlyjs=False, output_type='div'))

In [504]:
codes, cluster_idx = many_df_to_labels([tone_focus_on(df, c) for c in ['USA', 'RUS', 'FRA', 'UKR']])
clusters_df_grouped = pd.DataFrame(cluster_idx, index=codes, columns=['USA RUS FRA UKR'])

In [519]:
map_clusters_grouped = world_map_figure(title='Spectral Clustering<br>(aggregated with USA RUS FRA UKR)',
                                title_colorscale='colobar',
                                frames_title=['USA RUS FRA UKR'],
                                df=clusters_df_grouped.reset_index(),
                                locations_col='index',
                                txt_fn=lambda code: pyc.countries.get(alpha_3=code).name,
                                zmin=0,
                                zmax=1, 
                                colorscale=cs,
                                showscale=False,
                               )

iplot(map_clusters_grouped, validate=False)



In [531]:
#print(plotly.offline.plot(map_clusters_grouped, include_plotlyjs=False, output_type='div'))

In [92]:
def get_silhouette(X, i):
    clusters = cluster.SpectralClustering(n_clusters=i)
    clusters.fit(X)
    labels = clusters.labels_
    return metrics.silhouette_score(X, labels, metric='euclidean')

In [186]:
source_and_target_countries = list(set(df['Source_CountryCode'].values) & set(df['Target_CountryCode'].values))

In [206]:
all_weights = []
for code in tqdm(source_and_target_countries):
    df_tone = tone_focus_on(df, code)
    weights = many_df_to_weights([df_tone])[1]
    all_weights.append(weights)
    #print(weights.shape)
    #if len(weights) > 5:
    #    sil = np.array([get_silhouette(weights, i) for i in range(2, 6)])
    #    all_silhouettes.append(sil)



Exception in thread Thread-14:
Traceback (most recent call last):
  File "C:\Users\Greg\Anaconda3\envs\py36\lib\threading.py", line 916, in _bootstrap_inner
    self.run()
  File "C:\Users\Greg\Anaconda3\envs\py36\lib\site-packages\tqdm\_tqdm.py", line 144, in run
    for instance in self.tqdm_cls._instances:
  File "C:\Users\Greg\Anaconda3\envs\py36\lib\_weakrefset.py", line 60, in __iter__
    for itemref in self.data:
RuntimeError: Set changed size during iteration



In [418]:
all_silhouettes = []
for w, c in tqdm(list(zip(all_weights, source_and_target_countries))):
    if len(w) > 100:
        sil = np.array([get_silhouette(w, i) for i in range(2, 6)])
        all_silhouettes.append((c, sil))




In [517]:
x = list(range(2, 9))
df_tone = tone_focus_on(df, 'AUS')
y = np.array([get_silhouette(many_df_to_weights([df_tone])[1], i) for i in x])
del df_tone

In [518]:
traces = [Scatter(x=x, y=y)]
fig = traces_to_fig('Silhouette', traces, 'Number of clusters', 'Silhouette score')

iplot(fig)



In [532]:
#print(plotly.offline.plot(fig, include_plotlyjs=False, output_type='div'))