In [1]:
import sys
sys.path.append('../sample/')
from metropolis_sampler import MetropolisSampler

from random import uniform, gauss
import random
import numpy as np
import matplotlib.pyplot as plt

In [2]:
def initialize_state(dim):
  
    return np.array([uniform(-10, 10) for i in range(dim)])


def markov_process(x, step_length):

    result = x.copy()
    
    for i, item in enumerate(result):
        
        result[i] = item + gauss(0, 1) * step_length  
    
    return result

In [19]:
def N(mu, sigma):
    """ float * float -> ([float] -> float)
    """
    return lambda x: np.exp(- np.sum(np.square((x - mu) / sigma)))


## Recall SimulatedAnnealing is searching the argmin, instead of argmax.
def target_function(x):
    """ [float] -> float
    """
    return 1 * N(-5, 5)(x) + 10 * N(5, 5)(x)

def log_target_distribution(T):
    """ float -> ([float] -> float)
    """
    return lambda x: np.log(target_function(x))

In [94]:
def sampling_0(iterations_0, burn_in_ratio, dim, step_length_0, T):

    def initialize_state_0():
        return initialize_state(dim)
    
    def markov_process_0(x):
        return markov_process(x, step_length_0)

    ms = MetropolisSampler(iterations_0,
                           initialize_state_0,
                           markov_process_0,
                           int(burn_in_ratio * iterations_0),
                           log=False,
                           )
    chain_0 = ms.sampling(log_target_distribution(T))
    print('Pre- Accept Ratio: {0}'.format(ms.accept_ratio))
    
    return chain_0
    

def sampling_1(chian_0, iterations_1, burn_in_ratio, dim, step_length_1, T):

    def initialize_state_1():
        return random.choice(chain_0)
    
    def markov_process_1(x):
        return markov_process(x, step_length_1)
    
    chain_1 = []
    accept_ratios = []
    for i in range(1000):
        ms = MetropolisSampler(iterations_1,
                               initialize_state_1,
                               markov_process_1,
                               int(burn_in_ratio * iterations_1),
                               log=False,
                               )
        chain_1 += ms.sampling(log_target_distribution(T))
        accept_ratios.append(ms.accept_ratio)
    
    print('Accept Ratio: {0}'.format(np.mean(accept_ratios)))
        
    return chain_1

In [99]:
dim = 300

## Needs tuning
iterations_0 = int(10 ** 5 * 1)
burn_in_ratio = 0.5
step_length_0 = 0.1
T = 1

chain_0 = sampling_0(iterations_0, burn_in_ratio, dim, step_length_0, T)


Pre- Accept Ratio: 0.80391

In [101]:
iterations_1 = int(1000)
step_length_1 = 0.02


chain_1 = sampling_1(chain_0, iterations_1, burn_in_ratio, dim, step_length_1, T)

print('Lengh of chain: {0}'.format(len(chain)))

expect = [np.mean([state[axis] for state in chain])
          for axis in range(dim)
         ]

print('Expects: {0}'.format(expect))


Accept Ratio: 0.9612099999999999
Lengh of chain: 51000
Expects: [-4.4063910316644677, -4.4990719932785792, -4.9876179378395991, -4.903892187179582, -4.3668632199881854, -4.6456779443746026, -4.6993672524246426, -4.539402263657232, -4.5372555513109321, -3.8887120308793817, -5.5219057690531432, -5.5432850013444295, -4.8103359992211097, -5.2276309820177342, -5.2500580900729581, -4.6963172397588693, -5.2998363980455681, -5.7230682544212037, -3.7957302608403505, -4.4737027120570536, -5.9380785232090352, -5.3764909464815149, -4.5258420413583771, -4.8893107846202213, -4.8863076935000569, -5.7396259934059994, -5.3581201130282148, -4.0790758237079565, -4.8020525251154913, -4.581549656064869, -4.8068245859169227, -4.2935447975369083, -5.9360183585696182, -4.8572998242660841, -6.5045801048150409, -4.4946673641863324, -5.1384574150901203, -4.8572261974686022, -5.1631980466657206, -4.0288853609485704, -6.0344259319946776, -4.8548504692633125, -4.4874475668001104, -5.8670605350713432, -5.7880535165067455, -5.2531402273702099, -4.5834311125354574, -5.1367556596047086, -5.3700342661106628, -5.40588407524017, -4.7086377696565753, -5.5734838968252296, -5.7795805878673168, -5.5276303397491287, -4.6189838749529351, -5.0150847937048884, -5.1734178976416088, -3.6412879613117264, -5.9509515934049908, -4.6319680279425626, -4.4654621747924557, -4.5292654853814538, -5.8062343763954827, -6.3141051154448311, -4.9265484147205685, -4.5965918482806085, -6.1120819897869483, -5.3214493663505023, -3.3262845046283784, -5.9954762590848034, -5.8062919195564122, -6.0639487992958783, -4.9809137065707016, -5.3085909329382384, -4.7730928885792876, -5.6120647735821416, -4.6254669045680483, -4.8667992239369129, -4.6863401257592354, -5.6112431817472261, -3.5964232404440248, -5.6686707243544436, -4.1899696352370777, -5.7422778827257037, -6.4547588898396659, -4.3713534919570165, -4.3087315868815956, -4.6281333637838493, -4.8715302729908094, -3.9339030360603355, -4.0125584089568029, -4.6779350070986512, -5.1690499269319776, -5.9445416770833255, -5.0729964903619678, -4.6808279680482228, -5.2407162398749438, -5.3208466546152335, -6.1528550397929331, -6.0245448002417161, -5.2441916084628151, -6.0879731356374016, -4.4061465407017204, -5.277791908865284, -5.0061722150073269, -5.0515548617501285, -4.0568578110068314, -4.9871592314979818, -5.8537147643131364, -5.4704708550733212, -4.890939253370183, -5.062884783598391, -4.5996951229681233, -5.1399419044195174, -4.451919289007046, -5.6498976228450433, -4.8365869719598278, -4.4939432074288357, -5.7842169593502462, -4.6082996404155114, -4.6470709791741269, -5.9818213599536083, -5.0101683513883675, -5.0010005090508693, -4.6095469490399106, -4.8934153475588689, -6.2743155411956089, -4.9785393456588816, -4.3759292691196974, -4.5807005275018655, -5.3784336089230971, -5.5430249979386161, -4.4742563004769025, -5.0176593897652682, -4.0654581785670896, -4.7913340029656108, -4.3719974602901877, -5.5204065847506385, -3.6255155043584355, -4.7607735140826124, -5.3659901462662472, -4.8968128438347263, -6.4994537325228228, -5.206374218283524, -5.1196286905255928, -4.447798511184315, -3.7585999553395926, -4.7527545233968995, -4.8994655158957725, -4.4781455673457238, -4.3662526816190708, -5.0010258096256015, -5.0607110971147158, -6.1808948749825863, -4.5914531711116577, -5.5752256128557072, -6.4063481005105061, -4.9281965083340786, -5.0343261351256148, -5.1929799506767873, -6.1855153776912557, -4.7340303561727355, -5.452698118351714, -4.2269454332886207, -4.7250506926788534, -5.6412164902681985, -3.8728075886647413, -6.3119620174152091, -5.656431877246221, -4.8732845389936177, -5.2117179238796334, -4.4632000827729605, -5.1778598729806173, -5.2196589827389053, -4.1772872904917842, -5.3040359567234656, -4.2668304467017997, -3.5257821660100754, -4.6270919298425079, -4.7574581473361572, -4.8707097443595044, -5.0852228161546407, -5.4190427525770319, -4.339407656020752, -5.7285124350620515, -5.3116080491776998, -5.3264519582651886, -5.4221549744647577, -4.3050118281991869, -4.0202692497190977, -4.4404917378872728, -5.9693180612900187, -5.0310614239096383, -5.1557709350537637, -5.0287120692258211, -4.3079696472101361, -5.7246578179104572, -4.3254089310699815, -5.1191904287287056, -5.0008661120841147, -4.719265760871334, -3.6933914612092185, -5.388300031260151, -5.3229623318136019, -3.3457884511498284, -5.5977454082139424, -4.6950176196584534, -5.9649066085464995, -4.6498672327529187, -4.5313587122629659, -4.0828979334743183, -4.1002814303489208, -5.2742605346582625, -4.3532719075011626, -5.8924051222518159, -4.3354316864961211, -4.1956129770597466, -5.1973937188694981, -5.258017576442775, -5.3141261115572469, -5.8546272490763442, -4.5448285572093363, -3.9054535142578892, -4.4209862508018478, -4.7094320216785688, -5.3528326714812611, -4.2169252558117867, -4.8175168998454962, -5.0411316471533869, -4.6596110950090095, -4.3539695586162113, -5.375495272944435, -6.9419252015804975, -4.6201458196376679, -4.8541429772233906, -4.387064000315485, -5.1680359870771371, -4.2822927246746776, -4.7775396347116184, -4.3454566234172924, -3.583493178057414, -4.3895850599729815, -4.8716678813031224, -3.5973201141822426, -5.4073304242257398, -5.496120639774027, -5.6285816658100494, -3.6081455315569992, -5.9549624949747724, -4.0920098012407786, -4.848042485973358, -5.4419920943705105, -4.8682670747435379, -5.1744587285198307, -6.2778350702173897, -5.3225604544661653, -3.46343321353124, -4.5677166319223952, -5.3921800424504953, -5.5016577852204334, -4.8019716250848736, -4.9833035574310838, -4.6092369301767473, -5.3052773792689685, -4.7725127141632395, -6.2196147245251447, -5.2064102023915195, -4.4092493939896258, -5.1028975170551201, -4.7071311933264326, -5.2205895719066371, -6.0784714501993422, -4.2192554004860936, -4.8798014338132951, -4.8669171350809473, -6.0700887743583136, -4.8378960122798027, -4.5757443409466694, -4.9473109558034301, -4.5349504383152608, -5.2650632083357731, -4.6474112782027142, -5.3099063757696481, -4.4969498869146403, -6.0643833325765035, -4.5235766749697222, -5.0586953066749256, -5.424231526761381, -4.8213761218354794, -5.4685321943025986, -4.7991475954581526, -4.7976125988822602, -5.369009585404271, -5.1439886196121032, -5.2887407887853275, -3.6926417156280857, -4.4647374785387379, -5.323842932263176, -4.3903895890229387, -5.1182140836287777]

In [ ]:
steps = np.arange(len(chain))
targets = [target_function(state) for state in chain]


plt.plot(steps, targets)
plt.xlabel('step')
plt.ylabel('target value')
plt.show()


for axis in range(dim):
    
    xs = [state[axis] for state in chain]

    plt.plot(steps, xs)
    plt.xlabel('step')
    plt.ylabel('x[{0}]'.format(axis))
    plt.show()

Splendid, even for dim = 10.