diff --git a/rllib/algorithms/impala/impala.py b/rllib/algorithms/impala/impala.py index e43db9ca45f0a..f9d8779c924f0 100644 --- a/rllib/algorithms/impala/impala.py +++ b/rllib/algorithms/impala/impala.py @@ -555,8 +555,6 @@ def setup(self, config: AlgorithmConfig): # Queue of data to be sent to the Learner. self.data_to_place_on_learner = [] - # The local mixin buffer (if required). - self.local_mixin_buffer = None self._batch_being_built = [] # @OldAPIStack # Create extra aggregation workers and assign each rollout worker to @@ -566,18 +564,17 @@ def setup(self, config: AlgorithmConfig): i: [] for i in range(self.config.num_learners or 1) } - # Create our local mixin buffer if the num of aggregation workers is 0. + # Create our local mixin buffer. if not self.config.enable_rl_module_and_learner: - if self.config.replay_proportion > 0.0: - self.local_mixin_buffer = MixInMultiAgentReplayBuffer( - capacity=( - self.config.replay_buffer_num_slots - if self.config.replay_buffer_num_slots > 0 - else 1 - ), - replay_ratio=self.config.replay_ratio, - replay_mode=ReplayMode.LOCKSTEP, - ) + self.local_mixin_buffer = MixInMultiAgentReplayBuffer( + capacity=( + self.config.replay_buffer_num_slots + if self.config.replay_buffer_num_slots > 0 + else 1 + ), + replay_ratio=self.config.replay_ratio, + replay_mode=ReplayMode.LOCKSTEP, + ) # This variable is used to keep track of the statistics from the most recent # update of the learner group @@ -1092,9 +1089,8 @@ def _process_experiences_old_api_stack( batch = batch.decompress_if_needed() # Only make a pass through the buffer, if replay proportion is > 0.0 (and # we actually have one). - if self.local_mixin_buffer: - self.local_mixin_buffer.add(batch) - batch = self.local_mixin_buffer.replay(_ALL_POLICIES) + self.local_mixin_buffer.add(batch) + batch = self.local_mixin_buffer.replay(_ALL_POLICIES) if batch: processed_batches.append(batch)