diff --git a/tensorflow_probability/python/optimizer/bfgs.py b/tensorflow_probability/python/optimizer/bfgs.py index ea03553800..a157db7a4d 100644 --- a/tensorflow_probability/python/optimizer/bfgs.py +++ b/tensorflow_probability/python/optimizer/bfgs.py @@ -82,6 +82,7 @@ def minimize(value_and_gradients_function, validate_args=True, max_line_search_iterations=50, f_absolute_tolerance=0, + line_search_kwargs=None, name=None): """Applies the BFGS algorithm to minimize a differentiable function. @@ -177,6 +178,11 @@ def quadratic_loss_and_gradient(x): f_absolute_tolerance: Scalar `Tensor` of real dtype. If the absolute change in the objective value between one iteration and the next is smaller than this value, the algorithm is stopped. + line_search_kwargs: (Optional) Python `dict` of extra keyword arguments to + forward to the underlying `hager_zhang` line search, allowing the caller + to tune line search hyper-parameters such as `sufficient_decrease_param` + or `curvature_param`. See `tfp.optimizer.linesearch.hager_zhang` for the + full list of accepted parameters. name: (Optional) Python str. The name prefixed to the ops created by this function. If not supplied, the default name 'minimize' is used. @@ -286,7 +292,8 @@ def _body(state): next_state = bfgs_utils.line_search_step( current_state, value_and_gradients_function, actual_search_direction, tolerance, f_relative_tolerance, x_tolerance, stopping_condition, - max_line_search_iterations, f_absolute_tolerance) + max_line_search_iterations, f_absolute_tolerance, + line_search_kwargs=line_search_kwargs) # Update the inverse Hessian if needed and continue. return [_update_inv_hessian(current_state, next_state)] diff --git a/tensorflow_probability/python/optimizer/bfgs_test.py b/tensorflow_probability/python/optimizer/bfgs_test.py index 5a200dc5c0..a90122530e 100644 --- a/tensorflow_probability/python/optimizer/bfgs_test.py +++ b/tensorflow_probability/python/optimizer/bfgs_test.py @@ -645,6 +645,32 @@ def quadratic_with_spike(x): self.assertAllTrue(results.failed) self.assertAllFinite(results.position) + def test_line_search_kwargs_are_forwarded(self): + """`line_search_kwargs` reach `hager_zhang` and affect the search.""" + @_make_val_and_grad_fn + def rosenbrock(coord): + x, y = coord[0], coord[1] + return (1 - x)**2 + 100 * (y - x**2)**2 + + start = tf.constant([-1.2, 1.0]) + default_results = self.evaluate( + bfgs.minimize(rosenbrock, initial_position=start, tolerance=1e-5)) + small_step_results = self.evaluate( + bfgs.minimize( + rosenbrock, initial_position=start, tolerance=1e-5, + line_search_kwargs=dict(initial_step_size=0.1))) + + # Both runs should still converge to the Rosenbrock minimum. + self.assertTrue(default_results.converged) + self.assertTrue(small_step_results.converged) + self.assertArrayNear(small_step_results.position, + np.array([1.0, 1.0]), 1e-5) + # The default starts each line search at a unit step. Shrinking that to + # 0.1 must change the line search trajectory, and therefore the total + # number of objective evaluations. + self.assertNotEqual(default_results.num_objective_evaluations, + small_step_results.num_objective_evaluations) + if __name__ == '__main__': test_util.main() diff --git a/tensorflow_probability/python/optimizer/bfgs_utils.py b/tensorflow_probability/python/optimizer/bfgs_utils.py index 0495651e3f..889b93220f 100644 --- a/tensorflow_probability/python/optimizer/bfgs_utils.py +++ b/tensorflow_probability/python/optimizer/bfgs_utils.py @@ -148,7 +148,8 @@ def _is_negative_inf(x): def line_search_step(state, value_and_gradients_function, search_direction, grad_tolerance, f_relative_tolerance, x_tolerance, - stopping_condition, max_iterations, f_absolute_tolerance): + stopping_condition, max_iterations, f_absolute_tolerance, + line_search_kwargs=None): """Performs the line search step of the BFGS search procedure. Uses hager_zhang line search procedure to compute a suitable step size @@ -182,6 +183,12 @@ def line_search_step(state, value_and_gradients_function, search_direction, iterations of the hager_zhang line search algorithm f_absolute_tolerance: Scalar `Tensor` of real dtype. Specifies the tolerance for the absolute change in the objective value. + line_search_kwargs: (Optional) Python `dict` of extra keyword arguments to + forward to the underlying `hager_zhang` line search, allowing the caller + to tune line search hyper-parameters such as `initial_step_size`, + `sufficient_decrease_param` or `curvature_param`. Keys that collide with + arguments controlled by the outer optimization loop (`value_at_zero`, + `converged`, `max_iterations`) will raise a `TypeError`. Returns: A copy of the input state with the following fields updated: @@ -206,12 +213,14 @@ def line_search_step(state, value_and_gradients_function, search_direction, df=derivative_at_start_pt, full_gradient=state.objective_gradient) inactive = state.failed | state.converged + ls_kwargs = dict(line_search_kwargs) if line_search_kwargs else {} + ls_kwargs.setdefault('initial_step_size', _broadcast(1, state.position)) ls_result = hager_zhang( line_search_value_grad_func, - initial_step_size=_broadcast(1, state.position), value_at_zero=val_0, - converged=inactive, - max_iterations=max_iterations) # No search needed for these. + converged=inactive, # No search needed for these. + max_iterations=max_iterations, + **ls_kwargs) state_after_ls = update_fields( state,