MXNet print intermediate symbol values
Asked Answered
P

1

5

How do i find the actual numerical values held in an MXNet symbol.

Suppose I have,

x = mx.sym.Variable('x')
y = mx.sym.Variable('y')
z = x + y, 

if x = [100,200] and y=[300,400], I want to print:

z = [400,600],

sort of like tensorflow's eval() method

Polyamide answered 24/3, 2017 at 22:3 Comment(0)
P
8

After looking around a bit, I found you can do this by:

x = mx.sym.Variable('x')
y = mx.sym.Variable('y')
z = x + y
executor = z.bind(mx.cpu(), {'x': mx.nd.array([100,200]), 'y':mx.nd.array([300,400])})
output = executor.forward()

will give you the 'output':

[<NDArray 2 @cpu(0)>]

To print the actual numerical output:

print output[0].asnumpy()
array([ 400.,  600.], dtype=float32)
Polyamide answered 24/3, 2017 at 22:11 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.