In [ ]:
from os.path import join
from matplotlib import pyplot as plt
import numpy as np
import random
import os
from time import time
from bornagain import deg, angstrom, nm
import bornagain as ba

In [ ]:
PHIMIN, PHIMAX = -0.5, 0.5
ALPHAMIN, ALPHAMAX = 0.0, 1.0
NPHI, NALPHA = 224, 224

N_ANGLE_BINS = 120
MAX_NONZEROS = 5

In [ ]:
def get_sample(angle_distribution):
    # defining materials
    m_ambience = ba.HomogeneousMaterial("Air", 0.0, 0.0)
    m_particle = ba.HomogeneousMaterial("Particle", 6e-4, 2e-8)
    m_substrate = ba.HomogeneousMaterial("Substrate", 6e-6, 2e-8)

    # layers
    air_layer = ba.Layer(m_ambience)
    substrate_layer = ba.Layer(m_substrate)

    # particle and basic layout
    formfactor = ba.FormFactorFullSphere(0.62*nm)
    particle = ba.Particle(m_particle, formfactor)
    particle_layout = ba.ParticleLayout()
    particle_layout.addParticle(particle)

    # interference function and different layouts with correct weights
    for i, weight in enumerate(angle_distribution):
        angle = i*60.0 / len(angle_distribution)
        if weight > 0.0:
            interference = ba.InterferenceFunctionFinite2DLattice(18*nm, 18*nm, 120*deg, angle*deg, 20, 20)
            interference.setPositionVariance(10*nm)
            particle_layout.setInterferenceFunction(interference)
            particle_layout.setWeight(weight)
            air_layer.addLayout(particle_layout)

    multi_layer = ba.MultiLayer()
    multi_layer.addLayer(air_layer)
    multi_layer.addLayer(substrate_layer)

    return multi_layer

In [ ]:
def get_simulation():
    simulation = ba.GISASSimulation()
    
    detector = ba.RectangularDetector(NPHI, 13.15, NALPHA, 8.38)
    detector.setPerpendicularToDirectBeam(300.0, 6.57, -0.99)
    simulation.setDetector(detector)
    
    simulation.setDetectorResolutionFunction(ba.ResolutionFunction2DGaussian(0.04, 0.04))
    simulation.setBeamParameters(0.1*nm, 0.268*deg, 0.0*deg)
    simulation.setBeamIntensity(4.0e+09)
    return simulation

In [ ]:
def simulate(angle_distribution):
    sample = get_sample(angle_distribution)
    simulation = get_simulation()
    simulation.setSample(sample)
    simulation.runSimulation()
    return simulation.result()

In [ ]:
def sim_distr(angle_distribution, save_path, index):
    result = simulate(angle_distribution)
    filename_array = "data_distr{0}".format(index)
    filename_distr = "distr{0}".format(index)
    np.save(join(save_path, filename_array), result.array())
    np.save(join(save_path, filename_distr), angle_distribution)

In [ ]:
def generate_angle_distribution(num_bins, max_non_zeros):
    distr = np.zeros(num_bins)
    non_zeros = random.randrange(max_non_zeros) + 1
    succes = False
    while not succes:
        nonzero_indices = random.sample(range(num_bins), non_zeros)
        probs = np.random.random_sample((non_zeros,))
        total = np.sum(probs)
        if total > 0.0:
            probs = probs / total
            succes = True
        for i, idx in enumerate(nonzero_indices):
            distr[idx] = probs[i]
    return non_zeros, distr

In [ ]:
def generate_dataset(n_train, save_path, max_non_zeros):
    for i in range(n_train):
        non_zeros, distr = generate_angle_distribution(N_ANGLE_BINS, max_non_zeros)
        if i%10 == 9:
            status = "Example {0:d}: {1:d} non-zero probabilities".format(i+1, non_zeros)
            print(status)
        sim_distr(distr, save_path, i)

In [ ]:
def create_dir(path: str):
    if not os.path.exists(path):
        os.mkdir(path)
    if not os.path.isdir(path):
        raise FileExistsError(path)

In [ ]:
def seed_random_generators():
    random.seed(300416)
    np.random.seed(240201)

In [ ]:
DATA_PATH = "./ml_data"
N_TRAIN = 20
N_VAL = 10

seed_random_generators()


TRAIN_PATH = DATA_PATH + "/train"
VAL_PATH = DATA_PATH + "/val"

create_dir(DATA_PATH)
create_dir(TRAIN_PATH)
create_dir(VAL_PATH)

if os.listdir(TRAIN_PATH):
    print("WARNING: '{}' is not empty! Files may be overwritten"
          .format(TRAIN_PATH))

if os.listdir(VAL_PATH):
    print("WARNING: '{}' is not empty! Files may be overwritten"
          .format(VAL_PATH))

print("Generating training data: {0:d} examples".format(N_TRAIN))
START_TIME = time()

generate_dataset(N_TRAIN, TRAIN_PATH, MAX_NONZEROS)
print("Execution time for {0} training examples: {1:.2f} seconds"
      .format(N_TRAIN, time() - START_TIME))
print("Generating validation data: {0:d} examples".format(N_VAL))
START_TIME = time()

generate_dataset(N_VAL, VAL_PATH, MAX_NONZEROS)
print("Execution time for {0} validation examples: {1:.2f} seconds"
      .format(N_VAL, time() - START_TIME))

In [ ]:
!ls ml_data/train

In [ ]:
!ls ml_data/val

In [ ]:
%matplotlib notebook
i = 5
fname = "ml_data/val/data_distr{}.npy".format(i)
valdata = np.log(np.load(fname)+1e-10)
plt.figure(figsize=(8,8))
plt.imshow(valdata)

In [ ]: