当前位置:网站首页>Theoretical basis and code implementation of dueling dqn [pytoch + pendulum-v0]
Theoretical basis and code implementation of dueling dqn [pytoch + pendulum-v0]
2022-07-19 03:20:00 【lucky-wz】
Dueling DQN Theoretical basis
Dueling DQN It's based on DQN The improved algorithm , Its main breakthrough is Use the model structure to express the value function in a more detailed form , Make the model have better performance .
First, we can give the following formula and define a new variable :
q ( s t , a t ) = v ( s t ) + A ( s t + a t ) q(s_t,a_t)=v(s_t)+A(s_t+a_t) q(st,at)=v(st)+A(st+at)
in other words , Value function based on state and action q q q It can be decomposed into state based value functions v v v And advantage function A A A. Due to the existence
E a t [ q ( s t , a t ) ] = v ( s t ) E_{a_t}[q(s_t,a_t)]=v(s_t) Eat[q(st,at)]=v(st)
So if the value functions of all state actions are different , Some state action values q ( s , a ) q(s,a) q(s,a) Must be higher than the value of the state v ( s ) v(s) v(s), Of course, there will also be some state actions that are less than value , Then the dominance function can express the difference between current action and average performance : If due to average performance , Then the dominant function is positive , On the contrary, it is negative .
Since the concept of such a natural decomposition , Then such a structure can be considered when designing the model : On the basis of keeping the main structure of the network unchanged , Put the original network A single output becomes two outputs , An output is used to output v v v, It is a one-dimensional scalar ; Another output is used to output A A A, It has the same dimensions as the number of actions , Finally, add up the output of the two parts to make it the original q q q value .
After changing the output structure , Only a few changes to the model are needed to realize the function : The front part of the model can remain unchanged , The back part of the model changes from one output to two output , Finally, it is merged into one result .
Just doing this decomposition can't get good results , Because when q q q When the value is certain , v v v and a a a There are infinite possible combinations ( for example , For the same Q Q Q value , If you will V V V Value plus a constant of any size C C C, And then all A A A Value minus C C C, Then the obtained value remains unchanged , This leads to the instability of training .), In fact, only a small part of the combination is reasonable 、 Close to the real value . In order to solve q q q Values and v v v Value modeling is not unique , We need to do the advantage function A A A Make a limit . obviously A A A The expected value of the function is 0:
E a [ A ( s t , a t ) ] = E a ( q ( s t , a t ) − v ( s t ) ) = v ( s t ) − v ( s t ) = 0 E_a[A(s_t,a_t)]=E_a(q(s_t,a_t)-v(s_t))=v(s_t)-v(s_t)=0 Ea[A(st,at)]=Ea(q(st,at)−v(st))=v(st)−v(st)=0
Then we can analyze the output A A A Value constraints , For example, change the formula into :
q ( s t , a t ) = v ( s t ) + ( A ( s t , a t ) − 1 ∣ A ∣ ∑ a ′ A ( s t , a t ′ ) ) q(s_t,a_t)=v(s_t)+(A(s_t,a_t)-\dfrac{1}{|A|}\sum\limits_{a'}A(s_t,a_t')) q(st,at)=v(st)+(A(st,at)−∣A∣1a′∑A(st,at′))
Let every one A A A Value minus all... In the current state A A A The average of the values , We can guarantee that the expectation mentioned above is 0 Constraints , So it increases v v v and A A A The output stability of .
Another constraint is to subtract the current state A A A The maximum value of .
q ( s t , a t ) = v ( s t ) + ( A ( s t , a t ) − max a ′ A ( s t , a t ′ ) ) q(s_t, a_t)=v(s_t)+(A(s_t,a_t)-\max\limits_{a'}A(s_t,a_t')) q(st,at)=v(st)+(A(st,at)−a′maxA(st,at′))
There are many benefits to such decomposition :
- Through this decomposition , Not only can we get the given state and action q q q value , You can also get v v v Values and A A A value . In this way, if you need to use v v v When the value of , You can also get v v v Value without training a network .
- By explicitly giving v v v Function output value , Every time you update , Will be explicitly updated v v v function , such v v v The update frequency of the function will increase deterministically .
- From the perspective of network training , It was supposed to be training ∣ A ∣ |A| ∣A∣ The value is [ 0 , ∞ ] [0,\infty] [0,∞] The numerical , It becomes a training with a value of [ 0 , ∞ ] [0,\infty] [0,∞] And ∣ A ∣ |A| ∣A∣ The average is 0, The actual value is [ − C , C ] [-C,C] [−C,C] The numerical , For network training , The latter is obviously more friendly and easy .
- For some reinforcement learning problems , A A A The value range is much wider than v v v Small value , such It's easier to keep the sequence of actions by training the two separately . because A A A The value range is relatively small , Therefore, it is more sensitive to model update , such When updating the model, it will be easier to consider the relative changes with other actions , It will not make the original action sequence be accidentally broken because of an update . For example, in the following car driving game , The part of the agent's attention is displayed in orange , When there is no car in front of the agent , The action of the vehicle itself is not much different , At this time, the agent pays more attention to the state value , And when there is a car in front of the agent ( The agent needs to overtake ), The agent begins to pay attention to the difference of advantage value of different actions .
Dueling DQN Code implementation
Dueling DQN And DQN The difference is only in the network structure , Most of the code can still be used . We define a composite neural network of state value function and dominance function VAnet
.
class Qnet(nn.Module):
def __init__(self, state_dim, hidden_dim, action_dim):
super(Qnet, self).__init__()
self.layer = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim)
)
def forward(self, s):
s = self.layer(s)
return s
class VAnet(nn.Module):
def __init__(self, state_dim, hidden_dim, action_dim):
super(VAnet, self).__init__()
self.fc1 = nn.Linear(state_dim, hidden_dim) # Share the network part
self.fc_A = nn.Linear(hidden_dim, action_dim)
self.fc_V = nn.Linear(hidden_dim, 1)
def forward(self, x):
A = self.fc_A(F.relu(self.fc1(x)))
V = self.fc_V(F.relu(self.fc1(x)))
Q = V + A - A.mean(1).view(-1, 1) # Q Values are determined by V Values and A The value is calculated to
return Q
DQN The algorithm includes Double DQN and Dueling DQN
class DQN:
def __init__(self, args):
self.args = args
self.hidden_dim = args.hidden_size
self.batch_size = args.batch_size
self.lr = args.lr
self.gamma = args.gamma # The discount factor
self.epsilon = args.epsilon # epsilon- Greedy strategy
self.target_update = args.target_update # Target network update frequency
self.count = 0 # Counter , Record the number of updates
self.num_episodes = args.num_episodes
self.minimal_size = args.minimal_size
self.dqn_type = args.dqn_type
self.env = gym.make(args.env_name)
random.seed(args.seed)
np.random.seed(args.seed)
self.env.seed(args.seed)
torch.manual_seed(args.seed)
self.replay_buffer = ReplayBuffer(args.buffer_size)
self.state_dim = self.env.observation_space.shape[0]
self.action_dim = 11 # Divide the continuous action into 11 A discrete action
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#########################################################################################################
if self.dqn_type == "DuelingDQN": # Dueling DQN Adopt a different network framework
self.q_net = VAnet(self.state_dim, self.hidden_dim, self.action_dim).to(self.device)
self.target_q_net = VAnet(self.state_dim, self.hidden_dim, self.action_dim).to(self.device)
else:
self.q_net = Qnet(self.state_dim, self.hidden_dim, self.action_dim).to(self.device)
self.target_q_net = Qnet(self.state_dim, self.hidden_dim, self.action_dim).to(self.device)
#########################################################################################################
self.optimizer = Adam(self.q_net.parameters(), lr=self.lr)
def select_action(self, state): # epsilon- Greedy strategies take action
if np.random.random() < self.epsilon:
action = np.random.randint(self.action_dim)
else:
state = torch.tensor([state], dtype=torch.float).to(self.device)
action = self.q_net(state).argmax().item()
return action
def max_q_value(self, state):
state = torch.tensor([state], dtype=torch.float).to(self.device)
return self.q_net(state).max().item()
def update(self, transition):
states = torch.tensor(transition["states"], dtype=torch.float).to(self.device)
actions = torch.tensor(transition["actions"]).view(-1, 1).to(self.device)
rewards = torch.tensor(transition["rewards"], dtype=torch.float).view(-1, 1).to(self.device)
next_states = torch.tensor(transition["next_states"], dtype=torch.float).to(self.device)
dones = torch.tensor(transition["dones"], dtype=torch.float).view(-1, 1).to(self.device)
q_values = self.q_net(states).gather(1, actions) # Q value
# The maximum of the next state Q value
#########################################################################################################
if self.dqn_type == 'DoubleDQN':
max_action = self.q_net(next_states).max(1)[1].view(-1, 1)
max_next_q_values = self.target_q_net(next_states).gather(1, max_action)
else: # DQN
max_next_q_values = self.target_q_net(next_states).max(1)[0].view(-1, 1)
#########################################################################################################
q_targets = rewards + self.gamma * max_next_q_values * (1 - dones) # TD error
loss = torch.mean(F.mse_loss(q_values, q_targets)) # Mean square error loss function
self.optimizer.zero_grad() # PyTorch The default gradient in will accumulate , You need to explicitly set the gradient to 0
loss.backward() # Back propagation update parameters
self.optimizer.step()
if self.count % self.target_update == 0:
self.target_q_net.load_state_dict(self.q_net.state_dict()) # Update target network
self.count += 1
def train_DQN(self):
return_list = []
max_q_value_list = []
max_q_value = 0
for i in range(10):
with tqdm(total=int(self.num_episodes / 10), desc=f'Iteration {
i}') as pbar:
for episode in range(self.num_episodes // 10):
episode_return = 0
state = self.env.reset()
while True:
action = self.select_action(state)
max_q_value = self.max_q_value(state) * 0.005 + max_q_value * 0.995 # Smoothing
max_q_value_list.append(max_q_value) # Save the maximum for each state Q value
action_continuous = dis_to_con(action, self.env, self.action_dim)
next_state, reward, done, _ = self.env.step([action_continuous])
self.replay_buffer.add(state, action, reward, next_state, done)
if self.replay_buffer.size() > self.minimal_size:
s, a, r, s_, d = self.replay_buffer.sample(self.batch_size)
transitions = {
"states": s, "actions": a, "rewards": r, "next_states": s_, "dones": d}
self.update(transitions)
state = next_state
episode_return += reward
if done: break
return_list.append(episode_return)
if (episode + 1) % 10 == 0:
pbar.set_postfix(
{
"episode": f"{
self.num_episodes / 10 * i + episode + 1}",
"return": f"{
np.mean(return_list[-10:]):3f}"
}
)
pbar.update(1)
return return_list, max_q_value_list
Code run results
According to the code running results, we can find that , Compared to traditional DQN,Dueling DQN Learning under multiple action choices is more stable , The maximum return is also greater . from Dueling DQN The principle of , As the action space increases ,Dueling DQN Compared with DQN The advantage is more obvious .
in general ,Dueling DQN Be able to learn the differences of different movements well , It is very effective in the environment with large action space .
\quad
\quad
\quad
Reference resources :
- 《 Hands on learning, reinforcement learning 》
- 《 Strengthen learning Essentials 》
\quad
\quad
\quad
Continuous updating ~ Please correct any mistakes !
边栏推荐
- Automatic assembly & set injection
- GraphQL初识
- This is a mathematical problem
- 关于XML文件(六)-与JSON的区别
- [MCU simulation] (XVII) control transfer instructions - call and return instructions
- [MySQL] data query operation (select statement)
- mysqldump: [Warning] Using a password on the command line interface can be insecure.
- 乐视还有400多位员工?过着没有老板的神仙日子 官方出来回应了...
- [MCU simulation] (VI) addressing mode - index addressing and relative addressing
- C language foundation day4 array
猜你喜欢
Letv a plus de 400 employés? Le jour de l'immortel sans patron, les autorités ont répondu...
Several methods of face detection
关于XML文件(六)-与JSON的区别
JPA初识(ORM思想、JPA的基本操作)
MySQL optimized index
JDBC连接Mysql数据库
04_服务注册Eureka
By voting for the destruction of STI by Dao, seektiger is truly community driven
Yolov5 ncnn reasoning
Zabbix6.0 monitoring vcenter7.0
随机推荐
[single chip microcomputer simulation] (XI) instruction system logic operation instruction - logic and instruction anl, logic or instruction ORL
重写equals为什么要重写hashcode
GraphQL初识
The place where the dream begins ---- first knowing C language
SysTick定时器的基础学习以及手撕代码
[Jianzhi offer] 31-35 questions (judge whether a sequence is one of the out of stack sequences, sequence print binary tree, branch print, and reverse print each line), judge whether the sequence is th
Yolov5 ncnn reasoning
MySQL optimized index
【模板记录】字符串哈希判断回文串
Pure virtual function
[MCU simulation] (VII) addressing mode - bit addressing
MySQL multi table query
CorelDRAW cannot be installed. Solution
mysql复制表
D. Permutation Restoration(贪心/双指针/set)
Ubuntu clear CUDA cache
Go语言 实现发送短信验证码 并登录
Zabbix6.0 monitoring vcenter7.0
Automatic assembly & set injection
Backup kubernetes backup etcd data