Skip to content

Utils

jax_gather_nd

jax_gather_nd(
    params: Float[Array, " N *rest"],
    indices: Int[Array, " M 1"],
) -> Float[Array, " M *rest"]

Slice a params array at a set of indices.

This is a reimplementation of TensorFlow's gather_nd function: link

Parameters:

  • params (Float[Array, ' N *rest']) –

    an arbitrary array with leading axes of length NN upon which we shall slice.

  • indices (Int[Array, ' M 1']) –

    an integer array of length MM with values in the range [0,N)[0, N) whose value at index ii will be used to slice params at index ii.

Returns:

  • Float[Array, ' M *rest'] –

    An arbitrary array with leading axes of length MM.

calculate_heat_semigroup

calculate_heat_semigroup(
    kernel: GraphKernel,
) -> Float[Array, "N M"]

Returns the rescaled heat semigroup, S

Parameters:

  • kernel (GraphKernel) –

    instance of the graph kernel

Returns:

  • Float[Array, 'N M'] –

    S