In [118]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
Chinese restaurant process simluation (link):
In [133]:
class Table(object):
def __init__(self, table_id):
self.table_id = table_id
self.seated = 0
def add_guest(self):
self.seated+= 1
class Restaurant(object):
def __init__(self):
sns.set_style("darkgrid")
self.tables = []
self.guests = 0
self.plots = []
def __iter__(self):
return iter(self.tables)
def __len__(self):
return len(self.tables)
def add_table(self, new_id):
new_table = Table(new_id)
new_table.add_guest()
self.tables.append(new_table)
def add_guest(self):
total_guests = sum(table.seated for table in self.tables)
if total_guests == 0:
self.add_table(1)
else:
#probabilities for existing tables
probs = [table.seated*1./(total_guests + 1) for table in self.tables]
#probability of new table creation
probs+= [1./(total_guests + 1)]
tables = [table for table in self.tables] + ['new']
#choice returns lists of objects
chosen_table = np.random.choice(tables, 1, p=probs)[0]
if not hasattr(chosen_table, "seated"):
self.add_table(len(self.tables) + 1)
else:
chosen_table.add_guest()
def plot_restaurant(self):
plt.clf()
plt.figure(figsize=(5,5))
#use squarebox to avoid skewing the axes
box_size = np.floor(np.sqrt(len(self.tables)))
limits = (-0.25 - (box_size - 2)*0.25 , box_size - 0.75 + (box_size - 2)*0.25)
plt.xlim(limits)
plt.ylim(limits)
for table in self.tables:
table_x = (table.table_id - 1) % box_size
table_y = (table.table_id - 1) // box_size
plt.scatter(table_x, table_y, s=5000, c='#fec44f')
text = 'Table {}\n{} people'.format(table.table_id, table.seated)
plt.annotate(text, xy=(table_x, table_y)
, horizontalalignment='center'
, verticalalignment='center')
#plot people equally around the table, max 50
factor = 0.2 + (box_size-2)*0.2
placings = min(table.seated, 50)
for person in range(placings):
x = table_x + factor*np.cos(2*np.pi*person/placings)
y = table_y + factor*np.sin(2*np.pi*person/placings)
plt.scatter(x, y, c='#d95f0e')
ax = plt.gca();
ax.set_xticklabels([]);
ax.set_yticklabels([]);
self.plots.append(ax)
In [134]:
r = Restaurant()
for person in range(10000):
r.add_guest()
r.plot_restaurant()