Canonical Correlation Analysis (CCA)

Example is taken from Section 12.5.3, Machine Learning: A Probabilistic Perspective by Kevin Murphy.


In [1]:
from symgp import *
from sympy import *

from IPython.display import display, Math, Latex

Set up shapes, variables and constants

We have two observed variables x and y of shapes (D_x,1) and (D_y,1) and the latent variables z_s, z_x, z_y of shapes (L_o,1), (L_x,1) and (L_y,1).


In [2]:
# Shapes
D_x, D_y, L_o, L_x, L_y = symbols('D_x, D_y, L_o L_x L_y')

# Variables
x, y, z_s, z_x, z_y = utils.variables('x y z_{s} z_{x} z_{y}', [D_x, D_y, L_o, L_x, L_y])

# Constants
B_x, W_x, mu_x, B_y, W_y, mu_y = utils.constants('B_{x} W_{x} mu_{x} B_{y} W_{y} mu_{y}',
                                                 [(D_x,L_x), (D_x,L_o), D_x, (D_y,L_y), (D_y,L_o), D_y])
sig = symbols('\u03c3')  # Noise standard deviation

Define the model


In [3]:
# p(z_s), p(z_x), p(z_y)
p_zs = MVG([z_s],mean=ZeroMatrix(L_o,1),cov=Identity(L_o)) 
p_zx = MVG([z_x],mean=ZeroMatrix(L_x,1),cov=Identity(L_x)) 
p_zy = MVG([z_y],mean=ZeroMatrix(L_y,1),cov=Identity(L_y)) 

display(Latex(utils.matLatex(p_zs)))
display(Latex(utils.matLatex(p_zx)))
display(Latex(utils.matLatex(p_zx)))


\begin{align*} p\left(\mathbf{z_{s}}\right)&= \mathcal{N}\left(\mathbf{z_{s}};\mathbf{m}_{\mathbf{z_{s}}},\mathbf{\Sigma}_{\mathbf{z_{s}}}\right)\\ \mathbf{m}_{\mathbf{z_{s}}} &= \mathbf{0}\\ \mathbf{\Sigma}_{\mathbf{z_{s}}} &= \mathbf{I}\\ \end{align*}
\begin{align*} p\left(\mathbf{z_{x}}\right)&= \mathcal{N}\left(\mathbf{z_{x}};\mathbf{m}_{\mathbf{z_{x}}},\mathbf{\Sigma}_{\mathbf{z_{x}}}\right)\\ \mathbf{m}_{\mathbf{z_{x}}} &= \mathbf{0}\\ \mathbf{\Sigma}_{\mathbf{z_{x}}} &= \mathbf{I}\\ \end{align*}
\begin{align*} p\left(\mathbf{z_{x}}\right)&= \mathcal{N}\left(\mathbf{z_{x}};\mathbf{m}_{\mathbf{z_{x}}},\mathbf{\Sigma}_{\mathbf{z_{x}}}\right)\\ \mathbf{m}_{\mathbf{z_{x}}} &= \mathbf{0}\\ \mathbf{\Sigma}_{\mathbf{z_{x}}} &= \mathbf{I}\\ \end{align*}

In [4]:
# p(z)
p_z = p_zs*p_zx*p_zy

display(Latex(utils.matLatex(p_z)))


\begin{align*} p\left(\mathbf{z_{s}},\mathbf{z_{x}},\mathbf{z_{y}}\right)&= \mathcal{N}\left(\left[\begin{smallmatrix}\mathbf{z_{s}}\\\mathbf{z_{x}}\\\mathbf{z_{y}}\end{smallmatrix}\right];\mathbf{m}_{\mathbf{z_{s}},\mathbf{z_{x}},\mathbf{z_{y}}},\mathbf{\Sigma}_{\mathbf{z_{s}},\mathbf{z_{x}},\mathbf{z_{y}}}\right)\\ \mathbf{m}_{\mathbf{z_{s}},\mathbf{z_{x}},\mathbf{z_{y}}} &= \left[\begin{smallmatrix}\mathbf{0}\\\mathbf{0}\\\mathbf{0}\end{smallmatrix}\right]\\ \mathbf{\Sigma}_{\mathbf{z_{s}},\mathbf{z_{x}},\mathbf{z_{y}}} &= \left[\begin{smallmatrix}\mathbf{I}&\mathbf{0}&\mathbf{0}\\\mathbf{0}&\mathbf{I}&\mathbf{0}\\\mathbf{0}&\mathbf{0}&\mathbf{I}\end{smallmatrix}\right]\\ \end{align*}

In [5]:
# p(x|z)
p_x_g_z = MVG([x],mean=B_x*z_x + W_x*z_s + mu_x,cov=sig**2*Identity(D_x),cond_vars=[z_x,z_s])

display(Latex(utils.matLatex(p_x_g_z)))


\begin{align*} p\left(\mathbf{x}|\mathbf{z_{x}},\mathbf{z_{s}}\right)&= \mathcal{N}\left(\mathbf{x};\mathbf{m}_{\mathbf{x}|\mathbf{z_{x}},\mathbf{z_{s}}},\mathbf{\Sigma}_{\mathbf{x}|\mathbf{z_{x}},\mathbf{z_{s}}}\right)\\ \mathbf{m}_{\mathbf{x}|\mathbf{z_{x}},\mathbf{z_{s}}} &= \mathbf{mu_{x}} + \mathbf{B_{x}} \mathbf{z_{x}} + \mathbf{W_{x}} \mathbf{z_{s}}\\ \mathbf{\Sigma}_{\mathbf{x}|\mathbf{z_{x}},\mathbf{z_{s}}} &= \sigma^{2} \mathbf{I}\\ \end{align*}

In [6]:
# p(y|z)
p_y_g_z = MVG([y],mean=B_y*z_y + W_y*z_s + mu_y,cov=sig**2*Identity(D_y),cond_vars=[z_y,z_s])

display(Latex(utils.matLatex(p_y_g_z)))


\begin{align*} p\left(\mathbf{y}|\mathbf{z_{y}},\mathbf{z_{s}}\right)&= \mathcal{N}\left(\mathbf{y};\mathbf{m}_{\mathbf{y}|\mathbf{z_{y}},\mathbf{z_{s}}},\mathbf{\Sigma}_{\mathbf{y}|\mathbf{z_{y}},\mathbf{z_{s}}}\right)\\ \mathbf{m}_{\mathbf{y}|\mathbf{z_{y}},\mathbf{z_{s}}} &= \mathbf{mu_{y}} + \mathbf{B_{y}} \mathbf{z_{y}} + \mathbf{W_{y}} \mathbf{z_{s}}\\ \mathbf{\Sigma}_{\mathbf{y}|\mathbf{z_{y}},\mathbf{z_{s}}} &= \sigma^{2} \mathbf{I}\\ \end{align*}

Obtain joint distribution p(x,y)


In [7]:
# p(v|z) (p(x,y|z_s,z_x,z_y))   We denote v = (x;y) and z = (z_s;z_x;z_y) 
p_v_g_z = p_x_g_z*p_y_g_z

display(Latex(utils.matLatex(p_v_g_z)))


\begin{align*} p\left(\mathbf{x},\mathbf{y}|\mathbf{z_{s}},\mathbf{z_{x}},\mathbf{z_{y}}\right)&= \mathcal{N}\left(\left[\begin{smallmatrix}\mathbf{x}\\\mathbf{y}\end{smallmatrix}\right];\mathbf{m}_{\mathbf{x},\mathbf{y}|\mathbf{z_{s}},\mathbf{z_{x}},\mathbf{z_{y}}},\mathbf{\Sigma}_{\mathbf{x},\mathbf{y}|\mathbf{z_{s}},\mathbf{z_{x}},\mathbf{z_{y}}}\right)\\ \mathbf{m}_{\mathbf{x},\mathbf{y}|\mathbf{z_{s}},\mathbf{z_{x}},\mathbf{z_{y}}} &= \left[\begin{smallmatrix}\mathbf{mu_{x}} + \mathbf{B_{x}} \mathbf{z_{x}} + \mathbf{W_{x}} \mathbf{z_{s}}\\\mathbf{mu_{y}} + \mathbf{B_{y}} \mathbf{z_{y}} + \mathbf{W_{y}} \mathbf{z_{s}}\end{smallmatrix}\right]\\ \mathbf{\Sigma}_{\mathbf{x},\mathbf{y}|\mathbf{z_{s}},\mathbf{z_{x}},\mathbf{z_{y}}} &= \left[\begin{smallmatrix}\sigma^{2} \mathbf{I}&\mathbf{0}\\\mathbf{0}&\sigma^{2} \mathbf{I}\end{smallmatrix}\right]\\ \end{align*}

In [8]:
# p(v,z) (p(x,y,z_s,z_x,z_y))
p_v_z = p_v_g_z*p_z

display(Latex(utils.matLatex(p_v_z)))


\begin{align*} p\left(\mathbf{x},\mathbf{y},\mathbf{z_{s}},\mathbf{z_{x}},\mathbf{z_{y}}\right)&= \mathcal{N}\left(\left[\begin{smallmatrix}\mathbf{x}\\\mathbf{y}\\\mathbf{z_{s}}\\\mathbf{z_{x}}\\\mathbf{z_{y}}\end{smallmatrix}\right];\mathbf{m}_{\mathbf{x},\mathbf{y},\mathbf{z_{s}},\mathbf{z_{x}},\mathbf{z_{y}}},\mathbf{\Sigma}_{\mathbf{x},\mathbf{y},\mathbf{z_{s}},\mathbf{z_{x}},\mathbf{z_{y}}}\right)\\ \mathbf{m}_{\mathbf{x},\mathbf{y},\mathbf{z_{s}},\mathbf{z_{x}},\mathbf{z_{y}}} &= \left[\begin{smallmatrix}\mathbf{mu_{x}}\\\mathbf{mu_{y}}\\\mathbf{0}\\\mathbf{0}\\\mathbf{0}\end{smallmatrix}\right]\\ \mathbf{\Sigma}_{\mathbf{x},\mathbf{y},\mathbf{z_{s}},\mathbf{z_{x}},\mathbf{z_{y}}} &= \left[\begin{smallmatrix}\sigma^{2} \mathbf{I} + \mathbf{B_{x}} \mathbf{B_{x}}^T + \mathbf{W_{x}} \mathbf{W_{x}}^T&\mathbf{W_{x}} \mathbf{W_{y}}^T&\mathbf{W_{x}}&\mathbf{B_{x}}&\mathbf{0}\\\mathbf{W_{y}} \mathbf{W_{x}}^T&\sigma^{2} \mathbf{I} + \mathbf{B_{y}} \mathbf{B_{y}}^T + \mathbf{W_{y}} \mathbf{W_{y}}^T&\mathbf{W_{y}}&\mathbf{0}&\mathbf{B_{y}}\\\mathbf{W_{x}}^T&\mathbf{W_{y}}^T&\mathbf{I}&\mathbf{0}&\mathbf{0}\\\mathbf{B_{x}}^T&\mathbf{0}&\mathbf{0}&\mathbf{I}&\mathbf{0}\\\mathbf{0}&\mathbf{B_{y}}^T&\mathbf{0}&\mathbf{0}&\mathbf{I}\end{smallmatrix}\right]\\ \end{align*}

In [9]:
# p(v) (p(x,y))
p_v = p_v_z.marginalise([z_s,z_x,z_y])

display(Latex(utils.matLatex(p_v)))


\begin{align*} p\left(\mathbf{x},\mathbf{y}\right)&= \mathcal{N}\left(\left[\begin{smallmatrix}\mathbf{x}\\\mathbf{y}\end{smallmatrix}\right];\mathbf{m}_{\mathbf{x},\mathbf{y}},\mathbf{\Sigma}_{\mathbf{x},\mathbf{y}}\right)\\ \mathbf{m}_{\mathbf{x},\mathbf{y}} &= \left[\begin{smallmatrix}\mathbf{mu_{x}}\\\mathbf{mu_{y}}\end{smallmatrix}\right]\\ \mathbf{\Sigma}_{\mathbf{x},\mathbf{y}} &= \left[\begin{smallmatrix}\sigma^{2} \mathbf{I} + \mathbf{B_{x}} \mathbf{B_{x}}^T + \mathbf{W_{x}} \mathbf{W_{x}}^T&\mathbf{W_{x}} \mathbf{W_{y}}^T\\\mathbf{W_{y}} \mathbf{W_{x}}^T&\sigma^{2} \mathbf{I} + \mathbf{B_{y}} \mathbf{B_{y}}^T + \mathbf{W_{y}} \mathbf{W_{y}}^T\end{smallmatrix}\right]\\ \end{align*}