from collections import namedtuple
from functools import partial
import numpy as np
import numpyro.diagnostics as diags
import pandas as pd
from brmp.backend import data_from_numpy
from brmp.design import predictors
from brmp.family import free_param_names
from brmp.model import scalar_parameter_map, scalar_parameter_names
from brmp.utils import flatten
default_quantiles = [0.025, 0.25, 0.5, 0.75, 0.975]
# `Fit` carries around `formula`, `metadata` and `contrasts` for the
# sole purpose of being able to encode any new data passed to
# `fitted`.
# TODO: Consider storing `formula`, `metadata` and `contrasts` on
# `ModelDescPre` (and then `ModelDesc`) as an alternative to storing
# them on `Fit`. (Since it seems more natural.)
# One snag is that `ModelDescPre` only sees the lengths of any custom
# coding and not the full matrix. However, while deferring having to
# give a concrete data frame is useful (because it saves having to
# make up data to tinker with a model), it's not clear that deferring
# having to give contrasts has a similar benefit.
[docs]class Fit(namedtuple('Fit', 'formula metadata contrasts data model_desc model samples backend')):
# TODO: This doesn't match the brms interface, but the deviation
# aren't improvements either. Figure out what to do about that.
# brms | brmp
# -----------------------------------------------------------------------------------
# fitted(fit, summary=FALSE) | fit.fitted()
# fitted(dpar='mu', scale='linear', summary=FALSE) | fit.fitted('linear')
# fitted(dpar='mu', scale='response', summary=FALSE) | fit.fitted('response')
# fitted(fit, newdata=..., summary=FALSE) | fit.fitted(data=...)
# fitted(fit, ..., summary=TRUE) | summary(fit.fitted(...))
# predict(fit, summary=FALSE) | fit.fitted('sample')
# predict(fit, summary=TRUE) | summary(fit.fitted('sample'))
# https://rdrr.io/cran/brms/man/fitted.brmsfit.html
[docs] def fitted(self, what='expectation', data=None, seed=None):
"""
Produces predictions from the fitted model.
Predicted values are computed for each sample collected during inference,
and for each row in the data set.
:param what: The value to predict. Valid arguments and their effect are described below:
.. list-table::
:widths: auto
* - ``'expectation'``
- Computes the expected value of the response distribution.
* - ``'sample'``
- Draws a sample from the response distribution.
* - ``'response'``
- Computes the output of the model followed by any
inverse link function. i.e. The value of the location
parameter of the response distribution.
* - ``'linear'``
- Computes the output of the model prior to the application
of any inverse link function.
:type what: str
:param data: The data from which to compute the predictions. When omitted,
the data on which the model was fit is used.
:type data: pandas.DataFrame
:param seed: Random seed. Used only when ``'sample'`` is given as the ``'what'`` argument.
:type seed: int
:return: An array with shape ``(S, N)``. Where ``S`` is the number of samples taken
during inference and ``N`` is the number of rows in the data set used for prediction.
:rtype: numpy.ndarray
"""
assert what in ['sample', 'expectation', 'linear', 'response']
assert data is None or type(data) is pd.DataFrame
assert seed is None or type(seed) == int
get_param = self.samples.get_param
location = self.samples.location
to_numpy = self.backend.to_numpy
expected_response = partial(self.backend.expected_response, self.model)
sample_response = partial(self.backend.sample_response, self.model, seed)
inv_link = partial(self.backend.inv_link, self.model)
mu = location(self.data if data is None
else data_from_numpy(self.backend, predictors(self.formula, data, self.metadata, self.contrasts)))
if what == 'sample' or what == 'expectation':
args = [mu if name == 'mu' else get_param(name, False)
for name in free_param_names(self.model_desc.response.family)]
response_fn = sample_response if what == 'sample' else expected_response
return to_numpy(response_fn(*args))
elif what == 'linear':
return to_numpy(mu)
elif what == 'response':
return to_numpy(inv_link(mu))
else:
raise ValueError('Unhandled value of the `what` parameter encountered.')
# Similar to the following:
# https://rdrr.io/cran/rstan/man/stanfit-method-summary.html
# TODO: This produces the same output as the old implementation of
# `marginal`, though it's less efficient. Can the previous efficiency
# be recovered? The problem is that we pull out each individual scalar
# parameter as a vector and then stack those, rather than just stack
# entire parameters as before. One thought is that such an
# optimisation might be best pushed into `get_scalar_param`. i.e. This
# might accept a list of a parameter names and return the
# corresponding scalar parameters stacked into a matrix. The aim would
# be to do this without performing any unnecessary slicing. (Though
# this sounds fiddly.)
[docs] def marginals(self, qs=default_quantiles):
"""Produces a table containing statistics of the marginal
distibutions of the parameters of the fitted model.
:param qs: A list of quantiles to include in the output.
:type qs: list
:return: A table of marginal statistics.
:rtype: brmp.fit.ArrReprWrapper
Example::
fit = brm('y ~ x', df).fit()
print(fit.marginals())
# mean sd 2.5% 25% 50% 75% 97.5% n_eff r_hat
# b_x 0.42 0.33 -0.11 0.14 0.48 0.65 0.88 5.18 1.00
# sigma 0.78 0.28 0.48 0.61 0.68 0.87 1.32 5.28 1.10
"""
names = scalar_parameter_names(self.model_desc)
# TODO: Every call to `get_scalar_param` rebuilds the scalar
# parameter map.
vecs = [self.get_scalar_param(name, True) for name in names]
col_labels = ['mean', 'sd'] + format_quantiles(qs) + ['n_eff', 'r_hat']
samples = np.stack(vecs, axis=2)
stats_arr = marginal_stats(flatten(samples), qs)
n_eff = compute_diag_or_default(effective_sample_size, samples)
r_hat = compute_diag_or_default(gelman_rubin, samples)
arr = np.hstack([stats_arr, n_eff[..., np.newaxis], r_hat[..., np.newaxis]])
return ArrReprWrapper(arr, names, col_labels)
# A back end agnostic wrapper around back end specific implementations
# of `fit.samples.get_param`.
def get_param(self, name, preserve_chains=False):
return self.backend.to_numpy(self.samples.get_param(name, preserve_chains))
# TODO: If parameter and scalar parameter names never clash, perhaps
# having a single lookup method would be convenient. Perhaps this
# could be wired up to `fit.samples[...]`?
# TODO: Mention other ways of obtaining valid parameter names?
[docs] def get_scalar_param(self, name, preserve_chains=False):
"""
Extracts the values sampled for a single parameter from a model fit.
:param name: The name of a parameter of the model. Valid names are those
shown in the output of :func:`~brmp.fit.marginals`.
:type name: str
:param preserve_chains: Whether to group samples by the MCMC chain on which
they were collected.
:type preserve_chains: bool
:return: An array with shape ``(S,)`` when ``preserve_chains=False``, ``(C, S)``
otherwise. Where ``S`` is the number of samples taken during inference,
and ``C`` is the number of MCMC chains run.
:rtype: numpy.ndarray
"""
m = scalar_parameter_map(self.model_desc)
res = [p for (n, p) in m if n == name]
assert len(res) < 2
if len(res) == 0:
raise KeyError('unknown parameter name: {}'.format(name))
param_name, index = res[0]
# Construct a slice to pick out the given index at all chains (if
# present) and all samples.
slc = (Ellipsis,) + index
return self.get_param(param_name, preserve_chains)[slc]
# TODO: Consider delegating to `marginals` or similar?
def __repr__(self):
# The repr of namedtuple ends up long and not very useful for
# Fit. This is similar to the default implementation of repr
# used for classes.
return '<brmp.fit.Fit at {}>'.format(hex(id(self)))
Samples = namedtuple('Samples', ['raw_samples', 'get_param', 'location'])
def format_quantiles(qs):
return ['{:g}%'.format(q * 100) for q in qs]
# Computes statistics for an array produced by `marginal`.
def marginal_stats(arr, qs):
assert len(arr.shape) == 2
assert type(qs) == list
assert all(0 <= q <= 1 for q in qs)
mean = np.mean(arr, 0)
sd = np.std(arr, 0)
quantiles = np.quantile(arr, qs, 0)
stacked = np.hstack((mean.reshape((-1, 1)), sd.reshape((-1, 1)), quantiles.T))
return stacked
# TODO: Would it be better to replace these tables with pandas data
# frames? They also let you get at the underlying data as a numpy
# array (I assume), and have their own pretty printing.
class ArrReprWrapper:
def __init__(self, array, row_labels, col_labels):
assert len(array.shape) == 2
assert row_labels is None or array.shape[0] == len(row_labels)
assert col_labels is None or array.shape[1] == len(col_labels)
self.array = array
self.col_labels = col_labels
self.row_labels = row_labels
def __repr__(self):
# Format a float. 2 decimal places, space for sign.
def ff(x):
return '{: .2f}'.format(x)
table = [[ff(c) for c in r] for r in self.array.tolist()]
return layout_table(add_labels(table, self.col_labels, self.row_labels))
def add_labels(table, col_labels, row_labels):
assert type(table) == list
assert all(type(row) == list for row in table)
out = [col_labels] if col_labels is not None else []
out += table
if row_labels is not None:
rlabels = row_labels if col_labels is None else [''] + row_labels
assert len(out) == len(rlabels)
out = [[name] + r for r, name in zip(out, rlabels)]
return out
def layout_table(rows):
num_rows = len(rows)
assert num_rows > 0
num_cols = len(rows[0])
assert all(len(row) == num_cols for row in rows)
max_widths = [0] * num_cols
for row in rows:
for i, cell in enumerate(row):
max_widths[i] = max(max_widths[i], len(cell))
fmt = ' '.join('{{:>{}}}'.format(mw) for mw in max_widths)
return '\n'.join(fmt.format(*row) for row in rows)
# TODO: We could follow brms and make this available via a `summary`
# flag on `fitted`?
def summary(arr, qs=default_quantiles, row_labels=None):
col_labels = ['mean', 'sd'] + format_quantiles(qs)
return ArrReprWrapper(marginal_stats(arr, qs), row_labels, col_labels)
def gelman_rubin(samples):
if ((samples.shape[0] < 2 and samples.shape[1] < 4) or
(samples.shape[0] >= 2 and samples.shape[1] < 2)):
return None # Too few chains or samples.
elif samples.shape[0] >= 2:
return diags.gelman_rubin(samples)
else:
return diags.split_gelman_rubin(samples)
def effective_sample_size(samples):
if samples.shape[1] < 2:
return None # Too few samples.
else:
return diags.effective_sample_size(samples)
def compute_diag_or_default(diag, samples):
val = diag(samples)
if val is not None:
return val
else:
return np.full((samples.shape[2],), np.nan)