"""High-level plotting helpers for spectra, responses, fits, and posteriors.
``Plot`` dispatches per-object figures (spectrum, response, dataunit, fit
comparison, corner), ``ModelPlot`` composes multi-model comparisons, and
``Figure`` wraps the returned figure with notebook display plus
filename-aware saving (HTML/PDF/JSON).
Every renderer accepts ``ploter='plotly'`` or ``ploter='matplotlib'`` and
returns a :class:`Figure`.
"""
from itertools import chain
import sys
import warnings
import corner
from getdist import MCSamples, plots
import matplotlib as mpl
from matplotlib import rcParams
import matplotlib.pyplot as plt
import numpy as np
import plotly.express as px
import plotly.graph_objs as go
from plotly.subplots import make_subplots
from ..data.data import Data, DataUnit
from ..data.response import Auxiliary, Response
from ..data.spectrum import Spectrum
from ..infer.analyzer import Bootstrap, Posterior
from ..infer.infer import BayesInfer, Infer
from ..infer.pair import Pair
from ..model.model import Model
from .corner import corner_plotly
from .tools import json_dump
[docs]
class Plot:
"""Static factory for single-object figures over bayspec data types.
Every method takes a concrete bayspec object (``Spectrum``,
``Response``, ``DataUnit``, ``Data``, ``Pair``, ``Infer``,
``Posterior``/``Bootstrap``) and returns a :class:`Figure`. Plotly is
the default backend; ``matplotlib`` is available for static output.
"""
colors = (
px.colors.qualitative.Plotly
+ px.colors.qualitative.D3
+ px.colors.qualitative.G10
+ px.colors.qualitative.T10
+ px.colors.qualitative.Alphabet
)
[docs]
@staticmethod
def get_rgb(color, opacity=1.0):
"""Convert a matplotlib color plus opacity into a Plotly ``rgba`` string."""
rgba = mpl.colors.to_rgba(color)
r, g, b = (int(x * 255) for x in rgba[:3])
return f'rgba({r}, {g}, {b}, {opacity:f})'
[docs]
@staticmethod
def spectrum(cls, ploter='plotly'):
"""Plot counts vs. channel for a ``Spectrum``.
Args:
cls: ``Spectrum`` whose counts and errors are drawn.
ploter: Backend -- ``'plotly'`` or ``'matplotlib'``.
Returns:
A :class:`Figure` wrapping the plot.
Raises:
TypeError: If ``cls`` is not a ``Spectrum``.
"""
if not isinstance(cls, Spectrum):
raise TypeError('cls is not Spectrum type, cannot call spectrum method')
x = np.arange(len(cls.counts))
y = cls.counts.astype(float)
y_e = cls.errors.astype(float)
if ploter == 'plotly':
fig = go.Figure()
spec = go.Scatter(
x=x,
y=y,
mode='lines',
showlegend=False,
error_y=dict(type='data', array=y_e, thickness=1.5, width=0),
)
fig.add_trace(spec)
fig.update_xaxes(title_text='Channel')
fig.update_yaxes(title_text='Counts', type='log')
fig.update_layout(template='plotly_white', height=600, width=600)
fig.update_layout(legend=dict(x=1, y=1, xanchor='right', yanchor='bottom'))
elif ploter == 'matplotlib':
rcParams['font.family'] = 'sans-serif'
rcParams['font.size'] = 12
rcParams['pdf.fonttype'] = 42
fig = plt.figure(figsize=(7, 6))
gs = fig.add_gridspec(1, 1, wspace=0, hspace=0)
ax = fig.add_subplot(gs[0, 0])
ax.errorbar(
x, y, yerr=y_e, fmt='-', lw=1.0, color=Plot.colors[0], elinewidth=1.0, capsize=0
)
ax.set_yscale('log')
ax.set_xlabel('Channel')
ax.set_ylabel('Counts')
ax.minorticks_on()
ax.xaxis.set_ticks_position('both')
ax.yaxis.set_ticks_position('both')
ax.tick_params(axis='x', which='both', direction='in', labelcolor='k', colors='k')
ax.tick_params(axis='y', which='both', direction='in', labelcolor='k', colors='k')
ax.tick_params(axis='x', which='both', labeltop=False, labelbottom=True)
ax.tick_params(axis='y', which='both', labelleft=True, labelright=False)
fig_data = {'spec': {'x': x, 'y': y, 'y_e': y_e}}
return Figure(fig, fig_data, ploter)
[docs]
@staticmethod
def response(cls, ploter='plotly', ch_range=None, ph_range=None):
"""Plot the 2D detector response matrix as a contour.
Args:
cls: ``Response`` (non-``Auxiliary``) to visualize.
ploter: Backend -- ``'plotly'`` or ``'matplotlib'``.
ch_range: Optional ``(min, max)`` channel-energy window.
ph_range: Optional ``(min, max)`` photon-energy window.
Returns:
A :class:`Figure` wrapping the plot.
Raises:
TypeError: If ``cls`` is not a ``Response`` or is ``Auxiliary``.
"""
if not isinstance(cls, Response):
raise TypeError('cls is not Response type, cannot call response method')
if isinstance(cls, Auxiliary):
raise TypeError('cls is Auxiliary type, cannot call response method')
ch_mean = np.mean(cls.chbin, axis=1)
ph_mean = np.mean(cls.phbin, axis=1)
if ch_range is None:
ch_idx = np.arange(len(ch_mean))
else:
ch_idx = np.where((ch_mean >= ch_range[0]) & (ch_mean <= ch_range[1]))[0]
if ph_range is None:
ph_idx = np.arange(len(ph_mean))
else:
ph_idx = np.where((ph_mean >= ph_range[0]) & (ph_mean <= ph_range[1]))[0]
x = ch_mean[ch_idx].astype(float)
y = ph_mean[ph_idx].astype(float)
z = cls.drm[ph_idx, :][:, ch_idx].astype(float)
if ploter == 'plotly':
fig = go.Figure()
resp = go.Contour(z=z, x=x, y=y, colorscale='Jet')
fig.add_trace(resp)
fig.update_xaxes(title_text='Channel energy (keV)', type='log')
fig.update_yaxes(title_text='Photon energy (keV)', type='log')
fig.update_layout(template='plotly_white', height=600, width=600)
fig.update_layout(legend=dict(x=1, y=1, xanchor='right', yanchor='bottom'))
elif ploter == 'matplotlib':
rcParams['font.family'] = 'sans-serif'
rcParams['font.size'] = 12
rcParams['pdf.fonttype'] = 42
X, Y = np.meshgrid(x, y)
fig = plt.figure(figsize=(7, 6))
gs = fig.add_gridspec(1, 1, wspace=0, hspace=0)
ax = fig.add_subplot(gs[0, 0])
c = ax.contourf(X, Y, z, cmap='jet')
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xlabel('Channel energy (keV)')
ax.set_ylabel('Photon energy (keV)')
fig.colorbar(c, orientation='vertical')
ax.minorticks_on()
ax.xaxis.set_ticks_position('both')
ax.yaxis.set_ticks_position('both')
ax.tick_params(axis='x', which='both', direction='in', labelcolor='k', colors='k')
ax.tick_params(axis='y', which='both', direction='in', labelcolor='k', colors='k')
ax.tick_params(axis='x', which='both', labeltop=False, labelbottom=True)
ax.tick_params(axis='y', which='both', labelleft=True, labelright=False)
fig_data = {'resp': {'x': x, 'y': y, 'z': z}}
return Figure(fig, fig_data, ploter)
[docs]
@staticmethod
def response_photon(cls, ploter='plotly', ph_range=None):
"""Plot effective area vs. photon energy.
Args:
cls: ``Response`` or ``Auxiliary``.
ploter: Backend -- ``'plotly'`` or ``'matplotlib'``.
ph_range: Optional ``(min, max)`` photon-energy window.
Returns:
A :class:`Figure` wrapping the plot.
Raises:
TypeError: If ``cls`` is not a ``Response``.
"""
if not isinstance(cls, Response):
raise TypeError('cls is not Response type, cannot call response_photon method')
ph_mean = np.mean(cls.phbin, axis=1)
if ph_range is None:
ph_idx = np.arange(len(ph_mean))
else:
ph_idx = np.where((ph_mean >= ph_range[0]) & (ph_mean <= ph_range[1]))[0]
x = ph_mean[ph_idx].astype(float)
if isinstance(cls, Auxiliary):
y = cls.srp[ph_idx].astype(float)
else:
y = np.sum(cls.drm[ph_idx, :], axis=1).astype(float)
if ploter == 'plotly':
fig = go.Figure()
resp = go.Scatter(x=x, y=y, mode='lines', showlegend=False)
fig.add_trace(resp)
fig.update_xaxes(title_text='Photon energy (keV)', type='log')
fig.update_yaxes(type='log')
fig.update_layout(template='plotly_white', height=600, width=600)
fig.update_layout(legend=dict(x=1, y=1, xanchor='right', yanchor='bottom'))
elif ploter == 'matplotlib':
rcParams['font.family'] = 'sans-serif'
rcParams['font.size'] = 12
rcParams['pdf.fonttype'] = 42
fig = plt.figure(figsize=(7, 6))
gs = fig.add_gridspec(1, 1, wspace=0, hspace=0)
ax = fig.add_subplot(gs[0, 0])
ax.plot(x, y, lw=1.0, color=Plot.colors[0])
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xlabel('Photon energy (keV)')
ax.minorticks_on()
ax.xaxis.set_ticks_position('both')
ax.yaxis.set_ticks_position('both')
ax.tick_params(axis='x', which='both', direction='in', labelcolor='k', colors='k')
ax.tick_params(axis='y', which='both', direction='in', labelcolor='k', colors='k')
ax.tick_params(axis='x', which='both', labeltop=False, labelbottom=True)
ax.tick_params(axis='y', which='both', labelleft=True, labelright=False)
fig_data = {'resp': {'x': x, 'y': y}}
return Figure(fig, fig_data, ploter)
[docs]
@staticmethod
def response_channel(cls, ploter='plotly', ch_range=None):
"""Plot the channel-summed response vs. channel energy.
Args:
cls: ``Response`` (non-``Auxiliary``).
ploter: Backend -- ``'plotly'`` or ``'matplotlib'``.
ch_range: Optional ``(min, max)`` channel-energy window.
Returns:
A :class:`Figure` wrapping the plot.
Raises:
TypeError: If ``cls`` is not a ``Response`` or is ``Auxiliary``.
"""
if not isinstance(cls, Response):
raise TypeError('cls is not Response type, cannot call response_channel method')
if isinstance(cls, Auxiliary):
raise TypeError('cls is Auxiliary type, cannot call response_channel method')
ch_mean = np.mean(cls.chbin, axis=1)
if ch_range is None:
ch_idx = np.arange(len(ch_mean))
else:
ch_idx = np.where((ch_mean >= ch_range[0]) & (ch_mean <= ch_range[1]))[0]
x = ch_mean[ch_idx].astype(float)
y = np.sum(cls.drm[:, ch_idx], axis=0).astype(float)
if ploter == 'plotly':
fig = go.Figure()
obs = go.Scatter(x=x, y=y, mode='lines', showlegend=False)
fig.add_trace(obs)
fig.update_xaxes(title_text='Channel energy (keV)', type='log')
fig.update_yaxes(type='log')
fig.update_layout(template='plotly_white', height=600, width=600)
fig.update_layout(legend=dict(x=1, y=1, xanchor='right', yanchor='bottom'))
elif ploter == 'matplotlib':
rcParams['font.family'] = 'sans-serif'
rcParams['font.size'] = 12
rcParams['pdf.fonttype'] = 42
fig = plt.figure(figsize=(7, 6))
gs = fig.add_gridspec(1, 1, wspace=0, hspace=0)
ax = fig.add_subplot(gs[0, 0])
ax.plot(x, y, lw=1.0, color=Plot.colors[0])
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xlabel('Channel energy (keV)')
ax.minorticks_on()
ax.xaxis.set_ticks_position('both')
ax.yaxis.set_ticks_position('both')
ax.tick_params(axis='x', which='both', direction='in', labelcolor='k', colors='k')
ax.tick_params(axis='y', which='both', direction='in', labelcolor='k', colors='k')
ax.tick_params(axis='x', which='both', labeltop=False, labelbottom=True)
ax.tick_params(axis='y', which='both', labelleft=True, labelright=False)
fig_data = {'resp': {'x': x, 'y': y}}
return Figure(fig, fig_data, ploter)
[docs]
@staticmethod
def dataunit(cls, ploter='plotly', style='CE'):
"""Plot a single ``DataUnit``'s observed spectrum.
Args:
cls: ``DataUnit`` to plot; must be complete.
ploter: Backend -- ``'plotly'`` or ``'matplotlib'``.
style: Display style -- e.g. ``'CC'`` counts/channel,
``'CE'`` counts/keV, ``'NE'`` photon flux density.
Returns:
A :class:`Figure` wrapping the plot.
Raises:
TypeError: If ``cls`` is not a ``DataUnit``.
AttributeError: If the ``DataUnit`` fails its completeness check.
"""
if not isinstance(cls, DataUnit):
raise TypeError('cls is not DataUnit type, cannot call dataunit method')
if not cls.completeness:
raise AttributeError('failed for completeness check')
x = cls.rsp_chbin_mean.astype(float)
x_le = cls.rsp_chbin_width.astype(float) / 2
x_he = cls.rsp_chbin_width.astype(float) / 2
if style == 'CC':
src_y = cls.src_ctsrate.astype(float)
src_y_e = cls.src_ctsrate_error.astype(float)
bkg_y = cls.bkg_ctsrate.astype(float)
bkg_y_e = cls.bkg_ctsrate_error.astype(float)
net_y = cls.net_ctsrate.astype(float)
net_y_e = cls.net_ctsrate_error.astype(float)
ylabel = 'Counts/s/channel'
elif style == 'CE':
src_y = cls.src_ctsspec.astype(float)
src_y_e = cls.src_ctsspec_error.astype(float)
bkg_y = cls.bkg_ctsspec.astype(float)
bkg_y_e = cls.bkg_ctsspec_error.astype(float)
net_y = cls.net_ctsspec.astype(float)
net_y_e = cls.net_ctsspec_error.astype(float)
ylabel = 'Counts/s/keV'
else:
raise ValueError(f'unsupported style argument: {style}')
if ploter == 'plotly':
fig = go.Figure()
src = go.Scatter(
x=x,
y=src_y,
mode='markers',
name='Source',
showlegend=True,
error_x=dict(
type='data',
symmetric=False,
array=x_he,
arrayminus=x_le,
thickness=1.5,
width=0,
),
error_y=dict(type='data', array=src_y_e, thickness=1.5, width=0),
marker=dict(symbol='circle', size=3),
)
bkg = go.Scatter(
x=x,
y=bkg_y,
mode='markers',
name='Background',
showlegend=True,
error_x=dict(
type='data',
symmetric=False,
array=x_he,
arrayminus=x_le,
thickness=1.5,
width=0,
),
error_y=dict(type='data', array=bkg_y_e, thickness=1.5, width=0),
marker=dict(symbol='circle', size=3),
)
net = go.Scatter(
x=x,
y=net_y,
mode='markers',
name='Net',
showlegend=True,
error_x=dict(
type='data',
symmetric=False,
array=x_he,
arrayminus=x_le,
thickness=1.5,
width=0,
),
error_y=dict(type='data', array=net_y_e, thickness=1.5, width=0),
marker=dict(symbol='circle', size=3),
)
fig.add_trace(src)
fig.add_trace(bkg)
fig.add_trace(net)
fig.update_xaxes(title_text='Energy (keV)', type='log')
fig.update_yaxes(title_text=ylabel, type='log')
fig.update_layout(template='plotly_white', height=600, width=600)
fig.update_layout(legend=dict(x=1, y=1, xanchor='right', yanchor='bottom'))
elif ploter == 'matplotlib':
rcParams['font.family'] = 'sans-serif'
rcParams['font.size'] = 12
rcParams['pdf.fonttype'] = 42
fig = plt.figure(figsize=(7, 6))
gs = fig.add_gridspec(1, 1, wspace=0, hspace=0)
ax = fig.add_subplot(gs[0, 0])
ax.errorbar(
x,
src_y,
xerr=[x_le, x_he],
yerr=src_y_e,
fmt='none',
ecolor=Plot.colors[0],
elinewidth=1.0,
capsize=0,
label='Source',
)
ax.errorbar(
x,
bkg_y,
xerr=[x_le, x_he],
yerr=bkg_y_e,
fmt='none',
ecolor=Plot.colors[1],
elinewidth=1.0,
capsize=0,
label='Background',
)
ax.errorbar(
x,
net_y,
xerr=[x_le, x_he],
yerr=net_y_e,
fmt='none',
ecolor=Plot.colors[2],
elinewidth=1.0,
capsize=0,
label='Net',
)
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xlabel('Energy (keV)')
ax.set_ylabel(ylabel)
ax.minorticks_on()
ax.xaxis.set_ticks_position('both')
ax.yaxis.set_ticks_position('both')
ax.tick_params(axis='x', which='both', direction='in', labelcolor='k', colors='k')
ax.tick_params(axis='y', which='both', direction='in', labelcolor='k', colors='k')
ax.tick_params(axis='x', which='both', labeltop=False, labelbottom=True)
ax.tick_params(axis='y', which='both', labelleft=True, labelright=False)
ax.legend()
fig_data = {
'src': {'x': x, 'y': src_y, 'x_le': x_le, 'x_he': x_he, 'y_e': src_y_e},
'bkg': {'x': x, 'y': bkg_y, 'x_le': x_le, 'x_he': x_he, 'y_e': bkg_y_e},
'net': {'x': x, 'y': net_y, 'x_le': x_le, 'x_he': x_he, 'y_e': net_y_e},
}
return Figure(fig, fig_data, ploter)
[docs]
@staticmethod
def data(cls, ploter='plotly', style='CE'):
"""Plot every ``DataUnit`` in a ``Data`` container on one figure.
Args:
cls: ``Data`` whose units are drawn together.
ploter: Backend -- ``'plotly'`` or ``'matplotlib'``.
style: Display style (see :meth:`dataunit`).
Returns:
A :class:`Figure` wrapping the plot.
Raises:
TypeError: If ``cls`` is not a ``Data``.
"""
if not isinstance(cls, Data):
raise TypeError('cls is not Data type, cannot call data method')
if ploter == 'plotly':
fig = go.Figure()
elif ploter == 'matplotlib':
rcParams['font.family'] = 'sans-serif'
rcParams['font.size'] = 12
rcParams['pdf.fonttype'] = 42
fig = plt.figure(figsize=(7, 6))
gs = fig.add_gridspec(1, 1, wspace=0, hspace=0)
ax = fig.add_subplot(gs[0, 0])
x = cls.rsp_chbin_mean
x_le = [chw / 2 for chw in cls.rsp_chbin_width]
x_he = [chw / 2 for chw in cls.rsp_chbin_width]
if style == 'CC':
y = cls.net_ctsrate
y_e = cls.net_ctsrate_error
ylabel = 'Counts/s/channel'
elif style == 'CE':
y = cls.net_ctsspec
y_e = cls.net_ctsspec_error
ylabel = 'Counts/s/keV'
else:
raise ValueError(f'unsupported style argument: {style}')
fig_data = {}
for i, name in enumerate(cls.names):
if ploter == 'plotly':
obs = go.Scatter(
x=x[i].astype(float),
y=y[i].astype(float),
mode='markers',
name=f'{name}',
showlegend=True,
error_x=dict(
type='data',
symmetric=False,
array=x_he[i].astype(float),
arrayminus=x_le[i].astype(float),
color=Plot.colors[i],
thickness=1.5,
width=0,
),
error_y=dict(
type='data',
array=y_e[i].astype(float),
color=Plot.colors[i],
thickness=1.5,
width=0,
),
marker=dict(symbol='circle', size=3, color=Plot.colors[i]),
)
fig.add_trace(obs)
elif ploter == 'matplotlib':
ax.errorbar(
x[i],
y[i],
xerr=[x_le[i], x_he[i]],
yerr=y_e[i],
fmt='none',
ecolor=Plot.colors[i],
elinewidth=0.8,
capsize=0,
capthick=0,
label=name,
)
fig_data[name] = {
'obs': {'x': x[i], 'y': y[i], 'x_le': x_le[i], 'x_he': x_he[i], 'y_e': y_e[i]}
}
if ploter == 'plotly':
fig.update_xaxes(title_text='Energy (keV)', type='log')
fig.update_yaxes(title_text=ylabel, type='log')
fig.update_layout(template='plotly_white', height=600, width=600)
fig.update_layout(legend=dict(x=1, y=1, xanchor='right', yanchor='bottom'))
elif ploter == 'matplotlib':
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xlabel('Energy (keV)')
ax.set_ylabel(ylabel)
ax.minorticks_on()
ax.xaxis.set_ticks_position('both')
ax.yaxis.set_ticks_position('both')
ax.tick_params(axis='x', which='both', direction='in', labelcolor='k', colors='k')
ax.tick_params(axis='y', which='both', direction='in', labelcolor='k', colors='k')
ax.tick_params(axis='x', which='both', labeltop=False, labelbottom=True)
ax.tick_params(axis='y', which='both', labelleft=True, labelright=False)
ax.legend()
return Figure(fig, fig_data, ploter)
[docs]
@staticmethod
def model(ploter='plotly', style='NE', post=False, yrange=None):
"""Create an empty :class:`ModelPlot` for accumulating model traces.
Args:
ploter: Backend -- ``'plotly'`` or ``'matplotlib'``.
style: Spectrum style -- ``'NE'``/``'Fv'``/``'ENE'``/``'vFv'``/
``'EENE'`` for additive models, ``'NoU'`` for
multiplicative/mathematical models.
post: If ``True``, also draw the posterior credible band.
yrange: Optional ``(ymin, ymax)`` tuple for the y-axis.
Returns:
A fresh :class:`ModelPlot` ready for ``add_model`` calls.
"""
modelplot = ModelPlot(ploter=ploter, style=style, post=post, yrange=yrange)
return modelplot
[docs]
@staticmethod
def pair(cls, ploter='plotly', style='CE'):
"""Plot data and model together for a ``Pair``, with residual panel.
The top panel shows observed and model spectra for every data unit;
the bottom panel shows residuals in units of sigma.
Args:
cls: ``Pair`` whose data and model are drawn.
ploter: Backend -- ``'plotly'`` or ``'matplotlib'``.
style: Display style -- ``'CC'``, ``'CE'``, ``'NE'``, ``'Fv'``/
``'ENE'``, or ``'vFv'``/``'EENE'``.
Returns:
A :class:`Figure` wrapping the plot.
Raises:
TypeError: If ``cls`` is not a ``Pair``.
ValueError: If ``style`` is not recognized.
"""
if not isinstance(cls, Pair):
raise TypeError('cls is not Pair type, cannot call pair method')
if ploter == 'plotly':
fig = make_subplots(
rows=2,
cols=1,
row_heights=[0.75, 0.25],
shared_xaxes=True,
horizontal_spacing=0,
vertical_spacing=0.02,
)
elif ploter == 'matplotlib':
rcParams['font.family'] = 'sans-serif'
rcParams['font.size'] = 12
rcParams['pdf.fonttype'] = 42
fig = plt.figure(figsize=(6, 8))
gs = fig.add_gridspec(4, 1, wspace=0, hspace=0)
ax1 = fig.add_subplot(gs[0:3, 0])
ax2 = fig.add_subplot(gs[3, 0], sharex=ax1)
obs_x = cls.data.rsp_chbin_mean
obs_x_le = [chw / 2 for chw in cls.data.rsp_chbin_width]
obs_x_he = [chw / 2 for chw in cls.data.rsp_chbin_width]
if style == 'CC':
ylabel = 'Counts/s/channel'
obs_y = cls.data.net_ctsrate
obs_y_e = cls.data.net_ctsrate_error
mo_y = cls.model.conv_ctsrate
res_y = list(map(lambda oi, mi, si: (oi - mi) / si, obs_y, mo_y, obs_y_e))
elif style == 'CE':
ylabel = 'Counts/s/keV'
obs_y = cls.data.net_ctsspec
obs_y_e = cls.data.net_ctsspec_error
mo_y = cls.model.conv_ctsspec
res_y = list(map(lambda oi, mi, si: (oi - mi) / si, obs_y, mo_y, obs_y_e))
elif style == 'NE':
ylabel = 'Photons/cm2/s/keV'
obs_y = cls.deconv_phtspec
obs_y_e = cls.deconv_phtspec_error
mo_y = cls.phtspec_at_rsp
res_y = list(map(lambda oi, mi, si: (oi - mi) / si, obs_y, mo_y, obs_y_e))
elif style == 'Fv' or style == 'ENE':
ylabel = 'erg/cm2/s/keV'
obs_y = cls.deconv_flxspec
obs_y_e = cls.deconv_flxspec_error
mo_y = cls.flxspec_at_rsp
res_y = list(map(lambda oi, mi, si: (oi - mi) / si, obs_y, mo_y, obs_y_e))
elif style == 'vFv' or style == 'EENE':
ylabel = 'erg/cm2/s'
obs_y = cls.deconv_ergspec
obs_y_e = cls.deconv_ergspec_error
mo_y = cls.ergspec_at_rsp
res_y = list(map(lambda oi, mi, si: (oi - mi) / si, obs_y, mo_y, obs_y_e))
else:
raise ValueError(f'unsupported style argument: {style}')
yall = np.array(list(chain.from_iterable(obs_y)))
ymin = 0.5 * np.min(yall[yall > 0]).astype(float)
ymax = 2 * np.max(yall[yall > 0]).astype(float)
fig_data = {}
for i, name in enumerate(cls.data.names):
if ploter == 'plotly':
obs = go.Scatter(
x=obs_x[i].astype(float),
y=obs_y[i].astype(float),
mode='markers',
name=name,
showlegend=False,
error_x=dict(
type='data',
symmetric=False,
array=obs_x_he[i].astype(float),
arrayminus=obs_x_le[i].astype(float),
color=Plot.colors[i],
thickness=1.5,
width=0,
),
error_y=dict(
type='data',
array=obs_y_e[i].astype(float),
color=Plot.colors[i],
thickness=1.5,
width=0,
),
marker=dict(symbol='cross-thin', size=0, color=Plot.colors[i]),
)
mo = go.Scatter(
x=obs_x[i].astype(float),
y=mo_y[i].astype(float),
name=name,
showlegend=True,
mode='lines',
line=dict(width=2, color=Plot.colors[i]),
)
res = go.Scatter(
x=obs_x[i].astype(float),
y=res_y[i].astype(float),
name=name,
showlegend=False,
mode='markers',
marker=dict(
symbol='cross-thin',
size=10,
color=Plot.colors[i],
line=dict(width=1.5, color=Plot.colors[i]),
),
)
fig.add_trace(obs, row=1, col=1)
fig.add_trace(mo, row=1, col=1)
fig.add_trace(res, row=2, col=1)
elif ploter == 'matplotlib':
ax1.errorbar(
obs_x[i],
obs_y[i],
xerr=[obs_x_le[i], obs_x_he[i]],
yerr=obs_y_e[i],
fmt='none',
ecolor=Plot.colors[i],
elinewidth=0.8,
capsize=0,
capthick=0,
label=name,
)
ax1.plot(obs_x[i], mo_y[i], color=Plot.colors[i], lw=1.0)
ax2.scatter(
obs_x[i], res_y[i], marker='+', color=Plot.colors[i], s=40, linewidths=0.8
)
fig_data[name] = {
'obs': {
'x': obs_x[i],
'y': obs_y[i],
'x_le': obs_x_le[i],
'x_he': obs_x_he[i],
'y_e': obs_y_e[i],
},
'mo': {'x': obs_x[i], 'y': mo_y[i]},
'res': {'x': obs_x[i], 'y': res_y[i]},
}
if ploter == 'plotly':
fig.update_xaxes(title_text='', row=1, col=1, type='log')
fig.update_xaxes(title_text='Energy (keV)', row=2, col=1, type='log')
fig.update_yaxes(title_text=ylabel, row=1, col=1, type='log')
fig.update_yaxes(
title_text=ylabel, row=1, col=1, type='log', range=[np.log10(ymin), np.log10(ymax)]
)
fig.update_yaxes(title_text='Sigma', showgrid=False, range=[-3.5, 3.5], row=2, col=1)
fig.update_layout(template='plotly_white', height=700, width=600)
fig.update_layout(legend=dict(x=1, y=1, xanchor='right', yanchor='bottom'))
elif ploter == 'matplotlib':
ax1.set_xscale('log')
ax1.set_yscale('log')
ax1.set_ylabel(ylabel)
ax1.set_ylim([ymin, ymax])
ax1.minorticks_on()
ax1.xaxis.set_ticks_position('both')
ax1.yaxis.set_ticks_position('both')
ax1.tick_params(axis='x', which='both', direction='in', labelcolor='k', colors='k')
ax1.tick_params(axis='y', which='both', direction='in', labelcolor='k', colors='k')
ax1.tick_params(axis='x', which='both', labeltop=False, labelbottom=False)
ax1.tick_params(axis='y', which='both', labelleft=True, labelright=False)
ax1.legend()
ax2.axhline(0, c='grey', lw=1, ls='--')
ax2.set_xlabel('Energy (keV)')
ax2.set_ylabel('Sigma')
ax2.set_ylim([-3.49, 3.49])
ax2.minorticks_on()
ax2.xaxis.set_ticks_position('both')
ax2.yaxis.set_ticks_position('both')
ax2.tick_params(axis='x', which='both', direction='in', labelcolor='k', colors='k')
ax2.tick_params(axis='y', which='both', direction='in', labelcolor='k', colors='k')
ax2.tick_params(axis='x', which='both', labeltop=False, labelbottom=True)
ax2.tick_params(axis='y', which='both', labelleft=True, labelright=False)
return Figure(fig, fig_data, ploter)
[docs]
@staticmethod
def emcee_walker(cls):
"""Plot per-parameter emcee walker trajectories.
Args:
cls: ``BayesInfer`` or ``Posterior`` exposing
``posterior_sample`` and ``free_nparams``.
Returns:
A :class:`Figure` wrapping the matplotlib walker plot.
Raises:
TypeError: If ``cls`` is not a ``BayesInfer`` or ``Posterior``.
"""
if not isinstance(cls, (BayesInfer, Posterior)):
raise TypeError('cls is not BayesInfer or Posterior type, cannot call walker method')
params_sample = cls.posterior_sample[:, : cls.free_nparams].copy()
fig, axes = plt.subplots(cls.free_nparams, figsize=(10, 2 * cls.free_nparams), sharex='all')
for i in range(cls.free_nparams):
ax = axes[i]
ax.plot(params_sample[:, :, i], 'k', alpha=0.3)
ax.set_xlim(0, len(params_sample))
ax.set_ylabel(cls.free_plabels[i])
ax.yaxis.set_label_coords(-0.1, 0.5)
ax.minorticks_on()
ax.xaxis.set_ticks_position('both')
ax.yaxis.set_ticks_position('both')
ax.tick_params(axis='x', which='both', direction='in', labelcolor='k', colors='k')
ax.tick_params(axis='y', which='both', direction='in', labelcolor='k', colors='k')
ax.tick_params(axis='x', which='both', labeltop=False, labelbottom=True)
ax.tick_params(axis='y', which='both', labelleft=True, labelright=False)
axes[-1].set_xlabel('step number')
fig_data = None
return Figure(fig, fig_data, 'matplotlib')
[docs]
@staticmethod
def infer(cls, ploter='plotly', style='CE', rebin=True, at_par=None):
"""Plot data vs. inferred model (with residuals) from an ``Infer``.
Args:
cls: ``Infer`` or one of its subclasses (``Posterior``,
``Bootstrap``) to visualize.
ploter: Backend -- ``'plotly'`` or ``'matplotlib'``.
style: Display style (see :meth:`pair`).
rebin: Draw with re-binned channels when ``True``.
at_par: Which parameter point to evaluate the model at --
``'best'``, ``'best-ci'``, ``'median'``, ``'mean'``, or
``'truth'``. Defaults to ``'best'`` for ``Posterior`` and
``'truth'`` for ``Bootstrap``.
Returns:
A :class:`Figure` wrapping the plot.
Raises:
TypeError: If ``cls`` is not an ``Infer``.
ValueError: If ``at_par`` or ``style`` is not recognized, or if
``at_par='truth'`` but some parameters lack a truth value.
"""
if not isinstance(cls, Infer):
raise TypeError('cls is not Infer type, cannot call infer method')
if at_par is None:
if isinstance(cls, Posterior):
at_par = 'best'
if isinstance(cls, Bootstrap):
at_par = 'truth'
if isinstance(cls, (Posterior, Bootstrap)):
if at_par == 'best':
cls.at_par(cls.par_best)
elif at_par == 'best-ci':
cls.at_par(cls.par_best_ci)
elif at_par == 'median':
cls.at_par(cls.par_median)
elif at_par == 'mean':
cls.at_par(cls.par_mean)
elif at_par == 'truth':
if None in cls.par_truth:
raise ValueError('no truth value for some parameters')
else:
cls.at_par(cls.par_truth)
else:
raise ValueError(f'unsupported at_par argument: {at_par}')
if isinstance(cls, Bootstrap):
cls.at_par(cls.par_truth)
if ploter == 'plotly':
fig = make_subplots(
rows=2,
cols=1,
row_heights=[0.75, 0.25],
shared_xaxes=True,
horizontal_spacing=0,
vertical_spacing=0.02,
)
elif ploter == 'matplotlib':
rcParams['font.family'] = 'sans-serif'
rcParams['font.size'] = 12
rcParams['pdf.fonttype'] = 42
fig = plt.figure(figsize=(6, 8))
gs = fig.add_gridspec(4, 1, wspace=0, hspace=0)
ax1 = fig.add_subplot(gs[0:3, 0])
ax2 = fig.add_subplot(gs[3, 0], sharex=ax1)
if not rebin:
obs_x = cls.data_chbin_mean
obs_x_le = [chw / 2 for chw in cls.data_chbin_width]
obs_x_he = [chw / 2 for chw in cls.data_chbin_width]
res_y = cls.residual
else:
obs_x = cls.data_re_chbin_mean
obs_x_le = [chw / 2 for chw in cls.data_re_chbin_width]
obs_x_he = [chw / 2 for chw in cls.data_re_chbin_width]
res_y = cls.re_residual
if style == 'CC':
ylabel = 'Counts/s/channel'
if not rebin:
obs_y = cls.data_ctsrate
obs_y_e = cls.data_ctsrate_error
mo_y = cls.model_ctsrate
else:
obs_y = cls.data_re_ctsrate
obs_y_e = cls.data_re_ctsrate_error
mo_y = cls.model_re_ctsrate
elif style == 'CE':
ylabel = 'Counts/s/keV'
if not rebin:
obs_y = cls.data_ctsspec
obs_y_e = cls.data_ctsspec_error
mo_y = cls.model_ctsspec
else:
obs_y = cls.data_re_ctsspec
obs_y_e = cls.data_re_ctsspec_error
mo_y = cls.model_re_ctsspec
elif style == 'NE':
ylabel = 'Photons/cm2/s/keV'
if not rebin:
obs_y = cls.data_phtspec
obs_y_e = cls.data_phtspec_error
mo_y = cls.model_phtspec
else:
obs_y = cls.data_re_phtspec
obs_y_e = cls.data_re_phtspec_error
mo_y = cls.model_re_phtspec
elif style == 'Fv' or style == 'ENE':
ylabel = 'erg/cm2/s/keV'
if not rebin:
obs_y = cls.data_flxspec
obs_y_e = cls.data_flxspec_error
mo_y = cls.model_flxspec
else:
obs_y = cls.data_re_flxspec
obs_y_e = cls.data_re_flxspec_error
mo_y = cls.model_re_flxspec
elif style == 'vFv' or style == 'EENE':
ylabel = 'erg/cm2/s'
if not rebin:
obs_y = cls.data_ergspec
obs_y_e = cls.data_ergspec_error
mo_y = cls.model_ergspec
else:
obs_y = cls.data_re_ergspec
obs_y_e = cls.data_re_ergspec_error
mo_y = cls.model_re_ergspec
else:
raise ValueError(f'unsupported style argument: {style}')
yall = np.array(list(chain.from_iterable(obs_y)))
ymin = 0.5 * np.min(yall[yall > 0]).astype(float)
ymax = 2 * np.max(yall[yall > 0]).astype(float)
fig_data = {}
for i, name in enumerate(cls.data_names):
if ploter == 'plotly':
obs = go.Scatter(
x=obs_x[i].astype(float),
y=obs_y[i].astype(float),
mode='markers',
name=name,
showlegend=False,
error_x=dict(
type='data',
symmetric=False,
array=obs_x_he[i].astype(float),
arrayminus=obs_x_le[i].astype(float),
color=Plot.colors[i],
thickness=1.5,
width=0,
),
error_y=dict(
type='data',
array=obs_y_e[i].astype(float),
color=Plot.colors[i],
thickness=1.5,
width=0,
),
marker=dict(symbol='cross-thin', size=0, color=Plot.colors[i]),
)
mo = go.Scatter(
x=obs_x[i].astype(float),
y=mo_y[i].astype(float),
name=name,
showlegend=True,
mode='lines',
line=dict(width=2, color=Plot.colors[i]),
)
res = go.Scatter(
x=obs_x[i].astype(float),
y=res_y[i].astype(float),
name=name,
showlegend=False,
mode='markers',
marker=dict(
symbol='cross-thin',
size=10,
color=Plot.colors[i],
line=dict(width=1.5, color=Plot.colors[i]),
),
)
fig.add_trace(obs, row=1, col=1)
fig.add_trace(mo, row=1, col=1)
fig.add_trace(res, row=2, col=1)
elif ploter == 'matplotlib':
ax1.errorbar(
obs_x[i],
obs_y[i],
xerr=[obs_x_le[i], obs_x_he[i]],
yerr=obs_y_e[i],
fmt='none',
ecolor=Plot.colors[i],
elinewidth=0.8,
capsize=0,
capthick=0,
label=name,
)
ax1.plot(obs_x[i], mo_y[i], color=Plot.colors[i], lw=1.0)
ax2.scatter(
obs_x[i], res_y[i], marker='+', color=Plot.colors[i], s=40, linewidths=0.8
)
fig_data[name] = {
'obs': {
'x': obs_x[i],
'y': obs_y[i],
'x_le': obs_x_le[i],
'x_he': obs_x_he[i],
'y_e': obs_y_e[i],
},
'mo': {'x': obs_x[i], 'y': mo_y[i]},
'res': {'x': obs_x[i], 'y': res_y[i]},
}
if ploter == 'plotly':
fig.update_xaxes(title_text='', row=1, col=1, type='log')
fig.update_xaxes(title_text='Energy (keV)', row=2, col=1, type='log')
fig.update_yaxes(title_text=ylabel, row=1, col=1, type='log')
fig.update_yaxes(
title_text=ylabel, row=1, col=1, type='log', range=[np.log10(ymin), np.log10(ymax)]
)
fig.update_yaxes(title_text='Sigma', showgrid=False, range=[-3.5, 3.5], row=2, col=1)
fig.update_layout(template='plotly_white', height=700, width=600)
fig.update_layout(legend=dict(x=1, y=1, xanchor='right', yanchor='bottom'))
elif ploter == 'matplotlib':
ax1.set_xscale('log')
ax1.set_yscale('log')
ax1.set_ylabel(f'$\\rm {ylabel}$')
ax1.set_ylim([ymin, ymax])
ax1.minorticks_on()
ax1.minorticks_on()
ax1.xaxis.set_ticks_position('both')
ax1.yaxis.set_ticks_position('both')
ax1.tick_params(axis='x', which='both', direction='in', labelcolor='k', colors='k')
ax1.tick_params(axis='y', which='both', direction='in', labelcolor='k', colors='k')
ax1.tick_params(axis='x', which='both', labeltop=False, labelbottom=False)
ax1.tick_params(axis='y', which='both', labelleft=True, labelright=False)
ax1.legend()
ax2.axhline(0, c='grey', lw=1, ls='--')
ax2.set_xlabel('Energy (keV)')
ax2.set_ylabel('Sigma')
ax2.set_ylim([-3.49, 3.49])
ax2.minorticks_on()
ax2.minorticks_on()
ax2.xaxis.set_ticks_position('both')
ax2.yaxis.set_ticks_position('both')
ax2.tick_params(axis='x', which='both', direction='in', labelcolor='k', colors='k')
ax2.tick_params(axis='y', which='both', direction='in', labelcolor='k', colors='k')
ax2.tick_params(axis='x', which='both', labeltop=False, labelbottom=True)
ax2.tick_params(axis='y', which='both', labelleft=True, labelright=False)
return Figure(fig, fig_data, ploter)
[docs]
@staticmethod
def post_corner(cls, ploter='plotly', at_par=None):
"""Corner plot of a ``Posterior`` or ``Bootstrap`` sample.
Args:
cls: ``Posterior`` or ``Bootstrap`` whose parameter samples are
visualized.
ploter: Backend -- ``'plotly'``, ``'getdist'``, or ``'cornerpy'``.
at_par: Reference point overlaid on the plot -- ``'best'``,
``'best-ci'``, ``'median'``, ``'mean'``, or ``'truth'``.
Defaults to ``'best'`` for ``Posterior`` and ``'truth'``
for ``Bootstrap``.
Returns:
A :class:`Figure` wrapping the plot.
Raises:
TypeError: If ``cls`` is not a ``Posterior`` or ``Bootstrap``.
ValueError: If ``at_par`` is not recognized, if
``at_par='truth'`` but some parameters lack a truth value,
or if ``ploter`` is not one of the supported backends.
"""
if not isinstance(cls, (Posterior, Bootstrap)):
raise TypeError('cls is not Posterior or Bootstrap type, cannot call corner method')
data = cls.param_sample
weights = np.ones(data.shape[0], dtype=float) / data.shape[0]
# A non-converged, boundary-pinned posterior can collapse to a near-delta
# cloud with fewer (distinct) samples than parameters, which corner/getdist
# cannot plot (they assert n_samples >= n_dims). Skip with a placeholder so
# a batch loop survives the bad fit instead of crashing on the plot.
nsample, ndim = data.shape
if nsample <= ndim or np.ptp(data, axis=0).max() == 0:
warnings.warn(
f'Posterior too degenerate to corner-plot ({nsample} samples for '
f'{ndim} parameters); the run likely did not converge. Returning a '
f'placeholder figure.',
stacklevel=2,
)
fig = plt.figure(figsize=(4, 4))
fig.text(0.5, 0.5, 'degenerate posterior\nno corner plot', ha='center', va='center')
# Tag as matplotlib (not the requested backend) so Figure.save uses
# fig.savefig: the placeholder is a plain matplotlib figure and has no
# plotly/getdist export method.
return Figure(fig, None, 'matplotlib')
title_fmt = '$%.2f_{-%.2f}^{+%.2f}~(%.2f)$'
plabels = cls.free_indexed_plabels
if at_par is None:
if isinstance(cls, Posterior):
at_par = 'best'
if isinstance(cls, Bootstrap):
at_par = 'truth'
if at_par == 'best':
truth = cls.par_best
elif at_par == 'best-ci':
truth = cls.par_best_ci
elif at_par == 'median':
truth = cls.par_median
elif at_par == 'mean':
truth = cls.par_mean
elif at_par == 'truth':
if None in cls.par_truth:
raise ValueError('no truth value for some parameters')
else:
truth = cls.par_truth
else:
raise ValueError(f'unsupported at_par argument: {at_par}')
median = cls.par_median
error = cls.par_error(median)
if ploter == 'plotly':
levels = 1.0 - np.exp(-0.5 * np.array([1, 2]) ** 2)
fig = corner_plotly(
data, bins=30, weights=weights, smooth1d=2, smooth=2, labels=plabels, levels=levels
)
for i in range(cls.free_nparams):
fig.add_trace(
go.Scatter(
x=[median[i]],
y=[0.01],
mode='markers',
name=plabels[i],
showlegend=False,
error_x=dict(
type='data',
symmetric=False,
array=[error[i][1]],
arrayminus=[error[i][0]],
color='#FF0092',
thickness=1,
width=0,
),
marker=dict(symbol='circle', size=5, color='#FF0092'),
),
row=i + 1,
col=i + 1,
)
for yi in range(cls.free_nparams):
for xi in range(yi):
fig.add_vline(
truth[xi], line_width=1, line_color='#FF0092', row=yi + 1, col=xi + 1
)
fig.add_hline(
truth[yi], line_width=1, line_color='#FF0092', row=yi + 1, col=xi + 1
)
fig.add_trace(
go.Scatter(
x=[truth[xi]],
y=[truth[yi]],
mode='markers',
name=f'{plabels[xi]}&{plabels[yi]}',
showlegend=False,
marker=dict(symbol='square', size=5, color='#FF0092'),
),
row=yi + 1,
col=xi + 1,
)
elif ploter == 'getdist':
fig = plots.get_subplot_plotter()
fig.settings.num_plot_contours = 2
fig.settings.num_shades = 30
fig.settings.title_limit_fontsize = 10
sampler_type = getattr(cls, 'sampler_type', 'mcmc')
mcsample = MCSamples(samples=data, names=plabels, sampler=sampler_type)
mcsample.updateSettings({'contours': [0.6827, 0.9545, 0.9973]})
fig.triangle_plot(mcsample, plabels, shaded=True)
for i in range(cls.free_nparams):
ax = fig.subplots[i, i]
ax.set_title(
title_fmt % (median[i], error[i][0], error[i][1], truth[i]),
math_fontfamily='stix',
)
ax.errorbar(
median[i],
0.05,
xerr=[[error[i][0]], [error[i][1]]],
fmt='or',
ms=2,
ecolor='r',
elinewidth=0.7,
)
ax.tick_params(axis='both', which='both', zorder=10)
for yi in range(cls.free_nparams):
for xi in range(yi):
ax = fig.subplots[yi, xi]
ax.axvline(truth[xi], color='r', lw=0.7, ls='-')
ax.axhline(truth[yi], color='r', lw=0.7, ls='-')
ax.scatter(
truth[xi], truth[yi], marker='s', color='r', s=10, linewidths=0, zorder=10
)
ax.tick_params(axis='both', which='both', zorder=10)
elif ploter == 'cornerpy':
levels = 1.0 - np.exp(-0.5 * np.array([1, 2]) ** 2)
rcParams['font.family'] = 'sans-serif'
rcParams['font.size'] = 12
rcParams['pdf.fonttype'] = 42
fig = corner.corner(
data,
bins=30,
color='#08519c',
weights=weights,
labels=plabels,
show_titles=True,
use_math_text=True,
smooth1d=2,
smooth=2,
levels=levels,
plot_datapoints=True,
plot_density=True,
plot_contours=True,
fill_contours=False,
no_fill_contours=False,
)
axes = np.array(fig.axes).reshape((cls.free_nparams, cls.free_nparams))
for i in range(cls.free_nparams):
ax = axes[i, i]
ax.set_title(
title_fmt % (median[i], error[i][0], error[i][1], truth[i]),
math_fontfamily='stix',
)
ax.errorbar(
median[i],
0.005,
xerr=[[error[i][0]], [error[i][1]]],
fmt='or',
ms=2,
ecolor='r',
elinewidth=1,
)
for yi in range(cls.free_nparams):
for xi in range(yi):
ax = axes[yi, xi]
ax.axvline(truth[xi], color='r', lw=1, ls='-')
ax.axhline(truth[yi], color='r', lw=1, ls='-')
ax.scatter(truth[xi], truth[yi], marker='s', color='r', s=20, linewidths=0)
else:
raise ValueError(f'unsupported ploter: {ploter}')
fig_data = None
return Figure(fig, fig_data, ploter)
[docs]
class ModelPlot:
"""Accumulating figure that overlays multiple ``Model`` spectra.
Build one via :meth:`Plot.model`, then call :meth:`add_model` for each
model to compare, and finish with :meth:`get_fig`. The y-axis label
and which model spectra are valid depend on ``style``.
Attributes:
ploter: Active backend (``'plotly'`` or ``'matplotlib'``).
style: Spectrum style (see :meth:`Plot.model`).
post: Whether posterior credible bands are drawn alongside lines.
yrange: Optional y-axis range.
fig: Underlying figure object from the chosen backend.
fig_data: Raw plotted arrays keyed by model expression.
model_index: Running count of models added; used for color cycling.
"""
colors = (
px.colors.qualitative.Plotly
+ px.colors.qualitative.D3
+ px.colors.qualitative.G10
+ px.colors.qualitative.T10
+ px.colors.qualitative.Alphabet
)
def __init__(self, ploter='plotly', style='NE', post=False, yrange=None):
"""Initialize the figure for the requested backend and style.
Args:
ploter: ``'plotly'`` or ``'matplotlib'``.
style: Spectrum style -- ``'NE'``/``'Fv'``/``'ENE'``/``'vFv'``/
``'EENE'`` for additive models, ``'NoU'`` for
multiplicative/mathematical models.
post: If ``True``, draw posterior credible bands.
yrange: Optional ``(ymin, ymax)`` tuple.
Raises:
ValueError: If ``style`` is not recognized.
"""
self.ploter = ploter
self.style = style
self.post = post
self.yrange = yrange
if self.style == 'NE':
ylabel = 'Photons/cm2/s/keV'
elif self.style == 'Fv' or self.style == 'ENE':
ylabel = 'erg/cm2/s/keV'
elif self.style == 'vFv' or self.style == 'EENE':
ylabel = 'erg/cm2/s'
elif self.style == 'NoU':
ylabel = 'Dimensionless'
else:
raise ValueError(f'unsupported style argument: {self.style}')
if self.ploter == 'plotly':
self.fig = go.Figure()
self.fig.update_xaxes(title_text='Energy (keV)', type='log')
self.fig.update_yaxes(title_text=ylabel, type='log')
if yrange is not None:
self.fig.update_yaxes(range=[np.log10(yrange[0]), np.log10(yrange[1])])
self.fig.update_layout(template='plotly_white', height=600, width=600)
self.fig.update_layout(legend=dict(x=1, y=1, xanchor='right', yanchor='bottom'))
elif self.ploter == 'matplotlib':
rcParams['font.family'] = 'sans-serif'
rcParams['font.size'] = 12
rcParams['pdf.fonttype'] = 42
self.fig = plt.figure(figsize=(7, 6))
gs = self.fig.add_gridspec(1, 1, wspace=0, hspace=0)
self.ax = self.fig.add_subplot(gs[0, 0])
self.ax.set_xscale('log')
self.ax.set_yscale('log')
self.ax.set_xlabel('Energy (keV)')
self.ax.set_ylabel(ylabel)
if yrange is not None:
self.ax.set_ylim(yrange)
self.ax.minorticks_on()
self.ax.xaxis.set_ticks_position('both')
self.ax.yaxis.set_ticks_position('both')
self.ax.tick_params(axis='x', which='both', direction='in', labelcolor='k', colors='k')
self.ax.tick_params(axis='y', which='both', direction='in', labelcolor='k', colors='k')
self.ax.tick_params(axis='x', which='both', labeltop=False, labelbottom=True)
self.ax.tick_params(axis='y', which='both', labelleft=True, labelright=False)
self.fig_data = {}
self.model_index = -1
[docs]
@staticmethod
def get_rgb(color, opacity=1.0):
"""Convert a matplotlib color plus opacity into a Plotly ``rgba`` string."""
rgba = mpl.colors.to_rgba(color)
r, g, b = (int(x * 255) for x in rgba[:3])
return f'rgba({r}, {g}, {b}, {opacity:f})'
[docs]
def add_model(self, model, E, T=None, post=None, at_par=None):
"""Draw ``model`` at energies ``E`` onto the accumulating figure.
The required model spectrum method is selected from ``style`` and
``at_par``. When ``post`` is ``True`` the one-sigma credible band
is added alongside the point estimate.
Args:
model: ``Model`` instance compatible with the current ``style``.
E: Energy grid (keV) at which to evaluate the model.
T: Optional time argument forwarded to time-dependent models.
post: Overrides the instance-level ``post`` flag for this call.
at_par: Which parameter point to evaluate at -- ``'best'``,
``'best-ci'``, ``'median'``, ``'mean'``, or ``'truth'``.
Defaults to ``'best'`` when any truth value is missing,
otherwise ``'truth'``.
Raises:
TypeError: If ``model`` is not a ``Model``.
AttributeError: If the model type is incompatible with ``style``.
ValueError: If ``at_par`` or ``style`` is not recognized, or if
``at_par='truth'`` but some parameters lack a truth value.
"""
if not isinstance(model, Model):
raise TypeError('model is not Model type, cannot call add_model method')
if post is None:
post = self.post
if post and at_par is None:
at_par = 'best' if None in model.par_truth else 'truth'
self.model_index += 1
x = np.array(E).astype(float)
if self.style == 'NE':
if model.type not in ['add']:
raise AttributeError(f'{self.style} is invalid for {model.type} type model')
if post:
if at_par == 'best':
y = model.best_phtspec(E, T).astype(float)
elif at_par == 'best-ci':
y = model.best_ci_phtspec(E, T).astype(float)
elif at_par == 'median':
y = model.median_phtspec(E, T).astype(float)
elif at_par == 'mean':
y = model.mean_phtspec(E, T).astype(float)
elif at_par == 'truth':
if None in model.par_truth:
raise ValueError('no truth value for some parameters')
else:
y = model.truth_phtspec(E, T).astype(float)
else:
raise ValueError(f'unsupported at_par argument: {at_par}')
y_sample = model.phtspec_sample(E, T)
y_ci = y_sample['Isigma'].astype(float)
else:
y = model.phtspec(E, T).astype(float)
elif self.style == 'Fv' or self.style == 'ENE':
if model.type not in ['add']:
raise AttributeError(f'{self.style} is invalid for {model.type} type model')
if post:
if at_par == 'best':
y = model.best_flxspec(E, T).astype(float)
elif at_par == 'best-ci':
y = model.best_ci_flxspec(E, T).astype(float)
elif at_par == 'median':
y = model.median_flxspec(E, T).astype(float)
elif at_par == 'mean':
y = model.mean_flxspec(E, T).astype(float)
elif at_par == 'truth':
if None in model.par_truth:
raise ValueError('no truth value for some parameters')
else:
y = model.truth_flxspec(E, T).astype(float)
else:
raise ValueError(f'unsupported at_par argument: {at_par}')
y_sample = model.flxspec_sample(E, T)
y_ci = y_sample['Isigma'].astype(float)
else:
y = model.flxspec(E, T).astype(float)
elif self.style == 'vFv' or self.style == 'EENE':
if model.type not in ['add']:
raise AttributeError(f'{self.style} is invalid for {model.type} type model')
if post:
if at_par == 'best':
y = model.best_ergspec(E, T).astype(float)
elif at_par == 'best-ci':
y = model.best_ci_ergspec(E, T).astype(float)
elif at_par == 'median':
y = model.median_ergspec(E, T).astype(float)
elif at_par == 'mean':
y = model.mean_ergspec(E, T).astype(float)
elif at_par == 'truth':
if None in model.par_truth:
raise ValueError('no truth value for some parameters')
else:
y = model.truth_ergspec(E, T).astype(float)
else:
raise ValueError(f'unsupported at_par argument: {at_par}')
y_sample = model.ergspec_sample(E, T)
y_ci = y_sample['Isigma'].astype(float)
else:
y = model.ergspec(E, T).astype(float)
elif self.style == 'NoU':
if model.type not in ['mul', 'math']:
raise AttributeError(f'{self.style} is invalid for {model.type} type model')
if post:
if at_par == 'best':
y = model.best_nouspec(E).astype(float)
elif at_par == 'best-ci':
y = model.best_ci_nouspec(E).astype(float)
elif at_par == 'median':
y = model.median_nouspec(E).astype(float)
elif at_par == 'mean':
y = model.mean_nouspec(E).astype(float)
elif at_par == 'truth':
if None in model.par_truth:
raise ValueError('no truth value for some parameters')
else:
y = model.truth_nouspec(E).astype(float)
else:
raise ValueError(f'unsupported at_par argument: {at_par}')
y_sample = model.nouspec_sample(E)
y_ci = y_sample['Isigma'].astype(float)
else:
y = model.nouspec(E).astype(float)
else:
raise ValueError(f'unsupported style argument: {self.style}')
if self.ploter == 'plotly':
mo = go.Scatter(
x=x,
y=y,
mode='lines',
name=model.expr,
showlegend=True,
line=dict(width=2, color=ModelPlot.colors[self.model_index]),
)
self.fig.add_trace(mo)
if post:
low = go.Scatter(
x=x,
y=y_ci[0],
mode='lines',
name=f'{model.expr} lower',
fill=None,
line_color='rgba(0,0,0,0)',
showlegend=False,
)
self.fig.add_trace(low)
upp = go.Scatter(
x=x,
y=y_ci[1],
mode='lines',
name=f'{model.expr} CI',
fill='tonexty',
line_color='rgba(0,0,0,0)',
fillcolor=ModelPlot.get_rgb(ModelPlot.colors[self.model_index], 0.5),
showlegend=True,
)
self.fig.add_trace(upp)
elif self.ploter == 'matplotlib':
self.ax.plot(x, y, lw=1.0, color=ModelPlot.colors[self.model_index], label=model.expr)
if post:
self.ax.fill_between(
x,
y_ci[0],
y_ci[1],
fc=ModelPlot.colors[self.model_index],
alpha=0.5,
label=f'{model.expr} CI',
)
self.ax.legend()
if post:
self.fig_data[model.expr] = {'x': x, 'y': y, 'y_ci': y_ci}
else:
self.fig_data[model.expr] = {'x': x, 'y': y}
[docs]
def get_fig(self):
"""Wrap the accumulated plot in a :class:`Figure` for display or saving."""
return Figure(self.fig, self.fig_data, self.ploter)