Skip to content

Non Conjugate Functions

gpjax.decision_making.test_functions.non_conjugate_functions

PoissonTestFunction dataclass

Test function for GPs utilising the Poisson likelihood. Function taken from https://docs.jaxgaussianprocesses.com/examples/poisson/#dataset.

Attributes:

Name Type Description
search_space ContinuousSearchSpace

Search space for the function.

search_space = ContinuousSearchSpace(lower_bounds=jnp.array([-2.0]), upper_bounds=jnp.array([2.0])) class-attribute instance-attribute
__init__(search_space=ContinuousSearchSpace(lower_bounds=jnp.array([-2.0]), upper_bounds=jnp.array([2.0]))) -> None
generate_dataset(num_points: int, key: KeyArray) -> Dataset

Generate a toy dataset from the test function.

Parameters:

Name Type Description Default
num_points int

Number of points to sample.

required
key KeyArray

JAX PRNG key.

required

Returns:

Name Type Description
Dataset Dataset

Dataset of points sampled from the test function.

generate_test_points(num_points: int, key: KeyArray) -> Float[Array, 'N D']

Generate test points from the search space of the test function.

Parameters:

Name Type Description Default
num_points int

Number of points to sample.

required
key KeyArray

JAX PRNG key.

required

Returns:

Type Description
Float[Array, 'N D']

Float[Array, 'N D']: Test points sampled from the search space.

evaluate(x: Float[Array, 'N 1']) -> Integer[Array, 'N 1'] abstractmethod

Evaluate the test function at a set of points. Function taken from https://docs.jaxgaussianprocesses.com/examples/poisson/#dataset.

Parameters:

Name Type Description Default
x Float[Array, 'N D']

Points to evaluate the test function at.

required

Returns:

Type Description
Integer[Array, 'N 1']

Integer[Array, 'N 1']: Values of the test function at the points.