Skip to content

Decision Maker

AbstractDecisionMaker dataclass

AbstractDecisionMaker(search_space, posterior_handlers, datasets, key, batch_size, post_ask, post_tell)

Bases: ABC

AbstractDecisionMaker abstract base class which handles the core decision making loop, where we sequentially decide on points to query our function of interest at. The decision making loop is split into two key steps, ask and tell. The ask step is typically used to decide which point to query next. The tell step is typically used to update models and datasets with newly queried points. These steps can be combined in a 'run' loop which alternates between asking which point to query next and telling the decision maker about the newly queried point having evaluated the black-box function of interest at this point.

Attributes:

  • search_space (AbstractSearchSpace) –

    Search space over which we can evaluate the function(s) of interest.

  • posterior_handlers (Dict[str, PosteriorHandler]) –

    dictionary of posterior handlers, which are used to update posteriors throughout the decision making loop. Note that the word posteriors is used for consistency with GPJax, but these objects are typically referred to as models in the model-based decision making literature. Tags are used to distinguish between posteriors. In a typical Bayesian optimisation setup one of the tags will be OBJECTIVE, defined in decision_making.utils.

  • datasets (Dict[str, Dataset]) –

    dictionary of datasets, which are augmented with observations throughout the decision making loop. In a typical setup they are also used to update the posteriors, using the posterior_handlers. Tags are used to distinguish datasets, and correspond to tags in posterior_handlers.

  • key (KeyArray) –

    JAX random key, used to generate random numbers.

  • batch_size (int) –

    Number of points to query at each step of the decision making loop. Note that SinglePointUtilityFunctions are only capable of generating one point to be queried at each iteration of the decision making loop.

  • post_ask (List[Callable]) –

    List of functions to be executed after each ask step.

  • post_tell (List[Callable]) –

    List of functions to be executed after each tell step.

ask abstractmethod

ask(key)

Get the point(s) to be queried next.

Parameters:

  • key (KeyArray) –

    JAX PRNG key for controlling random state.

Returns:

  • Float[Array, 'B D'] –

    Float[Array, "1 D"]: Point to be queried next

tell

tell(observation_datasets, key)

Add newly observed data to datasets and update the corresponding posteriors.

Parameters:

  • observation_datasets (Mapping[str, Dataset]) –

    dictionary of datasets containing new observations. Tags are used to distinguish datasets, and correspond to tags in posterior_handlers and self.datasets.

  • key (KeyArray) –

    JAX PRNG key for controlling random state.

run

run(n_steps, black_box_function_evaluator)

Run the decision making loop continuously for for n_steps. This is broken down into three main steps: 1. Call the ask method to get the point to be queried next. 2. Call the black_box_function_evaluator to evaluate the black box functions of interest at the point chosen to be queried. 3. Call the tell method to update the datasets and posteriors with the newly observed data.

In addition to this, after the ask step, the functions in the post_ask list are executed, taking as arguments the decision maker and the point chosen to be queried next. Similarly, after the tell step, the functions in the post_tell list are executed, taking the decision maker as the sole argument.

Parameters:

  • n_steps (int) –

    Number of steps to run the decision making loop for.

  • black_box_function_evaluator (FunctionEvaluator) –

    Function evaluator which evaluates the black box functions of interest at supplied points.

Returns:

  • Mapping[str, Dataset] –

    Mapping[str, Dataset]: Dictionary of datasets containing the observations

  • Mapping[str, Dataset] –

    made throughout the decision making loop, as well as the initial data

  • Mapping[str, Dataset] –

    supplied when initialising the DecisionMaker.

UtilityDrivenDecisionMaker dataclass

UtilityDrivenDecisionMaker(search_space, posterior_handlers, datasets, key, batch_size, post_ask, post_tell, utility_function_builder, utility_maximizer)

Bases: AbstractDecisionMaker

UtilityDrivenDecisionMaker class which handles the core decision making loop in a typical model-based decision making setup. In this setup we use surrogate model(s) for the function(s) of interest, and define a utility function (often called the 'acquisition function' in the context of Bayesian optimisation) which characterises how useful it would be to query a given point within the search space given the data we have observed so far. This can then be used to decide which point(s) to query next.

The decision making loop is split into two key steps, ask and tell. The ask step forms a UtilityFunction from the current posteriors and datasets and returns the point which maximises it. It also stores the formed utility function under the attribute self.current_utility_function so that it can be called, for instance for plotting, after the ask function has been called. The tell step adds a newly queried point to the datasets and updates the posteriors.

This can be run as a typical ask-tell loop, or the run method can be used to run the decision making loop for a fixed number of steps. Moreover, the run method executes the functions in post_ask and post_tell after each ask and tell step respectively. This enables the user to add custom functionality, such as the ability to plot values of interest during the optimization process.

Attributes:

  • utility_function_builder (AbstractUtilityFunctionBuilder) –

    Object which builds utility functions from posteriors and datasets, to decide where to query next. In a typical Bayesian optimisation setup the point chosen to be queried next is the point which maximizes the utility function.

  • utility_maximizer (AbstractUtilityMaximizer) –

    Object which maximizes utility functions over the search space.

tell

tell(observation_datasets, key)

Add newly observed data to datasets and update the corresponding posteriors.

Parameters:

  • observation_datasets (Mapping[str, Dataset]) –

    dictionary of datasets containing new observations. Tags are used to distinguish datasets, and correspond to tags in posterior_handlers and self.datasets.

  • key (KeyArray) –

    JAX PRNG key for controlling random state.

run

run(n_steps, black_box_function_evaluator)

Run the decision making loop continuously for for n_steps. This is broken down into three main steps: 1. Call the ask method to get the point to be queried next. 2. Call the black_box_function_evaluator to evaluate the black box functions of interest at the point chosen to be queried. 3. Call the tell method to update the datasets and posteriors with the newly observed data.

In addition to this, after the ask step, the functions in the post_ask list are executed, taking as arguments the decision maker and the point chosen to be queried next. Similarly, after the tell step, the functions in the post_tell list are executed, taking the decision maker as the sole argument.

Parameters:

  • n_steps (int) –

    Number of steps to run the decision making loop for.

  • black_box_function_evaluator (FunctionEvaluator) –

    Function evaluator which evaluates the black box functions of interest at supplied points.

Returns:

  • Mapping[str, Dataset] –

    Mapping[str, Dataset]: Dictionary of datasets containing the observations

  • Mapping[str, Dataset] –

    made throughout the decision making loop, as well as the initial data

  • Mapping[str, Dataset] –

    supplied when initialising the DecisionMaker.

ask

ask(key)

Get updated utility function(s) and return the point(s) which maximises it/them. This method also stores the utility function(s) in self.current_utility_functions so that they can be accessed after the ask function has been called. This is useful for non-deterministic utility functions, which may differ between calls to ask due to the splitting of self.key.

Note that in general SinglePointUtilityFunctions are only capable of generating one point to be queried at each iteration of the decision making loop (i.e. self.batch_size must be 1). However, Thompson sampling can be used in a batched setting by drawing a batch of different samples from the GP posterior. This is done by calling build_utility_function with different keys sequentilly, and optimising each of these individual samples in sequence in order to obtain self.batch_size points to query next.

Parameters:

  • key (KeyArray) –

    JAX PRNG key for controlling random state.

Returns:

  • Float[Array, 'B D'] –

    Float[Array, "B D"]: Point(s) to be queried next.