Bar graphs are useful for displaying relationships between categorical data and at least one numerical variable. seaborn.countplot
is a barplot where the dependent variable is the number of instances of each instance of the independent variable.
dataset: IMDB 5000 Movie Dataset
In [20]:
%matplotlib inline
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
plt.rcParams['figure.figsize'] = (20.0, 10.0)
plt.rcParams['font.family'] = "serif"
In [21]:
df = pd.read_csv('../../datasets/movie_metadata.csv')
In [22]:
df.head()
Out[22]:
For the bar plot, let's look at the number of movies in each category, allowing each movie to be counted more than once.
In [23]:
# split each movie's genre list, then form a set from the unwrapped list of all genres
categories = set([s for genre_list in df.genres.unique() for s in genre_list.split("|")])
# one-hot encode each movie's classification
for cat in categories:
df[cat] = df.genres.transform(lambda s: int(cat in s))
# drop other columns
df = df[['director_name','genres','duration'] + list(categories)]
df.head()
Out[23]:
In [24]:
# convert from wide to long format and remove null classificaitons
df = pd.melt(df,
id_vars=['duration'],
value_vars = list(categories),
var_name = 'Category',
value_name = 'Count')
df = df.loc[df.Count>0]
top_categories = df.groupby('Category').aggregate(sum).sort_values('Count', ascending=False).index
howmany=10
# add an indicator whether a movie is short or long, split at 100 minutes runtime
df['islong'] = df.duration.transform(lambda x: int(x > 100))
df = df.loc[df.Category.isin(top_categories[:howmany])]
# sort in descending order
#df = df.loc[df.groupby('Category').transform(sum).sort_values('Count', ascending=False).index]
In [25]:
df.head()
Out[25]:
Basic plot
In [26]:
p = sns.countplot(data=df, x = 'Category')
color by a category
In [27]:
p = sns.countplot(data=df,
x = 'Category',
hue = 'islong')
make plot horizontal
In [28]:
p = sns.countplot(data=df,
y = 'Category',
hue = 'islong')
Saturation
In [29]:
p = sns.countplot(data=df,
y = 'Category',
hue = 'islong',
saturation=1)
Targeting a non-default axes
In [30]:
import matplotlib.pyplot as plt
fig, ax = plt.subplots(2)
sns.countplot(data=df,
y = 'Category',
hue = 'islong',
saturation=1,
ax=ax[1])
Out[30]:
Add error bars
In [31]:
import numpy as np
num_categories = df.Category.unique().size
p = sns.countplot(data=df,
y = 'Category',
hue = 'islong',
saturation=1,
xerr=7*np.arange(num_categories))
add black bounding lines
In [32]:
import numpy as np
num_categories = df.Category.unique().size
p = sns.countplot(data=df,
y = 'Category',
hue = 'islong',
saturation=1,
xerr=7*np.arange(num_categories),
edgecolor=(0,0,0),
linewidth=2)
Remove color fill
In [33]:
import numpy as np
num_categories = df.Category.unique().size
p = sns.countplot(data=df,
y = 'Category',
hue = 'islong',
saturation=1,
xerr=7*np.arange(num_categories),
edgecolor=(0,0,0),
linewidth=2,
fill=False)
In [34]:
import numpy as np
num_categories = df.Category.unique().size
p = sns.countplot(data=df,
y = 'Category',
hue = 'islong',
saturation=1,
xerr=7*np.arange(num_categories),
edgecolor=(0,0,0),
linewidth=2)
In [35]:
sns.set(font_scale=1.25)
num_categories = df.Category.unique().size
p = sns.countplot(data=df,
y = 'Category',
hue = 'islong',
saturation=1,
xerr=3*np.arange(num_categories),
edgecolor=(0,0,0),
linewidth=2)
In [36]:
help(sns.set)
In [37]:
plt.rcParams['font.family'] = "cursive"
#sns.set(style="white",font_scale=1.25)
num_categories = df.Category.unique().size
p = sns.countplot(data=df,
y = 'Category',
hue = 'islong',
saturation=1,
xerr=3*np.arange(num_categories),
edgecolor=(0,0,0),
linewidth=2)
In [38]:
plt.rcParams['font.family'] = 'Times New Roman'
#sns.set_style({'font.family': 'Helvetica'})
sns.set(style="white",font_scale=1.25)
num_categories = df.Category.unique().size
p = sns.countplot(data=df,
y = 'Category',
hue = 'islong',
saturation=1,
xerr=3*np.arange(num_categories),
edgecolor=(0,0,0),
linewidth=2)
In [39]:
bg_color = (0.25, 0.25, 0.25)
sns.set(rc={"font.style":"normal",
"axes.facecolor":bg_color,
"figure.facecolor":bg_color,
"text.color":"black",
"xtick.color":"black",
"ytick.color":"black",
"axes.labelcolor":"black"})
#sns.set_style({'font.family': 'Helvetica'})
#sns.set(style="white",font_scale=1.25)
num_categories = df.Category.unique().size
p = sns.countplot(data=df,
y = 'Category',
hue = 'islong',
saturation=1,
xerr=3*np.arange(num_categories),
edgecolor=(0,0,0),
linewidth=2)
leg = p.get_legend()
leg.set_title("Duration")
labs = leg.texts
labs[0].set_text("Short")
labs[1].set_text("Long")
leg.get_title().set_color('white')
for lab in labs:
lab.set_color('white')
p.axes.xaxis.label.set_text("Counts")
plt.text(900,0, "Bar Plot", fontsize = 95, color='white', fontstyle='italic')
In [40]:
p.get_figure().savefig('../figures/barplot.png')