QMix

QMIX: Monotonic Value Function Factorisation for Deep Multi-Agent Reinforcement Learning is a value-based method that can train decentralized policies in a centralized end-to-end fashion. QMIX employs a network that estimates joint action-values as a complex non-linear combination of per-agent values that condition only on local observations.

  • Experience replay: ✔️

  • Target network: ✔️

  • Gradient clipping: ❌

  • Reward clipping: ❌

  • Prioritized Experience Replay (PER): ✔️

  • Ornstein–Uhlenbeck noise: ❌

Code Snippet

def train(self, batch, t_env: int, episode_num: int, per_weight=None):
    rewards = batch["reward"][:, :-1]
    actions = batch["actions"][:, :-1]
    terminated = batch["terminated"][:, :-1].float()
    mask = batch["filled"][:, :-1].float()
    mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
    avail_actions = batch["avail_actions"]

    self.mac.agent.train()
    mac_out = []
    self.mac.init_hidden(batch.batch_size)
    for t in range(batch.max_seq_length):
        agent_outs = self.mac.forward(batch, t=t)
        mac_out.append(agent_outs)
    mac_out = th.stack(mac_out, dim=1)

    chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze(3)  # Remove the last dim
    chosen_action_qvals_ = chosen_action_qvals

    with th.no_grad():
        self.target_mac.agent.train()
        target_mac_out = []
        self.target_mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            target_agent_outs = self.target_mac.forward(batch, t=t)
            target_mac_out.append(target_agent_outs)

        target_mac_out = th.stack(target_mac_out, dim=1)  # Concat across time

        mac_out_detach = mac_out.clone().detach()
        mac_out_detach[avail_actions == 0] = -9999999
        cur_max_actions = mac_out_detach.max(dim=3, keepdim=True)[1]
        target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3)

        target_max_qvals = self.target_mixer(target_max_qvals, batch["state"])

        if getattr(self.args, 'q_lambda', False):
            qvals = th.gather(target_mac_out, 3, batch["actions"]).squeeze(3)
            qvals = self.target_mixer(qvals, batch["state"])

            targets = build_q_lambda_targets(rewards, terminated, mask, target_max_qvals, qvals,
                                self.args.gamma, self.args.td_lambda)
        else:
            targets = build_td_lambda_targets(rewards, terminated, mask, target_max_qvals,
                                                self.args.n_agents, self.args.gamma, self.args.td_lambda)

    chosen_action_qvals = self.mixer(chosen_action_qvals, batch["state"][:, :-1])

    td_error = (chosen_action_qvals - targets.detach())
    td_error2 = 0.5 * td_error.pow(2)

    mask = mask.expand_as(td_error2)
    masked_td_error = td_error2 * mask


    if self.use_per:
        per_weight = th.from_numpy(per_weight).unsqueeze(-1).to(device=self.device)
        masked_td_error = masked_td_error.sum(1) * per_weight

    loss = L_td = masked_td_error.sum() / mask.sum()


    self.optimiser.zero_grad()
    loss.backward()
    grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip)
    self.optimiser.step()

Parameters

Networks

class elegantrl.agents.net.Critic(*args: Any, **kwargs: Any)[source]