"""
=======================================
Plotting (:mod:`CADETProcess.plotting`)
=======================================
.. currentmodule:: CADETProcess.plotting
This module provides functionality for plotting in CADET-Process.
General Style
=============
.. autosummary::
:toctree: generated/
set_figure_style
SecondaryAxis
Layout
set_layout
Setup Figure
============
.. autosummary::
:toctree: generated/
setup_figure
create_and_save_figure
Annotations
===========
.. autosummary::
:toctree: generated/
Annotation
add_annotations
Ticks
=====
.. autosummary::
:toctree: generated/
Tick
set_yticks
set_xticks
Fill Regions
============
.. autosummary::
:toctree: generated/
FillRegion
add_fill_regions
Text
====
.. autosummary::
:toctree: generated/
add_text
Hlines
======
.. autosummary::
:toctree: generated/
HLines
add_hlines
""" # noqa
import sys
from functools import wraps
from typing import Any, Optional
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
from matplotlib import cycler
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from CADETProcess import CADETProcessError
from CADETProcess.dataStructure import (
Callable,
Integer,
List,
String,
Structure,
Tuple,
UnsignedFloat,
)
this = sys.modules[__name__]
style = "medium"
color_dict = {
"blue": matplotlib.colors.to_rgb("#000099"),
"red": matplotlib.colors.to_rgb("#990000"),
"green": matplotlib.colors.to_rgb("#009900"),
"orange": matplotlib.colors.to_rgb("#D79B00"),
"purple": matplotlib.colors.to_rgb("#896999"),
"grey": matplotlib.colors.to_rgb("#444444"),
}
color_list = list(color_dict.values())
chromapy_cycler = cycler(color=color_list)
linestyle_cycler = cycler("linestyle", ["--", ":", "-."])
textbox_props = dict(facecolor="white", alpha=1)
figure_styles = {
"small": {
"width": 5,
"height": 3,
"linewidth": 2,
"font_small": 8,
"font_medium": 10,
"font_large": 12,
"color_cycler": chromapy_cycler,
},
"medium": {
"width": 10,
"height": 6,
"linewidth": 4,
"font_small": 20,
"font_medium": 24,
"font_large": 28,
"color_cycler": chromapy_cycler,
},
"large": {
"width": 15,
"height": 9,
"linewidth": 6,
"font_small": 25,
"font_medium": 30,
"font_large": 40,
"color_cycler": chromapy_cycler,
},
}
set_figure_style()
def get_fig_size(
n_rows: Optional[int] = 1,
n_cols: Optional[int] = 1,
style: Optional[str] = None,
) -> tuple[float, float]:
"""
Get figure size for figures with multiple Axes.
Parameters
----------
n_rows : int, optional
Number of rows in the figure. The default is 1.
n_cols : int, optional
Number of columns in the figure. The default is 1.
style : str, optional
Style to use for the figure. The default is None.
Returns
-------
fig_size : tuple
Size of the figure (width, height)
"""
if style is None:
style = this.style
width = figure_styles[style]["width"]
height = figure_styles[style]["height"]
fig_size = (n_cols * width + 2, n_rows * height + 2)
return fig_size
[docs]
class SecondaryAxis(Structure):
"""Parameters for secondary axis."""
components = List()
y_label = String()
y_lim = Tuple()
transform = Callable()
[docs]
class Layout(Structure):
"""General figure layout."""
style = String()
title = String()
x_label = String()
x_ticks = List()
y_label = String()
y_ticks = List()
x_lim = Tuple()
y_lim = Tuple()
[docs]
def set_layout(
ax: Axes,
layout: Layout,
show_legend: bool = True,
ax_secondary: Optional[SecondaryAxis] = None,
secondary_layout: Optional[Layout] = None,
) -> None:
"""
Configure the layout of a matplotlib Axes object.
Parameters
----------
ax : Axes
The primary matplotlib Axes object to configure.
layout : Layout
Layout object containing axis labels, limits, title, and ticks.
show_legend : bool, optional
Whether to display the legend. Default is True.
ax_secondary : Optional[SecondaryAxis], optional
The secondary Axes object, if applicable.
secondary_layout : Optional[Layout], optional
Layout object for the secondary axis, if applicable.
"""
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_xlabel(layout.x_label)
ax.set_ylabel(layout.y_label)
ax.set_xlim(layout.x_lim)
ax.set_ylim(layout.y_lim)
ax.set_title(layout.title)
if layout.x_ticks is not None:
set_xticks(layout.x_ticks)
if layout.y_ticks is not None:
set_yticks(layout.y_ticks)
lines, labels = ax.get_legend_handles_labels()
if ax_secondary is not None:
ax_secondary.set_ylabel(secondary_layout.y_label)
ax_secondary.set_ylim(secondary_layout.y_lim)
if show_legend:
lines_secondary, labels_secondary = ax_secondary.get_legend_handles_labels()
ax_secondary.legend(
lines_secondary + lines, labels_secondary + labels, loc=0
)
else:
if show_legend and len(labels) != 0:
ax.legend()
[docs]
class Tick(Structure):
"""Parameters for Axes ticks."""
location: Tuple()
label: String()
[docs]
def set_yticks(ax: Axes, y_ticks: list[Tick]) -> None:
"""
Set the y-ticks on a matplotlib Axes object.
Parameters
----------
ax : Axes
The matplotlib Axes object to set the y-ticks on.
y_ticks : list[Tick]
List of Tick objects containing location and label for each y-tick.
"""
locs = np.array([y_tick.location for y_tick in y_ticks])
labels = [y_tick.label for y_tick in y_ticks]
ax.set_yticks(locs, labels)
[docs]
def set_xticks(ax: Axes, x_ticks: list[Tick]) -> None:
"""
Set the x-ticks on a matplotlib Axes object with rotation.
Parameters
----------
ax : Axes
The matplotlib Axes object to set the x-ticks on.
x_ticks : list[Tick]
List of Tick objects containing location and label for each x-tick.
"""
locs = np.array([x_tick.location for x_tick in x_ticks])
labels = [x_tick.label for x_tick in x_ticks]
plt.xticks(locs, labels, rotation=72, horizontalalignment="center")
[docs]
def add_text(
ax: Axes,
text: str,
position: tuple[float, float] = (0.05, 0.9),
tb_props: Optional[Any] = None,
**kwargs: Optional[dict],
) -> None:
"""
Add text to a matplotlib Axes object.
Parameters
----------
ax : Axes
The matplotlib Axes object to add text to.
text : str
The text to be added.
position : tuple[float], optional
The position of the text, default is (0.05, 0.9).
tb_props : Optional[Any], optional
Properties to update the textbox with.
**kwargs : Optional[dict]
Additional keyword arguments for text customization.
"""
if tb_props is not None:
textbox_props.update(tb_props)
ax.text(
*position,
text,
transform=ax.transAxes,
verticalalignment="top",
bbox=textbox_props,
**kwargs,
)
def add_overlay(
ax: Axes,
y_overlay: npt.ArrayLike | list[npt.ArrayLike],
x_overlay: Optional[npt.ArrayLike] = None,
**plot_args: Optional[dict],
) -> None:
"""
Add overlay plot(s) to a matplotlib Axes object.
Parameters
----------
ax : Axes
The matplotlib Axes object to which the overlay is added.
y_overlay : npt.ArrayLike | list[npt.ArrayLike]
The y-data for the overlay plot(s).
x_overlay : Optional[list], optional
The x-data for the overlay plot(s). If None, uses x-data from the first line in ax.
**plot_args : Optional[dict]
Additional keyword arguments for customizing the plot.
"""
if not isinstance(y_overlay, list):
y_overlay = [y_overlay]
if x_overlay is None:
x_overlay = ax.lines[0].get_xdata()
for y_over in y_overlay:
ax.plot(x_overlay, y_over, **plot_args)
ax.set_prop_cycle(None)
[docs]
class Annotation(Structure):
"""Parameters for text annotations."""
text = String()
xy = Tuple()
xytext = Tuple()
arrowstyle = "-|>"
[docs]
def add_annotations(
ax: Axes,
annotations: list[Annotation],
) -> None:
"""Add list of annotations to axis ax."""
for annotation in annotations:
ax.annotate(
annotation.text,
xy=annotation.xy,
xycoords="data",
xytext=annotation.xytext,
textcoords="offset points",
arrowprops={
"arrowstyle": annotation.arrowstyle,
},
)
[docs]
class FillRegion(Structure):
"""Parameters for fill region."""
color_index = Integer()
start = UnsignedFloat()
end = UnsignedFloat()
y_max = UnsignedFloat()
text = String()
[docs]
def add_fill_regions(
ax: Axes,
fill_regions: list[FillRegion],
x_lim: Optional[npt.ArrayLike] = None,
) -> None:
"""Add FillRegion to axes."""
for fill in fill_regions:
color = color_list[fill.color_index]
ax.fill_between(
[fill.start, fill.end],
fill.y_max,
alpha=0.3,
color=color,
)
if fill.text is not None:
if x_lim is None or fill.start < x_lim[0]:
x_position = (x_lim[0] + fill.end) / 2
else:
x_position = (fill.start + fill.end) / 2
y_position = 0.5 * fill.y_max
ax.text(
x_position,
y_position,
fill.text,
horizontalalignment="center",
verticalalignment="center",
)
[docs]
class HLines(Structure):
"""Parameters for plotting horizontal lines."""
y = UnsignedFloat()
x_min = UnsignedFloat()
x_max = UnsignedFloat()
[docs]
def add_hlines(ax: Axes, hlines: list[HLines]) -> None:
"""Add hlines to matplotlib Axes."""
for line in hlines:
ax.hlines(line.y, line.x_min, line.x_max)