From 70c9d9d32ac857eb2ef26c3337e841a03c35e34c Mon Sep 17 00:00:00 2001 From: wenzhangliu Date: Fri, 15 Dec 2023 16:17:34 +0800 Subject: [PATCH] update qrdqn for tensorflow --- demo_tensorflow.py | 2 +- .../learners/qlearning_family/qrdqn_learner.py | 18 +++++++++++------- xuance/tensorflow/policies/deterministic.py | 2 +- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/demo_tensorflow.py b/demo_tensorflow.py index af96bb14..d0abfce9 100644 --- a/demo_tensorflow.py +++ b/demo_tensorflow.py @@ -4,7 +4,7 @@ def parse_args(): parser = argparse.ArgumentParser("Run a demo.") - parser.add_argument("--method", type=str, default="perdqn") + parser.add_argument("--method", type=str, default="qrdqn") parser.add_argument("--env", type=str, default="classic_control") parser.add_argument("--env-id", type=str, default="CartPole-v1") parser.add_argument("--test", type=int, default=0) diff --git a/xuance/tensorflow/learners/qlearning_family/qrdqn_learner.py b/xuance/tensorflow/learners/qlearning_family/qrdqn_learner.py index 0f6200fc..9abb9fad 100644 --- a/xuance/tensorflow/learners/qlearning_family/qrdqn_learner.py +++ b/xuance/tensorflow/learners/qlearning_family/qrdqn_learner.py @@ -5,14 +5,13 @@ class QRDQN_Learner(Learner): def __init__(self, policy: tk.Model, optimizer: tk.optimizers.Optimizer, - summary_writer: Optional[SummaryWriter] = None, device: str = "cpu:0", model_dir: str = "./", gamma: float = 0.99, sync_frequency: int = 100): self.gamma = gamma self.sync_frequency = sync_frequency - super(QRDQN_Learner, self).__init__(policy, optimizer, summary_writer, device, model_dir) + super(QRDQN_Learner, self).__init__(policy, optimizer, device, model_dir) def update(self, obs_batch, act_batch, rew_batch, next_batch, terminal_batch): self.iterations += 1 @@ -22,8 +21,8 @@ def update(self, obs_batch, act_batch, rew_batch, next_batch, terminal_batch): ter_batch = tf.convert_to_tensor(terminal_batch) with tf.GradientTape() as tape: - _, _, evalZ, _ = self.policy(obs_batch) - _, targetA, _, targetZ = self.policy(next_batch) + _, _, evalZ = self.policy(obs_batch) + _, targetA, targetZ = self.policy.target(next_batch) current_quantile = tf.math.reduce_sum(evalZ * tf.expand_dims(tf.one_hot(act_batch, evalZ.shape[1]), axis=-1), axis=1) target_quantile = tf.math.reduce_sum(targetZ * tf.expand_dims(tf.one_hot(targetA, evalZ.shape[1]), axis=-1), axis=1) target_quantile = tf.expand_dims(rew_batch, 1) + self.gamma * target_quantile * (1 - tf.expand_dims(ter_batch, 1)) @@ -41,6 +40,11 @@ def update(self, obs_batch, act_batch, rew_batch, next_batch, terminal_batch): self.policy.copy_target() lr = self.optimizer._decayed_lr(tf.float32) - self.writer.add_scalar("Qloss", loss.numpy(), self.iterations) - self.writer.add_scalar("predictQ", tf.math.reduce_mean(current_quantile).numpy(), self.iterations) - self.writer.add_scalar("lr", lr.numpy(), self.iterations) + + info = { + "Qloss": loss.numpy(), + "predictQ": tf.math.reduce_mean(current_quantile).numpy(), + "lr": lr.numpy() + } + + return info diff --git a/xuance/tensorflow/policies/deterministic.py b/xuance/tensorflow/policies/deterministic.py index d2f1d622..b44583b7 100644 --- a/xuance/tensorflow/policies/deterministic.py +++ b/xuance/tensorflow/policies/deterministic.py @@ -342,7 +342,7 @@ def __init__(self, self.target_Zhead = QRDQNhead(self.representation.output_shapes['state'][0], self.action_dim, self.quantile_num, hidden_size, normalize, initialize, activation, device) - self.copy_target() + self.target_Zhead.set_weights(self.eval_Zhead.get_weights()) def call(self, observation: Union[np.ndarray, dict], **kwargs): outputs = self.representation(observation)