Utils
jax_gather_nd
Slice a params
array at a set of indices
.
This is a reimplementation of TensorFlow's gather_nd
function:
link
Parameters:
-
params
(Float[Array, ' N *rest']
) βan arbitrary array with leading axes of length \(N\) upon which we shall slice.
-
indices
(Int[Array, ' M 1']
) βan integer array of length \(M\) with values in the range \([0, N)\) whose value at index \(i\) will be used to slice
params
at index \(i\).
Returns:
-
Float[Array, ' M *rest']
βAn arbitrary array with leading axes of length \(M\).