Source code for bayspec.util.corner

"""Plotly-based corner-plot renderer for posterior samples."""

from __future__ import annotations

from collections.abc import Sequence
import logging

import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy.ndimage import gaussian_filter


[docs] def corner_plotly( xs: np.ndarray | Sequence, bins: int = 30, ranges: list[tuple[float, float]] | None = None, weights: np.ndarray | None = None, color: str | None = None, smooth1d: float | None = 1.0, smooth: float | None = 1.0, labels: list[str] | None = None, quantiles: list[float] | None = None, levels: list[float] | None = None, ) -> go.Figure: """Render a Plotly corner plot of 1D marginals and 2D contours. Draws smoothed 1D histograms on the diagonal and 2D contour plots on the lower triangle, with optional quantile markers. Args: xs: 1D or 2D array of samples; a 2D array is treated as rows of observations. bins: Number of bins used for both 1D and 2D histograms. ranges: Per-parameter plotting range. Defaults to full data range. weights: Sample weights. color: Line color for 1D histograms; defaults to a dark blue. smooth1d: Gaussian smoothing sigma for 1D histograms. smooth: Gaussian smoothing sigma for 2D histograms. labels: Axis labels, one per parameter. quantiles: Quantiles to mark on the 1D histograms. levels: Contour levels as enclosed probability mass. Defaults to the one- and two-sigma enclosed masses. Returns: The assembled ``plotly.graph_objects.Figure``. """ xs = _parse_input(xs) K = xs.shape[0] bins_list = [int(bins)] * K if ranges is None: ranges = [(np.min(x), np.max(x)) for x in xs] if levels is None: levels = 1.0 - np.exp(-0.5 * np.array([1, 2]) ** 2) if quantiles is None: quantiles = [] if labels is None: labels = [f'label{i}' for i in range(K)] if color is None: color = '#08519c' fig = make_subplots( rows=K, cols=K, vertical_spacing=0.02, horizontal_spacing=0.02, shared_xaxes=False, shared_yaxes=False, ) for i, x in enumerate(xs): n_bins_1d = bins_list[i] bins_1d = np.linspace(ranges[i][0], ranges[i][1], n_bins_1d + 1) n, _ = np.histogram(x, bins=bins_1d, weights=weights, density=True) if smooth1d is not None: n = gaussian_filter(n, smooth1d) x0 = np.repeat(bins_1d, 2)[1:-1] y0 = np.repeat(n, 2) # Plot 1D histogram on the diagonal fig.add_trace( go.Scatter( x=x0, y=y0, mode='lines', name=labels[i], showlegend=False, line=dict(width=2, color=color), ), row=i + 1, col=i + 1, ) # Plot quantiles on the 1D histogram if quantiles: qvalues = quantile(x, quantiles, weights=weights) for q in qvalues: idx = np.argmin(np.abs(q - x0)) yq = y0[idx] fig.add_shape( go.layout.Shape( type='line', x0=q, y0=0, x1=q, y1=yq, name=labels[i], showlegend=False, line=dict(color=color, dash='dash'), ), row=i + 1, col=i + 1, ) # Plot 2D histograms on the off-diagonals (lower triangle) for j, y in enumerate(xs[:i]): fig = plot_hist2d( y, x, bins=[bins_list[j], bins_list[i]], ranges=[ranges[j], ranges[i]], weights=weights, smooth=smooth, labels=[labels[j], labels[i]], levels=levels, fig=fig, subfig_idx=(i, j), ) fig.update_layout(template='plotly_white', height=200 * K, width=200 * K) # Hide all tick labels by default, set angle fig.update_xaxes(tickangle=-45, showticklabels=False) fig.update_yaxes(tickangle=-45, showticklabels=False) # Enable X tick labels for the bottom row, and set titles for i in range(K): fig.update_xaxes(title_text=labels[i], row=K, col=i + 1, showticklabels=True) # Enable Y tick labels for the leftmost column (skipping the top-left diagonal plot), and set titles for i in range(1, K): fig.update_yaxes(title_text=labels[i], row=i + 1, col=1, showticklabels=True) return fig
[docs] def plot_hist2d( x: np.ndarray, y: np.ndarray, bins: list[int], ranges: list[tuple[float, float]], weights: np.ndarray | None, smooth: float | None, labels: list[str], levels: np.ndarray, fig: go.Figure, subfig_idx: tuple[int, int], ) -> go.Figure: """Add a smoothed 2D histogram with contour levels to ``fig``. Contour levels are chosen so that each encloses a requested cumulative probability mass of the 2D histogram. Args: x: Horizontal samples. y: Vertical samples. bins: Bin counts ``[nx, ny]``. ranges: Axis ranges ``[(xmin, xmax), (ymin, ymax)]``. weights: Sample weights. smooth: Gaussian smoothing sigma for the 2D histogram. labels: ``[xlabel, ylabel]`` used to tag the trace. levels: Enclosed probability masses for contour generation. fig: Figure to add the contour trace to. subfig_idx: ``(row, col)`` zero-based index of the target subplot. Returns: The modified figure. """ i2, j2 = subfig_idx bins_x = np.linspace(ranges[0][0], ranges[0][1], bins[0] + 1) bins_y = np.linspace(ranges[1][0], ranges[1][1], bins[1] + 1) H, X, Y = np.histogram2d(x.flatten(), y.flatten(), bins=[bins_x, bins_y], weights=weights) if smooth is not None: H = gaussian_filter(H, smooth) # Calculate contour levels using vectorized searchsorted Hflat = np.sort(H.flatten())[::-1] sm = np.cumsum(Hflat) sm /= sm[-1] # Find indices where the CDF crosses the requested levels idx = np.searchsorted(sm, levels, side='right') - 1 idx = np.clip(idx, 0, len(Hflat) - 1) V = Hflat[idx] V.sort() # Handle edge case where contours might be identical m = np.diff(V) == 0 if np.any(m): logging.warning('Too few points to create valid contours.') while np.any(m): V[np.where(m)[0][0]] *= 1.0 - 1e-4 m = np.diff(V) == 0 V.sort() X1 = 0.5 * (X[1:] + X[:-1]) Y1 = 0.5 * (Y[1:] + Y[:-1]) # Pad H using numpy's native padding H_edge = np.pad(H, pad_width=1, mode='edge') H2 = np.pad(H_edge, pad_width=1, mode='constant', constant_values=H.min()) # Extrapolate coordinates for the padded array X2 = np.concatenate( [ X1[0] + np.array([-2, -1]) * np.diff(X1[:2]), X1, X1[-1] + np.array([1, 2]) * np.diff(X1[-2:]), ] ) Y2 = np.concatenate( [ Y1[0] + np.array([-2, -1]) * np.diff(Y1[:2]), Y1, Y1[-1] + np.array([1, 2]) * np.diff(Y1[-2:]), ] ) fig.add_trace( go.Contour( z=H2.T, x=X2, y=Y2, name=f'{labels[0]}&{labels[1]}', showlegend=False, contours=dict(start=min(V), end=max(V), size=max(V) - min(V) if max(V) > min(V) else 1), ncontours=len(V), colorscale='Blues', line=dict(width=2), showscale=False, ), row=i2 + 1, col=j2 + 1, ) return fig
[docs] def quantile( x: np.ndarray, q: float | list[float], weights: np.ndarray | None = None ) -> list[float]: """Return (optionally weighted) quantiles of ``x``. Falls back to ``numpy.percentile`` when ``weights`` is ``None``. Args: x: 1D sample array. q: Quantile or list of quantiles, each in ``[0, 1]``. weights: Optional sample weights. Returns: The requested quantile values as a list. Raises: ValueError: If any ``q`` is outside ``[0, 1]``, or if ``weights`` has a different length than ``x``. """ x = np.atleast_1d(x) q = np.atleast_1d(q) if np.any(q < 0.0) or np.any(q > 1.0): raise ValueError('Quantiles must be strictly between 0 and 1.') if weights is None: return np.percentile(x, 100.0 * q).tolist() weights = np.atleast_1d(weights) if len(x) != len(weights): raise ValueError('Dimension mismatch: len(weights) must equal len(x).') idx = np.argsort(x) sw = weights[idx] cdf = np.cumsum(sw)[:-1] cdf /= cdf[-1] cdf = np.insert(cdf, 0, 0.0) return np.interp(q, cdf, x[idx]).tolist()
def _parse_input(xs: np.ndarray | Sequence) -> np.ndarray: xs = np.atleast_1d(xs) if xs.ndim == 1: xs = xs[np.newaxis, :] elif xs.ndim == 2: xs = xs.T else: raise ValueError('The input sample array must be 1- or 2-D.') return xs