from typing import Any, Dict, List, Optional, Tuple, Union
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
[docs]
def get_black_to_grey(n):
"""Generate `n` grayscale colors starting with pure black."""
if n < 1:
raise ValueError("n must be at least 1")
greys = np.linspace(0, 255, n, dtype=int)
greys[0] = 0
return [(g, g, g) for g in greys]
[docs]
def get_timeseries_data(
df_quantiles: pd.DataFrame, column: str, quantile: float
) -> pd.DataFrame:
"""
Extracts the time series data for a specific column (compartment or demographic group) and quantile.
Args:
df_quantiles (pd.DataFrame): DataFrame containing quantile data for compartments and demographic groups.
column (str): The name of the column to extract data for.
quantile (float): The quantile to extract data for.
Returns:
pd.DataFrame: A DataFrame containing the time series data for the specified column and quantile.
"""
return df_quantiles.loc[(df_quantiles["quantile"] == quantile)][["date", column]]
[docs]
def plot_quantiles(
df_quantiles: pd.DataFrame,
columns: Union[List[str], str],
data: Optional[pd.DataFrame] = None,
data_date_column: str = "date",
ax: Optional[plt.Axes] = None,
lower_q: float = 0.05,
upper_q: float = 0.95,
show_median: bool = True,
show_data: bool = False,
ci_alpha: float = 0.3,
title: str = "",
ylabel: str = "",
xlabel: str = "",
show_legend: bool = True,
legend_loc: str = "upper left",
palette: str = "Dark2",
colors: Optional[Union[List[str], str]] = None,
labels: Optional[Union[List[str], str]] = None,
y_scale: str = "linear",
show_grid: bool = True,
) -> plt.Axes:
"""
Plots quantiles for compartments over time with optional observed data.
Args:
df_quantiles: DataFrame with columns: 'date', 'quantile', and data columns
columns: Column name(s) to plot
data: Optional DataFrame containing observed data
data_date_column: Column name for date in data
ax: Matplotlib axes to plot on
lower_q: Lower quantile value (0.05 = 5th percentile)
upper_q: Upper quantile value (0.95 = 95th percentile)
show_median: Whether to show median line
show_data: Whether to show observed data points
ci_alpha: Alpha value for confidence interval shading
title: Plot title
ylabel: Y-axis label
xlabel: X-axis label
show_legend: Whether to show legend
legend_loc: Legend location
palette: Color palette name
colors: Custom colors for lines
labels: Custom labels for legend
y_scale: Scale for y-axis ('linear' or 'log')
show_grid: Whether to show grid lines
Returns:
plt.Axes: The matplotlib axes object
"""
if not isinstance(columns, list):
columns = [columns]
if ax is None:
_, ax = plt.subplots(dpi=300, figsize=(10, 4))
if colors is None:
colors = sns.color_palette(palette, len(columns))
elif not isinstance(colors, list):
colors = [colors]
if labels is None:
labels = columns
elif not isinstance(labels, list):
labels = [labels]
pleg, handles = [], []
for t, (column, color, label) in enumerate(zip(columns, colors, labels)):
if show_median:
df_med = get_timeseries_data(df_quantiles, column, 0.5)
(p1,) = ax.plot(
df_med.date, df_med[column].values, color=color, label=label, zorder=2
)
df_q1 = get_timeseries_data(df_quantiles, column, lower_q)
df_q2 = get_timeseries_data(df_quantiles, column, upper_q)
p2 = ax.fill_between(
df_q1.date,
df_q1[column].values,
df_q2[column].values,
alpha=ci_alpha,
color=color,
linewidth=0.0,
zorder=1,
)
if show_median:
pleg.append((p1, p2))
handles.append(
f"{label} (median, {np.round((1 - lower_q * 2) * 100, 0)}% CI)"
)
else:
pleg.append(p2)
handles.append(f"{label} ({np.round((1 - lower_q * 2) * 100, 0)}% CI)")
if show_data and data is not None:
data_colors = get_black_to_grey(len(columns))
for column, data_color, label in zip(columns, data_colors, labels):
p_actual = ax.scatter(
data[data_date_column],
data[column],
s=10,
color=data_color,
zorder=3,
label=f"observed ({label})",
)
if show_legend:
pleg.append(p_actual)
handles.append(f"observed ({label})")
# Style improvements
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
if show_grid:
ax.grid(axis="y", linestyle="--", linewidth=0.3, alpha=0.5, zorder=0)
# Labels and formatting
ax.set_title(title)
ax.set_ylabel(ylabel)
ax.set_xlabel(xlabel)
ax.set_yscale(y_scale)
if show_legend and pleg:
ax.legend(pleg, handles, loc=legend_loc, frameon=False)
plt.tight_layout()
return ax
[docs]
def plot_posterior_distribution(
posterior: pd.DataFrame,
parameter: str,
ax: Optional[plt.Axes] = None,
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
kind: str = "hist",
color: str = "dodgerblue",
xlim: Optional[Tuple[float, float]] = None,
prior: Optional[Any] = None,
prior_range: bool = False,
title: Optional[str] = None,
fontsize: int = 10,
show_grid: bool = True,
show_kde: bool = True,
show_rug: bool = False,
figsize: Tuple[int, int] = (10, 4),
stat: str = "density",
bins: Union[int, str] = "auto",
alpha: float = 0.4,
vertical_lines: Optional[Dict[str, Dict[str, Any]]] = None,
**kwargs,
) -> plt.Axes:
"""
Plots the distribution of a parameter.
Args:
posterior: DataFrame containing the parameter values
parameter: The parameter to plot
ax: Matplotlib axes to plot on. Creates new figure if None
xlabel: X-axis label. If None, uses parameter name
ylabel: Y-axis label
kind: Type of plot ('hist', 'kde', or 'ecdf')
color: Color for the plot
xlim: Tuple of (min, max) for x-axis limits
prior: Prior distribution object with ppf method
prior_range: Whether to set x-axis limits to prior range
title: Plot title. If None, auto-generates
fontsize: Base font size for labels and ticks
show_grid: Whether to show grid lines
show_kde: Whether to show KDE curve with histogram (only for kind='hist')
show_rug: Whether to show rug plot
figsize: Figure size if creating new figure
stat: Statistic to plot ('count', 'density', 'probability')
bins: Number of bins or method for histogram
alpha: Transparency of the plot
vertical_lines: Dict of vertical lines to add, format:
{
'name': {
'x': value,
'color': 'color',
'linestyle': '--',
'label': 'label'
}
}
**kwargs: Additional arguments passed to plotting functions
Returns:
plt.Axes: The matplotlib axes object
Raises:
ValueError: If kind is not 'hist', 'kde', or 'ecdf'
ValueError: If prior_range is True but no prior is provided
"""
if ax is None:
_, ax = plt.subplots(dpi=300, figsize=figsize)
# Set default labels
xlabel = xlabel or parameter
if kind == "hist":
sns.histplot(
data=posterior,
x=parameter,
ax=ax,
color=color,
stat=stat,
bins=bins,
alpha=alpha,
kde=show_kde,
**kwargs,
)
elif kind == "kde":
sns.kdeplot(
data=posterior,
x=parameter,
ax=ax,
color=color,
fill=True,
alpha=alpha,
**kwargs,
)
elif kind == "ecdf":
sns.ecdfplot(data=posterior, x=parameter, ax=ax, color=color, **kwargs)
else:
raise ValueError(
f"Unknown kind for plot: {kind}. Must be 'hist', 'kde', or 'ecdf'"
)
# Add rug plot if requested
if show_rug:
sns.rugplot(data=posterior, x=parameter, ax=ax, color=color, alpha=alpha / 2)
# Add vertical lines if specified
if vertical_lines:
for line_specs in vertical_lines.values():
ax.axvline(
x=line_specs["x"],
color=line_specs.get("color", "red"),
linestyle=line_specs.get("linestyle", "--"),
label=line_specs.get("label", None),
alpha=line_specs.get("alpha", 1.0),
)
if line_specs.get("label"):
ax.legend(frameon=False)
# Style improvements
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
if show_grid:
ax.grid(axis="y", linestyle="--", linewidth=0.3, alpha=0.5)
# Set axis limits based on prior range or explicit limits
if prior_range:
if prior is None:
raise ValueError("prior must be provided when prior_range is True")
ax.set_xlim(prior.ppf(0), prior.ppf(1))
elif xlim is not None:
ax.set_xlim(xlim)
# Labels and title
ax.set_xlabel(xlabel, fontsize=fontsize)
if ylabel is not None:
ax.set_ylabel(ylabel, fontsize=fontsize)
if title is not None:
ax.set_title(title, fontsize=fontsize + 2, pad=20)
# Tick formatting
ax.tick_params(axis="both", which="major", labelsize=fontsize - 2)
# Adjust layout
plt.tight_layout()
return ax
[docs]
def plot_posterior_distribution_2d(
posterior: pd.DataFrame,
parameter_x: str,
parameter_y: str,
ax: Optional[plt.Axes] = None,
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
kind: str = "hist",
palette: str = "Blues",
xlim: Optional[Tuple[float, float]] = None,
ylim: Optional[Tuple[float, float]] = None,
prior_x: Optional[Any] = None,
prior_y: Optional[Any] = None,
prior_range: bool = False,
title: Optional[str] = None,
fontsize: int = 10,
cmap: Optional[str] = None,
show_grid: bool = True,
levels: int = 10,
figsize: Tuple[int, int] = (6, 6),
scatter: bool = False,
scatter_alpha: float = 0.5,
scatter_size: int = 20,
scatter_color: str = "k",
**kwargs,
) -> plt.Axes:
"""
Plots the 2D joint distribution of two parameters.
Args:
posterior: DataFrame containing the parameter values
parameter_x: Parameter to plot on x-axis
parameter_y: Parameter to plot on y-axis
ax: Matplotlib axes to plot on. Creates new figure if None
xlabel: X-axis label. If None, uses parameter_x
ylabel: Y-axis label. If None, uses parameter_y
kind: Plot type ('hist', 'kde', or 'scatter')
palette: Color palette for histogram/kde
xlim: Tuple of (min, max) for x-axis limits
ylim: Tuple of (min, max) for y-axis limits
prior_x: Prior distribution object for x parameter
prior_y: Prior distribution object for y parameter
prior_range: Whether to set axis limits to prior ranges
title: Plot title. If None, auto-generates
fontsize: Base font size for labels and ticks
cmap: Colormap for 2D histogram/kde
show_grid: Whether to show grid lines
levels: Number of contour levels for kde
figsize: Figure size if creating new figure
scatter: Whether to overlay scatter plot on kde
scatter_alpha: Alpha value for scatter points
scatter_size: Size of scatter points
scatter_color: Color for scatter points
**kwargs: Additional arguments passed to sns.histplot/kdeplot
Returns:
plt.Axes: The matplotlib axes object
Raises:
ValueError: If kind is not 'hist', 'kde', or 'scatter'
ValueError: If prior_range is True but priors are not provided
"""
if ax is None:
_, ax = plt.subplots(figsize=figsize, dpi=300)
# Set default labels if not provided
xlabel = xlabel or parameter_x
ylabel = ylabel or parameter_y
# Set default colormap
if cmap is None:
cmap = palette if kind == "hist" else sns.color_palette(palette, as_cmap=True)
# Plot based on kind
if kind == "hist":
sns.histplot(
data=posterior, x=parameter_x, y=parameter_y, ax=ax, cmap=cmap, **kwargs
)
elif kind == "kde":
sns.kdeplot(
data=posterior,
x=parameter_x,
y=parameter_y,
ax=ax,
cmap=cmap,
levels=levels,
fill=True,
**kwargs,
)
elif kind == "scatter":
ax.scatter(
posterior[parameter_x],
posterior[parameter_y],
alpha=scatter_alpha,
s=scatter_size,
c=scatter_color,
**kwargs,
)
else:
raise ValueError(
f"Unknown plot kind: {kind}. Must be 'hist', 'kde', or 'scatter'"
)
# Add scatter points if requested (for hist/kde)
if scatter and kind != "scatter":
ax.scatter(
posterior[parameter_x],
posterior[parameter_y],
alpha=scatter_alpha,
s=scatter_size,
c=scatter_color,
zorder=2,
)
# Style improvements
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
if show_grid:
ax.grid(linestyle="--", linewidth=0.3, alpha=0.5)
# Set axis limits based on prior ranges or explicit limits
if prior_range:
if prior_x is None or prior_y is None:
raise ValueError(
"Both prior_x and prior_y must be provided when prior_range is True"
)
ax.set_xlim(prior_x.ppf(0), prior_x.ppf(1))
ax.set_ylim(prior_y.ppf(0), prior_y.ppf(1))
else:
if xlim is not None:
ax.set_xlim(xlim)
if ylim is not None:
ax.set_ylim(ylim)
# Labels and title
ax.set_xlabel(xlabel, fontsize=fontsize)
ax.set_ylabel(ylabel, fontsize=fontsize)
if title is None:
title = f"Joint Distribution\n{parameter_x} vs {parameter_y}"
ax.set_title(title, fontsize=fontsize + 2, pad=20)
# Tick formatting
ax.tick_params(axis="both", which="major", labelsize=fontsize - 2)
# Adjust layout
plt.tight_layout()
return ax
[docs]
def plot_selected_trajectories(
calibration_results,
ax=None,
show_data=True,
columns="data",
lower_q=0.05,
upper_q=0.95,
show_median=True,
ci_alpha=0.3,
title="",
show_legend=True,
ylabel="",
palette="Dark2",
):
"""
TODO
"""
return 0
def _format_human_readable(value: float) -> str:
"""Format a number with K/M/B suffixes for readability."""
abs_value = abs(value)
if abs_value >= 1e9:
return f"{value / 1e9:.1f}B"
elif abs_value >= 1e6:
return f"{value / 1e6:.1f}M"
elif abs_value >= 1e3:
return f"{value / 1e3:.1f}K"
else:
return f"{value:.0f}"
[docs]
def plot_population(
population: Any,
ax: Optional[plt.Axes] = None,
title: Optional[str] = None,
color: str = "dodgerblue",
show_perc: bool = False,
fontsize: int = 10,
rotation: int = 45,
figsize: Tuple[int, int] = (10, 6),
bar_width: float = 0.8,
show_grid: bool = True,
ylabel: Optional[str] = None,
xlabel: str = "Demographic group",
show_values: bool = True,
fmt: Optional[str] = None,
value_fontsize: Optional[int] = None,
) -> plt.Axes:
"""
Plot the population distribution across demographic groups.
Args:
population: Population object containing demographic data
ax: Matplotlib axes to plot on. Creates new figure if None
title: Plot title. If None, uses default title
color: Color for the bars
show_perc: Whether to show population as percentages
fontsize: Base font size for labels and ticks
rotation: Rotation angle for x-axis labels
figsize: Figure size if creating new figure
bar_width: Width of the bars (between 0 and 1)
show_grid: Whether to show grid lines
ylabel: Y-axis label. If None, uses default based on show_perc
xlabel: X-axis label
show_values: Whether to show values above bars
fmt: Format string for values. If None, uses ".1f%" for percentages
and human-readable suffixes (K/M/B) for absolute numbers.
value_fontsize: Font size for bar values. If None, uses fontsize
Returns:
plt.Axes: The matplotlib axes object
"""
if ax is None:
_, ax = plt.subplots(figsize=figsize, dpi=300)
# Calculate values
values = population.Nk
if show_perc:
values = 100 * values / np.sum(values)
# Create bars
bars = ax.bar(population.Nk_names, values, color=color, width=bar_width)
# Show values on bars
if show_values:
value_fontsize = value_fontsize or fontsize
for bar in bars:
height = bar.get_height()
if fmt is not None:
label = format(height, fmt)
elif show_perc:
label = f"{height:.1f}%"
else:
label = _format_human_readable(height)
ax.text(
bar.get_x() + bar.get_width() / 2,
height,
label,
ha="center",
va="bottom",
fontsize=value_fontsize,
)
# Style improvements
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
if show_grid:
ax.grid(axis="y", linestyle="--", linewidth=0.3, alpha=0.5)
# Labels
if ylabel is None:
ylabel = "Population (%)" if show_perc else "Population"
ax.set_ylabel(ylabel, fontsize=fontsize)
ax.set_xlabel(xlabel, fontsize=fontsize)
# Title
if title is None:
title = f"Population Distribution - {population.name}"
ax.set_title(title, fontsize=fontsize + 2, pad=20)
# Tick formatting
ax.tick_params(axis="both", which="major", labelsize=fontsize)
plt.setp(ax.get_xticklabels(), rotation=rotation, ha="right")
# Adjust y-axis to start at 0
ax.set_ylim(bottom=0)
# Add some padding above highest bar for values
if show_values:
ax.set_ylim(top=ax.get_ylim()[1] * 1.1)
# Adjust layout
plt.tight_layout()
return ax
[docs]
def plot_spectral_radius(
epimodel: Any,
ax: Optional[plt.Axes] = None,
title: Optional[str] = None,
color: str = "k",
show_perc: bool = False,
layer: str = "overall",
show_interventions: bool = True,
interventions_palette: str = "Dark2",
interventions_colors: Optional[List[str]] = None,
fontsize: int = 10,
date_format: str = "%Y-%m-%d",
ylabel: Optional[str] = None,
xlabel: str = "Date",
show_grid: bool = True,
alpha: float = 0.2,
legend_loc: str = "upper left",
) -> plt.Axes:
"""
Plots the spectral radius of the contact matrices over time.
Args:
epimodel: The EpiModel object containing contact matrices and interventions
ax: Matplotlib axes to plot on. Creates new figure if None
title: Plot title. If None, uses default title
color: Color of the spectral radius line
show_perc: Whether to show as percentage change with respect to the initial value
layer: Contact matrix layer to analyze
show_interventions: Whether to show intervention periods
interventions_palette: Color palette for interventions
interventions_colors: Custom colors for interventions
fontsize: Base font size for labels and ticks
date_format: Format string for date labels
ylabel: Y-axis label. If None, auto-generated based on normalize/show_perc
xlabel: X-axis label
show_grid: Whether to show grid lines
alpha: Transparency for intervention highlights
legend_loc: Location of the legend
Returns:
plt.Axes: The matplotlib axes object
Raises:
ValueError: If no contact matrices are defined or layer doesn't exist
"""
if len(epimodel.Cs) == 0:
raise ValueError("No contact matrices defined over time")
if layer not in epimodel.population.layers + ["overall"]:
raise ValueError(
f"Layer '{layer}' not found. Available layers: {epimodel.population.layers + ['overall']}"
)
# Create figure if needed
if ax is None:
_, ax = plt.subplots(figsize=(10, 6), dpi=300)
# Compute spectral radius
dates = list(epimodel.Cs.keys())
rho = [np.linalg.eigvals(epimodel.Cs[date][layer]).max().real for date in dates]
# Normalize and convert to percentage if requested
if show_perc:
rho = np.array(rho) / rho[0]
rho = (rho - 1) * 100
# Plot spectral radius
ax.plot(dates, rho, color=color, label="Spectral radius", linewidth=2)
# Show interventions if requested
if show_interventions and hasattr(epimodel, "interventions"):
# Select interventions for the layer (if layer is "overall", all interventions are selected)
if layer == "overall":
interventions = epimodel.interventions
else:
interventions = [
intervention
for intervention in epimodel.interventions
if intervention["layer"] == layer
]
# get colors
colors = (
interventions_colors
if interventions_colors
else sns.color_palette(interventions_palette, len(interventions))
)
for intervention, color in zip(interventions, colors):
ax.axvspan(
intervention["start_date"],
intervention["end_date"],
alpha=alpha,
color=color,
label=intervention["name"],
)
# Style improvements
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
if show_grid:
ax.grid(axis="y", linestyle="--", alpha=0.3)
# Format dates
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
plt.setp(ax.get_xticklabels(), rotation=45, ha="right")
# Labels
if ylabel is None:
ylabel = "Change in spectral radius (%)" if show_perc else "Spectral radius"
ax.set_ylabel(ylabel, fontsize=fontsize)
ax.set_xlabel(xlabel, fontsize=fontsize)
if title is None:
title = f"Contact Pattern Intensity - {layer.title()} Layer"
ax.set_title(title, fontsize=fontsize + 2, pad=20)
# Legend if interventions are shown
if show_interventions and hasattr(epimodel, "interventions"):
ax.legend(loc=legend_loc, fontsize=fontsize - 2)
# Adjust layout
plt.tight_layout()
return ax
[docs]
def plot_distance_distribution(
distances: Union[np.ndarray, List[float], pd.Series],
ax: Optional[plt.Axes] = None,
kind: str = "hist",
color: str = "dodgerblue",
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
title: Optional[str] = None,
fontsize: int = 10,
show_grid: bool = True,
show_kde: bool = True,
show_rug: bool = False,
figsize: Tuple[int, int] = (10, 4),
xlim: Optional[Tuple[float, float]] = None,
ylim: Optional[Tuple[float, float]] = None,
stat: str = "density",
bins: Union[int, str] = "auto",
alpha: float = 0.4,
vertical_lines: Optional[Dict[str, Dict[str, Any]]] = None,
**kwargs,
) -> plt.Axes:
"""
Plots the distribution of distances/errors from calibration.
Args:
distances: Array-like object containing the distance/error values
ax: Matplotlib axes to plot on. Creates new figure if None
kind: Type of plot ('hist', 'kde', or 'ecdf')
color: Color for the plot
xlabel: X-axis label
ylabel: Y-axis label
title: Plot title
fontsize: Base font size for labels and ticks
show_grid: Whether to show grid lines
show_kde: Whether to show KDE curve with histogram (only for kind='hist')
show_rug: Whether to show rug plot
figsize: Figure size if creating new figure
xlim: Tuple of (min, max) for x-axis limits
ylim: Tuple of (min, max) for y-axis limits
stat: Statistic to plot for histogram ('count', 'density', 'probability')
bins: Number of bins or method for histogram
alpha: Transparency of the plot
vertical_lines: Dict of vertical lines to add, format:
{
'name': {
'x': value,
'color': 'color',
'linestyle': '--',
'label': 'label'
}
}
**kwargs: Additional arguments passed to plotting functions
Returns:
plt.Axes: The matplotlib axes object
Raises:
ValueError: If kind is not 'hist', 'kde', or 'ecdf'
"""
if ax is None:
_, ax = plt.subplots(dpi=300, figsize=figsize)
# Convert input to pandas Series for consistent handling
if not isinstance(distances, pd.Series):
distances = pd.Series(distances)
# Set default labels if not provided
if xlabel is None:
xlabel = "Distance"
if ylabel is None:
ylabel = {
"hist": "Density" if stat == "density" else "Count",
"kde": "Density",
"ecdf": "Cumulative Probability",
}.get(kind, "")
# Plot based on kind
if kind == "hist":
sns.histplot(
data=distances,
ax=ax,
color=color,
stat=stat,
bins=bins,
alpha=alpha,
kde=show_kde,
**kwargs,
)
elif kind == "kde":
sns.kdeplot(
data=distances, ax=ax, color=color, fill=True, alpha=alpha, **kwargs
)
elif kind == "ecdf":
sns.ecdfplot(data=distances, ax=ax, color=color, **kwargs)
else:
raise ValueError(
f"Unknown kind for plot: {kind}. Must be 'hist', 'kde', or 'ecdf'"
)
# Add rug plot if requested
if show_rug:
sns.rugplot(data=distances, ax=ax, color=color, alpha=alpha / 2)
# Add vertical lines if specified
if vertical_lines:
for line_specs in vertical_lines.values():
ax.axvline(
x=line_specs["x"],
color=line_specs.get("color", "red"),
linestyle=line_specs.get("linestyle", "--"),
label=line_specs.get("label", None),
alpha=line_specs.get("alpha", 1.0),
)
if line_specs.get("label"):
ax.legend(frameon=False)
# Style improvements
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
if show_grid:
ax.grid(axis="y", linestyle="--", linewidth=0.3, alpha=0.5)
# Set axis limits if provided
if xlim is not None:
ax.set_xlim(xlim)
if ylim is not None:
ax.set_ylim(ylim)
# Labels and title
ax.set_xlabel(xlabel, fontsize=fontsize)
ax.set_ylabel(ylabel, fontsize=fontsize)
if title is not None:
ax.set_title(title, fontsize=fontsize + 2, pad=20)
# Tick formatting
ax.tick_params(axis="both", which="major", labelsize=fontsize - 2)
# Adjust layout
plt.tight_layout()
return ax
[docs]
def plot_trajectories(
stacked: Dict[str, np.ndarray],
columns: Union[List[str], str],
data: Optional[pd.DataFrame] = None,
ax: Optional[plt.Axes] = None,
show_data: bool = False,
alpha: float = 0.1,
title: str = "",
ylabel: str = "",
xlabel: str = "",
show_legend: bool = True,
legend_loc: str = "upper left",
palette: str = "Dark2",
colors: Optional[Union[List[str], str]] = None,
labels: Optional[Union[List[str], str]] = None,
y_scale: str = "linear",
show_grid: bool = True,
dates: Optional[np.ndarray] = None,
) -> plt.Axes:
"""
Plots individual trajectories over time with optional observed data.
Args:
stacked: Dictionary mapping column names to arrays of shape (n_simulations, timesteps)
columns: Column name(s) to plot
data: Optional DataFrame containing observed data
ax: Matplotlib axes to plot on
show_data: Whether to show observed data points
alpha: Alpha value for individual trajectories
title: Plot title
ylabel: Y-axis label
xlabel: X-axis label
show_legend: Whether to show legend
legend_loc: Legend location
palette: Color palette name
colors: Custom colors for lines
labels: Custom labels for legend
y_scale: Scale for y-axis ('linear' or 'log')
show_grid: Whether to show grid lines
dates: Array of dates for x-axis. If None, uses range(timesteps)
Returns:
plt.Axes: The matplotlib axes object
"""
if not isinstance(columns, list):
columns = [columns]
if ax is None:
_, ax = plt.subplots(dpi=300, figsize=(10, 4))
if colors is None:
colors = sns.color_palette(palette, len(columns))
elif not isinstance(colors, list):
colors = [colors]
if labels is None:
labels = columns
elif not isinstance(labels, list):
labels = [labels]
# Create x-axis values
if dates is None:
x = np.arange(stacked[columns[0]].shape[1])
else:
x = dates
# Plot each trajectory for each column
pleg = []
for column, color, label in zip(columns, colors, labels):
trajectories = stacked[column]
# Plot individual trajectories
for traj in trajectories:
line = ax.plot(x, traj, color=color, alpha=alpha, linewidth=0.5, zorder=1)
# Plot median trajectory with higher alpha
mean_traj = np.median(trajectories, axis=0)
(line,) = ax.plot(
x, mean_traj, color=color, alpha=1.0, linewidth=2, label=label, zorder=2
)
pleg.append(line)
handles_data = []
if show_data and data is not None:
data_colors = get_black_to_grey(len(columns))
for column, data_color, label in zip(columns, data_colors, labels):
p_actual = ax.scatter(
x,
data[column],
s=10,
color=data_color,
zorder=3,
label=f"observed ({label})",
)
if show_legend:
pleg.append(p_actual)
handles_data.append(f"observed ({label})")
# Style improvements
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
if show_grid:
ax.grid(axis="y", linestyle="--", linewidth=0.3, alpha=0.5, zorder=0)
# Labels and formatting
ax.set_title(title)
ax.set_ylabel(ylabel)
ax.set_xlabel(xlabel)
ax.set_yscale(y_scale)
if show_legend and pleg:
ax.legend(
pleg,
labels + (handles_data if show_data and data is not None else []),
loc=legend_loc,
frameon=False,
)
plt.tight_layout()
return ax