-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsac.py
More file actions
202 lines (161 loc) · 7.66 KB
/
sac.py
File metadata and controls
202 lines (161 loc) · 7.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import numpy as np
import torch
from torch import nn
from torch.distributions.normal import Normal
import torch.nn.functional as F
from mmcv.runner.optimizer import build_optimizer
from ..builder import (AGENTS, build_buffer, build_network)
from .td3 import TwinCritic
LOG_STD_MAX = 2
LOG_STD_MIN = -20
class GaussianActor(nn.Module):
def __init__(self, network_cfg):
super().__init__()
self.network = build_network(network_cfg)
self.num_actions = network_cfg['out_channels'] //2
def forward(self,state):
net_out = self.network(state)
mu, log_std = net_out[:,:self.num_actions], net_out[:,-self.num_actions:]
log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
std = log_std.exp()
return mu, std
def action(self,state, stochastic=True):
mu, std = self.forward(state)
# Pre-squashed Action
pi_dist = Normal(mu,std)
u = pi_dist.sample() if stochastic else mu
# Squashed action
action = torch.tanh(u)
return action
def sample(self,state):
mu, std = self.forward(state)
# Pre-squashed Action
pi_dist = Normal(mu,std)
u = pi_dist.rsample()
# Compute log_prob from Gaussian, and then apply correction for Tanh squashing.
# See: appendix C in SAC paper (arXiv 1801.01290)
log_prob = pi_dist.log_prob(u).sum(axis=-1)
log_prob -= (2*(np.log(2) - u - F.softplus(-2*u))).sum(axis=1)
# Squashed action
action = torch.tanh(u)
return action, log_prob
@AGENTS.register_module()
class SAC:
def __init__(self,
num_states,
num_actions,
actor=dict(type='MLP'),
critic=dict(type='MLP'),
buffer = dict(capacity=2000, batch_size=128),
actor_optimizer=dict(type='Adam', lr=1e-3),
critic_optimizer=dict(type='Adam', lr=1e-3),
alpha_optimizer=dict(type='Adam', lr=1e-3),
alpha0 = 0.1,
gamma=0.9,
explore_rate=0.3,
polyak = 0.99,
target_update_iters=1,
start_steps=100,
**kwargs):
self.num_actions = num_actions
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
# The actor network
actor_cfg = actor.copy()
actor_cfg['in_channels']=num_states
actor_cfg['out_channels']=2*num_actions #(mean, std)
self.actor = GaussianActor(actor_cfg).to(self.device)
self.actor_optimizer = build_optimizer(self.actor, actor_optimizer)
# The critic and critic target network
critic_cfg = critic.copy()
critic_cfg['in_channels']=num_states+num_actions
critic_cfg['out_channels']=1
self.critic = TwinCritic(critic_cfg).to(self.device)
self.critic_target = TwinCritic(critic_cfg).to(self.device)
# The critic and critic target twin-networks
self.critic_optimizer = build_optimizer(self.critic, critic_optimizer)
# The memory is used to store and replay the experience
self.memory = build_buffer(buffer)
# Auto-tune alpha
self.alpha = alpha0 #init value
self.log_alpha = torch.tensor(np.log(alpha0),requires_grad=True, device=self.device)
self.target_entropy = -self.num_actions
self.alpha_optimizer = torch.optim.Adam([self.log_alpha],alpha_optimizer['lr'])
# Agent parameters
self.gamma = gamma
self.explore_rate = explore_rate
self.polyak = polyak
self.target_update_iters= target_update_iters
self.start_steps = start_steps
self.learn_step_counter = 0
# Network optimizer
self.loss_func = nn.MSELoss()
self._init_weights()
def _init_weights(self):
self.critic_target.load_state_dict(self.critic.state_dict())
def update_target_networks(self):
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(self.polyak * target_param.data + (1 - self.polyak) * param.data)
def store_transition(self, state, action, reward, new_state, done):
self.memory.addMemory(state, action, reward, new_state, done)
def act(self,state, is_train=False):
# To improve exploration at the start of training,
# in the first start_steps, the agent takes actions
# which are uniformly sampled from [-1,1]
if is_train and (self.learn_step_counter < self.start_steps) \
and (np.random.randn() <= self.explore_rate):# random policy
return np.random.uniform(low=-1.0,high=1.0,size=self.num_actions)
input = torch.Tensor(state).unsqueeze(0).to(self.device)
action = self.actor.action(input, stochastic = is_train)
return action.cpu().detach().numpy().flatten()
def learn(self, state, action, reward, new_state, done):
# Store the trainsition
self.store_transition(state, action, reward, new_state, done)
#sample batch from memory
mini_batch = self.memory.getMiniBatch(device=self.device)
(states, actions, rewards, next_states, finals) = mini_batch
# compute the loss for the critic networks
q1_eval, q2_eval = self.critic(states, actions)
with torch.no_grad():
q_target = self.get_critic_targets(rewards, next_states, finals)
loss_critic = self.loss_func(q1_eval, q_target) + self.loss_func(q2_eval, q_target)
# backward and optimize the critic network
self.critic_optimizer.zero_grad()
loss_critic.backward()
self.critic_optimizer.step()
#update the actor and target networks once every target_update_iters
if self.target_update_iters==1 or self.learn_step_counter % self.target_update_iters ==0:
# Actor Loss
pred_actions, log_prob = self.actor.sample(states)
q1_val,q2_val = self.critic(states, pred_actions)
q_val = torch.min(q1_val,q2_val).squeeze()
# We want to maximize the q_val
loss_actor = (self.alpha*log_prob -q_val).mean()
# backward and optimize the actor network
self.actor_optimizer.zero_grad()
loss_actor.backward()
self.actor_optimizer.step()
# Update target network by momentum
self.update_target_networks()
# Auto-tune alpha.
loss_alpha = -self.log_alpha * (log_prob.detach().mean() + self.target_entropy)
self.alpha_optimizer.zero_grad()
loss_alpha.backward()
self.alpha_optimizer.step()
self.alpha = self.log_alpha.detach().exp()
self.learn_step_counter+=1
def get_critic_targets(self, rewards, next_states, finals):
"""
Bootstrap the target
"""
# Step 1: Predict the next actions using the target actor network
next_actions, next_log_prob = self.actor.sample(next_states)
# Step 2: The two Critic targets take each the couple (s’, a’) as input
# and return two Q-values Qt1(s’,a’) and Qt2(s’,a’) as outputs
q1_target, q2_target = self.critic_target(next_states, next_actions)
# Step 3: We pick the minimum of these two Q-values, and add the entropy
q_target_next = torch.min(q1_target, q2_target).squeeze()
# Step 5: We get the final target of the two Critic models,
# which is: Qt = r + γ * (min(Qt1, Qt2) - alpha*log_prob(a))\
# where γ is the discount factor
q_target = rewards + self.gamma* (1-finals) * (q_target_next - self.alpha*next_log_prob)
return q_target.unsqueeze(1) # Output [batch_size, 1]