TypeError: __init__() got an unexpected keyword argument 'name' when loading a model with Custom Layer
Asked Answered
M

3

14

I made a custom layer in keras for reshaping the outputs of a CNN before feeding to ConvLSTM2D layer

class TemporalReshape(Layer):
    def __init__(self,batch_size,num_patches):
        super(TemporalReshape,self).__init__()
        self.batch_size = batch_size
        self.num_patches = num_patches

    def call(self,inputs):
        nshape = (self.batch_size,self.num_patches)+inputs.shape[1:]
        return tf.reshape(inputs, nshape)

    def get_config(self):
        config = super().get_config().copy()
        config.update({'batch_size':self.batch_size,'num_patches':self.num_patches})
        return config

When I try to load the best model using

model = tf.keras.models.load_model('/content/saved_models/model_best.h5',custom_objects={'TemporalReshape':TemporalReshape})

I get the error

TypeError                                 Traceback (most recent call last)
<ipython-input-83-40b46da33e91> in <module>()
----> 1 model = tf.keras.models.load_model('/content/saved_models/model_best.h5',custom_objects={'TemporalReshape':TemporalReshape})


/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/save.py in load_model(filepath, custom_objects, compile, options)
    180     if (h5py is not None and (
    181         isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
--> 182       return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile)
    183 
    184     filepath = path_to_string(filepath)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/hdf5_format.py in load_model_from_hdf5(filepath, custom_objects, compile)
    176     model_config = json.loads(model_config.decode('utf-8'))
    177     model = model_config_lib.model_from_config(model_config,
--> 178                                                custom_objects=custom_objects)
    179 
    180     # set weights

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/model_config.py in model_from_config(config, custom_objects)
     53                     '`Sequential.from_config(config)`?')
     54   from tensorflow.python.keras.layers import deserialize  # pylint: disable=g-import-not-at-top
---> 55   return deserialize(config, custom_objects=custom_objects)
     56 
     57 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/serialization.py in deserialize(config, custom_objects)
    173       module_objects=LOCAL.ALL_OBJECTS,
    174       custom_objects=custom_objects,
--> 175       printable_module_name='layer')

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    356             custom_objects=dict(
    357                 list(_GLOBAL_CUSTOM_OBJECTS.items()) +
--> 358                 list(custom_objects.items())))
    359       with CustomObjectScope(custom_objects):
    360         return cls.from_config(cls_config)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py in from_config(cls, config, custom_objects)
    615     """
    616     input_tensors, output_tensors, created_layers = reconstruct_from_config(
--> 617         config, custom_objects)
    618     model = cls(inputs=input_tensors, outputs=output_tensors,
    619                 name=config.get('name'))

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py in reconstruct_from_config(config, custom_objects, created_layers)
   1202   # First, we create all layers and enqueue nodes to be processed
   1203   for layer_data in config['layers']:
-> 1204     process_layer(layer_data)
   1205   # Then we process nodes in order of layer depth.
   1206   # Nodes that cannot yet be processed (if the inbound node

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py in process_layer(layer_data)
   1184       from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
   1185 
-> 1186       layer = deserialize_layer(layer_data, custom_objects=custom_objects)
   1187       created_layers[layer_name] = layer
   1188 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/serialization.py in deserialize(config, custom_objects)
    173       module_objects=LOCAL.ALL_OBJECTS,
    174       custom_objects=custom_objects,
--> 175       printable_module_name='layer')

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    358                 list(custom_objects.items())))
    359       with CustomObjectScope(custom_objects):
--> 360         return cls.from_config(cls_config)
    361     else:
    362       # Then `cls` may be a function returning a class.

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py in from_config(cls, config)
    695         A layer instance.
    696     """
--> 697     return cls(**config)
    698 
    699   def compute_output_shape(self, input_shape):

TypeError: __init__() got an unexpected keyword argument 'name'

When building the model, I used the custom layer like the following :

x = TemporalReshape(batch_size = 8, num_patches = 16)(x)

What is causing the error and how to load the model without this error?

Medication answered 13/10, 2020 at 14:22 Comment(3)
what if you put **kwargs in __init__?Zingale
@NicolasGervais Post your comment as answer and I will accept it. What you said in the above comment worked. Thanks a lot!! def __init__(self,batch_size,num_patches,**kwargs):Medication
That is quite interesting.Zingale
Z
20

Based on the error message only, I would suggest putting **kwargs in __init__. This object will then accept any other keyword argument that you haven't included.

def __init__(self, batch_size, num_patches, **kwargs):
        super(TemporalReshape, self).__init__(**kwargs) # <--- must, thanks https://stackoverflow.com/users/349130/dr-snoopy
        self.batch_size = batch_size
        self.num_patches = num_patches
Zingale answered 13/10, 2020 at 15:19 Comment(6)
This is correct but you are missing one key thing, the kwargs need to be passed to the parent initGib
Like this? super(TemporalReshape,self).__init__(**kwargs)Zingale
Yes that is what I meanGib
But, even without that I got no error. But thanks for the suggestionMedication
That's because the missing kwargs have default values. As a consequence you run the risk of not reconstructing the exact same thing you serialized.Withhold
I got an error missing 1 required positional argument after adding this.Siam
C
5

Insert **kwargs to __init__() function.

Error message: "TypeError: __init__() missing 3 required positional arguments: 'batch_size', 'num_patches'"

Cyathus answered 16/11, 2020 at 6:29 Comment(1)
This is not an answer, and the recent edit doesn't correspond to what the original answerer had said. Furthermore, this edit is a perfect copy of another answer.Zingale
N
0
class TemporalReshape(Layer):  
   def __init__(self, batch_size, num_patches, **kwargs):  
      if kwargs:  
         self.name=kwargs['name']  
      super(TemporalReshape, self).__init__()  
      ....  
Nonce answered 6/6 at 6:12 Comment(2)
As it’s currently written, your answer is unclear. Please edit to add additional details that will help others understand how this addresses the question asked. You can find more information on how to write good answers in the help center.Predicant
Please update your answer with some explanatory text and move your code to a codeblock. It will make your answer more useful for users may land here in the future.Farmergeneral

© 2022 - 2024 — McMap. All rights reserved.