From d0963a70ceb4d6801bbab24339a7cd9847d399b6 Mon Sep 17 00:00:00 2001 From: DvHuang <80681595@qq.com> Date: Thu, 16 Feb 2023 15:47:30 +0800 Subject: [PATCH] make sure transformer return past_key_values --- rl4lms/envs/text_generation/policy/causal_policy.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/rl4lms/envs/text_generation/policy/causal_policy.py b/rl4lms/envs/text_generation/policy/causal_policy.py index bdb41225..425f94a0 100644 --- a/rl4lms/envs/text_generation/policy/causal_policy.py +++ b/rl4lms/envs/text_generation/policy/causal_policy.py @@ -103,6 +103,10 @@ def _prepare_inputs_for_model( input_ids, **model_kwargs ) + """ Make sure to use the configuration in the configuration file""" + if model_inputs.get("use_cache", None) is None: + model_inputs['use_cache'] = self._generation_kwargs.get("use_cache", None) + if self._apply_model_parallel and unwrap_model(model).is_parallelizable: # if model is in parallel mode, move the tensors to the first device model_inputs = {