Non Conjugate Functions
gpjax.decision_making.test_functions.non_conjugate_functions
PoissonTestFunction
dataclass
Test function for GPs utilising the Poisson likelihood. Function taken from https://docs.jaxgaussianprocesses.com/examples/poisson/#dataset.
Attributes:
Name | Type | Description |
---|---|---|
search_space |
ContinuousSearchSpace
|
Search space for the function. |
search_space = ContinuousSearchSpace(lower_bounds=jnp.array([-2.0]), upper_bounds=jnp.array([2.0]))
class-attribute
instance-attribute
generate_dataset(num_points: int, key: KeyArray) -> Dataset
generate_test_points(num_points: int, key: KeyArray) -> Float[Array, 'N D']
Generate test points from the search space of the test function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
num_points |
int
|
Number of points to sample. |
required |
key |
KeyArray
|
JAX PRNG key. |
required |
Returns:
Type | Description |
---|---|
Float[Array, 'N D']
|
Float[Array, 'N D']: Test points sampled from the search space. |
evaluate(x: Float[Array, 'N 1']) -> Integer[Array, 'N 1']
abstractmethod
Evaluate the test function at a set of points. Function taken from https://docs.jaxgaussianprocesses.com/examples/poisson/#dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Float[Array, 'N D']
|
Points to evaluate the test function at. |
required |
Returns:
Type | Description |
---|---|
Integer[Array, 'N 1']
|
Integer[Array, 'N 1']: Values of the test function at the points. |