Source code for brmp.priors

from collections import defaultdict, namedtuple

from brmp.family import LKJ, Cauchy, Family, HalfCauchy, Type, fully_applied
from brmp.model_pre import GroupPre, ModelDescPre, PopulationPre
from brmp.utils import join

# `is_param` indicates whether a node corresponds to a parameter in
# the model. (Nodes without this flag set exist only to add structure
# to the parameters.) This infomation is used when extracting
# information from the tree.
Node = namedtuple('Node', 'name prior_edit is_param checks children')


def leaf(name, prior_edit=None, checks=[]):
    return Node(name, prior_edit, True, checks, [])


RESPONSE_PRIORS = {
    'Normal': {
        'sigma': HalfCauchy(3.)
    },
}


def get_response_prior(family, parameter):
    if family in RESPONSE_PRIORS:
        return RESPONSE_PRIORS[family][parameter]


# This is similar to brms `set_prior`. (e.g. `set_prior('<prior>',
# coef='x1')` is similar to `Prior(['x1'], '<prior>)`.) By specifying
# paths (rather than class/group/coef) we're diverging from brms, but
# the hope is that a brms-like interface can be put in front of this.

Prior = namedtuple('Prior', 'path prior')
"""

A :class:`~brmp.prior.Prior` instance associates a prior distribution with one
or more parameters of a model. One or more such instances may be passed to
:func:`~brmp.brm` to override its default choice of priors.

The parameters of the model to which a prior should be applied are specified
using a path. The following examples illustrate how this works:

.. list-table::
   :widths: auto
   :header-rows: 1

   * - Path
     - Selected Parameters
   * - ``('b',)``
     - All population level coefficients
   * - ``('b', 'intercept')``
     - The population level intercept
   * - ``('b', 'x')``
     - The population level coefficient ``x``
   * - ``('sd',)``
     - All standard deviations in all groups
   * - ``('sd', 'a')``
     - All standard deviations in the group which groups by column ``a``
   * - ``('sd', 'a:b')``
     - All standard deviations in the group which groups by columns ``a`` and ``b``
   * - ``('sd', 'a', 'intercept')``
     - The standard deviation of the intercept in the group which groups by column ``a``
   * - ``('cor',)``
     - All correlation matrices
   * - ``('cor', 'a')``
     - The correlation matrix of the group which groups by column ``a``
   * - ``('cor', 'a:b')``
     - The correlation matrix of the group which groups by columns ``a`` and ``b``
   * - ``('resp', `sigma`)``
     - The ``sigma`` parameter of the response distribution

Example::

  Prior(('b', 'intercept'), Normal(0., 1.))

:param path: A path describing one or more parameters of the model.
:type path: tuple
:param prior: A prior distribution, given as a :class:`~brmp.family.Family`
              with all of its parameters specified.
:type prior: brmp.family.Family
"""


def walk(node, path):
    assert type(node) == Node
    assert type(path) == tuple
    if len(path) == 0:
        return [node]
    else:
        name = path[0]
        selected_node = next((n for n in node.children if n.name == name), None)
        if selected_node is None:
            raise ValueError('Invalid path')
        return [node] + walk(selected_node, path[1:])


def select(node, path):
    return walk(node, path)[-1]


def edit(node, path, f):
    assert type(node) == Node
    assert type(path) == tuple
    if len(path) == 0:
        # We're at the node to be edited. (i.e. Empty path picks out
        # the root node.)
        newnode = f(node)
        assert type(newnode) == Node
        return newnode
    else:
        # Recursively edit the appropriate child. (Or children, if
        # names are duplicated.)
        name = path[0]
        assert any(n.name == name for n in node.children), 'Node "{}" not found.'.format(name)
        children = [edit(n, path[1:], f) if n.name == name else n
                    for n in node.children]
        return Node(node.name, node.prior_edit, node.is_param, node.checks, children)


# TODO: Match default priors used by brms. (An improper uniform is
# used for `b`. A Half Student-t here is used for priors on standard
# deviations, with its scale derived from the data.)

# TODO: It might be a good idea to build the tree with checks but no
# priors, and then add the priors using in the same way as user edits
# are applied, in order to ensure that the default meet the
# contraints. Or, perhaps a more convenient way of achieving the same
# thing is to make an separate pass over the entire default tree once
# built, and assert its consistency.

def default_prior(model_desc_pre):
    assert type(model_desc_pre) == ModelDescPre
    family = model_desc_pre.response.family
    assert type(family) == Family
    assert family.link is not None
    assert type(model_desc_pre.population) == PopulationPre
    assert type(model_desc_pre.groups) == list
    assert all(type(gm) == GroupPre for gm in model_desc_pre.groups)
    b_children = [leaf(name) for name in model_desc_pre.population.coefs]
    cor_children = [leaf(cols2str(group.columns)) for group in model_desc_pre.groups if group.corr]
    sd_children = [Node(cols2str(gm.columns), None, False, [], [leaf(name) for name in gm.coefs]) for gm in
                   model_desc_pre.groups]

    def mk_resp_prior_edit(param_name):
        prior = get_response_prior(family.name, param_name)
        if prior is not None:
            return Prior(('resp', param_name), prior)

    resp_children = [leaf(p.name, mk_resp_prior_edit(p.name), [chk_support(p.type)])
                     for p in model_desc_pre.response.nonlocparams]
    return Node('root', None, False, [], [
        Node('b', Prior(('b',), Cauchy(0., 1.)), False, [chk_support(Type['Real']())], b_children),
        Node('sd', Prior(('sd',), HalfCauchy(3.)), False, [chk_support(Type['PosReal']())], sd_children),
        Node('cor', Prior(('cor',), LKJ(1.)), False, [chk_lkj], cor_children),
        Node('resp', None, False, [], resp_children)])


def cols2str(cols):
    return ':'.join(cols)


def customize_prior(tree, priors):
    assert type(tree) == Node
    assert type(priors) == list
    assert all(type(p) == Prior for p in priors)
    for prior_edit in priors:
        # TODO: It probably makes sense to move this to the
        # constructor of Prior, once such a thing exists.
        if not fully_applied(prior_edit.prior):
            raise Exception('Distribution arguments missing from prior "{}"'.format(prior_edit.prior.name))
        tree = edit(tree, prior_edit.path,
                    lambda n: Node(n.name, prior_edit, n.is_param, n.checks, n.children))
    return tree


# It's important that trees maintain the order of their children, so
# that coefficients in the prior tree continue to line up with columns
# in the design matrix.
def build_prior_tree(model_desc_pre, priors, chk=True):
    tree = fill(customize_prior(default_prior(model_desc_pre), priors))
    if chk:
        # TODO: I might consider delaying this check (that all
        # parameters have priors) until just before code generation
        # happens. This could allow an under-specified model to be
        # pretty-printed, which might make it easier for users to see
        # what's going on. (Once `brm` returns a model rather than
        # running inference.) Doing so would require the `ModelDesc`
        # data structure and pretty printing code to handle missing
        # priors. (Does something similar apply to the response/family
        # compatibility checks currently in model.py?)
        missing_prior_paths = leaves_without_prior(tree)
        if len(missing_prior_paths) > 0:
            paths = ', '.join('"{}"'.format('/'.join(path)) for path in missing_prior_paths)
            raise Exception('Prior missing from {}.'.format(paths))
        errors = check(tree)
        if errors:
            raise Exception(format_errors(errors))
    return tree


# `fill` populates the `prior_edit` and `checks` properties of all
# nodes in a tree. Each node uses its own `prior_edit` value if set,
# otherwise the first `prior_edit` value encountered when walking up
# the tree from the node is used. The final values of `checks` comes
# from concatenating all of the lists of checks encountered when
# walking from the node to the root. (This is the behaviour, not the
# implementation.)
def fill(node, default=None, upstream_checks=[]):
    prior = node.prior_edit if node.prior_edit is not None else default
    checks = upstream_checks + node.checks
    return Node(node.name, prior, node.is_param, checks, [fill(n, prior, checks) for n in node.children])


def leaves(node, path=[]):
    this = [(node, path)] if node.is_param else []
    rest = join(leaves(n, path + [n.name]) for n in node.children)
    return this + rest


# Sanity checks

class Chk():
    def __init__(self, predicate, name):
        self.predicate = predicate
        self.name = name

    def __call__(self, node):
        assert type(node) == Node
        if node.prior_edit is None:
            # There is no prior to check.
            return True
        else:
            return self.predicate(node.prior_edit.prior)

    def __repr__(self):
        return 'Chk("{}")'.format(self.name)


def chk(name):
    def decorate(predicate):
        return Chk(predicate, name)

    return decorate


def chk_support(typ):
    # TODO: This could probably be relaxed to only require that the
    # support of the prior is a subset of type of the parameter.
    # (However this is easier and good enough for now.)
    def pred(prior):
        return prior.support() == typ

    return Chk(pred, 'has support of {}'.format(typ))


@chk('is LKJ')
def chk_lkj(prior):
    return prior.name == 'LKJ'


def check(tree):
    errors = defaultdict(lambda: defaultdict(list))
    for (node, path) in leaves(tree):
        for chk in node.checks:
            if not chk(node):
                # This holds because checks can only fail when a node
                # has a `prior_edit`.
                assert node.prior_edit is not None
                errors[node.prior_edit.path][chk].append(path)
    return errors


# TODO: There's info in `errors` which we're not making use of here.
def format_errors(errors):
    paths = ', '.join('"{}"'.format('/'.join(path))
                      for path in errors.keys())
    return 'Invalid prior specified at {}.'.format(paths)


def leaves_without_prior(tree):
    return [path for (node, path) in leaves(tree) if node.prior_edit is None]