Skip to content

Decision Maker

gpjax.decision_making.decision_maker

AbstractDecisionMaker dataclass

Bases: ABC

AbstractDecisionMaker abstract base class which handles the core decision making loop, where we sequentially decide on points to query our function of interest at. The decision making loop is split into two key steps, ask and tell. The ask step is typically used to decide which point to query next. The tell step is typically used to update models and datasets with newly queried points. These steps can be combined in a 'run' loop which alternates between asking which point to query next and telling the decision maker about the newly queried point having evaluated the black-box function of interest at this point.

Attributes:

Name Type Description
search_space AbstractSearchSpace

Search space over which we can evaluate the

posterior_handlers Dict[str, PosteriorHandler]

Dictionary of posterior handlers, which are used to update posteriors throughout the decision making loop. Note that the word posteriors is used for consistency with GPJax, but these objects are typically referred to as models in the model-based decision making literature. Tags are used to distinguish between posteriors. In a typical Bayesian optimisation setup one of the tags will be OBJECTIVE, defined in decision_making.utils.

datasets Dict[str, Dataset]

Dictionary of datasets, which are augmented with observations throughout the decision making loop. In a typical setup they are also used to update the posteriors, using the posterior_handlers. Tags are used to distinguish datasets, and correspond to tags in posterior_handlers.

key KeyArray

JAX random key, used to generate random numbers.

batch_size int

Number of points to query at each step of the decision making loop. Note that SinglePointUtilityFunctions are only capable of generating one point to be queried at each iteration of the decision making loop.

post_ask List[Callable]

List of functions to be executed after each ask step.

post_tell List[Callable]

List of functions to be executed after each tell step.

search_space: AbstractSearchSpace instance-attribute
posterior_handlers: Dict[str, PosteriorHandler] instance-attribute
datasets: Dict[str, Dataset] instance-attribute
key: KeyArray instance-attribute
batch_size: int instance-attribute
post_ask: List[Callable] instance-attribute
post_tell: List[Callable] instance-attribute
__init__(search_space: AbstractSearchSpace, posterior_handlers: Dict[str, PosteriorHandler], datasets: Dict[str, Dataset], key: KeyArray, batch_size: int, post_ask: List[Callable], post_tell: List[Callable]) -> None
__post_init__()

At initialisation we check that the posterior handlers and datasets are consistent (i.e. have the same tags), and then initialise the posteriors, optimizing them using the corresponding datasets.

ask(key: KeyArray) -> Float[Array, 'B D'] abstractmethod

Get the point(s) to be queried next.

Parameters:

Name Type Description Default
key KeyArray

JAX PRNG key for controlling random state.

required

Returns:

Type Description
Float[Array, 'B D']

Float[Array, "1 D"]: Point to be queried next

tell(observation_datasets: Mapping[str, Dataset], key: KeyArray)

Add newly observed data to datasets and update the corresponding posteriors.

Parameters:

Name Type Description Default
observation_datasets Mapping[str, Dataset]

Dictionary of datasets

required
key KeyArray

JAX PRNG key for controlling random state.

required
run(n_steps: int, black_box_function_evaluator: FunctionEvaluator) -> Mapping[str, Dataset]

Run the decision making loop continuously for for n_steps. This is broken down into three main steps: 1. Call the ask method to get the point to be queried next. 2. Call the black_box_function_evaluator to evaluate the black box functions of interest at the point chosen to be queried. 3. Call the tell method to update the datasets and posteriors with the newly observed data.

In addition to this, after the ask step, the functions in the post_ask list are executed, taking as arguments the decision maker and the point chosen to be queried next. Similarly, after the tell step, the functions in the post_tell list are executed, taking the decision maker as the sole argument.

Parameters:

Name Type Description Default
n_steps int

Number of steps to run the decision making loop for.

required
black_box_function_evaluator FunctionEvaluator

Function evaluator which evaluates the black box functions of interest at supplied points.

required

Returns:

Type Description
Mapping[str, Dataset]

Mapping[str, Dataset]: Dictionary of datasets containing the observations

Mapping[str, Dataset]

made throughout the decision making loop, as well as the initial data

Mapping[str, Dataset]

supplied when initialising the DecisionMaker.

UtilityDrivenDecisionMaker dataclass

Bases: AbstractDecisionMaker

UtilityDrivenDecisionMaker class which handles the core decision making loop in a typical model-based decision making setup. In this setup we use surrogate model(s) for the function(s) of interest, and define a utility function (often called the 'acquisition function' in the context of Bayesian optimisation) which characterises how useful it would be to query a given point within the search space given the data we have observed so far. This can then be used to decide which point(s) to query next.

The decision making loop is split into two key steps, ask and tell. The ask step forms a UtilityFunction from the current posteriors and datasets and returns the point which maximises it. It also stores the formed utility function under the attribute self.current_utility_function so that it can be called, for instance for plotting, after the ask function has been called. The tell step adds a newly queried point to the datasets and updates the posteriors.

This can be run as a typical ask-tell loop, or the run method can be used to run the decision making loop for a fixed number of steps. Moreover, the run method executes the functions in post_ask and post_tell after each ask and tell step respectively. This enables the user to add custom functionality, such as the ability to plot values of interest during the optimization process.

Attributes:

Name Type Description
utility_function_builder AbstractUtilityFunctionBuilder

Object which builds utility functions from posteriors and datasets, to decide where to query next. In a typical Bayesian optimisation setup the point chosen to be queried next is the point which maximizes the utility function.

utility_maximizer AbstractUtilityMaximizer

Object which maximizes utility functions over the search space.

search_space: AbstractSearchSpace instance-attribute
posterior_handlers: Dict[str, PosteriorHandler] instance-attribute
datasets: Dict[str, Dataset] instance-attribute
key: KeyArray instance-attribute
batch_size: int instance-attribute
post_ask: List[Callable] instance-attribute
post_tell: List[Callable] instance-attribute
utility_function_builder: AbstractUtilityFunctionBuilder instance-attribute
utility_maximizer: AbstractUtilityMaximizer instance-attribute
tell(observation_datasets: Mapping[str, Dataset], key: KeyArray)

Add newly observed data to datasets and update the corresponding posteriors.

Parameters:

Name Type Description Default
observation_datasets Mapping[str, Dataset]

Dictionary of datasets

required
key KeyArray

JAX PRNG key for controlling random state.

required
run(n_steps: int, black_box_function_evaluator: FunctionEvaluator) -> Mapping[str, Dataset]

Run the decision making loop continuously for for n_steps. This is broken down into three main steps: 1. Call the ask method to get the point to be queried next. 2. Call the black_box_function_evaluator to evaluate the black box functions of interest at the point chosen to be queried. 3. Call the tell method to update the datasets and posteriors with the newly observed data.

In addition to this, after the ask step, the functions in the post_ask list are executed, taking as arguments the decision maker and the point chosen to be queried next. Similarly, after the tell step, the functions in the post_tell list are executed, taking the decision maker as the sole argument.

Parameters:

Name Type Description Default
n_steps int

Number of steps to run the decision making loop for.

required
black_box_function_evaluator FunctionEvaluator

Function evaluator which evaluates the black box functions of interest at supplied points.

required

Returns:

Type Description
Mapping[str, Dataset]

Mapping[str, Dataset]: Dictionary of datasets containing the observations

Mapping[str, Dataset]

made throughout the decision making loop, as well as the initial data

Mapping[str, Dataset]

supplied when initialising the DecisionMaker.

__init__(search_space: AbstractSearchSpace, posterior_handlers: Dict[str, PosteriorHandler], datasets: Dict[str, Dataset], key: KeyArray, batch_size: int, post_ask: List[Callable], post_tell: List[Callable], utility_function_builder: AbstractUtilityFunctionBuilder, utility_maximizer: AbstractUtilityMaximizer) -> None
__post_init__()
ask(key: KeyArray) -> Float[Array, 'B D']

Get updated utility function(s) and return the point(s) which maximises it/them. This method also stores the utility function(s) in self.current_utility_functions so that they can be accessed after the ask function has been called. This is useful for non-deterministic utility functions, which may differ between calls to ask due to the splitting of self.key.

Note that in general SinglePointUtilityFunctions are only capable of generating one point to be queried at each iteration of the decision making loop (i.e. self.batch_size must be 1). However, Thompson sampling can be used in a batched setting by drawing a batch of different samples from the GP posterior. This is done by calling build_utility_function with different keys sequentilly, and optimising each of these individual samples in sequence in order to obtain self.batch_size points to query next.

Parameters:

Name Type Description Default
key KeyArray

JAX PRNG key for controlling random state.

required

Returns:

Type Description
Float[Array, 'B D']

Float[Array, "B D"]: Point(s) to be queried next.