Eachan Johnson
Tidy up
5eb5274
from functools import partial
import os
from carabiner.mpl import add_legend, grid, colorblind_palette
import gradio as gr
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from scipy.stats import multivariate_hypergeom, nbinom
# Set the default color cycle
mpl.rcParams['axes.prop_cycle'] = mpl.cycler(
color=("lightgrey", "dimgrey") + colorblind_palette()[1:],
)
SEED: int = 42
MAX_TIME: float = 5.
SOURCES_DIR: str = "sources"
def inject_markdown(filename):
with open(os.path.join(SOURCES_DIR, filename), 'r') as f:
md = f.read()
return gr.Markdown(
md,
latex_delimiters=[
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False},
],
)
def lotka_volterra(t, y, w, K):
remaining_capacity = np.sum(y) / K
dy = w * y.flatten() * (1. - remaining_capacity)
return dy
def grow(t, n0, w, K=None):
"""Deterministic population size at time t.
n0 : initial cells
w : growth rate
t : time (arbitrary units)
K : carrying capacity (None → pure exponential)
"""
if K is None:
return n0 * np.exp(w[None] * t[:,None])
# logistic with shared K
else:
ode_solution = solve_ivp(
lotka_volterra,
t_span=sorted(set([0, max(t)])),
t_eval=sorted(t),
y0=n0,
vectorized=True,
args=(w, K),
)
# print(ode_solution)
return ode_solution.y
def plotter_t(x, growths, scatter=False, **kwargs):
fig, axes = grid(aspect_ratio=1.5)
plotter_f = partial(axes.scatter, s=5.) if scatter else axes.plot
for i, y in enumerate(growths):
plotter_f(
x.flatten(), y.flatten(),
label=f"Mutant {i-1}" if i > 1 else "_none",
)
axes.set(
xlabel="Time",
yscale="log",
**kwargs,
)
add_legend(axes)
return fig
def plotter_ref(x, growths, scatter=False, fitlines=None, text=None, **kwargs):
fig, axes = grid(aspect_ratio=1.5 if text is None else 1.7)
plotter_f = partial(axes.scatter, s=5.) if scatter else axes.plot
for i, y in enumerate(growths):
plotter_f(
x.flatten(), y.flatten(),
label=f"Mutant {i-1}" if i > 1 else "_none",
)
if fitlines is not None:
fit_x, fit_y = fitlines
for i, b in enumerate(fit_y.flatten()):
y = np.exp(np.log(fit_x) @ b[None])
print(fit_x.shape, y.shape, b)
axes.plot(
fit_x.flatten(), y.flatten(),
label="_none",
)
if text is not None:
axes.text(
1.05, .1,
text,
fontsize=10,
transform=axes.transAxes,
)
axes.set(
xscale="log",
yscale="log",
**kwargs,
)
add_legend(axes)
return fig
def calculate_growth_curves(inoculum, inoculum_var, carrying_capacity, fitness, n_timepoints=100):
inoculum_var = inoculum + inoculum_var * np.square(inoculum)
p = inoculum / inoculum_var
n = (inoculum ** 2.) / (inoculum_var - inoculum)
w = [1., 0.] + list(fitness)
n0 = nbinom.rvs(n, p, size=len(w), random_state=SEED)
t = np.linspace(0., MAX_TIME, num=int(n_timepoints))
growths = grow(t, n0, w, inoculum * carrying_capacity)
ref_expansion = growths[0] / n0[0]
return t, ref_expansion, growths
def growth_plotter(inoculum, inoculum_var, carrying_capacity, *fitness):
t, ref_expansion, growths = calculate_growth_curves(inoculum, inoculum_var, carrying_capacity, fitness, n_timepoints=100)
return [
plotter_t(
t,
growths,
ylabel="Number of cells per strain",
),
plotter_ref(
ref_expansion,
growths,
xlabel="Fold-expansion of wild-type",
ylabel="Number of cells per strain",
),
]
def reads_sampler(population, sample_frac, seq_depth, reps, variance):
samples = []
for i, timepoint_pop in enumerate(np.split(population.astype(int), population.shape[-1], axis=-1)):
sample_size = np.floor(timepoint_pop.sum() * sample_frac).astype(int)
samples.append(
multivariate_hypergeom.rvs(
m=timepoint_pop.flatten(),
n=sample_size,
size=reps,
random_state=SEED + i,
).T
)
samples = np.stack(samples, axis=-2)
read_means = np.floor(seq_depth * samples.shape[0] * samples / samples.sum(axis=0, keepdims=True))
variance = read_means + variance * np.square(read_means)
p = read_means / variance
n = (read_means ** 2.) / (variance - read_means)
return np.stack([
nbinom.rvs(n[...,i], p[...,i], random_state=SEED + i)
for i in range(reps)
], axis=-1)
def fitness_fitter(read_counts, ref_expansion):
read_count_expansion = read_counts / np.mean(read_counts[:,:1], axis=-1, keepdims=True)
read_count_expansion_ref = read_count_expansion[:1]
log_read_count_correction = np.log(read_count_expansion) - np.log(read_count_expansion_ref)
ref_expansion = np.tile(
np.log(ref_expansion)[:,None],
(1, log_read_count_correction.shape[-1]),
).reshape((-1, 1))
betas = []
for i, log_strain_counts_corrected in enumerate(log_read_count_correction):
ols_fit = np.linalg.lstsq(a=ref_expansion, b=log_strain_counts_corrected.flatten())
betas.append(ols_fit[0])
return log_read_count_correction, np.asarray(betas)
def fitness_fitter_spike(log_read_count_corrected):
log_spike_count_corrected = log_read_count_corrected[1,...].flatten()[...,None]
betas = []
for i, log_strain_counts_corrected in enumerate(log_read_count_corrected):
ols_fit = np.linalg.lstsq(
a=log_spike_count_corrected,
b=log_strain_counts_corrected.flatten(),
)
betas.append(ols_fit[0])
return log_spike_count_corrected, np.asarray(betas)
def reads_plotter(
sample_frac, seq_reps, seq_depth, read_var,
inoculum, inoculum_var, carrying_capacity, *fitness
):
t, ref_expansion, growths = calculate_growth_curves(inoculum, inoculum_var, carrying_capacity, fitness, n_timepoints=10)
read_counts = reads_sampler(growths, sample_frac, seq_depth, seq_reps, read_var)
log_read_count_correction, betas = fitness_fitter(read_counts, ref_expansion)
plot_text = "\n".join(
f"Mutant {i-1}: $w_{i-1}/w_{'{wt}'}={1. + b:.2f}$"
for i, b in enumerate(betas.flatten()) if i > 1
)
log_spike_count_corrected, spike_betas = fitness_fitter_spike(log_read_count_correction)
plot_text_spike = "\n".join(
f"Mutant {i-1}: $w_{i-1}/w_{'{wt}'}={1. - b:.2f}$"
for i, b in enumerate(spike_betas.flatten()) if i > 1
)
read_count_correction = np.exp(log_read_count_correction)
return growth_plotter(inoculum, inoculum_var, carrying_capacity, *fitness) + [
plotter_t(
np.tile(t[:,None], (1, seq_reps)),
read_counts,
scatter=True,
ylabel="Read counts per strain",
),
plotter_ref(
np.tile(ref_expansion[:,None], (1, seq_reps)),
read_count_correction,
scatter=True,
fitlines=(ref_expansion[:,None], betas),
text=plot_text,
xlabel="Fold-expansion of wild-type",
ylabel="$\\frac{c_1(t)}{c_{wt}(t)} / \\frac{c_1(0)}{c_{wt}(0)}$",
),
plotter_ref(
read_count_correction[1:2,...],
read_count_correction,
scatter=True,
fitlines=(read_count_correction[1:2,...].flatten()[...,None], spike_betas),
text=plot_text_spike,
xlabel="$\\frac{c_{spike}(t)}{c_{wt}(t)} / \\frac{c_{spike}(0)}{c_{wt}(0)}$",
ylabel="$\\frac{c_1(t)}{c_{wt}(t)} / \\frac{c_1(0)}{c_{wt}(0)}$",
),
]
with gr.Blocks() as demo:
inject_markdown("header.md")
# Growth curves
inject_markdown("growth-curve-intro.md")
mut_fitness_defaults = [.5, 2., .2]
with gr.Row():
relative_fitness = [
gr.Slider(0., 3., step=.1, value=w, label=f"Relative fitness, mutant {i + 1}")
for i, w in enumerate(mut_fitness_defaults)
]
with gr.Row():
n_mutants = len(mut_fitness_defaults)
inoculum = gr.Slider(
10, 1_000_000,
step=10,
value=1000,
label="Average inoculum per strain",
)
inoculum_var = gr.Slider(
.001, 1.,
step=.001,
value=.001,
label="Inoculum variance between strains",
)
carrying_capacity = gr.Slider(
len(mut_fitness_defaults) + 1, 10_000,
step=1, value=10,
label="Total carrying capacity (x inoculum)",
)
plot_growth = gr.Button("Plot growth curves")
growth_curves_t = gr.Plot(label="Growth vs time", format="png")
inject_markdown("growth-curve-t-independent.md")
growth_curves_ref = gr.Plot(label="Growth vs WT expansion", format="png")
growth_curves = [growth_curves_t, growth_curves_ref]
# Read counts
inject_markdown("read-counts-intro.md")
with gr.Row():
sample_frac = gr.Slider(
.001, 1., step=.001,
value=.1,
label="Fraction of population per sample",
)
seq_reps = gr.Slider(
1, 10,
step=1,
value=3,
label="Technical replicates",
)
seq_depth = gr.Slider(
10, 10_000,
step=10,
value=10_000,
label="Average reads per strain per sample",
)
read_var = gr.Slider(
.001, 1.,
step=.001,
value=.001,
label="Sequencing variance",
)
plot_reads = gr.Button("Plot read counts")
read_curves_t = gr.Plot(label="Read counts vs time", format="png")
inject_markdown("read-counts-expansion.md")
read_curves_ref = gr.Plot(label="Read count diff vs WT expansion", format="png")
inject_markdown("read-counts-spike.md")
read_curves_t2 = gr.Plot(label="Read count diff vs spike count diff", format="png")
read_curves = [
read_curves_t,
read_curves_ref,
read_curves_t2,
]
# Events
plot_growth.click(
fn=growth_plotter,
inputs=[inoculum, inoculum_var, carrying_capacity, *relative_fitness],
outputs=growth_curves,
)
plot_reads.click(
fn=reads_plotter,
inputs=[sample_frac, seq_reps, seq_depth, read_var] + [inoculum, inoculum_var, carrying_capacity, *relative_fitness],
outputs=growth_curves + read_curves,
)
demo.launch(share=True)