Here's the utility I use for this. See kronecker_test
for example of usage
def fix_shape(tf_shape):
return tuple(int(dim) for dim in tf_shape)
def concat_blocks(blocks, validate_dims=True):
"""Takes 2d grid of blocks representing matrices and concatenates to single
matrix (aka ArrayFlatten)"""
if validate_dims:
col_dims = np.array([[int(b.shape[1]) for b in row] for row in blocks])
col_sums = col_dims.sum(1)
assert (col_sums[0] == col_sums).all()
row_dims = np.array([[int(b.shape[0]) for b in row] for row in blocks])
row_sums = row_dims.sum(0)
assert (row_sums[0] == row_sums).all()
block_rows = [tf.concat(row, axis=1) for row in blocks]
return tf.concat(block_rows, axis=0)
def chunks(l, n):
"""Yield successive n-sized chunks from l."""
for i in range(0, len(l), n):
yield l[i:i + n]
from tensorflow.python.framework import ops
original_shape_func = ops.set_shapes_for_outputs
def disable_shape_inference():
ops.set_shapes_for_outputs = lambda _: _
def enable_shape_inference():
ops.set_shapes_for_outputs = original_shape_func
def kronecker(A, B, do_shape_inference=True):
"""Kronecker product of A,B.
turn_off_shape_inference: if True, makes 10x10 kron go 2.4 sec -> 0.9 sec
"""
Arows, Acols = fix_shape(A.shape)
Brows, Bcols = fix_shape(B.shape)
Crows, Ccols = Arows*Brows, Acols*Bcols
temp = tf.reshape(A, [-1, 1, 1])*tf.expand_dims(B, 0)
Bshape = tf.constant((Brows, Bcols))
# turn off shape inference
if not do_shape_inference:
disable_shape_inference()
# [1, n, m] => [n, m]
slices = [tf.reshape(s, Bshape) for s in tf.split(temp, Crows)]
# import pdb; pdb.set_trace()
grid = list(chunks(slices, Acols))
assert len(grid) == Arows
result = concat_blocks(grid, validate_dims=do_shape_inference)
if not do_shape_inference:
enable_shape_inference()
result.set_shape((Arows*Brows, Acols*Bcols))
return result
def kronecker_test():
A0 = [[1,2],[3,4]]
B0 = [[6,7],[8,9]]
A = tf.constant(A0)
B = tf.constant(B0)
C = kronecker(A, B)
sess = tf.Session()
C0 = sess.run(C)
Ct = [[6, 7, 12, 14], [8, 9, 16, 18], [18, 21, 24, 28], [24, 27, 32, 36]]
Cnp = np.kron(A0, B0)
check_equal(C0, Ct)
check_equal(C0, Cnp)