How can we "associate" a Python context manager to the variables appearing in its block?
Asked Answered
L

3

12

As I understand it, context managers are used in Python for defining initializing and finalizing pieces of code (__enter__ and __exit__) for an object.

However, in the tutorial for PyMC3 they show the following context manager example:

basic_model = pm.Model()

with basic_model:

    # Priors for unknown model parameters
    alpha = pm.Normal('alpha', mu=0, sd=10)
    beta = pm.Normal('beta', mu=0, sd=10, shape=2)
    sigma = pm.HalfNormal('sigma', sd=1)

    # Expected value of outcome
    mu = alpha + beta[0]*X1 + beta[1]*X2

    # Likelihood (sampling distribution) of observations
    Y_obs = pm.Normal('Y_obs', mu=mu, sd=sigma, observed=Y)

and mention that this has the purpose of associating the variables alpha, beta, sigma, mu and Y_obs to the model basic_model.

I would like to understand how such a mechanism works. In the explanations of context managers I have found, I did not see anything suggesting how variables or objects defined within the context's block get somehow "associated" to the context manager. It would seem that the library (PyMC3) somehow has access to the "current" context manager so it can associate each newly created statement to it behind the scenes. But how can the library get access to the context manager?

Leucocyte answered 14/8, 2018 at 20:35 Comment(2)
This could be done by an implementation of __enter__ pushing information to a thread local stackTilley
@Tilley YepWiner
W
12

PyMC3 does this by maintaining a thread local variable as a class variable inside the Context class. Models inherit from Context.

Each time you call with on a model, the current model gets pushed onto the thread-specific context stack. The top of the stack thus always refers to the innermost (most recent) model used as a context manager.

Contexts (and thus Models) have a .get_context() class method to obtain the top of the context stack.

Distributions call Model.get_context() when they are created to associate themselves with the innermost model.

So in short:

  1. with model pushes model onto the context stack. This means that inside of the with block, type(model).contexts or Model.contexts, or Context.contexts now contain model as its last (top-most) element.
  2. Distribution.__init__() calls Model.get_context() (note capital M), which returns the top of the context stack. In our case this is model. The context stack is thread-local (there is one per thread), but it is not instance-specific. If there is only a single thread, there also is only a single context stack, regardless of the number of models.
  3. When exiting the context manager. model gets popped from the context stack.
Winer answered 14/8, 2018 at 21:11 Comment(2)
Great, thanks. But there is an apparently circularity I am not following: you get the Model/Context from the context stack, which is obtained from the Model/Context (get_context), which you get from the context stack, which you get from the Model/Context... how does a Distribution get access to either the Model/Context or the context stack to begin with?Leucocyte
I might need to add some emphasis that get_context() is a class method and that the context stack is a thread local class variable. get_context() isn't called on a model instance, but on the Model class.Winer
N
4

I don't know how it works in this specific case, but in general you will use some 'behind the scenes magic':

class Parent:
    def __init__(self):
        self.active_child = None

    def ContextManager(self):
        return Child(self)

    def Attribute(self):
        return self.active_child.Attribute()

class Child:
    def __init__(self,parent):
        self.parent = parent

    def __enter__(self):
        self.parent.active_child = self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.parent.active_child = None

    def Attribute(self):
        print("Called Attribute of child")

Using this code:

p = Parent()
with p.ContextManager():
    attr = p.Attribute()

will yield to following output:

Called Attribute of child
Nabal answered 14/8, 2018 at 20:58 Comment(2)
Thanks, but I didn't get it. The idea was to have something like with contextmanager : foo() and then have foo() somehow get access to contextmanager. In your example, this seems to be done by keeping p around and using it as a bridge, which seems to not achieve the same goal.Leucocyte
@Leucocyte But that is happing in your code snippet. pm is used as a bridge. ( you call pm.Model() and pm.Normal(...) and then the return values of Model and Normal are connected.Nabal
C
1

One can also inspect the stack for locals() variables when entering and exiting the context manager block and identify which one have changed.

class VariablePostProcessor(object):
    """Context manager that applies a function to all newly defined variables in the context manager.

    with VariablePostProcessor(print):
        a = 1
        b = 3

    It uses the (name, id(obj)) of the variable & object to detect if a variable has been added.
    If a name is already binded before the block to an object, it will detect the assignment to this name
    in the context manager block only if the id of the object has changed.

    a = 1
    b = 2
    with VariablePostProcessor(print):
        a = 1
        b = 3
    # will only detect 'b' has newly defined variable/object. 'a' will not be detected as it points to the
    # same object 1
    """

    @staticmethod
    def variables():
        # get the locals 2 stack above
        # (0 is this function, 1 is the __init__/__exit__ level, 2 is the context manager level)
        return {(k, id(v)): v for k, v in inspect.stack()[2].frame.f_locals.items()}

    def __init__(self, post_process):
        self.post_process = post_process
        # save the current stack
        self.dct = self.variables()

    def __enter__(self):
        return

    def __exit__(self, type, value, traceback):
        # compare variables defined at __exist__ with variables defined at __enter__
        dct_exit, dct_enter = self.variables(), self.dct
        for (name, id_) in set(dct_exit).difference(dct_enter):
            self.post_process(name, dct_exit[(name, id_)])

Typical use can be:

# let us define a Variable object that has a 'name' attribute that can be defined at initialisation time or later
class Variable:
    def __init__(self, name=None):
        self.name = name

# the following code
x = Variable('x')
y = Variable('y')
print(x.name, y.name)

# can be replaced by
with VariablePostProcessor(lambda name, obj: setattr(obj, "name", name)):
    x = Variable()
    y = Variable()
print(x.name, y.name)

# in such case, you can also define as a convenience
import functools
AutoRenamer = functools.partial(VariablePostProcessor, post_process=lambda name, obj: setattr(obj, "name", name))

# and rewrite the above code as
with AutoRenamer():
    x = Variable()
    y = Variable()
print(x.name, y.name)  # => x y
Connecticut answered 24/11, 2022 at 5:49 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.