Skip to content

Utils

gpjax.kernels.non_euclidean.utils

jax_gather_nd(params: Float[Array, ' N *rest'], indices: Int[Array, ' M 1']) -> Float[Array, ' M *rest']

Slice a params array at a set of indices.

This is a reimplementation of TensorFlow's gather_nd function: link

Parameters:

Name Type Description Default
params Float[Array]

An arbitrary array with leading axes of length $N$ upon which we shall slice.

required
indices Float[Int]

An integer array of length $M$ with values in the range $[0, N)$ whose value at index $i$ will be used to slice params at index $i$.

required
Returns
Float[Array: An arbitrary array with leading axes of length $M$.