QMix¶
QMIX: Monotonic Value Function Factorisation for Deep Multi-Agent Reinforcement Learning is a value-based method that can train decentralized policies in a centralized end-to-end fashion. QMIX employs a network that estimates joint action-values as a complex non-linear combination of per-agent values that condition only on local observations.
Experience replay: ✔️
Target network: ✔️
Gradient clipping: ❌
Reward clipping: ❌
Prioritized Experience Replay (PER): ✔️
Ornstein–Uhlenbeck noise: ❌
Code Snippet¶
def train(self, batch, t_env: int, episode_num: int, per_weight=None):
rewards = batch["reward"][:, :-1]
actions = batch["actions"][:, :-1]
terminated = batch["terminated"][:, :-1].float()
mask = batch["filled"][:, :-1].float()
mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
avail_actions = batch["avail_actions"]
self.mac.agent.train()
mac_out = []
self.mac.init_hidden(batch.batch_size)
for t in range(batch.max_seq_length):
agent_outs = self.mac.forward(batch, t=t)
mac_out.append(agent_outs)
mac_out = th.stack(mac_out, dim=1)
chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze(3) # Remove the last dim
chosen_action_qvals_ = chosen_action_qvals
with th.no_grad():
self.target_mac.agent.train()
target_mac_out = []
self.target_mac.init_hidden(batch.batch_size)
for t in range(batch.max_seq_length):
target_agent_outs = self.target_mac.forward(batch, t=t)
target_mac_out.append(target_agent_outs)
target_mac_out = th.stack(target_mac_out, dim=1) # Concat across time
mac_out_detach = mac_out.clone().detach()
mac_out_detach[avail_actions == 0] = -9999999
cur_max_actions = mac_out_detach.max(dim=3, keepdim=True)[1]
target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3)
target_max_qvals = self.target_mixer(target_max_qvals, batch["state"])
if getattr(self.args, 'q_lambda', False):
qvals = th.gather(target_mac_out, 3, batch["actions"]).squeeze(3)
qvals = self.target_mixer(qvals, batch["state"])
targets = build_q_lambda_targets(rewards, terminated, mask, target_max_qvals, qvals,
self.args.gamma, self.args.td_lambda)
else:
targets = build_td_lambda_targets(rewards, terminated, mask, target_max_qvals,
self.args.n_agents, self.args.gamma, self.args.td_lambda)
chosen_action_qvals = self.mixer(chosen_action_qvals, batch["state"][:, :-1])
td_error = (chosen_action_qvals - targets.detach())
td_error2 = 0.5 * td_error.pow(2)
mask = mask.expand_as(td_error2)
masked_td_error = td_error2 * mask
if self.use_per:
per_weight = th.from_numpy(per_weight).unsqueeze(-1).to(device=self.device)
masked_td_error = masked_td_error.sum(1) * per_weight
loss = L_td = masked_td_error.sum() / mask.sum()
self.optimiser.zero_grad()
loss.backward()
grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip)
self.optimiser.step()