Serializing an object in __main__ with pickle or dill
Asked Answered
P

3

7

I have a pickling problem. I want to serialize a function in my main script, then load it and run it in another script. To demonstrate this, I've made 2 scripts:

Attempt 1: The naive way:

dill_pickle_script_1.py

import pickle
import time

def my_func(a, b):
    time.sleep(0.1)  # The purpose of this will become evident at the end
    return a+b

if __name__ == '__main__':
    with open('testfile.pkl', 'wb') as f:
        pickle.dump(my_func, f)

dill_pickle_script_2.py

import pickle

if __name__ == '__main__':
    with open('testfile.pkl') as f:
        func = pickle.load(f)
        assert func(1, 2)==3

Problem: when I run script 2, I get AttributeError: 'module' object has no attribute 'my_func'. I understand why: because when my_func is serialized in script1, it belongs to the __main__ module. dill_pickle_script_2 can't know that __main__ there referred to the namespace of dill_pickle_script_1, and therefore cannot find the reference.

Attempt 2: Inserting an absolute import

I fix the problem by adding a little hack - I add an absolute import to my_func in dill_pickle_script_1 before pickling it.

dill_pickle_script_1.py

import pickle
import time

def my_func(a, b):
    time.sleep(0.1)
    return a+b

if __name__ == '__main__':
    from dill_pickle_script_1 import my_func  # Added absolute import
    with open('testfile.pkl', 'wb') as f:
        pickle.dump(my_func, f)

Now it works! However, I'd like to avoid having to do this hack every time I want to do this. (Also, I want to have my pickling be done inside some other module which wouldn't have know which module that my_func came from).

Attempt 3: Dill

I head that the package dill lets you serialize things in main and load them elsewhere. So I tried that:

dill_pickle_script_1.py

import dill
import time

def my_func(a, b):
    time.sleep(0.1)
    return a+b

if __name__ == '__main__':
    with open('testfile.pkl', 'wb') as f:
        dill.dump(my_func, f)

dill_pickle_script_2.py

import dill

if __name__ == '__main__':
    with open('testfile.pkl') as f:
        func = dill.load(f)
        assert func(1, 2)==3

Now, however, I have another problem: When running dill_pickle_script_2.py, I get a NameError: global name 'time' is not defined. It seems that dill did not realize that my_func referenced the time module and has to import it on load.

My Question?

How can I serialize an object in main, and load it again in another script so that all the imports used by that object are also loaded, without doing the nasty little hack in Attempt 2?

Phago answered 10/8, 2017 at 14:31 Comment(1)
Conchylicultor's answer is probably what you're looking for. Keep in mind though, that you should also read the pickled/dilled file in binary mode as well: open(testfile.pkl, 'rb')Evince
P
3

Well, I found a solution. It is a horrible but tidy kludge and not guaranteed to work in all cases. Any suggestions for improvement are welcome. The solution involves replacing the main reference with an absolute module reference in the pickle string, using the following helper functions:

import sys
import os

def pickle_dumps_without_main_refs(obj):
    """
    Yeah this is horrible, but it allows you to pickle an object in the main module so that it can be reloaded in another
    module.
    :param obj:
    :return:
    """
    currently_run_file = sys.argv[0]
    module_path = file_path_to_absolute_module(currently_run_file)
    pickle_str = pickle.dumps(obj, protocol=0)
    pickle_str = pickle_str.replace('__main__', module_path)  # Hack!
    return pickle_str


def pickle_dump_without_main_refs(obj, file_obj):
    string = pickle_dumps_without_main_refs(obj)
    file_obj.write(string)


def file_path_to_absolute_module(file_path):
    """
    Given a file path, return an import path.
    :param file_path: A file path.
    :return:
    """
    assert os.path.exists(file_path)
    file_loc, ext = os.path.splitext(file_path)
    assert ext in ('.py', '.pyc')
    directory, module = os.path.split(file_loc)
    module_path = [module]
    while True:
        if os.path.exists(os.path.join(directory, '__init__.py')):
            directory, package = os.path.split(directory)
            module_path.append(package)
        else:
            break
    path = '.'.join(module_path[::-1])
    return path

Now, I can simply change dill_pickle_script_1.py to say

import time
from artemis.remote.child_processes import pickle_dump_without_main_refs


def my_func(a, b):
    time.sleep(0.1)
    return a+b

if __name__ == '__main__':
    with open('testfile.pkl', 'wb') as f:
        pickle_dump_without_main_refs(my_func, f)

And then dill_pickle_script_2.py works!

Phago answered 11/8, 2017 at 8:30 Comment(0)
U
1

You can use dill.dump with recurse=True or dill.settings["recurse"] = True. It will capture closures:

In file A:

import time
import dill

def my_func(a, b):
  time.sleep(0.1)
  return a + b

with open("tmp.pkl", "wb") as f:
  dill.dump(my_func, f, recurse=True)

In file B:

import dill

with open("tmp.pkl", "rb") as f:
  my_func = dill.load(f)
Unschooled answered 16/2, 2021 at 19:20 Comment(0)
T
0

Here's another solution that modifies the serialization so that it will deserialize without any special measures. You could argue it is less hacky than Peter's solution.

Instead of hacking the output from pickle.dumps(), this subclasses Pickler to modify the way it pickles objects that refer back to __main__. This does mean that the fast (C implementation) pickler can't be used, so there is a performance penalty with this method. It also overrides the save_pers() method of Pickler, which isn't intended to be overridden. So this could break in a future version of Python (unlikely though).

def get_function_module_str(func):
    """Returns a dotted module string suitable for importlib.import_module() from a
    function reference.
    """
    source_file = Path(inspect.getsourcefile(func))
    # (Doesn't work with built-in functions)
    if not source_file.is_absolute():
        rel_path = source_file
    else:
        # It's an absolute path so find the longest entry in sys.path that shares a
        # common prefix and remove the prefix.
        for path_str in sorted(sys.path, key=len, reverse=True):
            try:
                rel_path = source_file.relative_to(Path(path_str))
                break
            except ValueError:
                pass
        else:
            raise ValueError(f"{source_file!r} is not on the Python path")
    # Replace path separators with dots.
    modules_str = ".".join(p for p in rel_path.with_suffix("").parts if p != "__init__")
    return modules_str, func.__name__


class ResolveMainPickler(pickle._Pickler):
    """Subclass of Pickler that replaces __main__ references with the actual module
    name."""

    def persistent_id(self, obj):
        """Override to see if this object is defined in "__main__" and if so to replace
        __main__ with the actual module name."""
        if getattr(obj, "__module__", None) == "__main__":
            module_str, obj_name = get_function_module_str(obj)
            obj_ref = getattr(importlib.import_module(module_str), obj_name)
            return obj_ref
        return None

    def save_pers(self, pid):
        """Override the function to save a persistent ID so that it saves it as a
        normal reference. So it can be unpickled with no special arrangements.
        """
        self.save(pid, save_persistent_id=False)


with io.BytesIO() as pickled:
    pickler = ResolveMainPickler(pickled)
    pickler.dump(obj)
    print(pickled.getvalue())

If you already know the name of the __main__ module then you could dispense with get_function_module_str() and just supply the name directly.

Trave answered 28/3, 2022 at 16:31 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.