Source code for elegantrl.agents.AgentTD3

import torch
from typing import Tuple
from copy import deepcopy
from torch import Tensor

from elegantrl.train.config import Config
from elegantrl.train.replay_buffer import ReplayBuffer
from elegantrl.agents.AgentBase import AgentBase
from elegantrl.agents.net import Actor, CriticTwin


[docs]class AgentTD3(AgentBase): """Twin Delayed DDPG algorithm. Addressing Function Approximation Error in Actor-Critic Methods. 2018. """ def __init__(self, net_dims: [int], state_dim: int, action_dim: int, gpu_id: int = 0, args: Config = Config()): self.act_class = getattr(self, 'act_class', Actor) self.cri_class = getattr(self, 'cri_class', CriticTwin) super().__init__(net_dims=net_dims, state_dim=state_dim, action_dim=action_dim, gpu_id=gpu_id, args=args) self.act_target = deepcopy(self.act) self.cri_target = deepcopy(self.cri) self.explore_noise_std = getattr(args, 'explore_noise_std', 0.05) # standard deviation of exploration noise self.policy_noise_std = getattr(args, 'policy_noise_std', 0.10) # standard deviation of exploration noise self.update_freq = getattr(args, 'update_freq', 2) # delay update frequency self.act.explore_noise_std = self.explore_noise_std # assign explore_noise_std for agent.act.get_action(state) def update_net(self, buffer: ReplayBuffer) -> Tuple[float, ...]: with torch.no_grad(): states, actions, rewards, undones = buffer.add_item self.update_avg_std_for_normalization( states=states.reshape((-1, self.state_dim)), returns=self.get_cumulative_rewards(rewards=rewards, undones=undones).reshape((-1,)) ) '''update network''' obj_critics = 0.0 obj_actors = 0.0 update_times = int(buffer.add_size * self.repeat_times) assert update_times >= 1 for update_c in range(update_times): obj_critic, state = self.get_obj_critic(buffer, self.batch_size) obj_critics += obj_critic.item() self.optimizer_update(self.cri_optimizer, obj_critic) self.soft_update(self.cri_target, self.cri, self.soft_update_tau) if update_c % self.update_freq == 0: # delay update action_pg = self.act(state) # policy gradient obj_actor = self.cri_target(state, action_pg).mean() # use cri_target is more stable than cri obj_actors += obj_actor.item() self.optimizer_update(self.act_optimizer, -obj_actor) self.soft_update(self.act_target, self.act, self.soft_update_tau) return obj_critics / update_times, obj_actors / update_times def get_obj_critic_raw(self, buffer: ReplayBuffer, batch_size: int) -> Tuple[Tensor, Tensor]: with torch.no_grad(): states, actions, rewards, undones, next_ss = buffer.sample(batch_size) # next_ss: next states next_as = self.act_target.get_action_noise(next_ss, self.policy_noise_std) # next actions next_qs = self.cri_target.get_q_min(next_ss, next_as) # next q values q_labels = rewards + undones * self.gamma * next_qs q1, q2 = self.cri.get_q1_q2(states, actions) obj_critic = self.criterion(q1, q_labels) + self.criterion(q2, q_labels) # twin critics return obj_critic, states def get_obj_critic_per(self, buffer: ReplayBuffer, batch_size: int) -> Tuple[Tensor, Tensor]: with torch.no_grad(): states, actions, rewards, undones, next_ss, is_weights, is_indices = buffer.sample_for_per(batch_size) # is_weights, is_indices: important sampling `weights, indices` by Prioritized Experience Replay (PER) next_as = self.act_target.get_action_noise(next_ss, self.policy_noise_std) next_qs = self.cri_target.get_q_min(next_ss, next_as) q_labels = rewards + undones * self.gamma * next_qs q1, q2 = self.cri.get_q1_q2(states, actions) td_errors = self.criterion(q1, q_labels) + self.criterion(q2, q_labels) obj_critic = (td_errors * is_weights).mean() buffer.td_error_update_for_per(is_indices.detach(), td_errors.detach()) return obj_critic, states