X-ray crystallography simulations

In order to execute all cells at once, select "Cell -> Run all".


In [1]:
from itertools import product
from functools import wraps

import numpy as np
from lll import lll
import matplotlib.pyplot as plt
import plotly.plotly as py
import plotly.graph_objs as go
from plotly.offline import init_notebook_mode, iplot, iplot_mpl
from ipywidgets import interact

init_notebook_mode()



In [92]:
def store_last_correct(func):
    last_correct = None
    
    @wraps(func)
    def wrapped(*args, **kwargs):
        nonlocal last_correct
        try:
            last_correct = func(*args, **kwargs)
            return last_correct
        except:
            return last_correct
    
    return wrapped

def unit_cell(a, b, c, atom_pos, Nx, Ny, Nz):
    """Make arrays of x-, y- and z-positions of a lattice from the
    lattice vectors, the atom positions and the number of unit cells.
    
    Parameters:
    -----------
    a : list
        First lattice vector
    b : list
        Second lattice vector
    c : list
        Third lattice vector
    atom_pos : list
        Positions of atoms in the unit cells in terms of a, b and c
    Nx : int
        number of unit cells in the x-direction to be plotted
    Ny : int
        number of unit cells in the y-direction to be plotted
    Nz : int
        number of unit cells in the z-direction to be plotted
        
    Returns:
    --------
    latt_coord_x : numpy.ndarray
        Array containing the x-coordinates of all atoms to be plotted
    latt_coord_y : numpy.ndarray
        Array containing the y-coordinates of all atoms to be plotted
    latt_coord_z : numpy.ndarray
        Array containing the z-coordinates of all atoms to be plotted
    """
    latt_coord_x = []
    latt_coord_y = []
    latt_coord_z = []
    for atom in atom_pos:
        xpos = atom[0]*a[0] + atom[1]*b[0] + atom[2]*c[0]
        ypos = atom[0]*a[1] + atom[1]*b[1] + atom[2]*c[1]
        zpos = atom[0]*a[2] + atom[1]*b[2] + atom[2]*c[2]
        xpos_all = [xpos + n*a[0] + m*b[0] + k*c[0] for n, m, k in
                     product(range(Nx), range(Ny), range(Nz))]
        ypos_all = [ypos + n*a[1] + m*b[1] + k*c[1] for n, m, k in
                     product(range(Nx), range(Ny), range(Nz))]
        zpos_all = [zpos + n*a[2] + m*b[2] + k*c[2] for n, m, k in
                     product(range(Nx), range(Ny), range(Nz))]
        latt_coord_x.append(xpos_all)
        latt_coord_y.append(ypos_all)
        latt_coord_z.append(zpos_all)
    latt_coord_x = np.array(latt_coord_x).flatten()
    latt_coord_y = np.array(latt_coord_y).flatten()
    latt_coord_z = np.array(latt_coord_z).flatten()
    return latt_coord_x, latt_coord_y, latt_coord_z


def reciprocal_lattice_vectors(*vectors):
    """Calculate reciprocal lattice vectors given lattice vectors.
    
    Using formula from
    https://en.wikipedia.org/wiki/Reciprocal_lattice#Generalization_of_a_dual_lattice
    """
    vectors = np.array(vectors).T
    reciprocal_vectors = vectors.dot(np.linalg.inv(vectors.T.dot(vectors)))
    return reciprocal_vectors


def laue_conditions(coords, k, dk, elastic=True):
    """Calculate which coordinates on reciprocal lattice satisfy Laue conditions.
    
    To not get nothing almost always, we're assuming that the wave vector `k`
    has a small spread dk.
    """
    k = np.array(k)[..., :]
    coords = np.asarray(coords)
    dk /= np.linalg.norm(k)
    where = ((np.linalg.norm(coords + k * (1 + elastic * dk/2), axis=1) < 
              np.linalg.norm(k) * (1 + dk/2)) &
             (np.linalg.norm(coords + k * (1 - elastic *dk/2), axis=1) > 
              np.linalg.norm(k) * (1 - dk/2)))
    return where

def rotation_matrix(k):
    if len(k) == 3:
        if np.allclose(k/np.linalg.norm(k), (1., 0, 0)):
            return np.identity(3)
        rotation_axis = np.cross(k, (1., 0, 0))
        rotation_axis /= np.linalg.norm(rotation_axis)
        cross_product_matrix = np.array([[0., -rotation_axis[2], rotation_axis[1]],
                                         [0, 0, -rotation_axis[0]], [0, 0, 0]])
        cross_product_matrix = cross_product_matrix - cross_product_matrix.T
        # rotation matrix
        R = (k[0] * np.identity(3) + np.linalg.norm(k[1:]) * cross_product_matrix +
             (np.linalg.norm(k) - k[0]) * np.outer(rotation_axis, rotation_axis))
        R /= np.linalg.norm(k)
    elif len(k) == 2:
        return np.array([[k[0], k[1]], [-k[1], k[0]]]) / np.linalg.norm(k)
    else:
        raise ValueError('Only 2D or 3D are supported.')
    return R

def k_image(coordinates, k, spherical_aberration=False):
    coordinates = coordinates.T + k[:, None]
    coordinates = rotation_matrix(k).dot(coordinates)
    small_angles = (2 * coordinates[0] > np.linalg.norm(coordinates, axis=0))
    if spherical_aberration:
        coordinates[1:] /= coordinates[0]
    return coordinates.T[small_angles, 1:], small_angles

def laue_visualizaton(lattice, k, dk, basis, form_factors, elastic=False):
    if np.array(lattice).shape != (2, 2):
        raise ValueError("Only 2D lattices are supported")
    k = np.array(k)
    reciprocal = lll(reciprocal_lattice_vectors(*lattice))[0]
    center = np.round(np.linalg.solve(reciprocal, k))
    r = np.linalg.norm(k) + dk/2
    r *= np.sqrt(np.max(np.diag(np.linalg.inv(reciprocal.T.dot(reciprocal)))))
    r = np.round(3 * r)
    points = np.mgrid[tuple(slice(i - r, i + r) for i in center)]
    points = points.reshape(len(center), -1).T
    point_coords = np.dot(points, reciprocal)
    where = laue_conditions(point_coords, k, dk, elastic)
    points, point_coords = np.array(points[where], dtype=int), point_coords[where]
    color = brightness(point_coords, basis, form_factors)

    data = [go.Scatter(x=point_coords[:, 0], y=point_coords[:, 1], 
                       text=[str(i) for i in points],
                       mode='markers', hoverinfo='text', 
                       marker=dict(size=5, color=color, cmin=0, cmax=max(color),
                                   colorscale="Greys",
                                   showscale=True, reversescale=True))]


    center1 = -k * (1 + elastic * dk / (2 * np.linalg.norm(k)))
    center2 = -k * (1 - elastic * dk / (2 * np.linalg.norm(k)))
    radius1 =  np.linalg.norm(k) + dk/2
    radius2 =  np.linalg.norm(k) - dk/2
    layout = go.Layout(
                    autosize=False,
                    width=600,
                    height=600,
                    xaxis=dict(
                        showgrid=False,
                        zeroline=False,
                        showline=True),
                    yaxis=dict(
                        showgrid=False,
                        zeroline=False,
                        showline=True),
                    hovermode='closest',
                    shapes = [
                        {
                            'type': 'circle', 'xref': 'x', 'yref': 'y',
                            'x0': center1[0] - radius1,
                            'y0': center1[1] - radius1,
                            'x1': center1[0] + radius1,
                            'y1': center1[1] + radius1,
                            'line': {'color': 'blue'}},
                        {
                            'type': 'circle', 'xref': 'x', 'yref': 'y',
                            'x0': center2[0] - radius2,
                            'y0': center2[1] - radius2,
                            'x1': center2[0] + radius2,
                            'y1': center2[1] + radius2,
                            'line': {'color': 'blue'}},
                        {
                            'type': 'line',
                            'x0': 0,
                            'y0': 0,
                            'x1': -k[0],
                            'y1': -k[1],
                            'line': {
                                'color': 'blue',
                                'width': 4}}])

    fig = go.Figure(data=data, layout=layout)
    iplot(fig, show_link=False)
    

def diffraction_pattern(lattice, k, dk, basis, form_factors, spherical_aberration=False, elastic=True):
    """Calculate diffraction pattern of a given lattice.
    
    Parameters:
    -----------
    lattice : array
        3x3 array of lattice vectors.
    k : array
        A vector of X-ray momentum.
    dk : float
        Spread in X-ray momentum (should be much smaller than k).
        
    Returns:
    --------
    data : array
        A 2D array with the coordinates of the diffraction pattern.
    """
    k = np.array(k)
    reciprocal = lll(reciprocal_lattice_vectors(*lattice))[0]
    center = np.round(np.linalg.solve(reciprocal, k))
    r = np.linalg.norm(k) + dk/2
    r /= np.sqrt(np.min(np.diag(np.linalg.inv(reciprocal.T.dot(reciprocal)))))
    r = np.round(3 * r)
    points = np.mgrid[tuple(slice(i - r, i + r) for i in center)]
    points = points.reshape(len(center), -1).T
    point_coords = np.dot(points, reciprocal)
    where = laue_conditions(point_coords, k, dk, elastic)
    points, point_coords = points[where], point_coords[where]
    image, indices = k_image(point_coords, k, spherical_aberration)
    projections, labels = image, points[indices]

    labels = np.array(labels, dtype=int)
    color = brightness(labels, basis, form_factors)
    data = [go.Scatter(x=projections.T[0], y=projections.T[1], text=[str(i) for i in labels],
                       mode='markers', hoverinfo='text', marker=dict(size=5, color=color,
                                                                     cmin=0, cmax=max(color),
                                                                     colorscale="Greys",
                                                                     showscale=True,
                                                                     reversescale=True))]
    layout = go.Layout(
            autosize=False,
            width=600,
            height=600,
            xaxis=dict(
                showgrid=False,
                zeroline=False,
                showline=True),
            yaxis=dict(
                showgrid=False,
                zeroline=False,
                showline=True),
            hovermode='closest')
    fig = go.Figure(data=data, layout=layout)
    iplot(fig, show_link=False)

def brightness(k, basis, form_factors):
    """Calculate the brightness of a point, based on the basis and form factors."""
    return np.abs(np.exp(2j * np.pi * k.dot(basis.T)).dot(form_factors))

In [95]:
@store_last_correct
def validate_input2(lattice, k, basis, form_factors):
    lattice = np.array(eval(lattice))
    k = np.array(eval(k))
    basis = np.array(eval(basis))
    form_factors = np.array(eval(form_factors))
    assert lattice.shape == (2, 2)
    assert k.shape == (2,)
    assert len(basis.shape) == 2 and basis.shape[1] == 2
    assert form_factors.shape == (basis.shape[0],)
    return lattice, k, basis, form_factors


@interact(lattice='[[1, 0], [0, 1]]', k='(8, 8)', dk=(0., 4.), elastic=False,
          basis='[[0, 0], [0.5, 0.5]]', form_factors='[1, .5]')
def plot_laue(lattice, k, dk, elastic, basis, form_factors):
    lattice, k, basis, form_factors = validate_input2(lattice, k, basis, form_factors)

    laue_visualizaton(lattice, k, dk, basis, form_factors, elastic)



In [94]:
@store_last_correct
def validate_input(lattice, k, basis, form_factors):
    lattice = np.array(eval(lattice))
    k = np.array(eval(k))
    basis = np.array(eval(basis))
    form_factors = np.array(eval(form_factors))
    assert lattice.shape == (3, 3)
    assert k.shape == (3,)
    assert len(basis.shape) == 2 and basis.shape[1] == 3
    assert form_factors.shape == (basis.shape[0],)
    return lattice, k, basis, form_factors


@interact(lattice='[[1, 0, 0], [0, 1, 0], [0, 0, 1]]', 
          k='(4, 4, 4)', basis='[[0, 0, 0], [0.5, 0.5, 0.5]]', form_factors='[1, .5]',
          dk=(0., 2.), spherical_aberration=True, elastic=False)
def calculate_and_plot_diffraction(lattice, k, dk, basis, form_factors,
                                   spherical_aberration, elastic):
    lattice, k, basis, form_factors = validate_input(lattice, k, basis, form_factors)
    
    diffraction_pattern(lattice, k, dk, basis, form_factors,
                        spherical_aberration, elastic)



In [ ]: