In [ ]:
import json
import requests
import pandas as pd
import pickle
from tqdm import tqdm_notebook as tqdm
from math import radians, cos, sin, sqrt, atan2
from IPython.display import display, HTML
from datetime import datetime
import itertools
import time
from itertools import count

In [ ]:
columns = [
    'agency_id', 
    'service_date_id', 'service_date_date',
    'route_id', 'route_short_name', 'route_long_name',
    'trip_id', 'trip_headsign', 'trip_short_name',
    'stop_time_id', 'stop_time_arrival_time', 'stop_time_departure_time', 'stop_time_stop_sequence', 
    'stop_id', 'stop_stop_id', 'stop_name', 
    'capacity_path_id', 'capacity_path_path', 
    'capacity_capacity_id', 'capacity_capacity_capacity1st', 'capacity_capacity_capacity2nd'
]

In [ ]:
in_dir = "in_data/"
out_dir = "out_data/"

We will be using the transport.opendata.ch API. Unfortunately, they have a pretty strict rate limiting.


In [ ]:
def get_request_transport(params):
    base_url="http://transport.opendata.ch/v1/connections"
    r = requests.get(base_url, params=params)
    return r

buzz = buzzergen(0.4)
def buzzergen(period):
    nexttime = time.time() + period
    for i in count():
        now = time.time()
        tosleep = nexttime - now
        if tosleep > 0:
            time.sleep(tosleep)
            nexttime += period
        else:
            nexttime = now + period
        yield i, nexttime

In [ ]:
# Used for debugging purposes
# Print the list of station retrieved from the backend and from the API for comparison purposes

def list_back(trip, f, t ):
    l = []
    for index, stop in trip.iterrows():
        l.append([stop['stop_stop_id'], stop['stop_time_stop_sequence'],  stop['stop_name']])
        
    return l
        
def list_transport(connections):
    l = []
    for connection in connections:
        for section in connection['sections']:
            if section['journey']:
                for i, stop in enumerate(section['journey']['passList']):
                    l.append([stop['station']['id'], i,  stop['station']['name']])
    return l

def pretty_print(zip_val):
    back  = ["","",""]
    trans = ["","",""]
    if zip_val[0]:
        back = zip_val[0]
    if zip_val[1]:
        trans = zip_val[1]
    print("{:<8s} {:<3s} {:20s} | {:<8s} {:<3s} {:10s}".format(
            str(back[0]),str(back[1]),str(back[2]), 
            str(trans[0]),str(trans[1]),str(trans[2])))


In [ ]:
def process_trip(trip, date, time):
    
    def get_info_backend(trip):
        # Create a map of the stations id to their respective information
        # Also find the start and end of each trip
        stops = {}
        for index, stop in trip.iterrows():
            stops[str(stop['stop_stop_id'])] = stop
        
        min_seq = min(trip.stop_time_stop_sequence)
        max_seq = max(trip.stop_time_stop_sequence)
        
        s = trip[trip.stop_time_stop_sequence == min_seq].stop_stop_id.item()
        e = trip[trip.stop_time_stop_sequence == max_seq].stop_stop_id.item()   
        return stops, s, e
    
    def get_info_transport(connections): 
        # Create a map of the stations id to their respective information
        passList = []
        for connection in connections:
            journey_stops = {}
            for section in connection['sections']:
                if section['journey']:
                    journey_stops.update({stop['station']['id']: stop for stop in section['journey']['passList']})

            passList.append(journey_stops)

        return passList
    
    stop_backend, s, e = get_info_backend(trip)
    
    params = {
        'from': s,
        'to': e,
        'date': date,
        'time': time,
        'direct':1,
        'limit':1
    }
    
    trans_r = get_request_transport(params)
    
    try:
        sections_transport = get_info_transport(trans_r.json()['connections'])
    except Exception as e:
        raise ValueError(trans_r.url, e)
        
    # We verify if the sequence of stations are identical, if not, raise a ValueError
    for journey in sections_transport:
        if set(stop_backend.keys()) == set(journey.keys()):  
            return sections_transport, trans_r.json()
    
    print("########## Error ##########")
    print(trans_r.url)
    l_back = list_back(trip, s, e)
    l_tans = list_transport(trans_r.json()['connections'])
    
    l = list(itertools.zip_longest(l_back, l_tans))
    
    for v in l:
        pretty_print(v)
    
    
    raise ValueError("Stops mismatch")


In [ ]:
columns = [
    'agency_id', 
    'service_date_id', 'service_date_date',
    'route_id', 'route_short_name', 'route_long_name',
    'trip_id', 'trip_headsign', 'trip_short_name',
    'stop_time_id', 'stop_time_arrival_time', 'stop_time_departure_time', 'stop_time_stop_sequence', 
    'stop_id', 'stop_stop_id', 'stop_name', 
    'capacity_path_id', 'capacity_path_path', 
    'capacity_capacity_id', 'capacity_capacity_capacity1st', 'capacity_capacity_capacity2nd'
]

Retrieval of the capacities for the provided date.

It takes a long long time....


In [ ]:
dates = ['2017-01-30','2017-01-31','2017-02-01','2017-02-02','2017-02-03','2017-02-04','2017-02-05']

In [ ]:
for date in dates:
    df = pd.read_csv(out_dir + date + '_processed.csv', index_col=0)
    df.columns = columns
    grouped = df.groupby(['trip_id', ])
    error = []
    out = pd.DataFrame(columns=columns)
    
    for name, group in tqdm(grouped, desc="Trips"):
        trip = group.sort_values(['stop_time_stop_sequence'])
        start = datetime.fromtimestamp(trip[trip.stop_time_stop_sequence == 0].stop_time_departure_time).time()
        
        next(buzz)
        
        try:
            # Check if already retrieved
            for index, stop in trip.iterrows():
                  if stop.capacity_capacity_capacity1st or stop.capacity_capacity_capacity2nd:
                        continue

            transport_stops, transport_json = process_trip(trip, date, start)   

            for index, stop in trip.iterrows():
                stop.capacity_capacity_capacity1st = transport_stops[0][str(stop['stop_stop_id'])]['prognosis']['capacity1st']
                stop.capacity_capacity_capacity2nd = transport_stops[0][str(stop['stop_stop_id'])]['prognosis']['capacity2nd']
                out = out.append(stop) 

        except ValueError as e:
            print("=>", e)
            error.append(trip)
        except Exception as e:
            print("=>",e)
            error.append(trip)
            
    out.to_csv(out_dir + 'capacity_' + date + '.csv')
    pickle.dump(error, open(out_dir + 'capacity_error_' + date + '.pkl', 'wb'), protocol=2)