Skip to content

Utils

jax_gather_nd

jax_gather_nd(params, indices)

Slice a params array at a set of indices.

This is a reimplementation of TensorFlow's gather_nd function: link

Parameters:

  • params (Float[Array, ' N *rest']) –

    an arbitrary array with leading axes of length NN upon which we shall slice.

  • indices (Int[Array, ' M 1']) –

    an integer array of length MM with values in the range [0,N)[0, N) whose value at index ii will be used to slice params at index ii.

Returns:

  • Float[Array, ' M *rest'] –

    An arbitrary array with leading axes of length MM.