Utils
jax_gather_nd
jax_gather_nd(
params: Float[Array, " N *rest"],
indices: Int[Array, " M 1"],
) -> Float[Array, " M *rest"]
Slice a params array at a set of indices.
This is a reimplementation of TensorFlow's gather_nd function:
link
Parameters:
-
params(Float[Array, ' N *rest']) βan arbitrary array with leading axes of length upon which we shall slice.
-
indices(Int[Array, ' M 1']) βan integer array of length with values in the range whose value at index will be used to slice
paramsat index .
Returns:
-
Float[Array, ' M *rest']βAn arbitrary array with leading axes of length .
calculate_heat_semigroup
Returns the rescaled heat semigroup, S
Parameters:
-
kernel(GraphKernel) βinstance of the graph kernel
Returns:
-
Float[Array, 'N M']βS