Utils
gpjax.kernels.non_euclidean.utils
jax_gather_nd(params: Float[Array, ' N *rest'], indices: Int[Array, ' M 1']) -> Float[Array, ' M *rest']
Slice a params
array at a set of indices
.
This is a reimplementation of TensorFlow's gather_nd
function:
link
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params |
Float[Array]
|
An arbitrary array with leading axes of length $N$ upon which we shall slice. |
required |
indices |
Float[Int]
|
An integer array of length $M$ with values in the range
$[0, N)$ whose value at index $i$ will be used to slice |
required |
Returns
Float[Array: An arbitrary array with leading axes of length $M$.