强化学习笔记+代码(二):SARSA算法原理和Agent实现

迈不过友情╰ 2023-07-17 15:59 571阅读 0赞

本文主要整理和参考了李宏毅的强化学习系列课程和莫烦python的强化学习教程
本系列主要分几个部分进行介绍

  1. 强化学习背景介绍
  2. SARSA算法原理和Agent实现
  3. Q-learning算法原理和Agent实现
  4. DQN算法原理和Agent实现(tensorflow)
  5. Double-DQN、Dueling DQN算法原理和Agent实现(tensorflow)
  6. Policy Gradients算法原理和Agent实现(tensorflow)
  7. Actor-Critic、A2C、A3C算法原理和Agent实现(tensorflow)

一、SARSA算法原理

上一篇内容奖励,强化学习的主要功能就是让agent学习尽可能好的动作action,使其后续获得的奖励尽可能的大。
假设在时刻t时,处于状态长期奖励为:
在这里插入图片描述
其中 r t + n r_{t+n} rt+n为t+n时获得的局部奖励, γ γ γ为下一期奖励传递到当期的衰减因子。则状态s的价值为 G t G_t Gt的条件期望:
在这里插入图片描述
同样定义状态s时执行动作a的价值也为 G t G_t Gt的条件期望:
在这里插入图片描述
可知: E ( Q π ( s , a ) ) = V π ( s ) E(Q_π(s,a))=V_π(s) E(Qπ(s,a))=Vπ(s)
下面直接给出深度学习最核心的Bellman 公式
在这里插入图片描述
Bellman 公式十分的直观, V π ( s ) V_π(s) Vπ(s)=根据动作走到下个状态的奖励+下个状态长期价值*衰减值的期望。价值会不断的传递,因此可以看出 V π ( s ) V_π(s) Vπ(s)衡量的是状态s的长期价值。
因为RL需要是Agent不断变强,就可以理解为让状态或让动作的价值不断变大,因此会选择根据如下方式获得新的 V π ( s ) V_π(s) Vπ(s)和 Q π ( s , a ) Q_π(s,a) Qπ(s,a)
在这里插入图片描述
上面几个公式是强化学习的精髓

SARSA算法是一种典型性的value-based和on-policy的算法。下面直接给出SARSA的算法
在这里插入图片描述
SARSA的算法中有几个需要注意的地方下面用彩色标记标记出
在这里插入图片描述
由于与环境进行了实际的交互,因此会执行动作a后到达状态s’,agent会根据状态s’直接执行动作a’。上面算法图中红框为新的状态s和动作a的价值,即 Q ∗ ( s , a ) Q_*(s,a) Q∗(s,a),蓝色框是旧的 Q ( s , a ) Q(s,a) Q(s,a),可以看出 Q ( s , a ) Q(s,a) Q(s,a)的更新方法与梯度下降十分的类似。还需要注意到,在循环中黄色框,循环末尾将s’估值给s,将动作a’赋值给a,带到下一次循环中进行运算,说明agent真的与环境进行实际交互。这是算法为on-policy的关键。

二、SARSA代码

此处直接参考莫烦python的强化学习教程进行代码编写,在基础上说明每一行代码的用途
1.environment的编写
首先RL需要一个环境,因为我们控制不了环境(比如下围棋时我们不不能改变棋盘的大小,何落子方式,只能只能在范围内落在线与线之间的交叉点上),这个环境是不可以改变的,因此后面的Q-learning也将沿用此环境。通常不同的问题有不同环境,我们真正需要关注的是agent即算法逻辑的编写。
此处以走方格为例编写一个environment
在这里插入图片描述
其中红色点为当前所在位置,走到黄色点获得奖励1,走到黑色点获得奖励-1

  1. """
  2. Reinforcement learning maze example.
  3. Red rectangle: explorer.
  4. Black rectangles: hells [reward = -1].
  5. Yellow bin circle: paradise [reward = +1].
  6. All other states: ground [reward = 0].
  7. This script is the environment part of this example. The RL is in RL_brain.py.
  8. View more on my tutorial page: https://morvanzhou.github.io/tutorials/
  9. """
  10. import numpy as np
  11. import time
  12. import sys
  13. if sys.version_info.major == 2:
  14. import Tkinter as tk
  15. else:
  16. import tkinter as tk
  17. UNIT = 40 # pixels
  18. MAZE_H = 4 # grid height
  19. MAZE_W = 4 # grid width
  20. class Maze(tk.Tk, object):
  21. def __init__(self):
  22. super(Maze, self).__init__()
  23. self.action_space = ['u', 'd', 'l', 'r']
  24. self.n_actions = len(self.action_space)
  25. self.title('maze')
  26. self.geometry('{0}x{1}'.format(MAZE_H * UNIT, MAZE_H * UNIT))
  27. self._build_maze()
  28. def _build_maze(self):
  29. self.canvas = tk.Canvas(self, bg='white',
  30. height=MAZE_H * UNIT,
  31. width=MAZE_W * UNIT)
  32. # create grids
  33. for c in range(0, MAZE_W * UNIT, UNIT):
  34. x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT
  35. self.canvas.create_line(x0, y0, x1, y1)
  36. for r in range(0, MAZE_H * UNIT, UNIT):
  37. x0, y0, x1, y1 = 0, r, MAZE_W * UNIT, r
  38. self.canvas.create_line(x0, y0, x1, y1)
  39. # create origin
  40. origin = np.array([20, 20])
  41. # hell
  42. hell1_center = origin + np.array([UNIT * 2, UNIT])
  43. self.hell1 = self.canvas.create_rectangle(
  44. hell1_center[0] - 15, hell1_center[1] - 15,
  45. hell1_center[0] + 15, hell1_center[1] + 15,
  46. fill='black')
  47. # hell
  48. hell2_center = origin + np.array([UNIT, UNIT * 2])
  49. self.hell2 = self.canvas.create_rectangle(
  50. hell2_center[0] - 15, hell2_center[1] - 15,
  51. hell2_center[0] + 15, hell2_center[1] + 15,
  52. fill='black')
  53. # create oval
  54. oval_center = origin + UNIT * 2
  55. self.oval = self.canvas.create_oval(
  56. oval_center[0] - 15, oval_center[1] - 15,
  57. oval_center[0] + 15, oval_center[1] + 15,
  58. fill='yellow')
  59. # create red rect
  60. self.rect = self.canvas.create_rectangle(
  61. origin[0] - 15, origin[1] - 15,
  62. origin[0] + 15, origin[1] + 15,
  63. fill='red')
  64. # pack all
  65. self.canvas.pack()
  66. def reset(self):
  67. self.update()
  68. time.sleep(0.5)
  69. self.canvas.delete(self.rect)
  70. origin = np.array([20, 20])
  71. self.rect = self.canvas.create_rectangle(
  72. origin[0] - 15, origin[1] - 15,
  73. origin[0] + 15, origin[1] + 15,
  74. fill='red')
  75. # return observation
  76. return self.canvas.coords(self.rect)
  77. def step(self, action):
  78. s = self.canvas.coords(self.rect)
  79. base_action = np.array([0, 0])
  80. if action == 0: # up
  81. if s[1] > UNIT:
  82. base_action[1] -= UNIT
  83. elif action == 1: # down
  84. if s[1] < (MAZE_H - 1) * UNIT:
  85. base_action[1] += UNIT
  86. elif action == 2: # right
  87. if s[0] < (MAZE_W - 1) * UNIT:
  88. base_action[0] += UNIT
  89. elif action == 3: # left
  90. if s[0] > UNIT:
  91. base_action[0] -= UNIT
  92. self.canvas.move(self.rect, base_action[0], base_action[1]) # move agent
  93. s_ = self.canvas.coords(self.rect) # next state
  94. # reward function
  95. if s_ == self.canvas.coords(self.oval):
  96. reward = 1
  97. done = True
  98. s_ = 'terminal'
  99. elif s_ in [self.canvas.coords(self.hell1), self.canvas.coords(self.hell2)]:
  100. reward = -1
  101. done = True
  102. s_ = 'terminal'
  103. else:
  104. reward = 0
  105. done = False
  106. return s_, reward, done
  107. def render(self):
  108. time.sleep(0.1)
  109. self.update()
  110. def update():
  111. for t in range(10):
  112. s = env.reset()
  113. while True:
  114. env.render()
  115. a = 1
  116. s, r, done = env.step(a)
  117. if done:
  118. break
  119. if __name__ == '__main__':
  120. env = Maze()
  121. env.after(100, update)
  122. env.mainloop()

2.agent的编写
通常不同的问题有不同环境,我们真正需要关注的是agent即算法逻辑的编写。

  1. from maze_env import Maze #即为上面的environment
  2. import numpy as np
  3. import pandas as pd
  4. #RL的父类定义
  5. class RL(object):
  6. #初始化
  7. #actions为可选动作, learning_rate为学习率,reward_decay为传递奖励是的递减系数gamma,1-e_greed为随机选择其他动作的概率
  8. def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
  9. self.actions = actions
  10. self.lr = learning_rate
  11. self.gamma = reward_decay
  12. self.epsilon = e_greedy
  13. #初始化qtable,行为observation的state, 列为当前状态可以选择的action(对于所有列,可以选择的action一样)
  14. self.q_table = pd.DataFrame(columns = self.actions, dtype=np.float64)
  15. def choose_action(self, observation):
  16. self.check_state_exist(observation) #检查当前状态是否存在,不存在就添加这个状态
  17. if np.random.uniform() < self.epsilon:
  18. state_action = self.q_table.loc[observation, :] #找到当前状态可以选择的动作
  19. #由于初始化或更新后一个状态下的动作值可能是相同的,为了避免每次都选择相同动作,用random.choice在值最大的action中损及选择一个
  20. action = np.random.choice(state_action[state_action==np.max(state_action)].index)
  21. else:
  22. action = np.random.choice(self.actions) #0.1的几率随机选择动作
  23. return action
  24. def check_state_exist(self, state):
  25. if state not in self.q_table.index:
  26. #若找不到该obversation的转态,则添加该状态到新的qtable
  27. #新的state的动作的q初始值赋值为0,列名为dataframe的列名,index为state
  28. self.q_table = self.q_table.append(pd.Series([0]*len(self.actions), index=self.q_table.columns, name=state))
  29. #不同方式的学习方法不同,用可变参数,直接pass
  30. def learning(self, *args):
  31. pass
  32. class SarsaTable(RL): #继承上面的RL
  33. #初始化
  34. #参数自己定义,含义继承父类RL
  35. #类方法choose_action、check_state_exist自动继承RL,参数不变
  36. def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
  37. super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)
  38. def learning(self, s, a,r, s_, a_):
  39. self.check_state_exist(s_) #检查动作后状态s_是否存在
  40. q_old = self.q_table.loc[s, a] #旧的q[s,a]值
  41. if s_!='terminal':
  42. #取下个状态s_和动作a_下q值
  43. q_predict = self.q_table.loc[s_, a_]
  44. q_new = r+self.gamma*q_predict #计算新的值
  45. else:
  46. q_new = r
  47. self.q_table.loc[s,a] = q_old - self.lr*(q_new - q_old) #根据更新公式更新,类似于梯度下降
  48. def update():
  49. for episode in range(100):
  50. #初始化环境
  51. observation = env.reset()
  52. #根据当前状态选行为
  53. action = RL.choose_action(str(observation))
  54. while True:
  55. # 刷新环境
  56. env.render()
  57. # 在环境中采取行为, 获得下一个 state_ (obervation_), reward, 和是否终止
  58. observation_, reward, done = env.step(action)
  59. #根据observation_选择observation_下应该选择的动作action_
  60. action_ = RL.choose_action(str(observation_))
  61. #从当前状态state,当前动作action,奖励r,执行动作后state_,state_下的action_,(s,a,r,s,a)
  62. RL.learning(str(observation), action, reward, str(observation_), action_)
  63. # 将下一个当成下一步的 state (observation) and action。
  64. #与qlearning的却别是sarsa在observation_下真正执行了动作action_,供下次使用
  65. #而qlearning中下次状态observation_时还要重新选择action_
  66. observation = observation_
  67. action = action_
  68. # 终止时跳出循环
  69. if done:
  70. break
  71. # 大循环完毕
  72. print('game over')
  73. env.destroy()
  74. if __name__ == "__main__":
  75. env = Maze()
  76. #Sarsa和SarsaLambda的调用方式一模一样
  77. #RL = SarsaTable(actions=list(range(env.n_actions)))
  78. RL = SarsaLambdaTable(actions=list(range(env.n_actions)))
  79. env.after(100, update)
  80. env.mainloop()

发表评论

表情:
评论列表 (有 0 条评论,571人围观)

还没有评论,来说两句吧...

相关阅读