import itertools
import re
from collections import namedtuple
from enum import Enum
from brmp.utils import join
# Maintains order.
def unique(xs):
seen = set()
out = []
for x in xs:
if x not in seen:
out.append(x)
seen.add(x)
return out
class OrderedSet():
def __init__(self, *items, items_are_unique=False):
# For most methods on this class `items` could be an arbitrary
# iterable. However, for `union`, we need any two ordered sets
# to have `items` be of the same type, in order to
# straight-forwardly concatenate them with `+`. I'll use a
# tuple, since it also has the benefit of been immutable, but
# list would work too.
self.items = tuple(items if items_are_unique else unique(items))
self.fset = frozenset(self.items)
assert len(self.fset) == len(self.items) # `unique` ensures nodups
def __hash__(self):
return hash((OrderedSet, self.fset))
def __eq__(self, other):
return self.fset == other.fset
def __iter__(self):
return self.items.__iter__()
def __next__(self):
return self.items.__next__()
def __len__(self):
return len(self.items)
def __getitem__(self, i):
return self.items[i]
def __repr__(self):
return '<{}>'.format(','.join(str(item) for item in self.items))
def union(self, other):
items = self.items + tuple(x for x in other.items if x not in self.fset)
return OrderedSet(*items, items_are_unique=True)
# Tokens
Paren = Enum('Paren', 'L R')
Assoc = Enum('Assoc', 'L R')
Op = namedtuple('Op', 'name assoc precedence')
Var = namedtuple('Var', 'name')
OPS = {
':': Op(':', Assoc.L, 5),
'+': Op('+', Assoc.L, 4),
'||': Op('||', Assoc.L, 3),
'|': Op('|', Assoc.L, 3),
'~': Op('~', Assoc.L, 2),
}
# AST
Leaf = namedtuple('Leaf', 'value')
Node = namedtuple('Node', 'op l r')
# TODO: Make into classes. Add validation. Add repr.
# TODO: Add intercepts by default. That probably ought to happen
# somewhere around here.
# TODO: Check for (and disallow) multiple groups using the same
# grouping column. I assume that this is the case elsewhere in the
# package. (e.g. Specifying priors for a particular group is done
# using the name of the grouping column, which would be ambiguous
# without the assumption. brms does this too.
# TODO: I don't think attaching these docs to the "private" `Formula`
# class makes much sense. What's a better alternative?
Formula = namedtuple('Formula',
['response', # response column name
'terms', # an OrderedSet of population level terms
'groups']) # list of groups
"""
Represents an lme4 formula.
.. list-table::
:widths: auto
* - ``~``
- A valid formula contains exactly one occurrence of ``~`` . The LHS gives
the name of the scalar response variable. The RHS describes the
structure of the model.
* - ``+``
- A combination of terms.
* - ``:``
- An interaction between two terms. Can also appear on the RHS of ``|`` or
``||`` to specify grouping by multiple factors.
* - ``|``
- Introduces a group-level term. (i.e. random effect.)
* - ``||``
- Introduces a group-level term, omitting modelling of group-level correlations.
* - ``1``
- Intercept term. Note that intercept terms are not added automatically.
The following examples are all parsed as valid formulae:
.. code-block:: text
y ~ x
y ~ 1 + x
y ~ 1 + x1:x2
y ~ 1 + x1 + (1 + x2 | a)
y ~ 1 + x1 + (1 + x2 || a:b)
"""
Group = namedtuple('Group',
['terms', # an OrderedSet of group-level terms
'columns', # names of grouping columns
'corr']) # model correlation between coeffs?
# TODO: Make it possible to union terms directly? (Could use in the
# `:` case of eval.)
Term = namedtuple('Term',
['factors']) # Factors in the Patsy sense. An OrderedSet.
_1 = Term(OrderedSet()) # Intercept
def allfactors(formula):
assert type(formula) == Formula
def all_from_terms(terms):
return join(list(term.factors) for term in terms)
return ([formula.response] +
all_from_terms(formula.terms) +
join(all_from_terms(group.terms) + group.columns for group in formula.groups))
def tokenize(inp):
return [str2token(s) for s in re.findall(r'\b\w+\b|[()~+:]|\|\|?', inp)]
def str2token(s):
if s in OPS:
return OPS[s]
elif s == '(':
return Paren.L
elif s == ')':
return Paren.R
else:
return Var(s)
# https://en.wikipedia.org/wiki/Shunting-yard_algorithm
# TODO: Add better error handling. (Wiki article has some extra checks
# I skipped.)
# TODO: Use stack/queue with correct asymptotic performance.
def shunt(tokens):
opstack = []
output = []
for token in tokens:
if type(token) == Var:
output.append(token)
elif type(token) == Op:
while (len(opstack) > 0 and opstack[-1] != Paren.L and
(opstack[-1].precedence > token.precedence or
(opstack[-1].precedence == token.precedence and opstack[-1].assoc == Assoc.L))):
output.append(opstack.pop())
opstack.append(token)
elif token == Paren.L:
opstack.append(token)
elif token == Paren.R:
while opstack[-1] != Paren.L:
output.append(opstack.pop())
assert opstack[-1] == Paren.L
opstack.pop()
else:
raise Exception('unhandled token type')
while opstack:
output.append(opstack.pop())
return output
# Evaluate rpn (as produced by `shunt`) to an ast.
def rpn2ast(tokens):
out = []
for token in tokens:
if type(token) == Var:
out.append(Leaf(token.name))
elif type(token) == Op:
right = out.pop()
left = out.pop()
out.append(Node(token.name, left, right))
else:
# No parens once in rpn.
raise Exception('unhandled token type')
assert len(out) == 1
return out[0]
# Returns an ordered set of population-level terms and a list of
# groups.
def eval_rhs(ast, allow_groups=True):
if type(ast) == Leaf:
if ast.value == "1":
return OrderedSet(_1), []
else:
return OrderedSet(Term(OrderedSet(ast.value))), []
elif type(ast) == Node and ast.op == '+':
termsl, groupsl = eval_rhs(ast.l, allow_groups)
termsr, groupsr = eval_rhs(ast.r, allow_groups)
return termsl.union(termsr), groupsl + groupsr
elif type(ast) == Node and ast.op == ':':
# lme4/brms say a formula has the general form:
# response ~ pterms + (gterms | group) + ...
# This suggests the interaction between groups is not
# possible. (Which is good, because I don't know what the
# semantics would be.) However, neither packages complains if
# you write something like `y ~ (a | b) : (b | a)`, which is
# odd. Here we don't allow interactions between groups.
termsl, groupsl = eval_rhs(ast.l, allow_groups=False)
termsr, groupsr = eval_rhs(ast.r, allow_groups=False)
assert len(groupsl) == 0
assert len(groupsr) == 0
terms = [Term(tl.factors.union(tr.factors))
for tl, tr in itertools.product(termsl, termsr)]
return OrderedSet(*terms), []
elif type(ast) == Node and ast.op in ['|', '||'] and allow_groups:
group_factors = eval_group_rhs(ast.r)
# Nesting of groups is not allowed.
terms, groups = eval_rhs(ast.l, allow_groups=False)
assert len(groups) == 0
return OrderedSet(), [Group(terms, group_factors, ast.op == '|')]
else:
# This if/else is not exhaustive, this can occur in regular
# use. e.g. When nested groups are present.
raise Exception('unhandled ast')
# Evaluate the expression to the right of the `|` or `||` in a group
# to a list of factor names.
# e.g. `a:b:c` -> ['a', 'b', 'c']
# TODO: Have `a:a` evaluate to `a`. (By using OrderedSet?)
def eval_group_rhs(ast):
if type(ast) == Leaf:
return [ast.value]
elif type(ast) == Node and ast.op == ':':
return eval_group_rhs(ast.l) + eval_group_rhs(ast.r)
else:
# TODO: Better error. Catch and re-throw within `eval_rhs` so
# that the whole group can be included in the message?
raise Exception('unhandled ast')
# Evaluate a formula of the form `y ~ <rhs>`, where the rhs is a sum
# of population terms and groups.
def evalf(ast):
assert type(ast) == Node and ast.op == '~'
# The lhs is expected to be a (response) variable
assert type(ast.l) == Leaf
terms, groups = eval_rhs(ast.r)
return Formula(ast.l.value, terms, groups)
def parse(s):
return evalf(rpn2ast(shunt(tokenize(s))))
def main():
print(parse('y ~ x1 + x2 + (1 + x3 | x4) + x5:x6'))
if __name__ == '__main__':
main()