Source code for elegantrl.agents.AgentSAC

import math
import torch
from typing import Tuple
from copy import deepcopy
from torch import Tensor

from elegantrl.agents.AgentBase import AgentBase
from elegantrl.agents.net import ActorSAC, ActorFixSAC, CriticTwin
from elegantrl.train.config import Config
from elegantrl.train.replay_buffer import ReplayBuffer


[docs]class AgentSAC(AgentBase): def __init__(self, net_dims: [int], state_dim: int, action_dim: int, gpu_id: int = 0, args: Config = Config()): self.act_class = getattr(self, 'act_class', ActorSAC) self.cri_class = getattr(self, 'cri_class', CriticTwin) super().__init__(net_dims=net_dims, state_dim=state_dim, action_dim=action_dim, gpu_id=gpu_id, args=args) self.cri_target = deepcopy(self.cri) self.alpha_log = torch.tensor((-1,), dtype=torch.float32, requires_grad=True, device=self.device) # trainable self.alpha_optimizer = torch.optim.AdamW((self.alpha_log,), lr=self.learning_rate * 4) self.target_entropy = getattr(args, 'target_entropy', action_dim) def update_net(self, buffer: ReplayBuffer) -> Tuple[float, ...]: with torch.no_grad(): states, actions, rewards, undones = buffer.add_item self.update_avg_std_for_normalization( states=states.reshape((-1, self.state_dim)), returns=self.get_cumulative_rewards(rewards=rewards, undones=undones).reshape((-1,)) ) '''update network''' obj_critics = 0.0 obj_actors = 0.0 alphas = 0.0 update_times = int(buffer.add_size * self.repeat_times) assert update_times >= 1 for _ in range(update_times): '''objective of critic (loss function of critic)''' obj_critic, state = self.get_obj_critic(buffer, self.batch_size) obj_critics += obj_critic.item() self.optimizer_update(self.cri_optimizer, obj_critic) self.soft_update(self.cri_target, self.cri, self.soft_update_tau) '''objective of alpha (temperature parameter automatic adjustment)''' action_pg, log_prob = self.act.get_action_logprob(state) # policy gradient obj_alpha = (self.alpha_log * (self.target_entropy - log_prob).detach()).mean() self.optimizer_update(self.alpha_optimizer, obj_alpha) '''objective of actor''' alpha = self.alpha_log.exp().detach() alphas += alpha.item() with torch.no_grad(): self.alpha_log[:] = self.alpha_log.clamp(-16, 2) q_value_pg = self.cri_target(state, action_pg).mean() obj_actor = (q_value_pg - log_prob * alpha).mean() obj_actors += obj_actor.item() self.optimizer_update(self.act_optimizer, -obj_actor) return obj_critics / update_times, obj_actors / update_times, alphas / update_times def get_obj_critic_raw(self, buffer, batch_size: int) -> Tuple[Tensor, Tensor]: with torch.no_grad(): states, actions, rewards, undones, next_ss = buffer.sample(batch_size) # next_ss: next states next_as, next_logprobs = self.act.get_action_logprob(next_ss) # next actions next_qs = self.cri_target.get_q_min(next_ss, next_as) # next q values alpha = self.alpha_log.exp().detach() q_labels = rewards + undones * self.gamma * (next_qs - next_logprobs * alpha) q1, q2 = self.cri.get_q1_q2(states, actions) obj_critic = self.criterion(q1, q_labels) + self.criterion(q2, q_labels) # twin critics return obj_critic, states def get_obj_critic_per(self, buffer: ReplayBuffer, batch_size: int) -> Tuple[Tensor, Tensor]: with torch.no_grad(): states, actions, rewards, undones, next_ss, is_weights, is_indices = buffer.sample_for_per(batch_size) next_as, next_logprobs = self.act.get_action_logprob(next_ss) next_qs = self.cri_target.get_q_min(next_ss, next_as) alpha = self.alpha_log.exp().detach() q_labels = rewards + undones * self.gamma * (next_qs - next_logprobs * alpha) q1, q2 = self.cri.get_q1_q2(states, actions) td_errors = self.criterion(q1, q_labels) + self.criterion(q2, q_labels) obj_critic = (td_errors * is_weights).mean() buffer.td_error_update_for_per(is_indices.detach(), td_errors.detach()) return obj_critic, states
[docs]class AgentModSAC(AgentSAC): # Modified SAC using reliable_lambda and Two Time-scale Update Rule def __init__(self, net_dims: [int], state_dim: int, action_dim: int, gpu_id: int = 0, args: Config = Config()): self.act_class = getattr(self, "act_class", ActorFixSAC) super().__init__(net_dims=net_dims, state_dim=state_dim, action_dim=action_dim, gpu_id=gpu_id, args=args) self.obj_c = 1.0 # for reliable_lambda def update_net(self, buffer: ReplayBuffer) -> Tuple[float, ...]: with torch.no_grad(): states, actions, rewards, undones = buffer.add_item self.update_avg_std_for_normalization( states=states.reshape((-1, self.state_dim)), returns=self.get_cumulative_rewards(rewards=rewards, undones=undones).reshape((-1,)) ) '''update network''' obj_critics = 0.0 obj_actors = 0.0 alphas = 0.0 update_times = int(buffer.add_size * self.repeat_times) assert update_times >= 1 update_a = 0 for update_c in range(1, update_times + 1): '''objective of critic (loss function of critic)''' obj_critic, state = self.get_obj_critic(buffer, self.batch_size) obj_critics += obj_critic.item() self.optimizer_update(self.cri_optimizer, obj_critic) self.soft_update(self.cri_target, self.cri, self.soft_update_tau) self.obj_c = 0.995 * self.obj_c + 0.005 * obj_critic.item() # for reliable_lambda reliable_lambda = math.exp(-self.obj_c ** 2) # for reliable_lambda if update_a / update_c < 1 / (2 - reliable_lambda): # auto TTUR '''objective of alpha (temperature parameter automatic adjustment)''' action_pg, log_prob = self.act.get_action_logprob(state) # policy gradient obj_alpha = (self.alpha_log * (self.target_entropy - log_prob).detach()).mean() self.optimizer_update(self.alpha_optimizer, obj_alpha) '''objective of actor''' alpha = self.alpha_log.exp().detach() alphas += alpha.item() with torch.no_grad(): self.alpha_log[:] = self.alpha_log.clamp(-16, 2) q_value_pg = self.cri_target(state, action_pg).mean() obj_actor = (q_value_pg - log_prob * alpha).mean() obj_actors += obj_actor.item() self.optimizer_update(self.act_optimizer, -obj_actor) return obj_critics / update_times, obj_actors / update_times, alphas / update_times