I am trying RNN on a variable length multivariate sequence classification problem.
I have defined following function to get the output of the sequence (i.e. the output of RNN cell after the final input from sequence is fed)
def get_sequence_output(x_sequence, initial_hidden_state):
previous_hidden_state = initial_hidden_state
for x_single in x_sequence:
hidden_state = gru_unit(previous_hidden_state, x_single)
previous_hidden_state = hidden_state
final_hidden_state = hidden_state
return final_hidden_state
Here x_sequence
is tensor of shape (?, ?, 10)
where first ? is for batch size and second ? is for sequence length and each input element is of length 10. gru
function takes a previous hidden state and current input and spits out next hidden state (a standard gated recurrent unit).
I am getting an error: 'Tensor' object is not iterable.
How do I iterate over a Tensor in sequence manner (reading single element at a time)?
My objective is to apply gru
function for every input from the sequence and get the final hidden state.