Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
baijinqiu committed Dec 15, 2023
2 parents 28b6e73 + 70c9d9d commit 1a89a60
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 9 deletions.
2 changes: 1 addition & 1 deletion demo_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

def parse_args():
parser = argparse.ArgumentParser("Run a demo.")
parser.add_argument("--method", type=str, default="perdqn")
parser.add_argument("--method", type=str, default="qrdqn")
parser.add_argument("--env", type=str, default="classic_control")
parser.add_argument("--env-id", type=str, default="CartPole-v1")
parser.add_argument("--test", type=int, default=0)
Expand Down
18 changes: 11 additions & 7 deletions xuance/tensorflow/learners/qlearning_family/qrdqn_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@ class QRDQN_Learner(Learner):
def __init__(self,
policy: tk.Model,
optimizer: tk.optimizers.Optimizer,
summary_writer: Optional[SummaryWriter] = None,
device: str = "cpu:0",
model_dir: str = "./",
gamma: float = 0.99,
sync_frequency: int = 100):
self.gamma = gamma
self.sync_frequency = sync_frequency
super(QRDQN_Learner, self).__init__(policy, optimizer, summary_writer, device, model_dir)
super(QRDQN_Learner, self).__init__(policy, optimizer, device, model_dir)

def update(self, obs_batch, act_batch, rew_batch, next_batch, terminal_batch):
self.iterations += 1
Expand All @@ -22,8 +21,8 @@ def update(self, obs_batch, act_batch, rew_batch, next_batch, terminal_batch):
ter_batch = tf.convert_to_tensor(terminal_batch)

with tf.GradientTape() as tape:
_, _, evalZ, _ = self.policy(obs_batch)
_, targetA, _, targetZ = self.policy(next_batch)
_, _, evalZ = self.policy(obs_batch)
_, targetA, targetZ = self.policy.target(next_batch)
current_quantile = tf.math.reduce_sum(evalZ * tf.expand_dims(tf.one_hot(act_batch, evalZ.shape[1]), axis=-1), axis=1)
target_quantile = tf.math.reduce_sum(targetZ * tf.expand_dims(tf.one_hot(targetA, evalZ.shape[1]), axis=-1), axis=1)
target_quantile = tf.expand_dims(rew_batch, 1) + self.gamma * target_quantile * (1 - tf.expand_dims(ter_batch, 1))
Expand All @@ -41,6 +40,11 @@ def update(self, obs_batch, act_batch, rew_batch, next_batch, terminal_batch):
self.policy.copy_target()

lr = self.optimizer._decayed_lr(tf.float32)
self.writer.add_scalar("Qloss", loss.numpy(), self.iterations)
self.writer.add_scalar("predictQ", tf.math.reduce_mean(current_quantile).numpy(), self.iterations)
self.writer.add_scalar("lr", lr.numpy(), self.iterations)

info = {
"Qloss": loss.numpy(),
"predictQ": tf.math.reduce_mean(current_quantile).numpy(),
"lr": lr.numpy()
}

return info
2 changes: 1 addition & 1 deletion xuance/tensorflow/policies/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def __init__(self,
self.target_Zhead = QRDQNhead(self.representation.output_shapes['state'][0], self.action_dim, self.quantile_num,
hidden_size,
normalize, initialize, activation, device)
self.copy_target()
self.target_Zhead.set_weights(self.eval_Zhead.get_weights())

def call(self, observation: Union[np.ndarray, dict], **kwargs):
outputs = self.representation(observation)
Expand Down

0 comments on commit 1a89a60

Please sign in to comment.