强化学习实战:用Python代码可视化不同策略下的状态访问分布(附Jupyter Notebook)
强化学习实战用Python代码可视化不同策略下的状态访问分布在强化学习领域理解智能体如何探索环境是算法设计的核心。想象你正在训练一个游戏AI——为什么有些策略能让角色快速通关而另一些却让角色困在某个区域反复徘徊这种差异不仅体现在最终得分上更直观反映在智能体访问环境状态的概率分布中。本文将带你用Python代码将这些抽象概念转化为可视化的热力图和轨迹动画让理论跃然屏上。1. 环境搭建与基础概念可视化1.1 创建自定义Gymnasium环境我们先构建一个简单的网格世界环境这个5x5的迷宫包含普通格子移动后获得-1奖励陷阱格子移动后获得-10奖励终点格子移动后获得20奖励并终止回合import gymnasium as gym from gymnasium import spaces import numpy as np class GridWorldEnv(gym.Env): def __init__(self): self.size 5 self.action_space spaces.Discrete(4) # 上,下,左,右 self.observation_space spaces.Discrete(self.size**2) self.trap_positions [(1,1), (3,3)] self.goal_position (4,4) def _get_obs(self): return self.state[0] * self.size self.state[1] def reset(self): self.state (0, 0) # 固定起点 return self._get_obs() def step(self, action): x, y self.state if action 0: x max(0, x-1) # 上 elif action 1: x min(self.size-1, x1) # 下 elif action 2: y max(0, y-1) # 左 else: y min(self.size-1, y1) # 右 self.state (x, y) done (self.state self.goal_position) if self.state in self.trap_positions: reward -10 elif done: reward 20 else: reward -1 return self._get_obs(), reward, done, {}1.2 状态访问分布的理论实现状态访问分布的计算公式为 $$ v^\pi(s) (1-\gamma)\sum_{t0}^\infty \gamma^t P_t^\pi(s) $$我们可以通过蒙特卡洛方法近似计算def compute_visitation(env, policy, episodes1000, gamma0.99): visitation np.zeros(env.observation_space.n) for _ in range(episodes): state env.reset() done False t 0 while not done: visitation[state] (gamma**t) action policy(state) state, _, done, _ env.step(action) t 1 visitation (1-gamma) * visitation / episodes return visitation2. 对比三种典型策略的表现2.1 随机策略基准测试def random_policy(state): return np.random.choice(4) # 随机选择动作 random_visits compute_visitation(env, random_policy)2.2 规避陷阱的保守策略def cautious_policy(state): x, y state // 5, state % 5 if x 3 and y 2: # 避免进入(3,3)陷阱 return 0 # 向上 if x 1 and y 0: # 避免进入(1,1)陷阱 return 1 # 向下 return np.random.choice([1, 3]) # 优先向右下移动 cautious_visits compute_visitation(env, cautious_policy)2.3 激进的最短路径策略def aggressive_policy(state): x, y state // 5, state % 5 if x 4: return 1 # 优先向下 if y 4: return 3 # 然后向右 return 0 aggressive_visits compute_visitation(env, aggressive_policy)3. 可视化分析与对比3.1 热力图绘制import seaborn as sns import matplotlib.pyplot as plt def plot_visitation(visits, title): plt.figure(figsize(8,6)) grid visits.reshape(5,5) ax sns.heatmap(grid, annotTrue, fmt.2f, cmapYlOrRd) ax.set_title(title) plt.show() plot_visitation(random_visits, 随机策略状态访问分布) plot_visitation(cautious_visits, 保守策略状态访问分布) plot_visitation(aggressive_visits, 激进策略状态访问分布)3.2 三维柱状图对比from mpl_toolkits.mplot3d import Axes3D def plot_3d_comparison(): fig plt.figure(figsize(12,8)) ax fig.add_subplot(111, projection3d) xpos np.arange(25) ypos np.zeros(25) zpos np.zeros(25) dx dy 0.5 * np.ones(25) dz_random random_visits * 100 # 放大可视化效果 dz_cautious cautious_visits * 100 dz_aggressive aggressive_visits * 100 ax.bar3d(xpos, ypos-0.3, zpos, dx, dy, dz_random, colorr, alpha0.5, label随机) ax.bar3d(xpos, ypos, zpos, dx, dy, dz_cautious, colorg, alpha0.5, label保守) ax.bar3d(xpos, ypos0.3, zpos, dx, dy, dz_aggressive, colorb, alpha0.5, label激进) ax.set_xticks(xpos) ax.set_xlabel(状态编号) ax.set_ylabel(策略类型) ax.set_zlabel(访问频率(%)) ax.legend() plt.show() plot_3d_comparison()4. 高级应用与优化技巧4.1 动态折扣因子实验折扣因子γ的选择显著影响访问分布gammas [0.5, 0.9, 0.99, 0.999] plt.figure(figsize(12,3)) for i, gamma in enumerate(gammas): visits compute_visitation(env, random_policy, gammagamma) plt.subplot(1, len(gammas), i1) sns.heatmap(visits.reshape(5,5), cbarFalse) plt.title(fγ{gamma}) plt.tight_layout() plt.show()4.2 占用度量的实际应用占用度量ρ(s,a)与状态访问分布的关系def compute_occupancy(env, policy, episodes1000, gamma0.99): occupancy np.zeros((env.observation_space.n, env.action_space.n)) for _ in range(episodes): state env.reset() done False t 0 while not done: action policy(state) occupancy[state, action] (gamma**t) state, _, done, _ env.step(action) t 1 occupancy (1-gamma) * occupancy / episodes return occupancy occupancy compute_occupancy(env, cautious_policy)4.3 从占用度量恢复策略根据定理2我们可以逆向工程def recover_policy(occupancy): policy np.zeros_like(occupancy) for s in range(occupancy.shape[0]): total np.sum(occupancy[s]) if total 0: policy[s] occupancy[s] / total else: policy[s] 1 / occupancy.shape[1] # 均匀分布 return policy recovered_policy recover_policy(occupancy)5. 实战建议与性能优化当处理更大状态空间时直接计算可能遇到内存问题。这时可以采用稀疏矩阵存储使用scipy.sparse矩阵存储访问计数并行化采样利用multiprocessing并行运行多个episode增量式计算对于非稳态策略采用指数移动平均更新访问分布from multiprocessing import Pool def parallel_visitation(args): env, policy, gamma, episodes args local_visits np.zeros(env.observation_space.n) for _ in range(episodes): state env.reset() done False t 0 while not done: local_visits[state] (gamma**t) action policy(state) state, _, done, _ env.step(action) t 1 return local_visits def fast_compute_visitation(env, policy, total_episodes10000, gamma0.99, workers4): episodes_per_worker total_episodes // workers with Pool(workers) as p: results p.map(parallel_visitation, [(env, policy, gamma, episodes_per_worker)]*workers) return (1-gamma) * np.sum(results, axis0) / total_episodes