from functools import partial import os from carabiner.mpl import add_legend, grid, colorblind_palette import gradio as gr import numpy as np import matplotlib as mpl import matplotlib.pyplot as plt from scipy.integrate import solve_ivp from scipy.stats import multivariate_hypergeom, nbinom # Set the default color cycle mpl.rcParams['axes.prop_cycle'] = mpl.cycler( color=("lightgrey", "dimgrey") + colorblind_palette()[1:], ) SEED: int = 42 MAX_TIME: float = 5. SOURCES_DIR: str = "sources" def inject_markdown(filename): with open(os.path.join(SOURCES_DIR, filename), 'r') as f: md = f.read() return gr.Markdown( md, latex_delimiters=[ {"left": "$$", "right": "$$", "display": True}, {"left": "$", "right": "$", "display": False}, ], ) def lotka_volterra(t, y, w, K): remaining_capacity = np.sum(y) / K dy = w * y.flatten() * (1. - remaining_capacity) return dy def grow(t, n0, w, K=None): """Deterministic population size at time t. n0 : initial cells w : growth rate t : time (arbitrary units) K : carrying capacity (None → pure exponential) """ if K is None: return n0 * np.exp(w[None] * t[:,None]) # logistic with shared K else: ode_solution = solve_ivp( lotka_volterra, t_span=sorted(set([0, max(t)])), t_eval=sorted(t), y0=n0, vectorized=True, args=(w, K), ) # print(ode_solution) return ode_solution.y def plotter_t(x, growths, scatter=False, **kwargs): fig, axes = grid(aspect_ratio=1.5) plotter_f = partial(axes.scatter, s=5.) if scatter else axes.plot for i, y in enumerate(growths): plotter_f( x.flatten(), y.flatten(), label=f"Mutant {i-1}" if i > 1 else "_none", ) axes.set( xlabel="Time", yscale="log", **kwargs, ) add_legend(axes) return fig def plotter_ref(x, growths, scatter=False, fitlines=None, text=None, **kwargs): fig, axes = grid(aspect_ratio=1.5 if text is None else 1.7) plotter_f = partial(axes.scatter, s=5.) if scatter else axes.plot for i, y in enumerate(growths): plotter_f( x.flatten(), y.flatten(), label=f"Mutant {i-1}" if i > 1 else "_none", ) if fitlines is not None: fit_x, fit_y = fitlines for i, b in enumerate(fit_y.flatten()): y = np.exp(np.log(fit_x) @ b[None]) print(fit_x.shape, y.shape, b) axes.plot( fit_x.flatten(), y.flatten(), label="_none", ) if text is not None: axes.text( 1.05, .1, text, fontsize=10, transform=axes.transAxes, ) axes.set( xscale="log", yscale="log", **kwargs, ) add_legend(axes) return fig def calculate_growth_curves(inoculum, inoculum_var, carrying_capacity, fitness, n_timepoints=100): inoculum_var = inoculum + inoculum_var * np.square(inoculum) p = inoculum / inoculum_var n = (inoculum ** 2.) / (inoculum_var - inoculum) w = [1., 0.] + list(fitness) n0 = nbinom.rvs(n, p, size=len(w), random_state=SEED) t = np.linspace(0., MAX_TIME, num=int(n_timepoints)) growths = grow(t, n0, w, inoculum * carrying_capacity) ref_expansion = growths[0] / n0[0] return t, ref_expansion, growths def growth_plotter(inoculum, inoculum_var, carrying_capacity, *fitness): t, ref_expansion, growths = calculate_growth_curves(inoculum, inoculum_var, carrying_capacity, fitness, n_timepoints=100) return [ plotter_t( t, growths, ylabel="Number of cells per strain", ), plotter_ref( ref_expansion, growths, xlabel="Fold-expansion of wild-type", ylabel="Number of cells per strain", ), ] def reads_sampler(population, sample_frac, seq_depth, reps, variance): samples = [] for i, timepoint_pop in enumerate(np.split(population.astype(int), population.shape[-1], axis=-1)): sample_size = np.floor(timepoint_pop.sum() * sample_frac).astype(int) samples.append( multivariate_hypergeom.rvs( m=timepoint_pop.flatten(), n=sample_size, size=reps, random_state=SEED + i, ).T ) samples = np.stack(samples, axis=-2) read_means = np.floor(seq_depth * samples.shape[0] * samples / samples.sum(axis=0, keepdims=True)) variance = read_means + variance * np.square(read_means) p = read_means / variance n = (read_means ** 2.) / (variance - read_means) return np.stack([ nbinom.rvs(n[...,i], p[...,i], random_state=SEED + i) for i in range(reps) ], axis=-1) def fitness_fitter(read_counts, ref_expansion): read_count_expansion = read_counts / np.mean(read_counts[:,:1], axis=-1, keepdims=True) read_count_expansion_ref = read_count_expansion[:1] log_read_count_correction = np.log(read_count_expansion) - np.log(read_count_expansion_ref) ref_expansion = np.tile( np.log(ref_expansion)[:,None], (1, log_read_count_correction.shape[-1]), ).reshape((-1, 1)) betas = [] for i, log_strain_counts_corrected in enumerate(log_read_count_correction): ols_fit = np.linalg.lstsq(a=ref_expansion, b=log_strain_counts_corrected.flatten()) betas.append(ols_fit[0]) return log_read_count_correction, np.asarray(betas) def fitness_fitter_spike(log_read_count_corrected): log_spike_count_corrected = log_read_count_corrected[1,...].flatten()[...,None] betas = [] for i, log_strain_counts_corrected in enumerate(log_read_count_corrected): ols_fit = np.linalg.lstsq( a=log_spike_count_corrected, b=log_strain_counts_corrected.flatten(), ) betas.append(ols_fit[0]) return log_spike_count_corrected, np.asarray(betas) def reads_plotter( sample_frac, seq_reps, seq_depth, read_var, inoculum, inoculum_var, carrying_capacity, *fitness ): t, ref_expansion, growths = calculate_growth_curves(inoculum, inoculum_var, carrying_capacity, fitness, n_timepoints=10) read_counts = reads_sampler(growths, sample_frac, seq_depth, seq_reps, read_var) log_read_count_correction, betas = fitness_fitter(read_counts, ref_expansion) plot_text = "\n".join( f"Mutant {i-1}: $w_{i-1}/w_{'{wt}'}={1. + b:.2f}$" for i, b in enumerate(betas.flatten()) if i > 1 ) log_spike_count_corrected, spike_betas = fitness_fitter_spike(log_read_count_correction) plot_text_spike = "\n".join( f"Mutant {i-1}: $w_{i-1}/w_{'{wt}'}={1. - b:.2f}$" for i, b in enumerate(spike_betas.flatten()) if i > 1 ) read_count_correction = np.exp(log_read_count_correction) return growth_plotter(inoculum, inoculum_var, carrying_capacity, *fitness) + [ plotter_t( np.tile(t[:,None], (1, seq_reps)), read_counts, scatter=True, ylabel="Read counts per strain", ), plotter_ref( np.tile(ref_expansion[:,None], (1, seq_reps)), read_count_correction, scatter=True, fitlines=(ref_expansion[:,None], betas), text=plot_text, xlabel="Fold-expansion of wild-type", ylabel="$\\frac{c_1(t)}{c_{wt}(t)} / \\frac{c_1(0)}{c_{wt}(0)}$", ), plotter_ref( read_count_correction[1:2,...], read_count_correction, scatter=True, fitlines=(read_count_correction[1:2,...].flatten()[...,None], spike_betas), text=plot_text_spike, xlabel="$\\frac{c_{spike}(t)}{c_{wt}(t)} / \\frac{c_{spike}(0)}{c_{wt}(0)}$", ylabel="$\\frac{c_1(t)}{c_{wt}(t)} / \\frac{c_1(0)}{c_{wt}(0)}$", ), ] with gr.Blocks() as demo: inject_markdown("header.md") # Growth curves inject_markdown("growth-curve-intro.md") mut_fitness_defaults = [.5, 2., .2] with gr.Row(): relative_fitness = [ gr.Slider(0., 3., step=.1, value=w, label=f"Relative fitness, mutant {i + 1}") for i, w in enumerate(mut_fitness_defaults) ] with gr.Row(): n_mutants = len(mut_fitness_defaults) inoculum = gr.Slider( 10, 1_000_000, step=10, value=1000, label="Average inoculum per strain", ) inoculum_var = gr.Slider( .001, 1., step=.001, value=.001, label="Inoculum variance between strains", ) carrying_capacity = gr.Slider( len(mut_fitness_defaults) + 1, 10_000, step=1, value=10, label="Total carrying capacity (x inoculum)", ) plot_growth = gr.Button("Plot growth curves") growth_curves_t = gr.Plot(label="Growth vs time", format="png") inject_markdown("growth-curve-t-independent.md") growth_curves_ref = gr.Plot(label="Growth vs WT expansion", format="png") growth_curves = [growth_curves_t, growth_curves_ref] # Read counts inject_markdown("read-counts-intro.md") with gr.Row(): sample_frac = gr.Slider( .001, 1., step=.001, value=.1, label="Fraction of population per sample", ) seq_reps = gr.Slider( 1, 10, step=1, value=3, label="Technical replicates", ) seq_depth = gr.Slider( 10, 10_000, step=10, value=10_000, label="Average reads per strain per sample", ) read_var = gr.Slider( .001, 1., step=.001, value=.001, label="Sequencing variance", ) plot_reads = gr.Button("Plot read counts") read_curves_t = gr.Plot(label="Read counts vs time", format="png") inject_markdown("read-counts-expansion.md") read_curves_ref = gr.Plot(label="Read count diff vs WT expansion", format="png") inject_markdown("read-counts-spike.md") read_curves_t2 = gr.Plot(label="Read count diff vs spike count diff", format="png") read_curves = [ read_curves_t, read_curves_ref, read_curves_t2, ] # Events plot_growth.click( fn=growth_plotter, inputs=[inoculum, inoculum_var, carrying_capacity, *relative_fitness], outputs=growth_curves, ) plot_reads.click( fn=reads_plotter, inputs=[sample_frac, seq_reps, seq_depth, read_var] + [inoculum, inoculum_var, carrying_capacity, *relative_fitness], outputs=growth_curves + read_curves, ) demo.launch(share=True)