Another example from the web of using FB Prophet package

Source url Data corresponds to retail_sales example csv

this is part 2 focused on tweaking the model


In [1]:
retail_datafile = '../datasets/example_retail_sales.csv'

In [2]:
from fbprophet import Prophet
import numpy as np
import pandas as pd
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize']=(20,10)
plt.style.use('ggplot')

In [3]:
sales_df = pd.read_csv(retail_datafile)
sales_df['y_orig'] = sales_df['y'] # to save a copy of the original data..you'll see why shortly. 
# log-transform y
sales_df['y'] = np.log(sales_df['y'])
sales_df.tail()


Out[3]:
ds y y_orig
288 2016-01-01 12.901537 400928
289 2016-02-01 12.932543 413554
290 2016-03-01 13.039184 460093
291 2016-04-01 13.019078 450935
292 2016-05-01 13.063507 471421

Start modeling


In [4]:
model = Prophet()
model.fit(sales_df[['ds','y']])
#create 12 months of future data
future_data = model.make_future_dataframe(periods=12, freq = 'm')
 
#forecast the data for future data
forecast_data = model.predict(future_data)


Disabling weekly seasonality. Run prophet with weekly_seasonality=True to override this.

In [5]:
forecast_data[['ds','yhat','yhat_lower','yhat_upper']].tail()


Out[5]:
ds yhat yhat_lower yhat_upper
300 2016-12-31 12.945071 12.922593 12.968187
301 2017-01-31 12.968193 12.943553 12.994479
302 2017-02-28 13.064662 13.038743 13.090464
303 2017-03-31 13.054117 13.028228 13.083777
304 2017-04-30 13.106254 13.077519 13.137530

In [6]:
model.plot(forecast_data);


While this is a nice chart, it is kind of ‘busy’ for me. Additionally, I like to view my forecasts with original data first and forecasts appended to the end (this ‘might’ make sense in a minute).

First, we need to get our data combined and indexed appropriately to start plotting. We are only interested (at least for the purposes of this article) in the ‘yhat’, ‘yhat_lower’ and ‘yhat_upper’ columns from the Prophet forecasted dataset. Note: There are much more pythonic ways to these steps, but I’m breaking them out for each of understanding.


In [ ]:
#model.plot_components(forecast_data);

In [7]:
sales_df.set_index('ds', inplace=True)
forecast_data.set_index('ds', inplace=True)
viz_df = sales_df.join(forecast_data[['yhat', 'yhat_lower','yhat_upper']], how = 'outer')
del viz_df['y']
viz_df.tail()


Out[7]:
y_orig yhat yhat_lower yhat_upper
ds
2016-12-31 NaN 12.945071 12.922593 12.968187
2017-01-31 NaN 12.968193 12.943553 12.994479
2017-02-28 NaN 13.064662 13.038743 13.090464
2017-03-31 NaN 13.054117 13.028228 13.083777
2017-04-30 NaN 13.106254 13.077519 13.137530

The y_orig column is null because we don't know the future values.

Break down of how to vizualize this data in a little more detail than Prophet does by default.

Begin by getting the last date in the original dataset.


In [8]:
sales_df.index = pd.to_datetime(sales_df.index)
last_date = sales_df.index[-1]

In [11]:
from datetime import date,timedelta
def plot_data(func_df, end_date):
    end_date = end_date - timedelta(weeks=4) # find the 2nd to last row in the data. We don't take the last row because we want the charted lines to connect
    mask = (func_df.index > end_date) # set up a mask to pull out the predicted rows of data.
    predict_df = func_df.loc[mask] # using the mask, we create a new dataframe with just the predicted data.
    # Now...plot everything
    fig, ax1 = plt.subplots()
    ax1.plot(func_df.y_orig)
    ax1.plot((np.exp(predict_df.yhat)), color='black', linestyle=':')
    ax1.fill_between(predict_df.index, np.exp(predict_df['yhat_upper']), np.exp(predict_df['yhat_lower']), alpha=0.5, color='darkgray')
    ax1.set_title('Sales (Orange) vs Sales Forecast (Black)')
    ax1.set_ylabel('Dollar Sales')
    ax1.set_xlabel('Date')
    # change the legend text
    L=ax1.legend() #get the legend
    L.get_texts()[0].set_text('Actual Sales') #change the legend text for 1st plot
    L.get_texts()[1].set_text('Forecasted Sales') #change the legend text for 2nd plot

In [13]:
viz_df


Out[13]:
y_orig yhat yhat_lower yhat_upper
ds
1992-01-01 146376.0 11.874747 11.854859 11.894563
1992-02-01 147079.0 11.891276 11.869607 11.909054
1992-03-01 159336.0 12.011923 11.993578 12.033317
1992-04-01 163669.0 12.000119 11.981204 12.018980
1992-05-01 170068.0 12.055784 12.036544 12.075837
1992-06-01 168663.0 12.032517 12.012758 12.053107
1992-07-01 169890.0 12.037031 12.017323 12.058126
1992-08-01 170364.0 12.061187 12.041897 12.081546
1992-09-01 164617.0 11.997632 11.979330 12.017966
1992-10-01 173655.0 12.039979 12.020624 12.060597
1992-11-01 171547.0 12.047887 12.028075 12.066417
1992-12-01 208838.0 12.207260 12.187456 12.227095
1993-01-01 153221.0 11.960038 11.940140 11.979353
1993-02-01 150087.0 11.945889 11.926188 11.965732
1993-03-01 170439.0 12.084882 12.065958 12.105847
1993-04-01 176456.0 12.073754 12.054098 12.093468
1993-05-01 182231.0 12.129287 12.109366 12.147717
1993-06-01 181535.0 12.108931 12.090544 12.128783
1993-07-01 183682.0 12.112923 12.091647 12.131549
1993-08-01 183318.0 12.137985 12.119185 12.158088
1993-09-01 177406.0 12.074482 12.055428 12.094284
1993-10-01 182737.0 12.112660 12.092245 12.130864
1993-11-01 187443.0 12.123727 12.104297 12.144571
1993-12-01 224540.0 12.282309 12.261702 12.302200
1994-01-01 161349.0 12.032085 12.012764 12.052036
1994-02-01 162841.0 12.028248 12.007762 12.046540
1994-03-01 192319.0 12.157697 12.137539 12.177807
1994-04-01 189569.0 12.147374 12.128195 12.166446
1994-05-01 194927.0 12.202822 12.183232 12.222423
1994-06-01 197946.0 12.184078 12.163224 12.205169
... ... ... ... ...
2014-12-01 501232.0 13.137585 13.118307 13.157730
2015-01-01 397252.0 12.881165 12.860706 12.899020
2015-02-01 386935.0 12.884120 12.864747 12.905354
2015-03-01 444110.0 13.000705 12.980476 13.019388
2015-04-01 438217.0 12.987883 12.968549 13.008046
2015-05-01 462615.0 13.039964 13.019736 13.060016
2015-06-01 448229.0 13.022062 13.002051 13.041851
2015-07-01 457710.0 13.021578 13.002040 13.039352
2015-08-01 456340.0 13.044898 13.025893 13.064373
2015-09-01 430917.0 12.977990 12.959415 12.996866
2015-10-01 444959.0 13.004558 12.983742 13.023885
2015-11-01 444507.0 13.018680 12.998983 13.038165
2015-12-01 518253.0 13.171971 13.151416 13.192021
2016-01-01 400928.0 12.913038 12.893216 12.932344
2016-02-01 413554.0 12.926129 12.907918 12.945903
2016-03-01 460093.0 13.043560 13.022865 13.061371
2016-04-01 450935.0 13.028318 13.010294 13.048533
2016-05-01 471421.0 13.080655 13.060637 13.099672
2016-05-31 NaN 13.058162 13.037392 13.077521
2016-06-30 NaN 13.057050 13.036503 13.076663
2016-07-31 NaN 13.081158 13.062024 13.101362
2016-08-31 NaN 13.014234 12.992829 13.034020
2016-09-30 NaN 13.036696 13.017422 13.057148
2016-10-31 NaN 13.054147 13.032693 13.077056
2016-11-30 NaN 13.206271 13.181678 13.227129
2016-12-31 NaN 12.945071 12.922593 12.968187
2017-01-31 NaN 12.968193 12.943553 12.994479
2017-02-28 NaN 13.064662 13.038743 13.090464
2017-03-31 NaN 13.054117 13.028228 13.083777
2017-04-30 NaN 13.106254 13.077519 13.137530

305 rows × 4 columns


In [14]:
plot_data(viz_df, last_date)



In [16]:
plot_data(viz_df.loc['2010-01-01':], last_date)



In [ ]: