Skip to content

Utils

jax_gather_nd

jax_gather_nd(params, indices)

Slice a params array at a set of indices.

This is a reimplementation of TensorFlow's gather_nd function: link

Parameters:

  • params (Float[Array, ' N *rest']) –

    an arbitrary array with leading axes of length \(N\) upon which we shall slice.

  • indices (Int[Array, ' M 1']) –

    an integer array of length \(M\) with values in the range \([0, N)\) whose value at index \(i\) will be used to slice params at index \(i\).

Returns:

  • Float[Array, ' M *rest'] –

    An arbitrary array with leading axes of length \(M\).