Source url Data corresponds to retail_sales example csv
In [1]:
retail_datafile = '../datasets/example_retail_sales.csv'
In [2]:
from fbprophet import Prophet
import numpy as np
import pandas as pd
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize']=(20,10)
plt.style.use('ggplot')
In [3]:
sales_df = pd.read_csv(retail_datafile)
sales_df['y_orig'] = sales_df['y'] # to save a copy of the original data..you'll see why shortly.
# log-transform y
sales_df['y'] = np.log(sales_df['y'])
sales_df.tail()
Out[3]:
In [4]:
model = Prophet()
model.fit(sales_df[['ds','y']])
#create 12 months of future data
future_data = model.make_future_dataframe(periods=12, freq = 'm')
#forecast the data for future data
forecast_data = model.predict(future_data)
In [5]:
forecast_data[['ds','yhat','yhat_lower','yhat_upper']].tail()
Out[5]:
In [6]:
model.plot(forecast_data);
While this is a nice chart, it is kind of ‘busy’ for me. Additionally, I like to view my forecasts with original data first and forecasts appended to the end (this ‘might’ make sense in a minute).
First, we need to get our data combined and indexed appropriately to start plotting. We are only interested (at least for the purposes of this article) in the ‘yhat’, ‘yhat_lower’ and ‘yhat_upper’ columns from the Prophet forecasted dataset. Note: There are much more pythonic ways to these steps, but I’m breaking them out for each of understanding.
In [ ]:
#model.plot_components(forecast_data);
In [7]:
sales_df.set_index('ds', inplace=True)
forecast_data.set_index('ds', inplace=True)
viz_df = sales_df.join(forecast_data[['yhat', 'yhat_lower','yhat_upper']], how = 'outer')
del viz_df['y']
viz_df.tail()
Out[7]:
The y_orig column is null because we don't know the future values.
Break down of how to vizualize this data in a little more detail than Prophet does by default.
Begin by getting the last date in the original dataset.
In [8]:
sales_df.index = pd.to_datetime(sales_df.index)
last_date = sales_df.index[-1]
In [11]:
from datetime import date,timedelta
def plot_data(func_df, end_date):
end_date = end_date - timedelta(weeks=4) # find the 2nd to last row in the data. We don't take the last row because we want the charted lines to connect
mask = (func_df.index > end_date) # set up a mask to pull out the predicted rows of data.
predict_df = func_df.loc[mask] # using the mask, we create a new dataframe with just the predicted data.
# Now...plot everything
fig, ax1 = plt.subplots()
ax1.plot(func_df.y_orig)
ax1.plot((np.exp(predict_df.yhat)), color='black', linestyle=':')
ax1.fill_between(predict_df.index, np.exp(predict_df['yhat_upper']), np.exp(predict_df['yhat_lower']), alpha=0.5, color='darkgray')
ax1.set_title('Sales (Orange) vs Sales Forecast (Black)')
ax1.set_ylabel('Dollar Sales')
ax1.set_xlabel('Date')
# change the legend text
L=ax1.legend() #get the legend
L.get_texts()[0].set_text('Actual Sales') #change the legend text for 1st plot
L.get_texts()[1].set_text('Forecasted Sales') #change the legend text for 2nd plot
In [13]:
viz_df
Out[13]:
In [14]:
plot_data(viz_df, last_date)
In [16]:
plot_data(viz_df.loc['2010-01-01':], last_date)
In [ ]: