Source code for

import math
import torch
import torch.nn as nn
from torch import Tensor
from torch.distributions.normal import Normal


class QNetBase(nn.Module):  # nn.Module is a standard PyTorch Network
    def __init__(self, state_dim: int, action_dim: int):
        self.explore_rate = 0.125
        self.state_dim = state_dim
        self.action_dim = action_dim = None  # build_mlp(dims=[state_dim + action_dim, *dims, 1])

        self.state_avg = nn.Parameter(torch.zeros((state_dim,)), requires_grad=False)
        self.state_std = nn.Parameter(torch.ones((state_dim,)), requires_grad=False)
        self.value_avg = nn.Parameter(torch.zeros((1,)), requires_grad=False)
        self.value_std = nn.Parameter(torch.ones((1,)), requires_grad=False)

    def state_norm(self, state: Tensor) -> Tensor:
        return (state - self.state_avg) / self.state_std

    def value_re_norm(self, value: Tensor) -> Tensor:
        return value * self.value_std + self.value_avg

[docs]class QNet(QNetBase): def __init__(self, dims: [int], state_dim: int, action_dim: int): super().__init__(state_dim=state_dim, action_dim=action_dim) = build_mlp(dims=[state_dim, *dims, action_dim]) layer_init_with_orthogonal([-1], std=0.1) def forward(self, state): state = self.state_norm(state) value = value = self.value_re_norm(value) return value # Q values for multiple actions def get_action(self, state): state = self.state_norm(state) if self.explore_rate < torch.rand(1): action =, keepdim=True) else: action = torch.randint(self.action_dim, size=(state.shape[0], 1)) return action
[docs]class QNetDuel(QNetBase): # Dueling DQN def __init__(self, dims: [int], state_dim: int, action_dim: int): super().__init__(state_dim=state_dim, action_dim=action_dim) self.net_state = build_mlp(dims=[state_dim, *dims]) self.net_adv = build_mlp(dims=[dims[-1], 1]) # advantage value self.net_val = build_mlp(dims=[dims[-1], action_dim]) # Q value layer_init_with_orthogonal(self.net_adv[-1], std=0.1) layer_init_with_orthogonal(self.net_val[-1], std=0.1) def forward(self, state): state = self.state_norm(state) s_enc = self.net_state(state) # encoded state q_val = self.net_val(s_enc) # q value q_adv = self.net_adv(s_enc) # advantage value value = q_val - q_val.mean(dim=1, keepdim=True) + q_adv # dueling Q value value = self.value_re_norm(value) return value def get_action(self, state): state = self.state_norm(state) if self.explore_rate < torch.rand(1): s_enc = self.net_state(state) # encoded state q_val = self.net_val(s_enc) # q value action = q_val.argmax(dim=1, keepdim=True) else: action = torch.randint(self.action_dim, size=(state.shape[0], 1)) return action
[docs]class QNetTwin(QNetBase): # Double DQN def __init__(self, dims: [int], state_dim: int, action_dim: int): super().__init__(state_dim=state_dim, action_dim=action_dim) self.net_state = build_mlp(dims=[state_dim, *dims]) self.net_val1 = build_mlp(dims=[dims[-1], action_dim]) # Q value 1 self.net_val2 = build_mlp(dims=[dims[-1], action_dim]) # Q value 2 self.soft_max = nn.Softmax(dim=1) layer_init_with_orthogonal(self.net_val1[-1], std=0.1) layer_init_with_orthogonal(self.net_val2[-1], std=0.1) def forward(self, state): state = self.state_norm(state) s_enc = self.net_state(state) # encoded state q_val = self.net_val1(s_enc) # q value return q_val # one group of Q values def get_q1_q2(self, state): state = self.state_norm(state) s_enc = self.net_state(state) # encoded state q_val1 = self.net_val1(s_enc) # q value 1 q_val1 = self.value_re_norm(q_val1) q_val2 = self.net_val2(s_enc) # q value 2 q_val2 = self.value_re_norm(q_val2) return q_val1, q_val2 # two groups of Q values def get_action(self, state): state = self.state_norm(state) s_enc = self.net_state(state) # encoded state q_val = self.net_val1(s_enc) # q value if self.explore_rate < torch.rand(1): action = q_val.argmax(dim=1, keepdim=True) else: a_prob = self.soft_max(q_val) action = torch.multinomial(a_prob, num_samples=1) return action
[docs]class QNetTwinDuel(QNetBase): # D3QN: Dueling Double DQN def __init__(self, dims: [int], state_dim: int, action_dim: int): super().__init__(state_dim=state_dim, action_dim=action_dim) self.net_state = build_mlp(dims=[state_dim, *dims]) self.net_adv1 = build_mlp(dims=[dims[-1], 1]) # advantage value 1 self.net_val1 = build_mlp(dims=[dims[-1], action_dim]) # Q value 1 self.net_adv2 = build_mlp(dims=[dims[-1], 1]) # advantage value 2 self.net_val2 = build_mlp(dims=[dims[-1], action_dim]) # Q value 2 self.soft_max = nn.Softmax(dim=1) layer_init_with_orthogonal(self.net_adv1[-1], std=0.1) layer_init_with_orthogonal(self.net_val1[-1], std=0.1) layer_init_with_orthogonal(self.net_adv2[-1], std=0.1) layer_init_with_orthogonal(self.net_val2[-1], std=0.1) def forward(self, state): state = self.state_norm(state) s_enc = self.net_state(state) # encoded state q_val = self.net_val1(s_enc) # q value q_adv = self.net_adv1(s_enc) # advantage value value = q_val - q_val.mean(dim=1, keepdim=True) + q_adv # one dueling Q value value = self.value_re_norm(value) return value def get_q1_q2(self, state): state = self.state_norm(state) s_enc = self.net_state(state) # encoded state q_val1 = self.net_val1(s_enc) # q value 1 q_adv1 = self.net_adv1(s_enc) # advantage value 1 q_duel1 = q_val1 - q_val1.mean(dim=1, keepdim=True) + q_adv1 q_duel1 = self.value_re_norm(q_duel1) q_val2 = self.net_val2(s_enc) # q value 2 q_adv2 = self.net_adv2(s_enc) # advantage value 2 q_duel2 = q_val2 - q_val2.mean(dim=1, keepdim=True) + q_adv2 q_duel2 = self.value_re_norm(q_duel2) return q_duel1, q_duel2 # two dueling Q values def get_action(self, state): state = self.state_norm(state) s_enc = self.net_state(state) # encoded state q_val = self.net_val1(s_enc) # q value if self.explore_rate < torch.rand(1): action = q_val.argmax(dim=1, keepdim=True) else: a_prob = self.soft_max(q_val) action = torch.multinomial(a_prob, num_samples=1) return action
"""Actor (policy network)""" class ActorBase(nn.Module): def __init__(self, state_dim: int, action_dim: int): super().__init__() self.state_dim = state_dim self.action_dim = action_dim = None # build_mlp(dims=[state_dim, *dims, action_dim]) self.explore_noise_std = None # standard deviation of exploration action noise self.ActionDist = torch.distributions.normal.Normal self.state_avg = nn.Parameter(torch.zeros((state_dim,)), requires_grad=False) self.state_std = nn.Parameter(torch.ones((state_dim,)), requires_grad=False) def state_norm(self, state: Tensor) -> Tensor: return (state - self.state_avg) / self.state_std
[docs]class Actor(ActorBase): def __init__(self, dims: [int], state_dim: int, action_dim: int): super().__init__(state_dim=state_dim, action_dim=action_dim) = build_mlp(dims=[state_dim, *dims, action_dim]) layer_init_with_orthogonal([-1], std=0.1) self.explore_noise_std = 0.1 # standard deviation of exploration action noise def forward(self, state: Tensor) -> Tensor: state = self.state_norm(state) return # action.tanh() def get_action(self, state: Tensor) -> Tensor: # for exploration state = self.state_norm(state) action = noise = (torch.randn_like(action) * self.explore_noise_std).clamp(-0.5, 0.5) return (action + noise).clamp(-1.0, 1.0) def get_action_noise(self, state: Tensor, action_std: float) -> Tensor: state = self.state_norm(state) action = noise = (torch.randn_like(action) * action_std).clamp(-0.5, 0.5) return (action + noise).clamp(-1.0, 1.0)
[docs]class ActorSAC(ActorBase): def __init__(self, dims: [int], state_dim: int, action_dim: int): super().__init__(state_dim=state_dim, action_dim=action_dim) self.net_s = build_mlp(dims=[state_dim, *dims], if_raw_out=False) # network of encoded state self.net_a = build_mlp(dims=[dims[-1], action_dim * 2]) # the average and log_std of action layer_init_with_orthogonal(self.net_a[-1], std=0.1) def forward(self, state): state = self.state_norm(state) s_enc = self.net_s(state) # encoded state a_avg = self.net_a(s_enc)[:, :self.action_dim] return a_avg.tanh() # action def get_action(self, state): state = self.state_norm(state) s_enc = self.net_s(state) # encoded state a_avg, a_std_log = self.net_a(s_enc).chunk(2, dim=1) a_std = a_std_log.clamp(-16, 2).exp() dist = Normal(a_avg, a_std) return dist.rsample().tanh() # action (re-parameterize) def get_action_logprob(self, state): state = self.state_norm(state) s_enc = self.net_s(state) # encoded state a_avg, a_std_log = self.net_a(s_enc).chunk(2, dim=1) a_std = a_std_log.clamp(-16, 2).exp() dist = Normal(a_avg, a_std) action = dist.rsample() action_tanh = action.tanh() logprob = dist.log_prob(a_avg) logprob -= (-action_tanh.pow(2) + 1.000001).log() # fix logprob using the derivative of action.tanh() return action_tanh, logprob.sum(1)
class ActorFixSAC(ActorSAC): def __init__(self, dims: [int], state_dim: int, action_dim: int): super().__init__(dims=dims, state_dim=state_dim, action_dim=action_dim) self.soft_plus = torch.nn.Softplus() def get_action_logprob(self, state): state = self.state_norm(state) s_enc = self.net_s(state) # encoded state a_avg, a_std_log = self.net_a(s_enc).chunk(2, dim=1) a_std = a_std_log.clamp(-16, 2).exp() dist = Normal(a_avg, a_std) action = dist.rsample() logprob = dist.log_prob(a_avg) logprob -= 2 * (math.log(2) - action - self.soft_plus(action * -2)) # fix logprob using SoftPlus return action.tanh(), logprob.sum(1)
[docs]class ActorPPO(ActorBase): def __init__(self, dims: [int], state_dim: int, action_dim: int): super().__init__(state_dim=state_dim, action_dim=action_dim) = build_mlp(dims=[state_dim, *dims, action_dim]) layer_init_with_orthogonal([-1], std=0.1) self.action_std_log = nn.Parameter(torch.zeros((1, action_dim)), requires_grad=True) # trainable parameter def forward(self, state: Tensor) -> Tensor: state = self.state_norm(state) return # action.tanh() def get_action(self, state: Tensor) -> (Tensor, Tensor): # for exploration state = self.state_norm(state) action_avg = action_std = self.action_std_log.exp() dist = self.ActionDist(action_avg, action_std) action = dist.sample() logprob = dist.log_prob(action).sum(1) return action, logprob def get_logprob_entropy(self, state: Tensor, action: Tensor) -> (Tensor, Tensor): state = self.state_norm(state) action_avg = action_std = self.action_std_log.exp() dist = self.ActionDist(action_avg, action_std) logprob = dist.log_prob(action).sum(1) entropy = dist.entropy().sum(1) return logprob, entropy @staticmethod def convert_action_for_env(action: Tensor) -> Tensor: return action.tanh()
[docs]class ActorDiscretePPO(ActorBase): def __init__(self, dims: [int], state_dim: int, action_dim: int): super().__init__(state_dim=state_dim, action_dim=action_dim) = build_mlp(dims=[state_dim, *dims, action_dim]) layer_init_with_orthogonal([-1], std=0.1) self.ActionDist = torch.distributions.Categorical self.soft_max = nn.Softmax(dim=-1) def forward(self, state: Tensor) -> Tensor: state = self.state_norm(state) a_prob = # action_prob without softmax return a_prob.argmax(dim=1) # get the indices of discrete action def get_action(self, state: Tensor) -> (Tensor, Tensor): state = self.state_norm(state) a_prob = self.soft_max( a_dist = self.ActionDist(a_prob) action = a_dist.sample() logprob = a_dist.log_prob(action) return action, logprob def get_logprob_entropy(self, state: Tensor, action: Tensor) -> (Tensor, Tensor): state = self.state_norm(state) a_prob = self.soft_max( # action.shape == (batch_size, 1), action.dtype = dist = self.ActionDist(a_prob) logprob = dist.log_prob(action.squeeze(1)) entropy = dist.entropy() return logprob, entropy @staticmethod def convert_action_for_env(action: Tensor) -> Tensor: return action.long()
"""Critic (value network)""" class CriticBase(nn.Module): # todo state_norm, value_norm def __init__(self, state_dim: int, action_dim: int): super().__init__() self.state_dim = state_dim self.action_dim = action_dim = None # build_mlp(dims=[state_dim + action_dim, *dims, 1]) self.state_avg = nn.Parameter(torch.zeros((state_dim,)), requires_grad=False) self.state_std = nn.Parameter(torch.ones((state_dim,)), requires_grad=False) self.value_avg = nn.Parameter(torch.zeros((1,)), requires_grad=False) self.value_std = nn.Parameter(torch.ones((1,)), requires_grad=False) def state_norm(self, state: Tensor) -> Tensor: return (state - self.state_avg) / self.state_std # todo state_norm def value_re_norm(self, value: Tensor) -> Tensor: return value * self.value_std + self.value_avg # todo value_norm
[docs]class Critic(CriticBase): def __init__(self, dims: [int], state_dim: int, action_dim: int): super().__init__(state_dim=state_dim, action_dim=action_dim) = build_mlp(dims=[state_dim + action_dim, *dims, 1]) layer_init_with_orthogonal([-1], std=0.5) def forward(self, state: Tensor, action: Tensor) -> Tensor: state = self.state_norm(state) values =, action), dim=1)) values = self.value_re_norm(values) return values.squeeze(dim=1) # q value
[docs]class CriticTwin(CriticBase): # shared parameter def __init__(self, dims: [int], state_dim: int, action_dim: int): super().__init__(state_dim=state_dim, action_dim=action_dim) = build_mlp(dims=[state_dim + action_dim, *dims, 2]) layer_init_with_orthogonal([-1], std=0.5) def forward(self, state, action): state = self.state_norm(state) values =, action), dim=1)) values = self.value_re_norm(values) return values.mean(dim=1) # mean Q value def get_q_min(self, state, action): state = self.state_norm(state) values =, action), dim=1)) values = self.value_re_norm(values) return torch.min(values, dim=1)[0] # min Q value def get_q1_q2(self, state, action): state = self.state_norm(state) values =, action), dim=1)) values = self.value_re_norm(values) return values[:, 0], values[:, 1] # two Q values
[docs]class CriticPPO(CriticBase): def __init__(self, dims: [int], state_dim: int, action_dim: int): super().__init__(state_dim=state_dim, action_dim=action_dim) = build_mlp(dims=[state_dim, *dims, 1]) layer_init_with_orthogonal([-1], std=0.5) def forward(self, state: Tensor) -> Tensor: state = self.state_norm(state) value = value = self.value_re_norm(value) return value.squeeze(1) # q value
"""utils""" def build_mlp(dims: [int], activation: nn = None, if_raw_out: bool = True) -> nn.Sequential: """ build MLP (MultiLayer Perceptron) dims: the middle dimension, `dims[-1]` is the output dimension of this network activation: the activation function if_remove_out_layer: if remove the activation function of the output layer. """ if activation is None: activation = nn.ReLU net_list = [] for i in range(len(dims) - 1): net_list.extend([nn.Linear(dims[i], dims[i + 1]), activation()]) if if_raw_out: del net_list[-1] # delete the activation function of the output layer to keep raw output return nn.Sequential(*net_list) def layer_init_with_orthogonal(layer, std=1.0, bias_const=1e-6): torch.nn.init.orthogonal_(layer.weight, std) torch.nn.init.constant_(layer.bias, bias_const) class NnReshape(nn.Module): def __init__(self, *args): super().__init__() self.args = args def forward(self, x): return x.view((x.size(0),) + self.args) class DenseNet(nn.Module): # plan to hyper-param: layer_number def __init__(self, lay_dim): super().__init__() self.dense1 = nn.Sequential(nn.Linear(lay_dim * 1, lay_dim * 1), nn.Hardswish()) self.dense2 = nn.Sequential(nn.Linear(lay_dim * 2, lay_dim * 2), nn.Hardswish()) self.inp_dim = lay_dim self.out_dim = lay_dim * 4 def forward(self, x1): # x1.shape==(-1, lay_dim*1) x2 =, self.dense1(x1)), dim=1) return (x2, self.dense2(x2)), dim=1 ) # x3 # x2.shape==(-1, lay_dim*4) class ConvNet(nn.Module): # pixel-level state encoder def __init__(self, inp_dim, out_dim, image_size=224): super().__init__() if image_size == 224: = nn.Sequential( # size==(batch_size, inp_dim, 224, 224) nn.Conv2d(inp_dim, 32, (5, 5), stride=(2, 2), bias=False), nn.ReLU(inplace=True), # size=110 nn.Conv2d(32, 48, (3, 3), stride=(2, 2)), nn.ReLU(inplace=True), # size=54 nn.Conv2d(48, 64, (3, 3), stride=(2, 2)), nn.ReLU(inplace=True), # size=26 nn.Conv2d(64, 96, (3, 3), stride=(2, 2)), nn.ReLU(inplace=True), # size=12 nn.Conv2d(96, 128, (3, 3), stride=(2, 2)), nn.ReLU(inplace=True), # size=5 nn.Conv2d(128, 192, (5, 5), stride=(1, 1)), nn.ReLU(inplace=True), # size=1 NnReshape(-1), # size (batch_size, 1024, 1, 1) ==> (batch_size, 1024) nn.Linear(192, out_dim), # size==(batch_size, out_dim) ) elif image_size == 112: = nn.Sequential( # size==(batch_size, inp_dim, 112, 112) nn.Conv2d(inp_dim, 32, (5, 5), stride=(2, 2), bias=False), nn.ReLU(inplace=True), # size=54 nn.Conv2d(32, 48, (3, 3), stride=(2, 2)), nn.ReLU(inplace=True), # size=26 nn.Conv2d(48, 64, (3, 3), stride=(2, 2)), nn.ReLU(inplace=True), # size=12 nn.Conv2d(64, 96, (3, 3), stride=(2, 2)), nn.ReLU(inplace=True), # size=5 nn.Conv2d(96, 128, (5, 5), stride=(1, 1)), nn.ReLU(inplace=True), # size=1 NnReshape(-1), # size (batch_size, 1024, 1, 1) ==> (batch_size, 1024) nn.Linear(128, out_dim), # size==(batch_size, out_dim) ) else: assert image_size in {224, 112} def forward(self, x): # assert x.shape == (batch_size, inp_dim, image_size, image_size) x = x.permute(0, 3, 1, 2) x = x / 128.0 - 1.0 return @staticmethod def check(): inp_dim = 3 out_dim = 32 batch_size = 2 image_size = [224, 112][1] # from import Conv2dNet net = ConvNet(inp_dim, out_dim, image_size) image = torch.ones((batch_size, image_size, image_size, inp_dim), dtype=torch.uint8) * 255 print(image.shape) output = net(image) print(output.shape)