"""
=======================================
Plotting (:mod:`CADETProcess.plotting`)
=======================================
.. currentmodule:: CADETProcess.plotting
This module provides functionality for plotting in CADET-Process.
General Utils
=============
.. autosummary::
:toctree: generated/
get_fig_size
setup_figure
get_all_twin_handles_labels
show_or_reopen
style_and_save_figure
Secondary Axis
==============
.. autosummary::
:toctree: generated/
SecondaryAxis
Text
====
.. autosummary::
:toctree: generated/
add_text
Annotations
===========
.. autosummary::
:toctree: generated/
annotate
Fill Regions
============
.. autosummary::
:toctree: generated/
fill_between
""" # noqa
import os
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from functools import wraps
from typing import Any, Callable, Literal, Optional
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy.typing as npt
from matplotlib import cycler
# %% Colors
color_dict = {
"blue": mpl.colors.to_rgb("#000099"),
"red": mpl.colors.to_rgb("#990000"),
"green": mpl.colors.to_rgb("#009900"),
"orange": mpl.colors.to_rgb("#D79B00"),
"purple": mpl.colors.to_rgb("#896999"),
"grey": mpl.colors.to_rgb("#444444"),
}
color_list = list(color_dict.values())
chromapy_cycler = cycler(color=color_list)
linestyle_cycler = cycler("linestyle", ["--", ":", "-."])
# %% Fig size
def mm_to_inches(mm: float) -> float:
"""
Convert mm to inches.
Parameters
----------
mm : float
Value in mm.
Returns
-------
float
Value in inches.
"""
return mm / 25.4
# %% Layout
# Figure sizes from Elsevier style guide:
# https://www.elsevier.com/about/policies-and-standards/author/artwork-and-media-instructions/artwork-sizing
figure_layouts = {
"minimal": {
"width": mm_to_inches(30),
"height": mm_to_inches(20),
"linewidth": 1.0, # Axes, spines, grid
"marker_size": 3, # For line plots
"font_small": 8, # Ticks, legend
"font_medium": 10, # Axis labels
"font_large": 12, # Title
"tick_length": 3, # Tick mark length
"color_cycler": chromapy_cycler,
},
"1_col": {
"width": mm_to_inches(90),
"height": mm_to_inches(60),
"linewidth": 1.0,
"marker_size": 5,
"font_small": 8,
"font_medium": 10,
"font_large": 12,
"tick_length": 4,
"color_cycler": chromapy_cycler,
},
"1.5_col": {
"width": mm_to_inches(140),
"height": mm_to_inches(93.33),
"linewidth": 1.2,
"marker_size": 7,
"font_small": 9,
"font_medium": 11,
"font_large": 14,
"tick_length": 5,
"color_cycler": chromapy_cycler,
},
"2_col": {
"width": mm_to_inches(190),
"height": mm_to_inches(126.67),
"linewidth": 1.5,
"marker_size": 9,
"font_small": 10,
"font_medium": 12,
"font_large": 16,
"tick_length": 6,
"color_cycler": chromapy_cycler,
},
}
# %% Figure size
[docs]
def get_fig_size(
layout: Literal["1_col", "1_5_col", "2_col"] = "1_col",
nrows: int = 1,
ncols: int = 1,
aspect: float | None = None,
scale_with_subplots: bool = False,
padding: float = 0.0,
figsize: tuple[float, float] | None = None,
) -> tuple[float, float]:
"""
Compute a publication-ready figure size in inches.
Parameters
----------
layout: Literal["1_col", "1_5_col", "2_col"] = "1_col",
Figure layout.
nrows : int
Number of subplot rows.
ncols : int
Number of subplot columns.
aspect : float | None
Width / height ratio. If provided, overrides preset height.
scale_with_subplots : bool
If True, multiply width/height by ncols/nrows.
padding : float
Extra inches to add.
figsize : tuple[float, float] | None
Override width/height directly.
Returns
-------
width, height : tuple[float, float]
"""
if figsize is not None:
width, height = figsize
else:
if layout not in figure_layouts:
raise ValueError(
f"Invalid layout {layout}. Options: {list(figure_layouts)}"
)
width = figure_layouts[layout]["width"]
height = figure_layouts[layout]["height"]
if aspect is not None:
height = width / aspect
if scale_with_subplots:
width = ncols * width + padding
height = nrows * height + padding
return width, height
# %% Style
@contextmanager
def mpl_style_context(
layout: Literal["1_col", "1_5_col", "2_col"] = "1_col",
) -> None:
"""Context manager for temporary matplotlib rc parameters."""
layout_settings = figure_layouts[layout]
rc_params = {
"axes.titlesize": layout_settings["font_large"],
"axes.labelsize": layout_settings["font_medium"],
"axes.prop_cycle": layout_settings["color_cycler"],
"figure.titlesize": layout_settings["font_large"],
"font.size": layout_settings["font_small"],
"legend.fontsize": layout_settings["font_small"],
"lines.linewidth": layout_settings["linewidth"],
"lines.markersize": layout_settings["marker_size"],
"xtick.labelsize": layout_settings["font_small"],
"ytick.labelsize": layout_settings["font_small"],
}
with mpl.rc_context(rc_params):
yield
# %% Twin Axis
def get_twins(
ax: plt.Axes,
axis: Literal["x", "y", "both"] = "x",
) -> list[plt.Axes]:
"""
Get twin axes of a specified axis.
Parameters
----------
ax : plt.Axes
Axes to get twins for.
axis : Literal["x", "y"], default="x"
Which twins to get .
Returns
-------
`list`[plt.Axes]
"""
axs = []
match axis:
case "x":
axs = ax.get_shared_x_axes().get_siblings(ax)
case "y":
axs = ax.get_shared_y_axes().get_siblings(ax)
return list(reversed(
[a for a in axs if (a is not ax) & (a.bbox.bounds == ax.bbox.bounds)]
))
[docs]
def get_all_twin_handles_labels(ax: plt.Axes) -> tuple[list, list]:
"""
Return handles and labels from an axes and all of its twins.
Parameters
----------
ax : matplotlib.axes.Axes
The reference axes.
Returns
-------
handles : list
All line/patch artists from `ax` and its twin axes.
labels : list of str
Corresponding legend labels.
"""
# Matplotlib groups axes that share x or y; twins belong to these groups.
sec_axs = get_twins(ax)
axs = [ax, *sec_axs]
handles = []
labels = []
for a in axs:
h, l = a.get_legend_handles_labels() # noqa: E741
handles.extend(h)
labels.extend(l)
return handles, labels
def offset_secondary_yaxes(
ax: plt.Axes,
spacing_factor: float = 0.05,
min_spacing: float = 0.2,
max_spacing: float = 0.5,
side: Literal["left", "right"] = "right",
) -> None:
"""
Offset secondary y-axis spines to avoid overlap, scaling with axes size.
Parameters
----------
ax : plt.Axes
The primary axes.
side : {"right", "left"}, optional
Side on which to place the secondary spines. Default is "right".
spacing_factor : float, optional
Factor to scale the spacing with the axis width. Default is 0.05.
max_spacing : float, optional
Maximum spacing. Default is 0.2.
side : {"right", "left"}, optional
Side on which to place the secondary spines. Default is "right".
Raises
------
ValueError
If `side` is not "right" or "left".
"""
fig = ax.get_figure()
twin_axes = get_twins(ax)
# Get axis dimensions in figure coordinates
bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
ax_width = bbox.width
# Calculate dynamic spacing
sign = 1 if side == "right" else -1
base = 1.0 if side == "right" else 0.0
spacing = min(max(spacing_factor * ax_width, min_spacing), max_spacing)
for i, sec_ax in enumerate(twin_axes[1:], start=1):
spine = sec_ax.spines[side]
position = base + sign * i * spacing
if sec_ax.yaxis.get_label_text():
position += 0.02
spine.set_position(("axes", position))
# %% Figure Utils
def figure_utils(func: Callable) -> Callable:
"""
Unified decorator for styling, and saving figures.
Returns
-------
Callable
Decorator function.
"""
@wraps(func)
def figure_utils_wrapper(
*args: Any,
ax: Optional[plt.Axes | npt.NDArray[plt.Axes]] = None,
setup_figure_kwargs: Optional[dict] = None,
file_name: Optional[os.PathLike] = None,
dpi: int = 300,
tight_layout: bool = True,
**kwargs: Any,
) -> tuple[plt.Figure, plt.Axes | npt.NDArray[plt.Axes]]:
"""
Apply styles, save, and optionally create a figure.
Parameters
----------
*args : Any
Additional positional arguments passed to the wrapped function.
ax : Optional[plt.Axes], default=None
Optional Matplotlib Axes.
If not provided, a new figure is created.
setup_figure_kwargs : Optional[dict], default=None
Additional options to setup the figure.
file_name : Optional[os.PathLike], default=None
File name to store the figure. If None is provided, the figure is
not saved.
dpi : int, default=300
DPI for saving the figure.
tight_layout : bool, default=True
If True, set tight layout.
**kwargs : Any
Additional keyword arguments passed to the wrapped function.
Returns
-------
tuple[plt.Figure, plt.Axes | npt.NDArray[plt.Axes]]
The Matplotlib Figure and Axes objects.
"""
setup_figure_kwargs = {
"layout": "1_col",
"scale_with_subplots": True,
**(setup_figure_kwargs or {}),
}
show = kwargs.pop("show", None)
if show is not None:
warnings.warn("`show` argument is deprectated.")
with mpl_style_context(setup_figure_kwargs["layout"]):
fig, ax = func(
*args,
ax=ax,
setup_figure_kwargs=setup_figure_kwargs,
**kwargs
)
if tight_layout:
fig.tight_layout()
if file_name is not None:
fig.savefig(file_name, dpi=dpi)
return fig, ax
return figure_utils_wrapper
# %% Secondary Axis
[docs]
@dataclass
class SecondaryAxis:
"""Convenience class for secondary axis configuration."""
components: list[str]
ylabel: str | None = None
ylim: tuple[float, float] | None = None
transform: Callable | None = None
# %% Text
textbox_props = dict(facecolor="white", alpha=1)
[docs]
def add_text(
ax: plt.Axes,
text: str,
position: tuple[float, float] = (0.05, 0.9),
*,
tb_props: dict | None = None,
**kwargs: Any,
) -> None:
"""
Add text to a matplotlib Axes object.
Parameters
----------
ax : plt.Axes
The matplotlib Axes object to add text to.
text : str
The text to be added.
position : tuple[float, float], default=(0.05, 0.9)
The position of the text in axes coordinates.
tb_props : dict | None, default=None
Dictionary of properties to update the textbox (e.g., `boxstyle`, `facecolor`).
**kwargs
Additional keyword arguments for `ax.text`.
"""
tb_props = {**textbox_props, **(tb_props or {})}
ax.text(
*position,
text,
transform=ax.transAxes,
verticalalignment="top",
bbox=tb_props,
**kwargs,
)
# %% Annotations
[docs]
def annotate(
ax: plt.Axes,
text: str,
xy: tuple[float, float],
xytext: tuple[float, float],
*,
arrowstyle: str = "-|>",
**kwargs: Any,
) -> None:
"""
Add an annotation to a matplotlib Axes object.
Parameters
----------
ax : plt.Axes
The matplotlib Axes object to annotate.
text : str
The annotation text.
xy : tuple[float, float]
The point (x, y) to annotate.
xytext : tuple[float, float]
The position (x, y) of the annotation text.
arrowstyle : str, default="-|>"
Style of the arrow connecting `xy` and `xytext`.
**kwargs
Additional keyword arguments for `ax.annotate`.
"""
ax.annotate(
text,
xy=xy,
xycoords="data",
xytext=xytext,
textcoords="offset points",
arrowprops={"arrowstyle": arrowstyle},
**kwargs,
)
# %% Fill regions
[docs]
def fill_between(
ax: plt.Axes,
start: float,
end: float,
y_max: float,
alpha: float = 0.3,
*,
color_index: int | None = None,
text: str | None = None,
**kwargs: Any,
) -> None:
"""
Add fill region with optional text labeling.
Parameters
----------
ax : plt.Axes
Matplotlib axes to plot on.
start : float
Start x-position of the fill region.
end : float
End x-position of the fill region.
y_max : float
Maximum y-value for the fill.
alpha : float, default=0.3
Transparency of the fill.
color_index : int | None, default=None
Index into a predefined color list. If None, uses `color` from `kwargs`.
text : str | None, default=None
Optional label for the region.
**kwargs
Additional arguments passed to `ax.fill_between`.
"""
color = color_list[color_index] if color_index is not None else kwargs.pop("color", None)
ax.fill_between(
[start, end],
y_max,
alpha=alpha, # Use the `alpha` parameter, not hardcoded 0.3
color=color,
**kwargs,
)
if text is not None:
x_position = (start + end) / 2
x_lim = ax.get_xlim()
if start < x_lim[0]:
x_position = (x_lim[0] + end) / 2
elif end > x_lim[1]:
x_position = (start + x_lim[1]) / 2
y_position = 0.5 * y_max
ax.text(
x_position,
y_position,
text,
ha="center",
va="center",
)