"""
Functions to visualise sync in coupled oscillators
"""
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sb
import xgi
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
__all__ = [
"plot_series",
"plot_order_param",
"plot_phases",
"plot_sync",
"plot_phases_line",
"plot_phases_ring",
"plot_summary",
]
[docs]def plot_series(thetas, times, color="grey", alpha=0.1, n=None, ax=None, **kwargs):
"""
Plot sin(theta) over time for the given phases thetas.
Parameters
----------
thetas : ndarray
The values of the phases.
times : ndarray
The corresponding times.
color : color
Color of the lines
alpha : float
Transparency of the lines
ax : Matplotlib axis, optional
The Matplotlib axis to plot on, by default None (creates a new axis).
n : int, optional
The number of thetas to plot, by default None (plots all thetas).
**kwargs
Additional arguments that will be passed to matplotlib's plot.
Returns
-------
ax : Matplotlib axis
The Matplotlib axis the plot was drawn on.
"""
if ax is None:
ax = plt.gca()
# plot time series
for theta in thetas[:n]:
ax.plot(times, np.sin(theta), c=color, alpha=alpha, **kwargs)
ax.set_xlabel("Time")
ax.set_ylabel(r"$\sin(\theta)$")
return ax
[docs]def plot_order_param(thetas, times, order=1, color="r", ls="-", ax=None, **kwargs):
"""
Plot the order parameter over time for the given phases thetas.
Parameters
----------
thetas : ndarray
The values of the phases over time.
times : ndarray
The corresponding times.
ax : Matplotlib axis, optional
The Matplotlib axis to plot on, by default None (creates a new axis).
order : int, optional
The order of the order parameter, by default 1.
color : str, optional
The color of the plot, by default "r".
ls : str, optional
The line style of the plot, by default "-".
**kwargs
Additional arguments that will be passed to matplotlib's plot.
Returns
-------
ax : Matplotlib axis
The Matplotlib axis the plot was drawn on.
"""
if ax is None:
ax = plt.gca()
N = len(thetas)
R = np.sum(np.exp(1j * order * thetas), axis=0) / N
ax.plot(times, np.abs(R), c=color, ls=ls, label=f"$R_{order}$", **kwargs)
ax.set_xlabel("Time")
ax.set_ylabel(r"$R$")
ax.set_ylim([-0.01, 1.01])
return ax
[docs]def plot_phases(thetas, radius=1, color="b", ms=2, ax=None, **kwargs):
"""
Plot the phase thetas of oscillators on a circle.
Parameters
----------
thetas : np.ndarray
The phase of each oscillator over time. Shape is (N, T).
ax : plt.Axes, optional
The axes to plot on, by default None.
color : str, optional
The color of the phase plot, by default "b".
**kwargs
Additional arguments that will be passed to matplotlib's plot.
Returns
-------
plt.Axes
The plot's axes.
"""
if ax is None:
ax = plt.gca()
# draw circle as reference
circle = np.linspace(0, 2 * np.pi, num=100, endpoint=True)
ax.plot(radius * np.cos(circle), radius * np.sin(circle), "-", c="lightgrey")
# draw phases
ax.plot(
radius * np.cos(thetas), radius * np.sin(thetas), "o", c=color, ms=ms, **kwargs
)
sb.despine(ax=ax, left=True, bottom=True)
ax.set_yticks([])
ax.set_xticks([])
ax.set_aspect("equal")
return ax
[docs]def plot_phases_line(thetas, ax=None, **kwargs):
"""
Plot the phases thetas of oscillators in order of node index.
Parameters
----------
thetas : np.ndarray
The phase of each oscillator over time. Shape is (N, T).
it : int
The time index to plot the phase plot for.
ax : plt.Axes, optional
The axes to plot on, by default None.
**kwargs: dict
All arguments to pass to matplotlib's plot
Returns
-------
plt.Axes
The plot's axes.
"""
if ax is None:
ax = plt.gca()
ax.plot(thetas % (2 * np.pi), "o", **kwargs)
ax.set_ylim([-0.1, 2 * np.pi + 0.1])
ax.set_yticks([0, np.pi, 2 * np.pi])
ax.set_yticklabels([0, r"$\pi$", r"$2\pi$"])
ax.set_xlabel("Node index")
ax.set_ylabel("Phase")
return ax
[docs]def plot_phases_ring(H, thetas, cmap="twilight", ax=None, colorbar=True, **kwargs):
"""
Plot the phase of oscillators at time `it` on a circle.
The phase values are represented by the node color, and the
oscillators are positioned evenly spaced on the circle.
Parameters
----------
H : xgi Hypergraph
Hypergraph to plot
thetas : np.ndarray
The phase of each oscillator over time. Shape is (N,).
cmap : colormap
Colormap used to map the phases to colors
ax : plt.Axes, optional
The axes to plot on, by default None.
colorbar : bool, optional
If True (default), plot a colorbar.
**kwargs: dict
All arguments to pass to xgi's `draw_nodes`.
Returns
-------
plt.Axes
The plot's axes.
"""
if ax is None:
ax = plt.gca()
pos = xgi.circular_layout(H)
thetas = thetas % (2 * np.pi)
ax, im = xgi.draw_nodes(
H,
pos=pos,
ax=ax,
node_fc=thetas,
vmin=0,
vmax=2 * np.pi,
node_fc_cmap=cmap,
**kwargs,
)
ax.set_aspect("equal")
if colorbar:
cbar = plt.colorbar(im)
cbar.set_ticks(ticks=[0, np.pi, 2 * np.pi], labels=[0, r"$\pi$", r"$2\pi$"])
return ax, im
[docs]def plot_sync(thetas, times, n=None, figsize=(4, 2), width_ratios=[3, 1]):
"""
Plot the time series of oscillators, their phase plots, and the order parameter.
Parameters
----------
thetas : np.ndarray
The phase of each oscillator over time. Shape is (N, T).
times : np.ndarray
The time stamps for the `thetas` data.
n : int, optional
Number of time series to plot, by default None.
Returns
-------
tuple
(`fig`, `axs`) where `fig` is a `plt.Figure` and `axs` is a numpy ndarray of `plt.Axes`.
"""
fig, axs = plt.subplots(
2, 2, figsize=figsize, width_ratios=width_ratios, sharex="col"
)
plot_series(thetas, times, ax=axs[0, 0], n=n)
plot_order_param(thetas, times, ax=axs[1, 0], order=1)
plot_order_param(thetas, times, ax=axs[1, 0], order=2, ls="--")
plot_phases(thetas[:, 0], ax=axs[0, 1])
plot_phases(thetas[:, -1], ax=axs[1, 1])
return fig, axs
[docs]def plot_summary(thetas, times, H):
N, n_t = thetas.shape
fig, axs = plt.subplots(
2, 3, figsize=(5, 2), width_ratios=[2.5, 1, 1], sharex="col"
)
plot_series(thetas, times, ax=axs[0, 0], n=N)
plot_order_param(thetas, times, ax=axs[1, 0], order=1)
plot_order_param(thetas, times, ax=axs[1, 0], order=2, ls="--")
axs[0, 0].set_xlabel("")
axs[1, 0].legend(loc="lower right", fontsize="x-small", frameon=False)
plot_phases(thetas[:, 0], ax=axs[0, 1])
plot_phases(thetas[:, -1], ax=axs[1, 1])
axs[0, 1].set_title(f"$t={times[0]}$s", fontsize="x-small")
axs[1, 1].set_title(f"$t={times[-1]}$s", fontsize="x-small")
plot_phases_ring(H, thetas[:, 0], ax=axs[0, 2], node_size=5, alpha=0.8, node_lw=0.1)
plot_phases_ring(
H, thetas[:, -1], ax=axs[1, 2], node_size=5, alpha=0.8, node_lw=0.1
)
axins = inset_axes(
axs[1, 0],
width="100%",
height="100%",
bbox_to_anchor=(0.2, 0.4, 0.4, 0.6),
bbox_transform=axs[1, 0].transAxes,
)
plot_phases_line(thetas[:, -1], ax=axins, mfc=None, alpha=0.8, ms=2)
axins.set_ylabel("")
axins.set_xlabel("")
axins.set_yticks([])
axins.set_xticks([])
axins.patch.set_alpha(0.5) # make inset transparent
sb.despine(ax=axins)
plt.subplots_adjust(hspace=0.5, top=0.8)
return fig