diff --git a/docs/source/documents/api/policies/gaussian_marl.rst b/docs/source/documents/api/policies/gaussian_marl.rst
index 51c7efb85..d28d5bd1d 100644
--- a/docs/source/documents/api/policies/gaussian_marl.rst
+++ b/docs/source/documents/api/policies/gaussian_marl.rst
@@ -1,4 +1,633 @@
Gaussian-MARL
=======================================
+.. raw:: html
+
+
+**PyTorch:**
+
+.. py:class::
+ xuance.torch.policies.gaussian_marl.BasicQhead(state_dim, action_dim, n_agents, hidden_sizes, normalize, initialize, activation, device)
+
+ :param state_dim: xxxxxx.
+ :type state_dim: xxxxxx
+ :param action_dim: xxxxxx.
+ :type action_dim: xxxxxx
+ :param n_agents: xxxxxx.
+ :type n_agents: xxxxxx
+ :param hidden_sizes: xxxxxx.
+ :type hidden_sizes: xxxxxx
+ :param normalize: xxxxxx.
+ :type normalize: xxxxxx
+ :param initialize: xxxxxx.
+ :type initialize: xxxxxx
+ :param activation: xxxxxx.
+ :type activation: xxxxxx
+ :param device: xxxxxx.
+ :type device: xxxxxx
+
+.. py:function::
+ xuance.torch.policies.gaussian_marl.BasicQhead.forward(x)
+
+ :param x: xxxxxx.
+ :type x: xxxxxx
+ :return: xxxxxx.
+ :rtype: xxxxxx
+
+
+.. py:class::
+ xuance.torch.policies.gaussian_marl.BasicQnetwork(action_space, n_agents, representation, hidden_size, normalize, initialize, activation, device)
+
+ :param action_space: xxxxxx.
+ :type action_space: xxxxxx
+ :param n_agents: xxxxxx.
+ :type n_agents: xxxxxx
+ :param representation: xxxxxx.
+ :type representation: xxxxxx
+ :param hidden_sizes: xxxxxx.
+ :type hidden_sizes: xxxxxx
+ :param normalize: xxxxxx.
+ :type normalize: xxxxxx
+ :param initialize: xxxxxx.
+ :type initialize: xxxxxx
+ :param activation: xxxxxx.
+ :type activation: xxxxxx
+ :param device: xxxxxx.
+ :type device: xxxxxx
+
+.. py:function::
+ xuance.torch.policies.gaussian_marl.BasicQnetwork.forward(observation, agent_ids)
+
+ :param observation: xxxxxx.
+ :type observation: xxxxxx
+ :param agent_ids: xxxxxx.
+ :type agent_ids: xxxxxx
+ :return: xxxxxx.
+ :rtype: xxxxxx
+
+.. py:function::
+ xuance.torch.policies.gaussian_marl.BasicQnetwork.target_Q(observation, agent_ids)
+
+ :param observation: xxxxxx.
+ :type observation: xxxxxx
+ :param agent_ids: xxxxxx.
+ :type agent_ids: xxxxxx
+ :return: xxxxxx.
+ :rtype: xxxxxx
+
+.. py:function::
+ xuance.torch.policies.gaussian_marl.BasicQnetwork.copy_target()
+
+ :return: None.
+ :rtype: xxxxxx
+
+.. py:class::
+ xuance.torch.policies.gaussian_marl.ActorNet(state_dim, n_agents, action_dim, hidden_sizes, normalize, initialize, activation, device)
+
+ :param state_dim: xxxxxx.
+ :type state_dim: xxxxxx
+ :param n_agents: xxxxxx.
+ :type n_agents: xxxxxx
+ :param action_dim: xxxxxx.
+ :type action_dim: xxxxxx
+ :param hidden_sizes: xxxxxx.
+ :type hidden_sizes: xxxxxx
+ :param normalize: xxxxxx.
+ :type normalize: xxxxxx
+ :param initialize: xxxxxx.
+ :type initialize: xxxxxx
+ :param activation: xxxxxx.
+ :type activation: xxxxxx
+ :param device: xxxxxx.
+ :type device: xxxxxx
+
+.. py:function::
+ xuance.torch.policies.gaussian_marl.ActorNet.forward(x)
+
+ :param x: xxxxxx.
+ :type x: xxxxxx
+ :return: xxxxxx.
+ :rtype: xxxxxx
+
+.. py:class::
+ xuance.torch.policies.gaussian_marl.CriticNet(state_dim, n_agents, hidden_sizes, normalize, initialize, activation, device)
+
+ :param state_dim: xxxxxx.
+ :type state_dim: xxxxxx
+ :param n_agents: xxxxxx.
+ :type n_agents: xxxxxx
+ :param hidden_sizes: xxxxxx.
+ :type hidden_sizes: xxxxxx
+ :param normalize: xxxxxx.
+ :type normalize: xxxxxx
+ :param initialize: xxxxxx.
+ :type initialize: xxxxxx
+ :param activation: xxxxxx.
+ :type activation: xxxxxx
+ :param device: xxxxxx.
+ :type device: xxxxxx
+
+.. py:function::
+ xuance.torch.policies.gaussian_marl.CriticNet.forward(x)
+
+ :param x: xxxxxx.
+ :type x: xxxxxx
+ :return: xxxxxx.
+ :rtype: xxxxxx
+
+.. py:class::
+ xuance.torch.policies.gaussian_marl.MAAC_Policy(action_space, n_agents, representation, mixer, actor_hidden_size, critic_hidden_size, normalize, initialize, activation, device)
+
+ :param action_space: xxxxxx.
+ :type action_space: xxxxxx
+ :param n_agents: xxxxxx.
+ :type n_agents: xxxxxx
+ :param representation: xxxxxx.
+ :type representation: xxxxxx
+ :param mixer: xxxxxx.
+ :type mixer: xxxxxx
+ :param actor_hidden_size: xxxxxx.
+ :type actor_hidden_size: xxxxxx
+ :param critic_hidden_size: xxxxxx.
+ :type critic_hidden_size: xxxxxx
+ :param normalize: xxxxxx.
+ :type normalize: xxxxxx
+ :param initialize: xxxxxx.
+ :type initialize: xxxxxx
+ :param activation: xxxxxx.
+ :type activation: xxxxxx
+ :param device: xxxxxx.
+ :type device: xxxxxx
+
+.. py:function::
+ xuance.torch.policies.gaussian_marl.MAAC_Policy.forward(observation, agent_ids, *rnn_hidden)
+
+ :param observation: xxxxxx.
+ :type observation: xxxxxx
+ :param agent_ids: xxxxxx.
+ :type agent_ids: xxxxxx
+ :param *rnn_hidden: xxxxxx.
+ :type *rnn_hidden: xxxxxx
+ :return: xxxxxx.
+ :rtype: xxxxxx
+
+.. py:function::
+ xuance.torch.policies.gaussian_marl.MAAC_Policy.get_values(critic_in, agent_ids, *rnn_hidden)
+
+ :param critic_in: xxxxxx.
+ :type critic_in: xxxxxx
+ :param agent_ids: xxxxxx.
+ :type agent_ids: xxxxxx
+ :param *rnn_hidden: xxxxxx.
+ :type *rnn_hidden: xxxxxx
+ :return: xxxxxx.
+ :rtype: xxxxxx
+
+.. py:function::
+ xuance.torch.policies.gaussian_marl.MAAC_Policy.value_tot(values_n, global_state)
+
+ :param values_n: xxxxxx.
+ :type values_n: xxxxxx
+ :param global_state: xxxxxx.
+ :type global_state: xxxxxx
+ :return: None.
+ :rtype: xxxxxx
+
+.. py:class::
+ xuance.torch.policies.gaussian_marl.Basic_ISAC_policy(action_space, n_agents, representation, actor_hidden_size, critic_hidden_size, normalize, initialize, activation, device)
+
+ :param action_space: xxxxxx.
+ :type action_space: xxxxxx
+ :param n_agents: xxxxxx.
+ :type n_agents: xxxxxx
+ :param representation: xxxxxx.
+ :type representation: xxxxxx
+ :param actor_hidden_size: xxxxxx.
+ :type actor_hidden_size: xxxxxx
+ :param critic_hidden_size: xxxxxx.
+ :type critic_hidden_size: xxxxxx
+ :param normalize: xxxxxx.
+ :type normalize: xxxxxx
+ :param initialize: xxxxxx.
+ :type initialize: xxxxxx
+ :param activation: xxxxxx.
+ :type activation: xxxxxx
+ :param device: xxxxxx.
+ :type device: xxxxxx
+
+.. py:function::
+ xuance.torch.policies.gaussian_marl.Basic_ISAC_policy.forward(observation, agent_ids)
+
+ :param observation: xxxxxx.
+ :type observation: xxxxxx
+ :param agent_ids: xxxxxx.
+ :type agent_ids: xxxxxx
+ :return: xxxxxx.
+ :rtype: xxxxxx
+
+.. py:function::
+ xuance.torch.policies.gaussian_marl.Basic_ISAC_policy.critic(observation, actions, agent_ids)
+
+ :param observation: xxxxxx.
+ :type observation: xxxxxx
+ :param actions: xxxxxx.
+ :type actions: xxxxxx
+ :param agent_ids: xxxxxx.
+ :type agent_ids: xxxxxx
+ :return: xxxxxx.
+ :rtype: xxxxxx
+
+.. py:function::
+ xuance.torch.policies.gaussian_marl.Basic_ISAC_policy.target_critic(observation, actions, agent_ids)
+
+ :param observation: xxxxxx.
+ :type observation: xxxxxx
+ :param actions: xxxxxx.
+ :type actions: xxxxxx
+ :param agent_ids: xxxxxx.
+ :type agent_ids: xxxxxx
+ :return: None.
+ :rtype: xxxxxx
+
+.. py:function::
+ xuance.torch.policies.gaussian_marl.Basic_ISAC_policy.target_actor(observation, agent_ids)
+
+ :param observation: xxxxxx.
+ :type observation: xxxxxx
+ :param agent_ids: xxxxxx.
+ :type agent_ids: xxxxxx
+ :return: None.
+ :rtype: xxxxxx
+
+.. py:function::
+ xuance.torch.policies.gaussian_marl.Basic_ISAC_policy.soft_update(tau)
+
+ :param tau: xxxxxx.
+ :type tau: xxxxxx
+ :return: None.
+ :rtype: xxxxxx
+
+.. py:class::
+ xuance.torch.policies.gaussian_marl.MASAC_policy(action_space, n_agents, representation, actor_hidden_size, critic_hidden_size, normalize, initialize, activation, device)
+
+ :param action_space: xxxxxx.
+ :type action_space: xxxxxx
+ :param n_agents: xxxxxx.
+ :type n_agents: xxxxxx
+ :param representation: xxxxxx.
+ :type representation: xxxxxx
+ :param actor_hidden_size: xxxxxx.
+ :type actor_hidden_size: xxxxxx
+ :param critic_hidden_size: xxxxxx.
+ :type critic_hidden_size: xxxxxx
+ :param normalize: xxxxxx.
+ :type normalize: xxxxxx
+ :param initialize: xxxxxx.
+ :type initialize: xxxxxx
+ :param activation: xxxxxx.
+ :type activation: xxxxxx
+ :param device: xxxxxx.
+ :type device: xxxxxx
+
+.. py:function::
+ xuance.torch.policies.gaussian_marl.MASAC_policy.critic(observation, actions, agent_ids)
+
+ :param observation: xxxxxx.
+ :type observation: xxxxxx
+ :param actions: xxxxxx.
+ :type actions: xxxxxx
+ :param agent_ids: xxxxxx.
+ :type agent_ids: xxxxxx
+ :return: xxxxxx.
+ :rtype: xxxxxx
+
+.. py:function::
+ xuance.torch.policies.gaussian_marl.MASAC_policy.target_critic(observation, actions, agent_ids)
+
+ :param observation: xxxxxx.
+ :type observation: xxxxxx
+ :param actions: xxxxxx.
+ :type actions: xxxxxx
+ :param agent_ids: xxxxxx.
+ :type agent_ids: xxxxxx
+ :return: xxxxxx.
+ :rtype: xxxxxx
+
+.. raw:: html
+
+
+
+**TensorFlow:**
+
+.. raw:: html
+
+
+
+**MindSpore:**
+
+.. raw:: html
+
+
+
+Source Code
+-----------------
+
+.. tabs::
+
+ .. group-tab:: PyTorch
+
+ .. code-block:: python
+
+ import torch.distributions
+ from torch.distributions.multivariate_normal import MultivariateNormal
+
+ from xuance.torch.policies import *
+ from xuance.torch.utils import *
+
+
+ class BasicQhead(nn.Module):
+ def __init__(self,
+ state_dim: int,
+ action_dim: int,
+ n_agents: int,
+ hidden_sizes: Sequence[int],
+ normalize: Optional[ModuleType] = None,
+ initialize: Optional[Callable[..., torch.Tensor]] = None,
+ activation: Optional[ModuleType] = None,
+ device: Optional[Union[str, int, torch.device]] = None):
+ super(BasicQhead, self).__init__()
+ layers_ = []
+ input_shape = (state_dim + n_agents,)
+ for h in hidden_sizes:
+ mlp, input_shape = mlp_block(input_shape[0], h, normalize, activation, initialize, device)
+ layers_.extend(mlp)
+ layers_.extend(mlp_block(input_shape[0], action_dim, None, None, None, device)[0])
+ self.model = nn.Sequential(*layers_)
+
+ def forward(self, x: torch.Tensor):
+ return self.model(x)
+
+
+ class BasicQnetwork(nn.Module):
+ def __init__(self,
+ action_space: Discrete,
+ n_agents: int,
+ representation: nn.Module,
+ hidden_size: Sequence[int] = None,
+ normalize: Optional[ModuleType] = None,
+ initialize: Optional[Callable[..., torch.Tensor]] = None,
+ activation: Optional[ModuleType] = None,
+ device: Optional[Union[str, int, torch.device]] = None):
+ super(BasicQnetwork, self).__init__()
+ self.action_dim = action_space.n
+ self.representation = representation
+ self.representation_info_shape = self.representation.output_shapes
+
+ self.eval_Qhead = BasicQhead(self.representation.output_shapes['state'][0], self.action_dim, n_agents,
+ hidden_size, normalize, initialize, activation, device)
+ self.target_Qhead = copy.deepcopy(self.eval_Qhead)
+
+ def forward(self, observation: torch.Tensor, agent_ids: torch.Tensor):
+ outputs = self.representation(observation)
+ q_inputs = torch.concat([outputs['state'], agent_ids], dim=-1)
+ evalQ = self.eval_Qhead(q_inputs)
+ argmax_action = evalQ.argmax(dim=-1, keepdim=False)
+ return outputs, argmax_action, evalQ
+
+ def target_Q(self, observation: torch.Tensor, agent_ids: torch.Tensor):
+ outputs = self.representation(observation)
+ q_inputs = torch.concat([outputs['state'], agent_ids], dim=-1)
+ return self.target_Qhead(q_inputs)
+
+ def copy_target(self):
+ for ep, tp in zip(self.eval_Qhead.parameters(), self.target_Qhead.parameters()):
+ tp.data.copy_(ep)
+
+
+ class ActorNet(nn.Module):
+ def __init__(self,
+ state_dim: int,
+ n_agents: int,
+ action_dim: int,
+ hidden_sizes: Sequence[int],
+ normalize: Optional[ModuleType] = None,
+ initialize: Optional[Callable[..., torch.Tensor]] = None,
+ activation: Optional[ModuleType] = None,
+ device: Optional[Union[str, int, torch.device]] = None):
+ super(ActorNet, self).__init__()
+ self.device = device
+ layers = []
+ input_shape = (state_dim + n_agents,)
+ for h in hidden_sizes:
+ mlp, input_shape = mlp_block(input_shape[0], h, normalize, activation, initialize, device)
+ layers.extend(mlp)
+ layers.append(nn.Linear(hidden_sizes[0], action_dim, device=device))
+ # layers.append(nn.Sigmoid())
+ self.mu = nn.Sequential(*layers)
+ self.log_std = nn.Parameter(-torch.ones((action_dim,), device=device))
+ self.dist = DiagGaussianDistribution(action_dim)
+
+ def forward(self, x: torch.Tensor):
+ self.dist.set_param(self.mu(x), self.log_std.exp())
+ return self.dist
+
+
+ class CriticNet(nn.Module):
+ def __init__(self,
+ state_dim: int,
+ n_agents: int,
+ hidden_sizes: Sequence[int],
+ normalize: Optional[ModuleType] = None,
+ initialize: Optional[Callable[..., torch.Tensor]] = None,
+ activation: Optional[ModuleType] = None,
+ device: Optional[Union[str, int, torch.device]] = None
+ ):
+ super(CriticNet, self).__init__()
+ layers = []
+ input_shape = (state_dim + n_agents,)
+ for h in hidden_sizes:
+ mlp, input_shape = mlp_block(input_shape[0], h, normalize, activation, initialize, device)
+ layers.extend(mlp)
+ layers.extend(mlp_block(input_shape[0], 1, None, None, initialize, device)[0])
+ self.model = nn.Sequential(*layers)
+
+ def forward(self, x: torch.tensor):
+ return self.model(x)
+
+
+ class MAAC_Policy(nn.Module):
+ """
+ MAAC_Policy: Multi-Agent Actor-Critic Policy with Gaussian policies
+ """
+
+ def __init__(self,
+ action_space: Discrete,
+ n_agents: int,
+ representation: nn.Module,
+ mixer: Optional[VDN_mixer] = None,
+ actor_hidden_size: Sequence[int] = None,
+ critic_hidden_size: Sequence[int] = None,
+ normalize: Optional[ModuleType] = None,
+ initialize: Optional[Callable[..., torch.Tensor]] = None,
+ activation: Optional[ModuleType] = None,
+ device: Optional[Union[str, int, torch.device]] = None,
+ **kwargs):
+ super(MAAC_Policy, self).__init__()
+ self.device = device
+ self.action_dim = action_space.shape[0]
+ self.n_agents = n_agents
+ self.representation = representation[0]
+ self.representation_critic = representation[1]
+ self.representation_info_shape = self.representation.output_shapes
+ self.lstm = True if kwargs["rnn"] == "LSTM" else False
+ self.use_rnn = True if kwargs["use_recurrent"] else False
+ self.actor = ActorNet(self.representation.output_shapes['state'][0], n_agents, self.action_dim,
+ actor_hidden_size, normalize, initialize, activation, device)
+ dim_input_critic = self.representation_critic.output_shapes['state'][0]
+ self.critic = CriticNet(dim_input_critic, n_agents, critic_hidden_size,
+ normalize, initialize, activation, device)
+ self.mixer = mixer
+ self.pi_dist = None
+
+ def forward(self, observation: torch.Tensor, agent_ids: torch.Tensor,
+ *rnn_hidden: torch.Tensor, **kwargs):
+ if self.use_rnn:
+ outputs = self.representation(observation, *rnn_hidden)
+ rnn_hidden = (outputs['rnn_hidden'], outputs['rnn_cell'])
+ else:
+ outputs = self.representation(observation)
+ rnn_hidden = None
+ actor_input = torch.concat([outputs['state'], agent_ids], dim=-1)
+ self.pi_dist = self.actor(actor_input)
+ return rnn_hidden, self.pi_dist
+
+ def get_values(self, critic_in: torch.Tensor, agent_ids: torch.Tensor,
+ *rnn_hidden: torch.Tensor, **kwargs):
+ shape_obs = critic_in.shape
+ # get representation features
+ if self.use_rnn:
+ batch_size, n_agent, episode_length, dim_obs = tuple(shape_obs)
+ outputs = self.representation_critic(critic_in.reshape(-1, episode_length, dim_obs), *rnn_hidden)
+ outputs['state'] = outputs['state'].view(batch_size, n_agent, episode_length, -1)
+ rnn_hidden = (outputs['rnn_hidden'], outputs['rnn_cell'])
+ else:
+ batch_size, n_agent, dim_obs = tuple(shape_obs)
+ outputs = self.representation_critic(critic_in.reshape(-1, dim_obs))
+ outputs['state'] = outputs['state'].view(batch_size, n_agent, -1)
+ rnn_hidden = None
+ # get critic values
+ critic_in = torch.concat([outputs['state'], agent_ids], dim=-1)
+ v = self.critic(critic_in)
+ return rnn_hidden, v
+
+ def value_tot(self, values_n: torch.Tensor, global_state=None):
+ if global_state is not None:
+ global_state = torch.as_tensor(global_state).to(self.device)
+ return values_n if self.mixer is None else self.mixer(values_n, global_state)
+
+
+ class Basic_ISAC_policy(nn.Module):
+ def __init__(self,
+ action_space: Space,
+ n_agents: int,
+ representation: nn.Module,
+ actor_hidden_size: Sequence[int],
+ critic_hidden_size: Sequence[int],
+ normalize: Optional[ModuleType] = None,
+ initialize: Optional[Callable[..., torch.Tensor]] = None,
+ activation: Optional[ModuleType] = None,
+ device: Optional[Union[str, int, torch.device]] = None
+ ):
+ super(Basic_ISAC_policy, self).__init__()
+ self.action_dim = action_space.shape[0]
+ self.n_agents = n_agents
+ self.representation = representation
+ self.representation_info_shape = self.representation.output_shapes
+
+ self.actor_net = ActorNet(representation.output_shapes['state'][0], n_agents, self.action_dim,
+ actor_hidden_size, normalize, initialize, activation, device)
+ dim_input_critic = representation.output_shapes['state'][0] + self.action_dim
+ self.critic_net = CriticNet(dim_input_critic, n_agents, critic_hidden_size,
+ normalize, initialize, activation, device)
+ self.target_actor_net = copy.deepcopy(self.actor_net)
+ self.target_critic_net = copy.deepcopy(self.critic_net)
+ self.parameters_actor = list(self.representation.parameters()) + list(self.actor_net.parameters())
+ self.parameters_critic = self.critic_net.parameters()
+
+ def forward(self, observation: torch.Tensor, agent_ids: torch.Tensor):
+ outputs = self.representation(observation)
+ actor_in = torch.concat([outputs['state'], agent_ids], dim=-1)
+ act = self.actor_net(actor_in)
+ return outputs, act
+
+ def critic(self, observation: torch.Tensor, actions: torch.Tensor, agent_ids: torch.Tensor):
+ outputs = self.representation(observation)
+ critic_in = torch.concat([outputs['state'], actions, agent_ids], dim=-1)
+ return self.critic_net(critic_in)
+
+ def target_critic(self, observation: torch.Tensor, actions: torch.Tensor, agent_ids: torch.Tensor):
+ outputs = self.representation(observation)
+ critic_in = torch.concat([outputs['state'], actions, agent_ids], dim=-1)
+ return self.target_critic_net(critic_in)
+
+ def target_actor(self, observation: torch.Tensor, agent_ids: torch.Tensor):
+ outputs = self.representation(observation)
+ actor_in = torch.concat([outputs['state'], agent_ids], dim=-1)
+ return self.target_actor_net(actor_in)
+
+ def soft_update(self, tau=0.005):
+ for ep, tp in zip(self.actor_net.parameters(), self.target_actor_net.parameters()):
+ tp.data.mul_(1 - tau)
+ tp.data.add_(tau * ep.data)
+ for ep, tp in zip(self.critic_net.parameters(), self.target_critic_net.parameters()):
+ tp.data.mul_(1 - tau)
+ tp.data.add_(tau * ep.data)
+
+
+ class MASAC_policy(Basic_ISAC_policy):
+ def __init__(self,
+ action_space: Space,
+ n_agents: int,
+ representation: nn.Module,
+ actor_hidden_size: Sequence[int],
+ critic_hidden_size: Sequence[int],
+ normalize: Optional[ModuleType] = None,
+ initialize: Optional[Callable[..., torch.Tensor]] = None,
+ activation: Optional[ModuleType] = None,
+ device: Optional[Union[str, int, torch.device]] = None
+ ):
+ super(MASAC_policy, self).__init__(action_space, n_agents, representation,
+ actor_hidden_size, critic_hidden_size,
+ normalize, initialize, activation, device)
+ dim_input_critic = (representation.output_shapes['state'][0] + self.action_dim) * self.n_agents
+ self.critic_net = CriticNet(dim_input_critic, n_agents, critic_hidden_size,
+ normalize, initialize, activation, device)
+ self.target_critic_net = copy.deepcopy(self.critic_net)
+ self.parameters_critic = self.critic_net.parameters()
+
+ def critic(self, observation: torch.Tensor, actions: torch.Tensor, agent_ids: torch.Tensor):
+ bs = observation.shape[0]
+ outputs_n = self.representation(observation)['state'].view(bs, 1, -1).expand(-1, self.n_agents, -1)
+ actions_n = actions.view(bs, 1, -1).expand(-1, self.n_agents, -1)
+ critic_in = torch.concat([outputs_n, actions_n, agent_ids], dim=-1)
+ return self.critic_net(critic_in)
+
+ def target_critic(self, observation: torch.Tensor, actions: torch.Tensor, agent_ids: torch.Tensor):
+ bs = observation.shape[0]
+ outputs_n = self.representation(observation)['state'].view(bs, 1, -1).expand(-1, self.n_agents, -1)
+ actions_n = actions.view(bs, 1, -1).expand(-1, self.n_agents, -1)
+ critic_in = torch.concat([outputs_n, actions_n, agent_ids], dim=-1)
+ return self.target_critic_net(critic_in)
+
+
+
+
+ .. group-tab:: TensorFlow
+
+ .. code-block:: python
+
+
+ .. group-tab:: MindSpore
+
+ .. code-block:: python