In this Jupyter notebook we generate a plot for Wikipedia, illustrating the graph of a few probability density functions for the Dirichlet distribution, corresponding to different parameter vectors $\alpha$.
In [1]:
import numpy as np
import matplotlib.tri as tri
import scipy.stats as st
Dirichlet distribution is defined on the open simplex $\{(x_1, x_2, x_2)\:|\: x_1+x_2+x_3=1, x_k\in(0,1)\}$.
$(x_1, x_2, x_3)$ are interpreted as the baricentric coordinates of the points in a planar triangle.
We take an equilateral triangle and subdivide it uniformly and recursively, by a procedure of type 1-to-4 split:
In [2]:
def cartesian2baric(point, M, dist=1.e-15):
# point is list or tuple of two floats representing the cartesia coordinates of a 2d point
# M is the matrix of the transformation from cartesian to barycentric coordinates
baric = np.dot(M, np.array(point +(1,)))
return np.clip(baric, dist, 1.0 - dist) # clip the baric to force it belong to the open simplex
def uniftriang(vertices, subdiv_level=7):
#define a uniform triangulation of the triangle of vertices vertices
triangle = tri.Triangulation(vertices[:, 0], vertices[:, 1])
refined_tri = tri.UniformTriRefiner(triangle)
finaltri = refined_tri.refine_triangulation(subdiv=subdiv_level)
#finaltri.triangles are the simplices of the triangulation
#finaltri.x, finaltri.y are the cartesian coordinates of the triangulation vertices
return finaltri
Define the vertices of an equilateral triangle, subdivide it, and compute the baricentric coordinates of the triangulation points:
In [3]:
tri_vertices = np.array([[0, 0], [1, 0], [0.5, np.sqrt(3)/2]])
A = np.array([[0, 1, 0.5], [0, 0, np.sqrt(3)/2], [1, 1, 1]]) #transformation matrix from barycentric to cartesian coords
invA = np.linalg.inv(A)
triangul = uniftriang(tri_vertices)
baric_coords = [cartesian2baric(point, invA) for point in zip(triangul.x, triangul.y)]
We define and plot the surface representing a Dirichlet probability density function as a Plotly Mesh3d:
In [4]:
import plotly.plotly as py
import plotly.graph_objs as go
from plotly import tools as tls
In [5]:
def plotly_triangular_mesh(vertices, simplices, intensities=None, colorscale="Viridis",
flatshading=False, showscale=False, reversescale=False, plot_edges=False):
#vertices - vertices of the triangulation; a numpy array of shape (n_vertices, 3)
#simplices - simplices (subtriangles) of the triangulation; a numpy array of shape (n_simplices, 3)
#intensities can be either a function of (x,y,z) or a list of values; if it is None the intensity is z
x, y, z = vertices.T
I, J, K = simplices.T
if intensities is None:
intensity = z
elif hasattr(intensities, '__call__'):
intensity = intensities(x,y,z)
elif isinstance(intensities, (list, np.ndarray)):
intensity = intensities #intensities are given in a list
else:
raise ValueError("intensities can be either a function or a list, np.array")
return dict(type='mesh3d',
x=x,
y=y,
z=z,
colorscale=colorscale,
reversescale=reversescale,
intensity= intensity,
flatshading=flatshading,
i=I,
j=J,
k=K,
name='',
showscale=showscale
)
Define a list of parameters $\alpha$ for the Dirichlet distributions to be plotted:
In [6]:
alpha = [[(1.3, 1.3, 1.3), (3, 3, 3), (7, 7, 7)],
[(2, 6, 11), (14, 9, 5), (6, 2, 6)]]
m = len(alpha)
n = len(alpha[0])
In [7]:
fig = tls.make_subplots(rows=m, cols=n, vertical_spacing=0.0075, horizontal_spacing=0.025,
specs=[ [{'is_3d': True}, {'is_3d': True}, {'is_3d': True}],
[{'is_3d': True}, {'is_3d': True}, {'is_3d': True}]
])
In [8]:
scenes = [['scene{}'.format(j+1+i*n) for j in range(n)] for i in range(m)]
scenes
Out[8]:
In [9]:
axis = dict(showbackground=True,
backgroundcolor="rgb(230, 230,230)",
gridcolor="rgb(255, 255, 255)",
zerolinecolor="rgb(255, 255, 255)",
tickfont=dict(size=11),
titlefont =dict(size=12))
scene = dict(xaxis=dict(axis),
yaxis=dict(axis),
zaxis=dict(axis),
aspectratio=dict(x=1, y=1, z=0.25))
fig.update_scenes(scene);
In [10]:
pl_deep = [[0.0, 'rgb(253, 253, 204)'],
[0.1, 'rgb(201, 235, 177)'],
[0.2, 'rgb(145, 216, 163)'],
[0.3, 'rgb(102, 194, 163)'],
[0.4, 'rgb(81, 168, 162)'],
[0.5, 'rgb(72, 141, 157)'],
[0.6, 'rgb(64, 117, 152)'],
[0.7, 'rgb(61, 90, 146)'],
[0.8, 'rgb(65, 64, 123)'],
[0.9, 'rgb(55, 44, 80)'],
[1.0, 'rgb(39, 26, 44)']]
In [11]:
for i in range(m):
for j in range(n):
X = st.dirichlet(np.array(alpha[i][j]))
C = [X.pdf(baric_coords[k]) for k in range(len(baric_coords)) ]
zmax = max(C)
surf_vertices = np.vstack((triangul.x, triangul.y, C)).T # vertices of the surface triangulation
trace = plotly_triangular_mesh(surf_vertices, triangul.triangles, intensities=None, colorscale=pl_deep)
fig.append_trace(trace, i+1, j+1)
fig.update_scenes({'zaxis': {'tickvals': [round(zmax/2,1), round(zmax,1)]}})
fig.layout.update(title='Dirichlet distribution over an open 2-simplex'+
'<br> alpha=(1.3, 1.3, 1.3), (3, 3, 3), (7, 7, 7), '+
'<br>(2, 6, 11), (14, 9, 5), (6, 2, 6) ',
font=dict(family='Georgia, serif',
size=14),
margin=dict(t=135),
height=800,
width=900,
showlegend=False
);
In [13]:
fw = go.FigureWidget(fig)
fw
In [14]:
from IPython.display import IFrame
IFrame('https://plot.ly/~empet/13886/', width=900, height=800)
Out[14]:
In [ ]: