Huge Monty Hall Bayesian Network

authors:
Jacob Schreiber [jmschreiber91@gmail.com]
Nicholas Farn [nicholasfarn@gmail.com]

Lets expand the Bayesian network for the monty hall problem in order to make sure that training with all types of wild types works properly.


In [1]:
import math
from pomegranate import *

We'll create the discrete distribution for our friend first.


In [2]:
friend = DiscreteDistribution( { True: 0.5, False: 0.5 } )

The emissions for our guest are completely random.


In [3]:
guest = ConditionalProbabilityTable(
	[[ True, 'A', 0.50 ],
	 [ True, 'B', 0.25 ],
	 [ True, 'C', 0.25 ],
	 [ False, 'A', 0.0 ],
	 [ False, 'B', 0.7 ],
	 [ False, 'C', 0.3 ]], [friend] )

Then the distribution for the remaining cars.


In [4]:
remaining = DiscreteDistribution( { 0: 0.1, 1: 0.7, 2: 0.2, } )

The probability of whether the prize is randomized is dependent on the number of remaining cars.


In [5]:
randomize = ConditionalProbabilityTable(
	[[ 0, True , 0.05 ],
     [ 0, False, 0.95 ],
     [ 1, True , 0.8 ],
     [ 1, False, 0.2 ],
     [ 2, True , 0.5 ],
     [ 2, False, 0.5 ]], [remaining] )

Now the conditional probability table for the prize. This is dependent on the guest's friend and whether or not it is randomized.


In [6]:
prize = ConditionalProbabilityTable(
	[[ True, True, 'A', 0.3 ],
	 [ True, True, 'B', 0.4 ],
	 [ True, True, 'C', 0.3 ],
	 [ True, False, 'A', 0.2 ],
	 [ True, False, 'B', 0.4 ],
	 [ True, False, 'C', 0.4 ],
	 [ False, True, 'A', 0.1 ],
	 [ False, True, 'B', 0.9 ],
	 [ False, True, 'C', 0.0 ],
	 [ False, False, 'A', 0.0 ],
	 [ False, False, 'B', 0.4 ],
	 [ False, False, 'C', 0.6]], [randomize, friend] )

Finally we can create the conditional probability table for our Monty. This is dependent on the guest and the prize.


In [7]:
monty = ConditionalProbabilityTable(
	[[ 'A', 'A', 'A', 0.0 ],
	 [ 'A', 'A', 'B', 0.5 ],
	 [ 'A', 'A', 'C', 0.5 ],
	 [ 'A', 'B', 'A', 0.0 ],
	 [ 'A', 'B', 'B', 0.0 ],
	 [ 'A', 'B', 'C', 1.0 ],
	 [ 'A', 'C', 'A', 0.0 ],
	 [ 'A', 'C', 'B', 1.0 ],
	 [ 'A', 'C', 'C', 0.0 ],
	 [ 'B', 'A', 'A', 0.0 ],
	 [ 'B', 'A', 'B', 0.0 ],
	 [ 'B', 'A', 'C', 1.0 ],
	 [ 'B', 'B', 'A', 0.5 ],
	 [ 'B', 'B', 'B', 0.0 ],
	 [ 'B', 'B', 'C', 0.5 ],
	 [ 'B', 'C', 'A', 1.0 ],
	 [ 'B', 'C', 'B', 0.0 ],
	 [ 'B', 'C', 'C', 0.0 ],
	 [ 'C', 'A', 'A', 0.0 ],
	 [ 'C', 'A', 'B', 1.0 ],
	 [ 'C', 'A', 'C', 0.0 ],
	 [ 'C', 'B', 'A', 1.0 ],
	 [ 'C', 'B', 'B', 0.0 ],
	 [ 'C', 'B', 'C', 0.0 ],
	 [ 'C', 'C', 'A', 0.5 ],
	 [ 'C', 'C', 'B', 0.5 ],
	 [ 'C', 'C', 'C', 0.0 ]], [guest, prize] )

Now we can create the states for our bayesian network.


In [8]:
s0 = State( friend, name="friend")
s1 = State( guest, name="guest" )
s2 = State( prize, name="prize" )
s3 = State( monty, name="monty" )
s4 = State( remaining, name="remaining" )
s5 = State( randomize, name="randomize" )

Now we'll create our bayesian network with an instance of BayesianNetwork, then add the possible states.


In [9]:
network = BayesianNetwork( "test" )
network.add_states(s0, s1, s2, s3, s4, s5)

Then the possible transitions.


In [10]:
network.add_transition( s0, s1 )
network.add_transition( s1, s3 )
network.add_transition( s2, s3 )
network.add_transition( s4, s5 )
network.add_transition( s5, s2 )
network.add_transition( s0, s2 )

With a "bake" to finalize the structure of our network.


In [11]:
network.bake()

Now let's create our network from the following data.


In [12]:
data = [[ True,  'A', 'A', 'C', 1, True  ],
		[ True,  'A', 'A', 'C', 0, True  ],
		[ False, 'A', 'A', 'B', 1, False ],
		[ False, 'A', 'A', 'A', 2, False ],
		[ False, 'A', 'A', 'C', 1, False ],
		[ False, 'B', 'B', 'B', 2, False ],
		[ False, 'B', 'B', 'C', 0, False ],
		[ True,  'C', 'C', 'A', 2, True  ],
		[ True,  'C', 'C', 'C', 1, False ],
		[ True,  'C', 'C', 'C', 0, False ],
		[ True,  'C', 'C', 'C', 2, True  ],
		[ True,  'C', 'B', 'A', 1, False ]]

network.fit( data )


Out[12]:
{
    "class" : "BayesianNetwork",
    "name" : "test",
    "structure" : [
        [],
        [
            0
        ],
        [
            5,
            0
        ],
        [
            1,
            2
        ],
        [],
        [
            4
        ]
    ],
    "states" : [
        {
            "class" : "State",
            "distribution" : {
                "class" : "Distribution",
                "dtype" : "bool",
                "name" : "DiscreteDistribution",
                "parameters" : [
                    {
                        "True" : 0.5833333333333334,
                        "False" : 0.4166666666666667
                    }
                ],
                "frozen" : false
            },
            "name" : "friend",
            "weight" : 1.0
        },
        {
            "class" : "State",
            "distribution" : {
                "class" : "Distribution",
                "name" : "ConditionalProbabilityTable",
                "table" : [
                    [
                        "True",
                        "A",
                        "0.2857142857142857"
                    ],
                    [
                        "True",
                        "B",
                        "0.0"
                    ],
                    [
                        "True",
                        "C",
                        "0.7142857142857143"
                    ],
                    [
                        "False",
                        "A",
                        "0.6"
                    ],
                    [
                        "False",
                        "B",
                        "0.4"
                    ],
                    [
                        "False",
                        "C",
                        "0.0"
                    ]
                ],
                "dtypes" : [
                    "bool",
                    "str",
                    "float"
                ],
                "parents" : [
                    {
                        "class" : "Distribution",
                        "dtype" : "bool",
                        "name" : "DiscreteDistribution",
                        "parameters" : [
                            {
                                "True" : 0.5833333333333334,
                                "False" : 0.4166666666666667
                            }
                        ],
                        "frozen" : false
                    }
                ]
            },
            "name" : "guest",
            "weight" : 1.0
        },
        {
            "class" : "State",
            "distribution" : {
                "class" : "Distribution",
                "name" : "ConditionalProbabilityTable",
                "table" : [
                    [
                        "True",
                        "True",
                        "A",
                        "0.5"
                    ],
                    [
                        "True",
                        "True",
                        "B",
                        "0.0"
                    ],
                    [
                        "True",
                        "True",
                        "C",
                        "0.5"
                    ],
                    [
                        "True",
                        "False",
                        "A",
                        "0.3333333333333333"
                    ],
                    [
                        "True",
                        "False",
                        "B",
                        "0.3333333333333333"
                    ],
                    [
                        "True",
                        "False",
                        "C",
                        "0.3333333333333333"
                    ],
                    [
                        "False",
                        "True",
                        "A",
                        "0.0"
                    ],
                    [
                        "False",
                        "True",
                        "B",
                        "0.3333333333333333"
                    ],
                    [
                        "False",
                        "True",
                        "C",
                        "0.6666666666666666"
                    ],
                    [
                        "False",
                        "False",
                        "A",
                        "0.6"
                    ],
                    [
                        "False",
                        "False",
                        "B",
                        "0.4"
                    ],
                    [
                        "False",
                        "False",
                        "C",
                        "0.0"
                    ]
                ],
                "dtypes" : [
                    "bool",
                    "bool",
                    "str",
                    "float"
                ],
                "parents" : [
                    {
                        "class" : "Distribution",
                        "name" : "ConditionalProbabilityTable",
                        "table" : [
                            [
                                "0",
                                "False",
                                "0.6666666666666666"
                            ],
                            [
                                "0",
                                "True",
                                "0.3333333333333333"
                            ],
                            [
                                "1",
                                "False",
                                "0.8"
                            ],
                            [
                                "1",
                                "True",
                                "0.2"
                            ],
                            [
                                "2",
                                "False",
                                "0.5"
                            ],
                            [
                                "2",
                                "True",
                                "0.5"
                            ]
                        ],
                        "dtypes" : [
                            "int",
                            "bool",
                            "float"
                        ],
                        "parents" : [
                            {
                                "class" : "Distribution",
                                "dtype" : "int",
                                "name" : "DiscreteDistribution",
                                "parameters" : [
                                    {
                                        "0" : 0.25,
                                        "1" : 0.4166666666666667,
                                        "2" : 0.3333333333333333
                                    }
                                ],
                                "frozen" : false
                            }
                        ]
                    },
                    {
                        "class" : "Distribution",
                        "dtype" : "bool",
                        "name" : "DiscreteDistribution",
                        "parameters" : [
                            {
                                "True" : 0.5833333333333334,
                                "False" : 0.4166666666666667
                            }
                        ],
                        "frozen" : false
                    }
                ]
            },
            "name" : "prize",
            "weight" : 1.0
        },
        {
            "class" : "State",
            "distribution" : {
                "class" : "Distribution",
                "name" : "ConditionalProbabilityTable",
                "table" : [
                    [
                        "A",
                        "A",
                        "A",
                        "0.2"
                    ],
                    [
                        "A",
                        "A",
                        "B",
                        "0.2"
                    ],
                    [
                        "A",
                        "A",
                        "C",
                        "0.6"
                    ],
                    [
                        "A",
                        "B",
                        "A",
                        "0.3333333333333333"
                    ],
                    [
                        "A",
                        "B",
                        "B",
                        "0.3333333333333333"
                    ],
                    [
                        "A",
                        "B",
                        "C",
                        "0.3333333333333333"
                    ],
                    [
                        "A",
                        "C",
                        "A",
                        "0.3333333333333333"
                    ],
                    [
                        "A",
                        "C",
                        "B",
                        "0.3333333333333333"
                    ],
                    [
                        "A",
                        "C",
                        "C",
                        "0.3333333333333333"
                    ],
                    [
                        "B",
                        "A",
                        "A",
                        "0.3333333333333333"
                    ],
                    [
                        "B",
                        "A",
                        "B",
                        "0.3333333333333333"
                    ],
                    [
                        "B",
                        "A",
                        "C",
                        "0.3333333333333333"
                    ],
                    [
                        "B",
                        "B",
                        "A",
                        "0.0"
                    ],
                    [
                        "B",
                        "B",
                        "B",
                        "0.5"
                    ],
                    [
                        "B",
                        "B",
                        "C",
                        "0.5"
                    ],
                    [
                        "B",
                        "C",
                        "A",
                        "0.3333333333333333"
                    ],
                    [
                        "B",
                        "C",
                        "B",
                        "0.3333333333333333"
                    ],
                    [
                        "B",
                        "C",
                        "C",
                        "0.3333333333333333"
                    ],
                    [
                        "C",
                        "A",
                        "A",
                        "0.3333333333333333"
                    ],
                    [
                        "C",
                        "A",
                        "B",
                        "0.3333333333333333"
                    ],
                    [
                        "C",
                        "A",
                        "C",
                        "0.3333333333333333"
                    ],
                    [
                        "C",
                        "B",
                        "A",
                        "1.0"
                    ],
                    [
                        "C",
                        "B",
                        "B",
                        "0.0"
                    ],
                    [
                        "C",
                        "B",
                        "C",
                        "0.0"
                    ],
                    [
                        "C",
                        "C",
                        "A",
                        "0.25"
                    ],
                    [
                        "C",
                        "C",
                        "B",
                        "0.0"
                    ],
                    [
                        "C",
                        "C",
                        "C",
                        "0.75"
                    ]
                ],
                "dtypes" : [
                    "str",
                    "str",
                    "str",
                    "float"
                ],
                "parents" : [
                    {
                        "class" : "Distribution",
                        "name" : "ConditionalProbabilityTable",
                        "table" : [
                            [
                                "True",
                                "B",
                                "0.0"
                            ],
                            [
                                "True",
                                "C",
                                "0.7142857142857143"
                            ],
                            [
                                "True",
                                "A",
                                "0.2857142857142857"
                            ],
                            [
                                "False",
                                "B",
                                "0.4"
                            ],
                            [
                                "False",
                                "C",
                                "0.0"
                            ],
                            [
                                "False",
                                "A",
                                "0.6"
                            ]
                        ],
                        "dtypes" : [
                            "bool",
                            "str",
                            "float"
                        ],
                        "parents" : [
                            {
                                "class" : "Distribution",
                                "dtype" : "bool",
                                "name" : "DiscreteDistribution",
                                "parameters" : [
                                    {
                                        "True" : 0.5833333333333334,
                                        "False" : 0.4166666666666667
                                    }
                                ],
                                "frozen" : false
                            }
                        ]
                    },
                    {
                        "class" : "Distribution",
                        "name" : "ConditionalProbabilityTable",
                        "table" : [
                            [
                                "False",
                                "True",
                                "B",
                                "0.3333333333333333"
                            ],
                            [
                                "False",
                                "True",
                                "C",
                                "0.6666666666666666"
                            ],
                            [
                                "False",
                                "True",
                                "A",
                                "0.0"
                            ],
                            [
                                "False",
                                "False",
                                "B",
                                "0.4"
                            ],
                            [
                                "False",
                                "False",
                                "C",
                                "0.0"
                            ],
                            [
                                "False",
                                "False",
                                "A",
                                "0.6"
                            ],
                            [
                                "True",
                                "True",
                                "B",
                                "0.0"
                            ],
                            [
                                "True",
                                "True",
                                "C",
                                "0.5"
                            ],
                            [
                                "True",
                                "True",
                                "A",
                                "0.5"
                            ],
                            [
                                "True",
                                "False",
                                "B",
                                "0.3333333333333333"
                            ],
                            [
                                "True",
                                "False",
                                "C",
                                "0.3333333333333333"
                            ],
                            [
                                "True",
                                "False",
                                "A",
                                "0.3333333333333333"
                            ]
                        ],
                        "dtypes" : [
                            "bool",
                            "bool",
                            "str",
                            "float"
                        ],
                        "parents" : [
                            {
                                "class" : "Distribution",
                                "name" : "ConditionalProbabilityTable",
                                "table" : [
                                    [
                                        "0",
                                        "False",
                                        "0.6666666666666666"
                                    ],
                                    [
                                        "0",
                                        "True",
                                        "0.3333333333333333"
                                    ],
                                    [
                                        "1",
                                        "False",
                                        "0.8"
                                    ],
                                    [
                                        "1",
                                        "True",
                                        "0.2"
                                    ],
                                    [
                                        "2",
                                        "False",
                                        "0.5"
                                    ],
                                    [
                                        "2",
                                        "True",
                                        "0.5"
                                    ]
                                ],
                                "dtypes" : [
                                    "int",
                                    "bool",
                                    "float"
                                ],
                                "parents" : [
                                    {
                                        "class" : "Distribution",
                                        "dtype" : "int",
                                        "name" : "DiscreteDistribution",
                                        "parameters" : [
                                            {
                                                "0" : 0.25,
                                                "1" : 0.4166666666666667,
                                                "2" : 0.3333333333333333
                                            }
                                        ],
                                        "frozen" : false
                                    }
                                ]
                            },
                            {
                                "class" : "Distribution",
                                "dtype" : "bool",
                                "name" : "DiscreteDistribution",
                                "parameters" : [
                                    {
                                        "True" : 0.5833333333333334,
                                        "False" : 0.4166666666666667
                                    }
                                ],
                                "frozen" : false
                            }
                        ]
                    }
                ]
            },
            "name" : "monty",
            "weight" : 1.0
        },
        {
            "class" : "State",
            "distribution" : {
                "class" : "Distribution",
                "dtype" : "int",
                "name" : "DiscreteDistribution",
                "parameters" : [
                    {
                        "0" : 0.25,
                        "1" : 0.4166666666666667,
                        "2" : 0.3333333333333333
                    }
                ],
                "frozen" : false
            },
            "name" : "remaining",
            "weight" : 1.0
        },
        {
            "class" : "State",
            "distribution" : {
                "class" : "Distribution",
                "name" : "ConditionalProbabilityTable",
                "table" : [
                    [
                        "0",
                        "True",
                        "0.3333333333333333"
                    ],
                    [
                        "0",
                        "False",
                        "0.6666666666666666"
                    ],
                    [
                        "1",
                        "True",
                        "0.2"
                    ],
                    [
                        "1",
                        "False",
                        "0.8"
                    ],
                    [
                        "2",
                        "True",
                        "0.5"
                    ],
                    [
                        "2",
                        "False",
                        "0.5"
                    ]
                ],
                "dtypes" : [
                    "int",
                    "bool",
                    "float"
                ],
                "parents" : [
                    {
                        "class" : "Distribution",
                        "dtype" : "int",
                        "name" : "DiscreteDistribution",
                        "parameters" : [
                            {
                                "0" : 0.25,
                                "1" : 0.4166666666666667,
                                "2" : 0.3333333333333333
                            }
                        ],
                        "frozen" : false
                    }
                ]
            },
            "name" : "randomize",
            "weight" : 1.0
        }
    ]
}

We can see the results below. Lets look at the distribution for our Friend first.


In [13]:
print(friend)


{
    "class" :"Distribution",
    "dtype" :"bool",
    "name" :"DiscreteDistribution",
    "parameters" :[
        {
            "True" :0.5833333333333334,
            "False" :0.4166666666666667
        }
    ],
    "frozen" :false
}

Then our Guest.


In [14]:
print(guest)


True	B	0.0
True	C	0.7142857142857143
True	A	0.2857142857142857
False	B	0.4
False	C	0.0
False	A	0.6

Now the remaining cars.


In [15]:
print(remaining)


{
    "class" :"Distribution",
    "dtype" :"int",
    "name" :"DiscreteDistribution",
    "parameters" :[
        {
            "0" :0.25,
            "1" :0.4166666666666667,
            "2" :0.3333333333333333
        }
    ],
    "frozen" :false
}

And the probability the prize is randomized.


In [16]:
print(randomize)


0	False	0.6666666666666666
0	True	0.3333333333333333
1	False	0.8
1	True	0.2
2	False	0.5
2	True	0.5

Now the distribution of the Prize.


In [17]:
print(prize)


False	True	B	0.3333333333333333
False	True	C	0.6666666666666666
False	True	A	0.0
False	False	B	0.4
False	False	C	0.0
False	False	A	0.6
True	True	B	0.0
True	True	C	0.5
True	True	A	0.5
True	False	B	0.3333333333333333
True	False	C	0.3333333333333333
True	False	A	0.3333333333333333

And finally our Monty.


In [18]:
print(monty)


B	B	B	0.5
B	B	C	0.5
B	B	A	0.0
B	C	B	0.3333333333333333
B	C	C	0.3333333333333333
B	C	A	0.3333333333333333
B	A	B	0.3333333333333333
B	A	C	0.3333333333333333
B	A	A	0.3333333333333333
C	B	B	0.0
C	B	C	0.0
C	B	A	1.0
C	C	B	0.0
C	C	C	0.75
C	C	A	0.25
C	A	B	0.3333333333333333
C	A	C	0.3333333333333333
C	A	A	0.3333333333333333
A	B	B	0.3333333333333333
A	B	C	0.3333333333333333
A	B	A	0.3333333333333333
A	C	B	0.3333333333333333
A	C	C	0.3333333333333333
A	C	A	0.3333333333333333
A	A	B	0.2
A	A	C	0.6
A	A	A	0.2