Source code for elegantrl.train.replay_buffer

import os
import math
import torch
from typing import Tuple
from torch import Tensor

from elegantrl.train.config import Config

[docs]class ReplayBuffer: # for off-policy def __init__(self, max_size: int, state_dim: int, action_dim: int, gpu_id: int = 0, num_envs: int = 1, if_use_per: bool = False, args: Config = Config()): self.p = 0 # pointer self.if_full = False self.cur_size = 0 self.add_size = 0 self.add_item = None self.max_size = max_size self.num_envs = num_envs self.device = torch.device(f"cuda:{gpu_id}" if (torch.cuda.is_available() and (gpu_id >= 0)) else "cpu") self.states = torch.empty((max_size, num_envs, state_dim), dtype=torch.float32, device=self.device) self.actions = torch.empty((max_size, num_envs, action_dim), dtype=torch.float32, device=self.device) self.rewards = torch.empty((max_size, num_envs), dtype=torch.float32, device=self.device) self.undones = torch.empty((max_size, num_envs), dtype=torch.float32, device=self.device) self.if_use_per = if_use_per if if_use_per: self.sum_trees = [SumTree(buf_len=max_size) for _ in range(num_envs)] self.per_alpha = getattr(args, 'per_alpha', 0.6) # alpha = (Uniform:0, Greedy:1) self.per_beta = getattr(args, 'per_beta', 0.4) # alpha = (Uniform:0, Greedy:1) """PER. Prioritized Experience Replay. Section 4 alpha, beta = 0.7, 0.5 for rank-based variant alpha, beta = 0.6, 0.4 for proportional variant """ else: self.sum_trees = None self.per_alpha = None self.per_beta = None def update(self, items: Tuple[Tensor, ...]): self.add_item = items states, actions, rewards, undones = items # assert states.shape[1:] == (env_num, state_dim) # assert actions.shape[1:] == (env_num, action_dim) # assert rewards.shape[1:] == (env_num,) # assert undones.shape[1:] == (env_num,) self.add_size = rewards.shape[0] p = self.p + self.add_size # pointer if p > self.max_size: self.if_full = True p0 = self.p p1 = self.max_size p2 = self.max_size - self.p p = p - self.max_size self.states[p0:p1], self.states[0:p] = states[:p2], states[-p:] self.actions[p0:p1], self.actions[0:p] = actions[:p2], actions[-p:] self.rewards[p0:p1], self.rewards[0:p] = rewards[:p2], rewards[-p:] self.undones[p0:p1], self.undones[0:p] = undones[:p2], undones[-p:] else: self.states[self.p:p] = states self.actions[self.p:p] = actions self.rewards[self.p:p] = rewards self.undones[self.p:p] = undones if self.if_use_per: '''data_ids for single env''' data_ids = torch.arange(self.p, p, dtype=torch.long, device=self.device) if p > self.max_size: data_ids = torch.fmod(data_ids, self.max_size) '''apply data_ids for vectorized env''' for sum_tree in self.sum_trees: sum_tree.update_ids(data_ids=data_ids.cpu(), prob=10.) self.p = p self.cur_size = self.max_size if self.if_full else self.p def sample(self, batch_size: int) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: sample_len = self.cur_size - 1 ids = torch.randint(sample_len * self.num_envs, size=(batch_size,), requires_grad=False) ids0 = torch.fmod(ids, sample_len) # ids % sample_len ids1 = torch.div(ids, sample_len, rounding_mode='floor') # ids // sample_len return (self.states[ids0, ids1], self.actions[ids0, ids1], self.rewards[ids0, ids1], self.undones[ids0, ids1], self.states[ids0 + 1, ids1],) # next_state def sample_for_per(self, batch_size: int) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: beg = -self.max_size end = (self.cur_size - self.max_size) if (self.cur_size < self.max_size) else -1 '''get is_indices, is_weights''' is_indices: list = [] is_weights: list = [] assert batch_size % self.num_envs == 0 sub_batch_size = batch_size // self.num_envs for env_i in range(self.num_envs): sum_tree = self.sum_trees[env_i] _is_indices, _is_weights = sum_tree.important_sampling(batch_size, beg, end, self.per_beta) is_indices.append(_is_indices + sub_batch_size * env_i) is_weights.append(_is_weights) is_indices: Tensor = torch.hstack(is_indices).to(self.device) is_weights: Tensor = torch.hstack(is_weights).to(self.device) ids0 = torch.fmod(is_indices, self.cur_size) # is_indices % sample_len ids1 = torch.div(is_indices, self.cur_size, rounding_mode='floor') # is_indices // sample_len return ( self.states[ids0, ids1], self.actions[ids0, ids1], self.rewards[ids0, ids1], self.undones[ids0, ids1], self.states[ids0 + 1, ids1], # next_state is_weights, # important sampling weights is_indices, # important sampling indices ) def td_error_update_for_per(self, is_indices: Tensor, td_error: Tensor): # td_error = (q-q).detach_().abs() prob = td_error.clamp(1e-8, 10).pow(self.per_alpha) # self.sum_tree.update_ids(is_indices.cpu(), prob.cpu()) batch_size = td_error.shape[0] sub_batch_size = batch_size // self.num_envs for env_i in range(self.num_envs): sum_tree = self.sum_trees[env_i] slice_i = env_i * sub_batch_size slice_j = slice_i + sub_batch_size sum_tree.update_ids(is_indices[slice_i:slice_j].cpu(), prob[slice_i:slice_j].cpu()) def save_or_load_history(self, cwd: str, if_save: bool): item_names = ( (self.states, "states"), (self.actions, "actions"), (self.rewards, "rewards"), (self.undones, "undones"), ) if if_save: for item, name in item_names: if self.cur_size == self.p: buf_item = item[:self.cur_size] else: buf_item = torch.vstack((item[self.p:self.cur_size], item[0:self.p])) file_path = f"{cwd}/replay_buffer_{name}.pth" print(f"| buffer.save_or_load_history(): Save {file_path}"), file_path) elif all([os.path.isfile(f"{cwd}/replay_buffer_{name}.pth") for item, name in item_names]): max_sizes = [] for item, name in item_names: file_path = f"{cwd}/replay_buffer_{name}.pth" print(f"| buffer.save_or_load_history(): Load {file_path}") buf_item = torch.load(file_path) max_size = buf_item.shape[0] item[:max_size] = buf_item max_sizes.append(max_size) assert all([max_size == max_sizes[0] for max_size in max_sizes]) self.cur_size = max_sizes[0]
class SumTree: """ BinarySearchTree for PER (SumTree) Contributor: Github GyChou, Github mississippiu Reference: Reference: """ def __init__(self, buf_len: int): self.buf_len = buf_len # replay buffer len self.max_len = (buf_len - 1) + buf_len # parent_nodes_num + leaf_nodes_num self.depth = math.ceil(math.log2(self.max_len)) self.tree = torch.zeros(self.max_len, dtype=torch.float32) def update_id(self, data_id: int, prob=10): # 10 is max_prob tree_id = data_id + self.buf_len - 1 delta = prob - self.tree[tree_id] self.tree[tree_id] = prob for depth in range(self.depth - 2): # propagate the change through tree tree_id = (tree_id - 1) // 2 # faster than the recursive loop self.tree[tree_id] += delta def update_ids(self, data_ids: Tensor, prob: Tensor = 10.): # 10 is max_prob l_ids = data_ids + self.buf_len - 1 self.tree[l_ids] = prob for depth in range(self.depth - 2): # propagate the change through tree p_ids = ((l_ids - 1) // 2).unique() # parent indices l_ids = p_ids * 2 + 1 # left children indices r_ids = l_ids + 1 # right children indices self.tree[p_ids] = self.tree[l_ids] + self.tree[r_ids] l_ids = p_ids def get_leaf_id_and_value(self, v) -> Tuple[int, float]: """Tree structure and array storage: Tree index: 0 -> storing priority sum | | 1 2 | | | | 3 4 5 6 -> storing priority for transitions Array type for storing: [0, 1, 2, 3, 4, 5, 6] """ p_id = 0 # the leaf's parent node for depth in range(self.depth - 2): # propagate the change through tree l_id = min(2 * p_id + 1, self.max_len - 1) # the leaf's left node r_id = l_id + 1 # the leaf's right node if v <= self.tree[l_id]: p_id = l_id else: v -= self.tree[l_id] p_id = r_id return p_id, self.tree[p_id] # leaf_id and leaf_value def important_sampling(self, batch_size: int, beg: int, end: int, per_beta: float) -> Tuple[Tensor, Tensor]: # get random values for searching indices with proportional prioritization values = (torch.arange(batch_size) + torch.rand(batch_size)) * (self.tree[0] / batch_size) # get proportional prioritization leaf_ids, leaf_values = list(zip(*[self.get_leaf_id_and_value(v) for v in values])) leaf_ids = torch.tensor(leaf_ids, dtype=torch.long) leaf_values = torch.tensor(leaf_values, dtype=torch.float32) indices = leaf_ids - (self.buf_len - 1) assert 0 <= indices.min() assert indices.max() < self.buf_len prob_ary = leaf_values / self.tree[beg:end].min() weights = torch.pow(prob_ary, -per_beta) return indices, weights