"""Inference drivers: base aggregator, Bayesian samplers, and max-likelihood fits.
:class:`Infer` aggregates one or more ``(Data, Model)`` pairs, resolves
shared parameters through the linking machinery, and exposes the joint
prior, log-likelihood, and fit-statistic entry points that the samplers
consume. :class:`BayesInfer` adds driver methods for MultiNest and
emcee; :class:`MaxLikeFit` adds lmfit and iminuit minimisers plus
covariance-driven bootstrap sampling.
"""
from collections import OrderedDict
from collections.abc import Callable
import ctypes
import json
import os
import warnings
import numpy as np
from ..data.data import Data
from ..model.model import Model
from ..util.info import Info
from ..util.tools import JsonEncoder, SuperDict, json_dump
from .pair import Pair
[docs]
class Infer:
"""Aggregate of one or more ``(Data, Model)`` pairs with shared parameters.
Collects per-pair parameters into a single flat index, identifies
which of them are free (after honouring ``Par.link``/``unlink``
relations and the frozen flag), and publishes the pooled
log-likelihood, log-prior, fit statistic, and residuals.
**Group docstring for aggregator properties.** List-valued
``data_*`` properties flatten the same-named ``Data`` property over
every bound ``Data``; ``model_*`` properties flatten the matching
``Model`` property over every bound ``Model``. The family covers the
following patterns (the ``_re_`` infix uses re-binned channels):
- ``data_chbin_mean``/``_re_``, ``data_chbin_width``/``_re_``
(sourced from ``rsp_chbin_*``);
- ``data_{ctsrate,ctsspec,phtspec,flxspec,ergspec}`` plus their
``_error``/``_re_`` variants (net counts and deconvolved spectra);
- ``model_{ctsrate,ctsspec,phtspec,flxspec,ergspec}`` plus their
``_re_`` variants (convolved model spectra sampled at the data
channels).
Attributes:
pairs: The raw list of ``(Data, Model)`` tuples.
Data/Model/Pair: Per-pair unpacked containers.
inference_type: Display label shown in :meth:`__str__`.
loglike_func/logprior_func/prior_transform_func: Optional user
overrides for the corresponding computations; ``None`` means
use the built-in implementation.
"""
def __init__(self, pairs=None):
"""Build an inference container with the given ``(Data, Model)`` pairs.
Args:
pairs: ``None`` or a list of ``(Data, Model)`` / ``(Model, Data)``
tuples. The two orderings are both accepted.
"""
self.pairs = pairs
self.loglike_func = None
self.logprior_func = None
self.prior_transform_func = None
self.inference_type = 'Inference'
@property
def pairs(self):
return self._pairs
@pairs.setter
def pairs(self, new_pairs):
"""Replace all pairs from a list of tuples, then :meth:`_extract`.
Raises:
ValueError: If ``new_pairs`` is not ``None`` or a list.
"""
self._pairs = list()
if new_pairs is None:
pass
elif isinstance(new_pairs, list):
for pair in new_pairs:
if isinstance(pair, (tuple, list)):
self._addpair(*pair)
self._extract()
else:
raise ValueError('unsupported pair type')
def _addpair(self, *pair):
p1, p2 = pair
if isinstance(p1, Data):
data = p1
if isinstance(p2, Model):
model = p2
else:
raise ValueError('p1 is Data type, then p2 should be Model type')
elif isinstance(p1, Model):
model = p1
if isinstance(p2, Data):
data = p2
else:
raise ValueError('p1 is Model type, then p2 should be Data type')
else:
raise ValueError('unsupported pair type')
self._pairs.append((data, model))
[docs]
def append(self, *pair):
"""Append a ``(Data, Model)`` or ``(Model, Data)`` pair and re-extract.
Args:
*pair: Two positional arguments, one ``Data`` and one ``Model``
in either order.
"""
self._addpair(*pair)
self._extract()
def _extract(self):
if self.pairs is None:
raise ValueError('pairs is None')
self._EXTRACT = object()
self.nparis = len(self.pairs)
self.Data = [pair[0] for pair in self.pairs]
self.Model = [pair[1] for pair in self.pairs]
self.Pair = [Pair(*pair) for pair in self.pairs]
self.data_names = [key for data in self.Data for key in data.names]
self.model_exprs = [model.expr for model in self.Model]
self._you_free()
@property
def cdicts(self):
"""Mapping from every model expression to its ``cdicts`` dictionary."""
return OrderedDict([(md.expr, md.cdicts) for md in (self.Model + self.Data)])
@property
def pdicts(self):
"""Mapping from every model/data expression to its ``pdicts`` dictionary."""
return OrderedDict([(md.expr, md.pdicts) for md in (self.Model + self.Data)])
@property
def cfg(self):
"""Flat :class:`SuperDict` of every model config parameter."""
cid = 0
cfg = SuperDict()
for md in self.Model + self.Data:
for config in md.cdicts.values():
for cg in config.values():
cid += 1
cfg[str(cid)] = cg
return cfg
@property
def par(self):
"""Flat :class:`SuperDict` of every model+data parameter (free or frozen)."""
pid = 0
par = SuperDict()
for md in self.Model + self.Data:
for params in md.pdicts.values():
for pr in params.values():
pid += 1
par[str(pid)] = pr
return par
@property
def pvalues(self):
"""Tuple of current parameter values, preserving ``par`` order."""
return tuple([pr.value for pr in self.par.values()])
[docs]
@staticmethod
def foo(id):
"""Recover the Python object at address ``id`` via ``ctypes`` reflection."""
return ctypes.cast(id, ctypes.py_object).value
@property
def idpid(self):
"""Map each ``id(Par)`` to the set of ``par#`` indices it occupies.
Used to detect which parameter slots share the same underlying
:class:`Par` instance (linked parameters).
"""
pid = 0
idpid = SuperDict()
for md in self.Model + self.Data:
for params in md.pdicts.values():
for pr in params.values():
pid += 1
if str(id(pr)) not in idpid:
idpid[str(id(pr))] = {str(pid)}
else:
idpid[str(id(pr))].add(str(pid))
return idpid
@property
def all_config(self):
"""List of per-config rows tagged with component and class (``model``/``data``)."""
cid = 0
all_config = list()
for i, md in enumerate(self.Model + self.Data):
cls = 'model' if i < self.nparis else 'data'
for expr, config in md.cdicts.items():
for cl, cg in config.items():
cid += 1
all_config.append(
{
'cfg#': str(cid),
'Class': cls,
'Expression': md.expr,
'Component': expr,
'Parameter': cl,
'Value': cg.val,
}
)
return all_config
@property
def all_params(self):
"""List of per-parameter rows, with linked-parameter mates resolved."""
pid = 0
all_params = list()
for i, md in enumerate(self.Model + self.Data):
cls = 'model' if i < self.nparis else 'data'
for expr, params in md.pdicts.items():
for pl, pr in params.items():
pid += 1
self_id = self.idpid[str(id(pr))]
mate_id = [self.idpid[str(id(mate))] for mate in pr.mates]
mates = self_id.union(*mate_id)
mates.remove(str(pid))
all_params.append(
{
'par#': str(pid),
'Class': cls,
'Expression': md.expr,
'Component': expr,
'Parameter': pl,
'Value': pr.val,
'Prior': f'{pr.prior_info}',
'Frozen': pr.frozen,
'Mates': mates,
'Posterior': f'{pr.post_info}',
}
)
return all_params
def _you_free(self):
"""Rebuild the ``free_*`` caches by walking ``all_params`` once.
A parameter is free if it is not ``frozen`` and no earlier slot
that shares its underlying :class:`Par` has already claimed the
free-parameter slot.
"""
unfree_par = set()
self._free_par = SuperDict()
self._free_params = list()
for param in self.all_params:
pid = param['par#']
if param['Frozen']:
unfree_par.update(param['Mates'])
else:
if pid not in unfree_par:
self._free_par[pid] = self.par[pid]
self._free_params.append(param)
unfree_par.update(param['Mates'])
else:
unfree_par.update(param['Mates'])
self._free_plabels = [param['Parameter'] for param in self._free_params]
self._free_pvalues = [param['Value'] for param in self._free_params]
self._free_pranges = [par.range for par in self._free_par.values()]
self._free_nparams = len(self._free_plabels)
[docs]
def link(self, pids):
"""Link every :class:`Par` in ``pids`` so they share value/prior/posterior.
Args:
pids: Iterable of ``par#`` indices (as strings or ints).
"""
for i, ip in enumerate(pids):
for j, jp in enumerate(pids):
if j > i and id(self.par[ip]) != id(self.par[jp]):
self.par[ip].link(self.par[jp])
self._you_free()
[docs]
def unlink(self, pids):
"""Undo any linking between every pair drawn from ``pids``.
Args:
pids: Iterable of ``par#`` indices.
"""
for i, ip in enumerate(pids):
for j, jp in enumerate(pids):
if j > i and id(self.par[ip]) != id(self.par[jp]):
self.par[ip].unlink(self.par[jp])
self._you_free()
@property
def free_par(self):
""":class:`SuperDict` of the free :class:`Par` instances keyed by ``par#``."""
return self._free_par
@property
def free_params(self):
"""Rows from :attr:`all_params` restricted to the free parameters."""
return self._free_params
@property
def free_plabels(self):
"""LaTeX-decorated labels of the free parameters, in canonical order."""
return self._free_plabels
@property
def clean_free_plabels(self):
""":attr:`free_plabels` with LaTeX braces, dollars, and backslashes stripped."""
return [
pl.replace('$', '').replace('{', '').replace('}', '').replace('\\', '')
for pl in self._free_plabels
]
@property
def free_indexed_plabels(self):
"""Free-parameter labels prefixed with their ``par#`` index."""
return [
f'p{key}({label})'
for label, key in zip(self.free_plabels, self.free_par.keys(), strict=False)
]
@property
def clean_free_indexed_plabels(self):
"""Indexed labels with LaTeX markup removed."""
return [
f'p{key}({label})'
for label, key in zip(self.clean_free_plabels, self.free_par.keys(), strict=False)
]
@property
def free_pvalues(self):
"""Current values of every free parameter."""
return self._free_pvalues
@property
def free_pranges(self):
"""Per-parameter ``(low, high)`` plausible ranges used by minimisers."""
return self._free_pranges
@property
def free_nparams(self):
"""Number of free parameters."""
return self._free_nparams
@property
def cfg_info(self):
"""Tabular :class:`Info` view of every configuration parameter."""
all_config = self.all_config.copy()
return Info.from_list_dict(all_config)
@property
def par_info(self):
"""Tabular parameter view tagging free slots with ``*`` and resolving linked aliases."""
self._you_free()
all_params = self.all_params.copy()
for par in all_params:
if par['par#'] in self.free_par:
par['par#'] = par['par#'] + '*'
else:
if par['Frozen']:
par['Prior'] = 'frozen'
else:
par['Prior'] = '=par#{{{}}}'.format(','.join(par['Mates']))
all_params = Info.list_dict_to_dict(all_params)
del all_params['Posterior']
del all_params['Mates']
del all_params['Frozen']
return Info.from_dict(all_params)
@property
def notable_par_info(self):
"""Parameter view like :attr:`par_info` but hides frozen data-level rows."""
self._you_free()
all_params = self.all_params.copy()
notable_params = list()
for par in all_params:
if par['par#'] in self.free_par:
par['par#'] = par['par#'] + '*'
else:
if par['Frozen']:
par['Prior'] = 'frozen'
if par['Class'] == 'data':
continue
else:
par['Prior'] = '=par#{{{}}}'.format(','.join(par['Mates']))
notable_params.append(par)
notable_params = Info.list_dict_to_dict(notable_params)
del notable_params['Posterior']
del notable_params['Mates']
del notable_params['Frozen']
return Info.from_dict(notable_params)
@property
def free_par_info(self):
"""Tabular :class:`Info` view restricted to the free parameters."""
self._you_free()
free_params = self.free_params.copy()
free_params = Info.list_dict_to_dict(free_params)
del free_params['Posterior']
del free_params['Mates']
del free_params['Frozen']
return Info.from_dict(free_params)
[docs]
def save(self, savepath):
"""Dump config and parameter tables under ``savepath``.
Args:
savepath: Directory path. Created if missing.
"""
if not os.path.exists(savepath):
os.makedirs(savepath)
json_dump(self.cfg_info.data_list_dict, savepath + '/infer_cfg.json')
json_dump(self.par_info.data_list_dict, savepath + '/infer_par.json')
@property
def data_chbin_mean(self):
"""Concatenated per-channel midpoints from every bound ``Data``.
Every ``data_*`` and ``model_*`` property on this class flattens
the same-named list from the underlying ``Data``/``Model`` across
every pair, so downstream code can treat the whole inference as a
single long series.
"""
return [value for data in self.Data for value in data.rsp_chbin_mean]
@property
def data_re_chbin_mean(self):
return [value for data in self.Data for value in data.rsp_re_chbin_mean]
@property
def data_chbin_width(self):
return [value for data in self.Data for value in data.rsp_chbin_width]
@property
def data_re_chbin_width(self):
return [value for data in self.Data for value in data.rsp_re_chbin_width]
@property
def data_ctsrate(self):
return [value for data in self.Data for value in data.net_ctsrate]
@property
def data_re_ctsrate(self):
return [value for data in self.Data for value in data.net_re_ctsrate]
@property
def data_ctsrate_error(self):
return [value for data in self.Data for value in data.net_ctsrate_error]
@property
def data_re_ctsrate_error(self):
return [value for data in self.Data for value in data.net_re_ctsrate_error]
@property
def data_ctsspec(self):
return [value for data in self.Data for value in data.net_ctsspec]
@property
def data_re_ctsspec(self):
return [value for data in self.Data for value in data.net_re_ctsspec]
@property
def data_ctsspec_error(self):
return [value for data in self.Data for value in data.net_ctsspec_error]
@property
def data_re_ctsspec_error(self):
return [value for data in self.Data for value in data.net_re_ctsspec_error]
@property
def data_phtspec(self):
return [value for data in self.Data for value in data.deconv_phtspec]
@property
def data_re_phtspec(self):
return [value for data in self.Data for value in data.deconv_re_phtspec]
@property
def data_phtspec_error(self):
return [value for data in self.Data for value in data.deconv_phtspec_error]
@property
def data_re_phtspec_error(self):
return [value for data in self.Data for value in data.deconv_re_phtspec_error]
@property
def data_flxspec(self):
return [value for data in self.Data for value in data.deconv_flxspec]
@property
def data_re_flxspec(self):
return [value for data in self.Data for value in data.deconv_re_flxspec]
@property
def data_flxspec_error(self):
return [value for data in self.Data for value in data.deconv_flxspec_error]
@property
def data_re_flxspec_error(self):
return [value for data in self.Data for value in data.deconv_re_flxspec_error]
@property
def data_ergspec(self):
return [value for data in self.Data for value in data.deconv_ergspec]
@property
def data_re_ergspec(self):
return [value for data in self.Data for value in data.deconv_re_ergspec]
@property
def data_ergspec_error(self):
return [value for data in self.Data for value in data.deconv_ergspec_error]
@property
def data_re_ergspec_error(self):
return [value for data in self.Data for value in data.deconv_re_ergspec_error]
@property
def model_ctsrate(self):
return [value for model in self.Model for value in model.conv_ctsrate]
@property
def model_re_ctsrate(self):
return [value for model in self.Model for value in model.conv_re_ctsrate]
@property
def model_ctsspec(self):
return [value for model in self.Model for value in model.conv_ctsspec]
@property
def model_re_ctsspec(self):
return [value for model in self.Model for value in model.conv_re_ctsspec]
@property
def model_phtspec(self):
return [value for model in self.Model for value in model.phtspec_at_rsp]
@property
def model_re_phtspec(self):
return [value for model in self.Model for value in model.re_phtspec_at_rsp]
@property
def model_flxspec(self):
return [value for model in self.Model for value in model.flxspec_at_rsp]
@property
def model_re_flxspec(self):
return [value for model in self.Model for value in model.re_flxspec_at_rsp]
@property
def model_ergspec(self):
return [value for model in self.Model for value in model.ergspec_at_rsp]
@property
def model_re_ergspec(self):
return [value for model in self.Model for value in model.re_ergspec_at_rsp]
@property
def residual(self):
"""Per-unit sigma residuals aggregated across every pair."""
return list(
map(
lambda oi, mi, si: (oi - mi) / si,
self.data_ctsrate,
self.model_ctsrate,
self.data_ctsrate_error,
)
)
@property
def re_residual(self):
"""Per-unit sigma residuals on the re-binned grid."""
return list(
map(
lambda oi, mi, si: (oi - mi) / si,
self.data_re_ctsrate,
self.model_re_ctsrate,
self.data_re_ctsrate_error,
)
)
@property
def prior_list(self):
"""Per-parameter prior densities evaluated at the current free values."""
return [par.prior.pdf(par.val) for par in self.free_par.values()]
@property
def prior(self):
"""Joint prior density as the product of :attr:`prior_list`."""
return np.prod(self.prior_list)
@property
def logprior(self):
"""Joint log-prior; ``-inf`` when the prior vanishes."""
if self.prior == 0:
return -np.inf
else:
return np.log(self.prior)
@property
def stat_list(self):
"""Concatenated per-unit statistic across every pair."""
return np.hstack([pair.stat_list for pair in self.Pair])
@property
def pseudo_residual_list(self):
"""Concatenated per-unit pseudo-residual arrays across every pair."""
return [rd for pair in self.Pair for rd in pair.pseudo_residual_list]
@property
def weight_list(self):
"""Concatenated per-unit weights across every pair."""
return np.hstack([pair.weight_list for pair in self.Pair])
@property
def stat(self):
"""Summed fit statistic across every pair."""
return np.sum([pair.stat for pair in self.Pair])
@property
def pseudo_residual(self):
"""Concatenated weight-scaled pseudo-residual vector across every pair."""
return np.hstack([pair.pseudo_residual for pair in self.Pair])
@property
def loglike_list(self):
"""Concatenated per-unit log-likelihood across every pair."""
return np.hstack([pair.loglike_list for pair in self.Pair])
@property
def loglike(self):
"""Summed log-likelihood across every pair."""
return np.sum([pair.loglike for pair in self.Pair])
@property
def npoint_list(self):
"""Concatenated per-unit point counts across every pair."""
return np.hstack([pair.npoint_list for pair in self.Pair])
@property
def npoint(self):
"""Total number of fitted data points across every pair."""
return np.sum([pair.npoint for pair in self.Pair])
@property
def dof(self):
"""Degrees of freedom: :attr:`npoint` minus the free-parameter count."""
return self.npoint - self.free_nparams
@property
def all_stat(self):
"""Per-pair statistic summary plus a totals row, ready for ``Info.from_dict``."""
all_stat = OrderedDict()
all_stat['Data'] = ['Total']
all_stat['Model'] = ['Total']
all_stat['Statistic'] = ['stat/dof']
all_stat['Value'] = [f'{self.stat:.3f}/{self.dof:d}']
all_stat['Bins'] = [self.npoint]
for dt, mo in zip(self.Data, self.Model, strict=False):
mex = mo.expr
for sex, stat in zip(dt.names, dt.stats, strict=False):
all_stat['Data'].insert(-1, sex)
all_stat['Model'].insert(-1, mex)
all_stat['Statistic'].insert(-1, stat)
all_stat['Value'] = [stat for stat in self.stat_list] + all_stat['Value']
all_stat['Bins'] = [point for point in self.npoint_list] + all_stat['Bins']
return all_stat
@property
def stat_info(self):
"""Tabular :class:`Info` view of :attr:`all_stat`."""
all_stat = self.all_stat.copy()
return Info.from_dict(all_stat)
def __str__(self):
return (
f'*** {self.inference_type} ***\n'
f'*** Configurations ***\n'
f'{self.cfg_info.text_table}\n'
f'*** Parameters ***\n'
f'{self.notable_par_info.text_table}'
)
def __repr__(self):
return self.__str__()
def _repr_html_(self):
return (
f'{self.cfg_info.html_style}'
f'<details open>'
f'<summary style="margin-bottom: 10px;"><b>{self.inference_type}</b></summary>'
f'<details open style="margin-top: 10px;">'
f'<summary style="margin-bottom: 10px;"><b>Configurations</b></summary>'
f'{self.cfg_info.html_table}'
f'</details>'
f'<details open style="margin-top: 10px;">'
f'<summary style="margin-bottom: 10px;"><b>Parameters</b></summary>'
f'{self.notable_par_info.html_table}'
f'</details>'
f'</details>'
)
[docs]
def at_par(self, theta):
"""Write free-parameter values from the 1-indexed sequence ``theta``."""
for i, thi in enumerate(theta):
self.free_par[i + 1].val = thi
@property
def prior_transform_func(self):
"""Optional user override for the unit-cube to prior transform."""
return self._prior_transform_func
@prior_transform_func.setter
def prior_transform_func(self, new_prior_transform_func):
"""Install a user-provided prior-transform callable or clear it with ``None``.
Raises:
ValueError: If the argument is neither callable nor ``None``.
"""
if isinstance(new_prior_transform_func, (Callable, type(None))):
self._prior_transform_func = new_prior_transform_func
else:
raise ValueError('prior_transform_func is expected to be Callable or None')
@property
def logprior_func(self):
"""Optional user override for the log-prior computation."""
return self._logprior_func
@logprior_func.setter
def logprior_func(self, new_logprior_func):
"""Install a user-provided log-prior callable or clear it with ``None``."""
if isinstance(new_logprior_func, (Callable, type(None))):
self._logprior_func = new_logprior_func
else:
raise ValueError('logprior_func is expected to be Callable or None')
@property
def loglike_func(self):
"""Optional user override for the log-likelihood computation."""
return self._loglike_func
@loglike_func.setter
def loglike_func(self, new_loglike_func):
"""Install a user-provided log-likelihood callable or clear it with ``None``."""
if isinstance(new_loglike_func, (Callable, type(None))):
self._loglike_func = new_loglike_func
else:
raise ValueError('loglike_func is expected to be Callable or None')
[docs]
def calc_logprior(self, theta):
"""Apply ``theta`` and return the log-prior (or the user override)."""
self.at_par(theta)
if self.logprior_func is None:
return self.logprior
else:
return self.logprior_func(self, theta)
[docs]
def calc_stat(self, theta):
"""Apply ``theta`` and return the summed fit statistic."""
self.at_par(theta)
return self.stat
[docs]
def calc_pseudo_residual(self, theta):
"""Apply ``theta`` and return the concatenated pseudo-residual vector."""
self.at_par(theta)
return self.pseudo_residual
[docs]
def calc_loglike(self, theta):
"""Apply ``theta`` and return the log-likelihood (or the user override)."""
self.at_par(theta)
if self.loglike_func is None:
return self.loglike
else:
return self.loglike_func(self, theta)
[docs]
def calc_logprob(self, theta):
"""Return the unnormalised log-posterior; ``-inf`` outside the prior support."""
lp = self.calc_logprior(theta)
if not np.isfinite(lp):
return -np.inf
return lp + self.calc_loglike(theta)
[docs]
def calc_logprior_sample(self, theta_sample):
"""Vectorized log-prior over a sample matrix; returns ``-inf`` where it vanishes.
Args:
theta_sample: ``(nsample, nparams)`` array of draws.
Returns:
``(nsample,)`` array of log-prior values.
"""
prior_list_sample = np.zeros_like(theta_sample, dtype=float)
for i in range(theta_sample.shape[1]):
prior_list_sample[:, i] = self.free_par[i + 1].prior.pdf(theta_sample[:, i])
prior_sample = np.prod(prior_list_sample, axis=1)
return np.where(prior_sample == 0, -np.inf, np.log(prior_sample))
[docs]
class BayesInfer(Infer):
""":class:`Infer` extension that wires up MultiNest and emcee drivers.
Adds ``multinest`` and ``emcee`` methods that run the samplers,
persist chains to ``savepath``, and return a :class:`Posterior` for
downstream analysis.
"""
def __init__(self, pairs=None):
"""Initialise like :class:`Infer` and tag the inference type."""
super().__init__(pairs=pairs)
self.inference_type = 'Bayesian Inference'
[docs]
def multinest_calc_loglike(self, theta):
"""MultiNest-facing wrapper around :meth:`calc_loglike`."""
return self.calc_loglike(theta)
[docs]
def multinest_safe_calc_loglike(self, cube, ndim, nparams, lnew):
"""MultiNest C-callback log-likelihood; returns ``-2e100`` when non-finite."""
try:
cube_arr = np.array([cube[i] for i in range(ndim)])
ll = float(self.multinest_calc_loglike(cube_arr))
if not np.isfinite(ll):
return -2e100
return ll
except Exception as e:
import sys
sys.stderr.write(f'ERROR in loglikelihood: {e}\n')
sys.exit(1)
@staticmethod
def _read_ns_global_evidence(stats_file):
"""Parse the plain NS global log-evidence value from a MultiNest stats file.
Args:
stats_file: Path to MultiNest's ``stats.dat``.
Returns:
The nested-sampling global log-evidence as a float, or ``None`` if the
line cannot be read. Used as a fallback when INS writes an unparseable
subnormal evidence that breaks ``get_stats``.
"""
try:
with open(stats_file) as f:
line = f.readline()
value = line.split(':', 1)[1].split('+/-')[0]
return float(value)
except (OSError, IndexError, ValueError):
return None
[docs]
def multinest(
self,
nlive=500,
resume=True,
verbose=False,
max_iter=None,
ins=True,
savepath='./',
random_seed=None,
):
"""Run MultiNest and return a :class:`Posterior` wrapping the result.
Args:
nlive: Number of live points.
resume: Resume from any chain already present under ``savepath``.
verbose: Forward MultiNest's verbose flag.
max_iter: Hard cap on nested-sampling iterations, a backstop
against runs that never reach the evidence tolerance (e.g.
likelihood plateaus). ``None`` (default) caps at
``nlive * 50``, well above the ``~nlive * H`` a normal run
needs, so it never clips genuine convergence. A run that hits
this cap has not converged; its result is unreliable.
ins: Enable INS for a more accurate evidence. ``True``
(default) suits most fits, but INS weights can underflow
when the posterior rails against a hard prior boundary,
yielding a NaN/subnormal evidence that corrupts
``stats.dat``. Set ``False`` for such boundary-pinned fits
to fall back to the robust plain nested-sampling evidence.
savepath: Directory for MultiNest outputs and cached samples.
random_seed: Seed forwarded to MultiNest for reproducible runs.
``None`` (default) lets MultiNest pick a system-time seed,
so different runs differ.
Returns:
A :class:`~bayspec.infer.analyzer.Posterior`.
"""
import pymultinest
from .analyzer import Posterior
self.sampler_type = 'nested'
self._you_free()
max_iter = nlive * 50 if max_iter is None else int(max_iter)
savepath_prefix = savepath + '/1-'
if not os.path.exists(savepath):
os.makedirs(savepath)
pymultinest.run(
LogLikelihood=self.multinest_safe_calc_loglike,
Prior=self.multinest_safe_prior_transform,
n_dims=self.free_nparams,
resume=resume,
verbose=verbose,
n_live_points=nlive,
outputfiles_basename=savepath_prefix,
sampling_efficiency=0.3,
importance_nested_sampling=ins,
multimodal=False,
max_iter=max_iter,
seed=-1 if random_seed is None else int(random_seed),
)
capped = False
if os.path.exists(savepath_prefix + 'ev.dat'):
with open(savepath_prefix + 'ev.dat') as f:
niter = sum(1 for _ in f)
if niter >= max_iter:
capped = True
msg = (
f'MultiNest stopped at the max_iter cap ({max_iter} iterations) '
'without reaching the evidence tolerance: the posterior and '
'evidence are unreliable. Check for likelihood plateaus or overly '
'wide priors, or rerun with a larger max_iter.'
)
warnings.warn(msg, stacklevel=2)
with open(savepath_prefix + 'max_iter_warning.txt', 'w') as f:
f.write(msg + '\n')
multinest_analyzer = pymultinest.Analyzer(
outputfiles_basename=savepath_prefix, n_params=self.free_nparams
)
try:
posterior_stats = multinest_analyzer.get_stats()
except ValueError as e:
# get_stats() chokes when INS writes a subnormal evidence without an
# exponent marker (e.g. ``0.19...-322``) into stats.dat: a sign the INS
# evidence is garbage from a posterior pinned to a prior boundary. The
# plain NS evidence is on a separate, well-formed line, so recover that
# and carry on rather than failing the whole run.
posterior_stats = None
self.logevidence = self._read_ns_global_evidence(savepath_prefix + 'stats.dat')
if self.logevidence is None:
raise RuntimeError(
f'pymultinest could not parse the MultiNest stats file ({e}), and '
f'the plain nested-sampling evidence was also unreadable. This '
f'usually means the run did not converge and the posterior '
f'collapsed onto a prior boundary. '
+ (f'The max_iter cap ({max_iter}) was also hit. ' if capped else '')
+ 'Inspect the model and priors for this dataset.'
) from e
warnings.warn(
'INS evidence in stats.dat was unparseable (likely a boundary-pinned, '
'non-converged posterior); falling back to the plain nested-sampling '
'log-evidence.',
stacklevel=2,
)
else:
ins_logevidence = posterior_stats.get('nested importance sampling global log-evidence')
if ins and ins_logevidence is not None and np.isfinite(ins_logevidence):
self.logevidence = ins_logevidence
else:
self.logevidence = posterior_stats['nested sampling global log-evidence']
if (not resume) or (not os.path.exists(savepath_prefix + 'posterior_sample.txt')):
self.posterior_sample = multinest_analyzer.get_equal_weighted_posterior()
self.posterior_sample[:, -1] = self.posterior_sample[:, -1] + self.calc_logprior_sample(
self.posterior_sample[:, 0:-1]
)
np.savetxt(savepath_prefix + 'posterior_sample.txt', self.posterior_sample)
else:
self.posterior_sample = np.loadtxt(savepath_prefix + 'posterior_sample.txt')
with open(savepath_prefix + 'nlive.json', 'w') as f:
json.dump(nlive, f, indent=4, cls=JsonEncoder)
if posterior_stats is not None:
with open(savepath_prefix + 'posterior_stats.json', 'w') as f:
json.dump(
posterior_stats,
f,
indent=4,
cls=JsonEncoder,
)
return Posterior(self)
[docs]
def emcee_calc_logprob(self, theta):
"""emcee-facing wrapper around :meth:`calc_logprob`."""
return self.calc_logprob(theta)
[docs]
def emcee(self, nstep=1000, discard=100, resume=True, savepath='./', random_seed=None):
"""Run emcee and return a :class:`Posterior` wrapping the flattened chain.
Args:
nstep: Number of MCMC steps per walker.
discard: Burn-in steps discarded before flattening.
resume: Reuse an existing chain cached under ``savepath``.
savepath: Directory for chain outputs.
random_seed: Seed for reproducible runs. ``None`` (default)
lets emcee draw fresh entropy, so different runs differ.
Returns:
A :class:`~bayspec.infer.analyzer.Posterior`.
"""
import emcee
from .analyzer import Posterior
self.sampler_type = 'mcmc'
self._you_free()
savepath_prefix = savepath + '/1-'
if not os.path.exists(savepath):
os.makedirs(savepath)
rng = np.random.default_rng(random_seed)
ndim = self.free_nparams
nwalkers = 32 if 2 * ndim < 32 else 2 * ndim
pos = self.free_pvalues + 1e-4 * rng.standard_normal((nwalkers, ndim))
if (not resume) or (not os.path.exists(savepath_prefix + '.npz')):
# emcee proposals draw from numpy's global RNG; seed it only on opt-in.
if random_seed is not None:
np.random.seed(int(random_seed))
emcee_sampler = emcee.EnsembleSampler(nwalkers, ndim, self.emcee_calc_logprob)
emcee_sampler.run_mcmc(pos, nstep, progress=True)
params_sample = emcee_sampler.get_chain()
np.savez(savepath_prefix + '.npz', sample=params_sample)
logprob_sample = emcee_sampler.get_log_prob()
np.savetxt(savepath_prefix + 'logprob.dat', logprob_sample)
try:
autocorr_time = emcee_sampler.get_autocorr_time()
with open(savepath_prefix + 'autocorr_time.json', 'w') as f:
json.dump(
autocorr_time,
f,
indent=4,
cls=JsonEncoder,
)
except Exception:
pass
params_sample = np.load(savepath_prefix + '.npz')['sample']
logprob_sample = np.loadtxt(savepath_prefix + 'logprob.dat')
flat_params_sample = params_sample[discard:, :, :].reshape(-1, ndim)
flat_logprob_sample = logprob_sample[discard:, :].reshape(-1)
self.posterior_sample = np.hstack(
(flat_params_sample, np.reshape(flat_logprob_sample, (-1, 1)))
)
np.savetxt(savepath_prefix + 'posterior_sample.txt', self.posterior_sample)
with open(savepath_prefix + 'nstep.json', 'w') as f:
json.dump(nstep, f, indent=4, cls=JsonEncoder)
with open(savepath_prefix + 'discard.json', 'w') as f:
json.dump(discard, f, indent=4, cls=JsonEncoder)
return Posterior(self)
[docs]
class MaxLikeFit(Infer):
""":class:`Infer` extension for maximum-likelihood fits with bootstrap sampling.
Provides :meth:`lmfit` and :meth:`iminuit` drivers. Both run the
minimiser, cache the best fit, build a covariance-driven bootstrap
sample respecting the free-parameter ranges, and return a
:class:`Bootstrap` for downstream analysis.
"""
def __init__(self, pairs=None):
"""Initialise like :class:`Infer` and tag the inference type."""
super().__init__(pairs=pairs)
self.inference_type = 'Maximum Likelihood Estimation'
def _make_bootstrap_sample(
self, values, covar=None, errors=None, nsample=1000, random_seed=450001
):
"""Draw a covariance-respecting bootstrap sample and score each draw.
Falls back to a diagonal covariance built from ``errors`` when
``covar`` is missing or non-finite. Draws are rejected if they
fall outside any free parameter's range.
Args:
values: Best-fit free-parameter vector.
covar: Optional parameter covariance matrix.
errors: Optional per-parameter uncertainties, used for the
fallback diagonal covariance.
nsample: Target number of valid draws.
random_seed: Seed for reproducibility.
"""
values = np.asarray(values, dtype=float)
ndim = values.size
nsample = max(int(nsample), 1)
if covar is not None:
covar = np.asarray(covar, dtype=float)
if covar is None or covar.shape != (ndim, ndim) or (not np.isfinite(covar).all()):
msg = (
'Covariance matrix is not provided or invalid. '
'Using diagonal covariance with variances from errors or zeros.'
)
warnings.warn(msg, stacklevel=2)
err = np.zeros(ndim, dtype=float) if errors is None else np.asarray(errors, dtype=float)
err = np.where(np.isfinite(err), np.abs(err), 0.0)
covar = np.diag(err * err)
covar = 0.5 * (covar + covar.T)
eigval, eigvec = np.linalg.eigh(covar)
scale = np.max(np.abs(eigval)) if eigval.size else 1.0
floor = np.finfo(float).eps * (scale if scale > 0 else 1.0)
eigval = np.clip(eigval, floor, None)
covar = eigvec @ np.diag(eigval) @ eigvec.T
lower = np.array([pr[0] for pr in self.free_pranges], dtype=float)
upper = np.array([pr[1] for pr in self.free_pranges], dtype=float)
rng = np.random.default_rng(random_seed)
param_sample = [values.copy()]
tries = 0
while len(param_sample) < nsample and tries < 10:
batch_size = max(4 * (nsample - len(param_sample)), 128)
draw = rng.multivariate_normal(values, covar, size=batch_size, check_valid='ignore')
draw = np.atleast_2d(draw)
inside = np.all((draw >= lower) & (draw <= upper), axis=1)
param_sample.extend(draw[inside][: nsample - len(param_sample)])
tries += 1
if len(param_sample) < nsample:
msg = f'Only {len(param_sample)} valid samples were generated after {tries} attempts.'
warnings.warn(msg, stacklevel=2)
param_sample = np.asarray(param_sample, dtype=float)
else:
param_sample = np.asarray(param_sample[:nsample], dtype=float)
loglike_sample = np.array([self.calc_loglike(theta) for theta in param_sample], dtype=float)
self.bootstrap_sample = np.hstack((param_sample, loglike_sample[:, None]))
self.at_par(values)
@staticmethod
def _display_results(*objects):
"""Render each object with IPython when available, otherwise ``print`` it."""
valid_objects = [obj for obj in objects if obj is not None]
try:
from IPython.display import display
except ImportError:
for obj in valid_objects:
print(obj)
return
for obj in valid_objects:
display(obj)
[docs]
def lmfit_residual(self, params):
"""lmfit-facing residual callback; delegates to :meth:`calc_pseudo_residual`."""
theta = [params[pl] for pl in self.clean_free_plabels]
return self.calc_pseudo_residual(theta)
[docs]
def lmfit(self, savepath=None):
"""Run ``lmfit.minimize`` on the pseudo-residuals and bootstrap the result.
Args:
savepath: Optional directory for persisted bootstrap samples
and summary JSON; pass ``None`` to skip disk IO.
Returns:
A :class:`~bayspec.infer.analyzer.Bootstrap`.
"""
import lmfit
from .analyzer import Bootstrap
self._you_free()
lmfit_params = lmfit.Parameters()
for pl, pv, pr in zip(
self.clean_free_plabels, self.free_pvalues, self.free_pranges, strict=False
):
lmfit_params.add(pl, value=pv, min=pr[0], max=pr[1], vary=True)
lmfit_result = lmfit.minimize(self.lmfit_residual, lmfit_params)
self._display_results(lmfit_result)
values = np.array([lmfit_result.params[pl].value for pl in self.clean_free_plabels])
errors = np.array(
[
np.nan if lmfit_result.params[pl].stderr is None else lmfit_result.params[pl].stderr
for pl in self.clean_free_plabels
]
)
covar = getattr(lmfit_result, 'covar', None)
self._make_bootstrap_sample(values, covar=covar, errors=errors)
maxlike_res = {'values': values, 'errors': errors, 'covar': covar}
if savepath is not None:
savepath_prefix = savepath + '/1-'
np.savetxt(savepath_prefix + 'bootstrap_sample.txt', self.bootstrap_sample)
with open(savepath_prefix + 'maxlike_res.json', 'w') as f:
json.dump(
maxlike_res,
f,
indent=4,
cls=JsonEncoder,
)
return Bootstrap(self)
[docs]
def iminuit_cost(self, *theta):
"""iminuit-facing cost function; returns ``1e100`` when the stat is non-finite."""
cost = self.calc_stat(theta)
if np.isfinite(cost):
return float(cost)
else:
return 1e100
[docs]
def iminuit(self, savepath=None):
"""Run iminuit's ``migrad`` + ``hesse`` + ``minos`` and bootstrap the result.
Args:
savepath: Optional directory for persisted bootstrap samples
and summary JSON.
Returns:
A :class:`~bayspec.infer.analyzer.Bootstrap`.
"""
import iminuit
from .analyzer import Bootstrap
self._you_free()
minuit = iminuit.Minuit(
self.iminuit_cost, *self.free_pvalues, name=self.clean_free_indexed_plabels
)
minuit.errordef = 2 * iminuit.Minuit.LIKELIHOOD
minuit.print_level = 0
for pl, pr in zip(self.clean_free_indexed_plabels, self.free_pranges, strict=False):
minuit.limits[pl] = pr
minuit.migrad()
minuit.hesse()
minuit.minos()
self._display_results(minuit)
values = np.array([par.value for par in minuit.params])
errors = np.array([par.error for par in minuit.params])
minos_errors = np.array([par.merror for par in minuit.params])
covar = None if minuit.covariance is None else np.asarray(minuit.covariance)
self._make_bootstrap_sample(values, covar=covar, errors=errors)
maxlike_res = {
'values': values,
'errors': errors,
'minos_errors': minos_errors,
'covar': covar,
}
if savepath is not None:
savepath_prefix = savepath + '/1-'
np.savetxt(savepath_prefix + 'bootstrap_sample.txt', self.bootstrap_sample)
with open(savepath_prefix + 'maxlike_res.json', 'w') as f:
json.dump(
maxlike_res,
f,
indent=4,
cls=JsonEncoder,
)
return Bootstrap(self)