Given the previous answers do not take into account the memory required for gradients, and/or intermediate outputs, and/or mixed dtypes, and/or nested models, I decided to give it a go as well. Note that the function returns the estimated memory requirement in bits, that the model must be compiled with a fully known input shape (including batch_size
), and that the function does not consider the memory required for internal computations (e.g., neural attention). Microsoft has developed a method that is likely more accurate, but has not released the code.
import tensorflow as tf, warnings
# Define function to calculate one layer's memory requirement
def layer_mem(layer: tf.keras.layers.Layer, prev_layer_mem: int) -> int:
# Check whether calculations can be performed
if not hasattr(layer, "output_shape") or (None in layer.output_shape):
msg = f"Check `model.summary(expand_nested=True)` and recompile model to ensure that {layer.name} has a fully defined `output_shape`, including `batch_size`. Using previous layer's memory requirement."
warnings.warn(msg)
return prev_layer_mem
# Collect sizes
out_size = int(tf.reduce_prod(layer.output_shape))
params = gradients = int(layer.count_params())
bits = int(layer.dtype[-2:])
# Calculate memory requirement
return (params+gradients+out_size)*bits
# Define recursive function to gather all layers' memory requirements
def model_mem(model: tf.keras.Model) -> int:
# Make limitations known
warnings.warn("This function does not take into account the memory required for calculations (e.g., outer products)")
# Initialize
total_bits = 0
# Loop over layers in model
for layer in model.layers:
# In case of nested model...
if hasattr(layer, "layers"):
# ... apply recursion
total_bits += model_mem(layer)
else:
# Calculate and add layer's memory requirement
prev_layer_mem = layer_mem(layer, locals().get("prev_layer_mem", 0))
total_bits += prev_layer_mem
return total_bits