ntt123 commited on
Commit
ab9e1be
·
verified ·
1 Parent(s): e5efec0

Update tacotron.py

Browse files
Files changed (1) hide show
  1. tacotron.py +1 -1
tacotron.py CHANGED
@@ -71,7 +71,7 @@ class BiGRU(pax.Module):
71
  return jnp.where(reset_mask, x0, xt)
72
 
73
  state, _ = self.rnn_bwd(prev, x)
74
- state = jax.tree_map(reset_state, x_bwd_states0, state)
75
  return state, state.hidden
76
 
77
  x_bwd_states, x_bwd = pax.scan(
 
71
  return jnp.where(reset_mask, x0, xt)
72
 
73
  state, _ = self.rnn_bwd(prev, x)
74
+ state = jax.tree_util.tree_map(reset_state, x_bwd_states0, state)
75
  return state, state.hidden
76
 
77
  x_bwd_states, x_bwd = pax.scan(