Wright-Fisher model of mutation and random genetic drift

A Wright-Fisher model has a fixed population size N and discrete non-overlapping generations. Each generation, each individual has a random number of offspring whose mean is proportional to the individual's fitness. Each generation, mutation may occur.

Setup


In [1]:
import numpy as np
import itertools

Make population dynamic model

Basic parameters


In [2]:
pop_size = 60

In [3]:
seq_length = 100

In [4]:
alphabet = ['A', 'T', 'G', 'C']

In [5]:
base_haplotype = "AAAAAAAAAA"

Setup a population of sequences

Store this as a lightweight Dictionary that maps a string to a count. All the sequences together will have count N.


In [6]:
pop = {}

In [7]:
pop["AAAAAAAAAA"] = 40

In [8]:
pop["AAATAAAAAA"] = 30

In [9]:
pop["AATTTAAAAA"] = 30

In [10]:
pop["AAATAAAAAA"]


Out[10]:
30

Add mutation

Mutations occur each generation in each individual in every basepair.


In [11]:
mutation_rate = 0.0001 # per gen per individual per site

Walk through population and mutate basepairs. Use Poisson splitting to speed this up (you may be familiar with Poisson splitting from its use in the Gillespie algorithm).

  • In naive scenario A: take each element and check for each if event occurs. For example, 100 elements, each with 1% chance. This requires 100 random numbers.
  • In Poisson splitting scenario B: Draw a Poisson random number for the number of events that occur and distribute them randomly. In the above example, this will most likely involve 1 random number draw to see how many events and then a few more draws to see which elements are hit.

First off, we need to get random number of total mutations


In [12]:
def get_mutation_count():
    mean = mutation_rate * pop_size * seq_length
    return np.random.poisson(mean)

Here we use Numpy's Poisson random number.


In [13]:
get_mutation_count()


Out[13]:
0

We need to get random haplotype from the population.


In [14]:
pop.keys()


Out[14]:
['AAAAAAAAAA', 'AAATAAAAAA', 'AATTTAAAAA']

In [15]:
[x/float(pop_size) for x in pop.values()]


Out[15]:
[0.6666666666666666, 0.5, 0.5]

In [16]:
def get_random_haplotype():
    haplotypes = pop.keys() 
    frequencies = [x/float(pop_size) for x in pop.values()]
    total = sum(frequencies)
    frequencies = [x / total for x in frequencies]
    return np.random.choice(haplotypes, p=frequencies)

Here we use Numpy's weighted random choice.


In [17]:
get_random_haplotype()


Out[17]:
'AAATAAAAAA'

Here, we take a supplied haplotype and mutate a site at random.


In [18]:
def get_mutant(haplotype):
    site = np.random.randint(seq_length)
    possible_mutations = list(alphabet)
    possible_mutations.remove(haplotype[site])
    mutation = np.random.choice(possible_mutations)
    new_haplotype = haplotype[:site] + mutation + haplotype[site+1:]
    return new_haplotype

In [19]:
get_mutant("AAAAAAAAAA")


---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-19-b5d699206b70> in <module>()
----> 1 get_mutant("AAAAAAAAAA")

<ipython-input-18-0e1a10bc7940> in get_mutant(haplotype)
      2     site = np.random.randint(seq_length)
      3     possible_mutations = list(alphabet)
----> 4     possible_mutations.remove(haplotype[site])
      5     mutation = np.random.choice(possible_mutations)
      6     new_haplotype = haplotype[:site] + mutation + haplotype[site+1:]

IndexError: string index out of range

Putting things together, in a single mutation event, we grab a random haplotype from the population, mutate it, decrement its count, and then check if the mutant already exists in the population. If it does, increment this mutant haplotype; if it doesn't create a new haplotype of count 1.


In [ ]:
def mutation_event():
    haplotype = get_random_haplotype()
    if pop[haplotype] > 1:
        pop[haplotype] -= 1
        new_haplotype = get_mutant(haplotype)
        if new_haplotype in pop:
            pop[new_haplotype] += 1
        else:
            pop[new_haplotype] = 1

In [ ]:
mutation_event()

In [ ]:
pop

To create all the mutations that occur in a single generation, we draw the total count of mutations and then iteratively add mutation events.


In [ ]:
def mutation_step():
    mutation_count = get_mutation_count()
    for i in range(mutation_count):
        mutation_event()

In [ ]:
mutation_step()

In [ ]:
pop

Add genetic drift

Given a list of haplotype frequencies currently in the population, we can take a multinomial draw to get haplotype counts in the following generation.


In [ ]:
def get_offspring_counts():
    haplotypes = pop.keys() 
    frequencies = [x/float(pop_size) for x in pop.values()]
    return list(np.random.multinomial(pop_size, frequencies))

Here we use Numpy's multinomial random sample.


In [ ]:
get_offspring_counts()

We then need to assign this new list of haplotype counts to the pop dictionary. To save memory and computation, if a haplotype goes to 0, we remove it entirely from the pop dictionary.


In [ ]:
def offspring_step():
    counts = get_offspring_counts()
    for (haplotype, count) in zip(pop.keys(), counts):
        if (count > 0):
            pop[haplotype] = count
        else:
            del pop[haplotype]

In [ ]:
offspring_step()

In [ ]:
pop

Combine and iterate

Each generation is simply a mutation step where a random number of mutations are thrown down, and an offspring step where haplotype counts are updated.


In [ ]:
def time_step():
    mutation_step()
    offspring_step()

Can iterate this over a number of generations.


In [ ]:
generations = 500

In [ ]:
def simulate():
    for i in range(generations):
        time_step()

In [ ]:
simulate()

In [ ]:
pop

Record

We want to keep a record of past population frequencies to understand dynamics through time. At each step in the simulation, we append to a history object.


In [ ]:
pop = {"AAAAAAAAAA": pop_size}

In [ ]:
history = []

In [ ]:
def simulate():
    clone_pop = dict(pop)
    history.append(clone_pop)
    for i in range(generations):
        time_step()
        clone_pop = dict(pop)
        history.append(clone_pop)

In [ ]:
simulate()

In [ ]:
pop

In [ ]:
history[0]

In [ ]:
history[1]

In [ ]:
history[2]

Analyze trajectories

Calculate diversity

Here, diversity in population genetics is usually shorthand for the statistic π, which measures pairwise differences between random individuals in the population. π is usually measured as substitutions per site.


In [ ]:
pop

First, we need to calculate the number of differences per site between two arbitrary sequences.


In [ ]:
def get_distance(seq_a, seq_b):
    diffs = 0
    length = len(seq_a)
    assert len(seq_a) == len(seq_b)
    for chr_a, chr_b in zip(seq_a, seq_b):
        if chr_a != chr_b:
            diffs += 1
    return diffs / float(length)

In [ ]:
get_distance("AAAAAAAAAA", "AAAAAAAAAB")

We calculate diversity as a weighted average between all pairs of haplotypes, weighted by pairwise haplotype frequency.


In [ ]:
def get_diversity(population):
    haplotypes = population.keys()
    haplotype_count = len(haplotypes)
    diversity = 0
    for i in range(haplotype_count):
        for j in range(haplotype_count):
            haplotype_a = haplotypes[i]
            haplotype_b = haplotypes[j]
            frequency_a = population[haplotype_a] / float(pop_size)
            frequency_b = population[haplotype_b] / float(pop_size)
            frequency_pair = frequency_a * frequency_b
            diversity += frequency_pair * get_distance(haplotype_a, haplotype_b)
    return diversity

In [ ]:
get_diversity(pop)

In [ ]:
def get_diversity_trajectory():
    trajectory = [get_diversity(generation) for generation in history]
    return trajectory

In [ ]:
get_diversity_trajectory()

Plot diversity

Here, we use matplotlib for all Python plotting.


In [2]:
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib as mpl

Here, we make a simple line plot using matplotlib's plot function.


In [ ]:
plt.plot(get_diversity_trajectory())

Here, we style the plot a bit with x and y axes labels.


In [ ]:
def diversity_plot():
    mpl.rcParams['font.size']=14
    trajectory = get_diversity_trajectory()
    plt.plot(trajectory, "#447CCD")    
    plt.ylabel("diversity")
    plt.xlabel("generation")

In [ ]:
diversity_plot()

Analyze and plot divergence

In population genetics, divergence is generally the number of substitutions away from a reference sequence. In this case, we can measure the average distance of the population to the starting haplotype. Again, this will be measured in terms of substitutions per site.


In [ ]:
def get_divergence(population):
    haplotypes = population.keys()
    divergence = 0
    for haplotype in haplotypes:
        frequency = population[haplotype] / float(pop_size)
        divergence += frequency * get_distance(base_haplotype, haplotype)
    return divergence

In [ ]:
def get_divergence_trajectory():
    trajectory = [get_divergence(generation) for generation in history]
    return trajectory

In [ ]:
get_divergence_trajectory()

In [ ]:
def divergence_plot():
    mpl.rcParams['font.size']=14
    trajectory = get_divergence_trajectory()
    plt.plot(trajectory, "#447CCD")
    plt.ylabel("divergence")
    plt.xlabel("generation")

In [ ]:
divergence_plot()

Plot haplotype trajectories

We also want to directly look at haplotype frequencies through time.


In [ ]:
def get_frequency(haplotype, generation):
    pop_at_generation = history[generation]
    if haplotype in pop_at_generation:
        return pop_at_generation[haplotype]/float(pop_size)
    else:
        return 0

In [ ]:
get_frequency("AAAAAAAAAA", 4)

In [ ]:
def get_trajectory(haplotype):
    trajectory = [get_frequency(haplotype, gen) for gen in range(generations)]
    return trajectory

In [ ]:
get_trajectory("AAAAAAAAAA")

We want to plot all haplotypes seen during the simulation.


In [ ]:
def get_all_haplotypes():
    haplotypes = set()   
    for generation in history:
        for haplotype in generation:
            haplotypes.add(haplotype)
    return haplotypes

In [ ]:
get_all_haplotypes()

Here is a simple plot of their overall frequencies.


In [ ]:
haplotypes = get_all_haplotypes()
for haplotype in haplotypes:
    plt.plot(get_trajectory(haplotype))
plt.show()

In [ ]:
colors = ["#781C86", "#571EA2", "#462EB9", "#3F47C9", "#3F63CF", "#447CCD", "#4C90C0", "#56A0AE", "#63AC9A", "#72B485", "#83BA70", "#96BD60", "#AABD52", "#BDBB48", "#CEB541", "#DCAB3C", "#E49938", "#E68133", "#E4632E", "#DF4327", "#DB2122"]

In [ ]:
colors_lighter = ["#A567AF", "#8F69C1", "#8474D1", "#7F85DB", "#7F97DF", "#82A8DD", "#88B5D5", "#8FC0C9", "#97C8BC", "#A1CDAD", "#ACD1A0", "#B9D395", "#C6D38C", "#D3D285", "#DECE81", "#E8C77D", "#EDBB7A", "#EEAB77", "#ED9773", "#EA816F", "#E76B6B"]

We can use stackplot to stack these trajectoies on top of each other to get a better picture of what's going on.


In [ ]:
def stacked_trajectory_plot(xlabel="generation"):
    mpl.rcParams['font.size']=18
    haplotypes = get_all_haplotypes()
    trajectories = [get_trajectory(haplotype) for haplotype in haplotypes]
    plt.stackplot(range(generations), trajectories, colors=colors_lighter)
    plt.ylim(0, 1)
    plt.ylabel("frequency")
    plt.xlabel(xlabel)

In [ ]:
stacked_trajectory_plot()

Plot SNP trajectories


In [ ]:
def get_snp_frequency(site, generation):
    minor_allele_frequency = 0.0
    pop_at_generation = history[generation]
    for haplotype in pop_at_generation.keys():
        allele = haplotype[site]
        frequency = pop_at_generation[haplotype] / float(pop_size)
        if allele != "A":
            minor_allele_frequency += frequency
    return minor_allele_frequency

In [ ]:
get_snp_frequency(3, 5)

In [ ]:
def get_snp_trajectory(site):
    trajectory = [get_snp_frequency(site, gen) for gen in range(generations)]
    return trajectory

In [ ]:
get_snp_trajectory(3)

Find all variable sites.


In [ ]:
def get_all_snps():
    snps = set()   
    for generation in history:
        for haplotype in generation:
            for site in range(seq_length):
                if haplotype[site] != "A":
                    snps.add(site)
    return snps

In [ ]:
def snp_trajectory_plot(xlabel="generation"):
    mpl.rcParams['font.size']=18
    snps = get_all_snps()
    trajectories = [get_snp_trajectory(snp) for snp in snps]
    data = []
    for trajectory, color in itertools.izip(trajectories, itertools.cycle(colors)):
        data.append(range(generations))
        data.append(trajectory)    
        data.append(color)
    plt.plot(*data)   
    plt.ylim(0, 1)
    plt.ylabel("frequency")
    plt.xlabel(xlabel)

In [ ]:
snp_trajectory_plot()

Scale up

Here, we scale up to more interesting parameter values.


In [1]:
pop_size = 50
seq_length = 100
generations = 500
mutation_rate = 0.0001 # per gen per individual per site

In this case there are $\mu$ = 0.01 mutations entering the population every generation.


In [ ]:
seq_length * mutation_rate

And the population genetic parameter $\theta$, which equals $2N\mu$, is 1.


In [ ]:
2 * pop_size * seq_length * mutation_rate

In [ ]:
base_haplotype = ''.join(["A" for i in range(seq_length)])
pop.clear()
del history[:]
pop[base_haplotype] = pop_size

In [ ]:
simulate()

In [ ]:
plt.figure(num=None, figsize=(14, 14), dpi=80, facecolor='w', edgecolor='k')
plt.subplot2grid((3,2), (0,0), colspan=2)
stacked_trajectory_plot(xlabel="")
plt.subplot2grid((3,2), (1,0), colspan=2)
snp_trajectory_plot(xlabel="")
plt.subplot2grid((3,2), (2,0))
diversity_plot()
plt.subplot2grid((3,2), (2,1))
divergence_plot()

In [ ]: