In [118]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

Chinese restaurant process simluation (link):


In [133]:
class Table(object):
    def __init__(self, table_id):
        self.table_id = table_id
        self.seated = 0
        
    def add_guest(self):
        self.seated+= 1

class Restaurant(object):
    def __init__(self):
        sns.set_style("darkgrid")
        self.tables = []
        self.guests = 0
        self.plots = []
    
    def __iter__(self):
        return iter(self.tables)
    
    def __len__(self):
        return len(self.tables)
    
    def add_table(self, new_id):
        new_table = Table(new_id)
        new_table.add_guest()
        self.tables.append(new_table)
        
    def add_guest(self):
        total_guests = sum(table.seated for table in self.tables)
        
        if total_guests == 0:
            self.add_table(1)
        else:
            #probabilities for existing tables
            probs = [table.seated*1./(total_guests + 1) for table in self.tables]
            #probability of new table creation
            probs+= [1./(total_guests + 1)]
        
            tables = [table for table in self.tables] + ['new']
            
            #choice returns lists of objects
            chosen_table = np.random.choice(tables, 1, p=probs)[0]
            
            if not hasattr(chosen_table, "seated"):
                self.add_table(len(self.tables) + 1)
            else:
                chosen_table.add_guest()
                
    def plot_restaurant(self):
        plt.clf()
        plt.figure(figsize=(5,5))

        #use squarebox to avoid skewing the axes
        box_size = np.floor(np.sqrt(len(self.tables)))      
        limits = (-0.25 - (box_size - 2)*0.25 , box_size - 0.75 + (box_size - 2)*0.25)
        plt.xlim(limits)
        plt.ylim(limits)

        for table in self.tables:
            table_x = (table.table_id - 1) % box_size
            table_y = (table.table_id - 1) // box_size
            plt.scatter(table_x, table_y, s=5000, c='#fec44f')
            text = 'Table {}\n{} people'.format(table.table_id, table.seated)
            plt.annotate(text, xy=(table_x, table_y)
                        , horizontalalignment='center'
                        , verticalalignment='center')
            
            #plot people equally around the table, max 50
            factor = 0.2 + (box_size-2)*0.2
            placings = min(table.seated, 50)
            for person in range(placings):
                x = table_x + factor*np.cos(2*np.pi*person/placings)
                y = table_y + factor*np.sin(2*np.pi*person/placings)
                plt.scatter(x, y, c='#d95f0e')

        ax = plt.gca();
        ax.set_xticklabels([]);
        ax.set_yticklabels([]);
        self.plots.append(ax)

In [134]:
r = Restaurant()
for person in range(10000):
    r.add_guest()

r.plot_restaurant()


<matplotlib.figure.Figure at 0x31bf8ba8>