PPO-KL散度近端策略优化玩cartpole游戏

news/2024/9/17 3:56:05

 

其实KL散度在这个游戏里的作用不大,游戏的action比较简单,不像LM里的action是一个很大的向量,可以直接用surr1,最大化surr1,实验测试确实是这样,而且KL的系数不能给太大,否则惩罚力度太大,action model 和ref model产生的action其实分布的差距并不太大

 

import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pygame
import sys
from collections import deque# 定义策略网络
class PolicyNetwork(nn.Module):def __init__(self):super(PolicyNetwork, self).__init__()self.fc = nn.Sequential(nn.Linear(4, 2),nn.Tanh(),nn.Linear(2, 2),  # CartPole的动作空间为2nn.Softmax(dim=-1))def forward(self, x):return self.fc(x)# 定义值网络
class ValueNetwork(nn.Module):def __init__(self):super(ValueNetwork, self).__init__()self.fc = nn.Sequential(nn.Linear(4, 2),nn.Tanh(),nn.Linear(2, 1))def forward(self, x):return self.fc(x)# 经验回放缓冲区
class RolloutBuffer:def __init__(self):self.states = []self.actions = []self.rewards = []self.dones = []self.log_probs = []def store(self, state, action, reward, done, log_prob):self.states.append(state)self.actions.append(action)self.rewards.append(reward)self.dones.append(done)self.log_probs.append(log_prob)def clear(self):self.states = []self.actions = []self.rewards = []self.dones = []self.log_probs = []def get_batch(self):return (torch.tensor(self.states, dtype=torch.float),torch.tensor(self.actions, dtype=torch.long),torch.tensor(self.rewards, dtype=torch.float),torch.tensor(self.dones, dtype=torch.bool),torch.tensor(self.log_probs, dtype=torch.float))# PPO更新函数
def ppo_update(policy_net, value_net, optimizer_policy, optimizer_value, buffer, epochs=100, gamma=0.99, clip_param=0.2):states, actions, rewards, dones, old_log_probs = buffer.get_batch()returns = []advantages = []G = 0adv = 0dones = dones.to(torch.int)# print(dones)for reward, done, value in zip(reversed(rewards), reversed(dones), reversed(value_net(states))):if done:G = 0adv = 0G = reward + gamma * G  #蒙特卡洛回溯G值delta = reward + gamma * value.item() * (1 - done) - value.item()  #TD差分# adv = delta + gamma * 0.95 * adv * (1 - done)  #adv = delta + adv*(1-done)returns.insert(0, G)advantages.insert(0, adv)returns = torch.tensor(returns, dtype=torch.float)  #价值advantages = torch.tensor(advantages, dtype=torch.float)advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)  #add baselinefor _ in range(epochs):action_probs = policy_net(states)dist = torch.distributions.Categorical(action_probs)new_log_probs = dist.log_prob(actions)ratio = (new_log_probs - old_log_probs).exp()KL = new_log_probs.exp()*(new_log_probs - old_log_probs).mean()   #KL散度 p*log(p/p')#下面三行是核心surr1 = ratio * advantagesPPO1,PPO2 = True,False# print(surr1,KL*500)if PPO1 == True:actor_loss = -(surr1 - KL).mean()if PPO2 == True:surr2 = torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param) * advantagesactor_loss = -torch.min(surr1, surr2).mean()optimizer_policy.zero_grad()actor_loss.backward()optimizer_policy.step()value_loss = (returns - value_net(states)).pow(2).mean()optimizer_value.zero_grad()value_loss.backward()optimizer_value.step()# 初始化环境和模型
env = gym.make('CartPole-v1')
policy_net = PolicyNetwork()
value_net = ValueNetwork()
optimizer_policy = optim.Adam(policy_net.parameters(), lr=3e-4)
optimizer_value = optim.Adam(value_net.parameters(), lr=1e-3)
buffer = RolloutBuffer()# Pygame初始化
pygame.init()
screen = pygame.display.set_mode((600, 400))
clock = pygame.time.Clock()draw_on = False
# 训练循环
state = env.reset()
for episode in range(10000):  # 训练轮次done = Falsestate = state[0]step= 0while not done:step+=1state_tensor = torch.FloatTensor(state).unsqueeze(0)action_probs = policy_net(state_tensor)   #旧policy推理数据dist = torch.distributions.Categorical(action_probs)action = dist.sample()log_prob = dist.log_prob(action)next_state, reward, done, _ ,_ = env.step(action.item())buffer.store(state, action.item(), reward, done, log_prob)state = next_state# 实时显示for event in pygame.event.get():if event.type == pygame.QUIT:pygame.quit()sys.exit()if draw_on:# 清屏并重新绘制
            screen.fill((0, 0, 0))cart_x = int(state[0] * 100 + 300)  # 位置转换为屏幕坐标pygame.draw.rect(screen, (0, 128, 255), (cart_x, 300, 50, 30))pygame.draw.line(screen, (255, 0, 0), (cart_x + 25, 300), (cart_x + 25 - int(50 * np.sin(state[2])), 300 - int(50 * np.cos(state[2]))), 5)pygame.display.flip()clock.tick(60)if step >2000:draw_on = Trueppo_update(policy_net, value_net, optimizer_policy, optimizer_value, buffer)buffer.clear()state = env.reset()print(f'Episode {episode} completed , reward:  {step}.')# 结束训练
env.close()
pygame.quit()

 

效果:

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.ryyt.cn/news/31987.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈,一经查实,立即删除!

相关文章

steam发行问题

非常重要,最新steam营销推广 https://store.steampowered.com/news/group/4145017/view/4191238396458987547

软件设计模式概念篇

创建型模式 1、创建型模式(Creational Pattern)对类的实例化过程进行了抽象,能够将软件模块中对象的创建和对象的使用分离。 2、为了使软件的结构更加清晰,外界对于这些对象只需要知道它们共同的接口,而不需要清楚其具体的实现细节,使整个系统的设计更加符合单一职责原则。…

mysql中explain命令详解

前言 我们可以使用 explain 命令来查看 SQL 语句的执行计划,从而帮助我们优化慢查询。 使用注意:使用的 mysql 版本为 8.0.28数据准备 CREATE TABLE `tb_product2` (`id` bigint NOT NULL AUTO_INCREMENT COMMENT 商品ID,`name` varchar(20) DEFAULT NULL COMMENT 商品名称,`…

vasp极化计算

为什么我算一个结构,理论上应该是右极化态的,为什么只有离子极化,没有电子极化?是铁电相构建有问题还是计算的数据有问题?

超线程/同步多线程(HT/SMT)技术

超线程/同步多线程(HT/SMT)技术 虽然现在超线程(Hyper-Threading)被大家广泛接受,并把所有一个物理核心上有多个虚拟核心的技术都叫做超线程,但这其实是Intel的一个营销名称。而实际上这一类技术的(学术/技术)通行名称是同步多线程(SMT,Simultaneous Multithreading)…

Linux基础-文件特殊权限

# day13今日安排默写昨日作业讲解文件权限篇综合知识脑图特殊权限(了解)linux提供的12个特殊权限 默认的9位权限 rwx rwx rwx还有三个隐藏的特殊权限,如下 suid 比如 /usr/bin/passwdsgidsbit 特殊权限对照表类别 suid sgid sticky字符表示 S S T出现位置 用户权限位x 用户…

【django学习-28】列表界面模板下载与上传文件

前言,我们在实际项目开发过程中,经常有列表界面,有上传功能,并且支持先下载模板,后上传 1.实现效果与前端展示<form method="post" enctype="multipart/form-data" action="/depart/multi/">{% csrf_token %}<div class="for…

AoPS - Chapter 3 More Triangles

本章主要讲解正弦定理、余弦定理、海伦公式、Stewarts Theorem。本章主要讲解正弦定理、余弦定理、海伦公式、Stewarts Theorem。 本文在没有特殊说明时,默认在 \(\triangle ABC\) 中:\(a\) 为角 \(A\) 对边,\(b\) 为角 \(B\) 对边,\(c\) 为角 \(C\) 对边。 \(r\) 为内切圆…