用强化学习SAC算法帮我实现航向控制,要使用PID算法搭建
下面给出一个完整的「用强化学习SAC 算法实现航向(Heading)控制」的代码示例,其中:
pip install torch stable-baselines3[extra] gym numpy matplotlib
import numpy as np
import gym
from gym import spaces
class HeadingEnv(gym.Env):
"""
简化船舶航向保持环境:
- 状态 s = [e, de, ie] (误差、微分、积分)
- 动作 a = [k_p, k_i, k_d] 的缩放因子,范围 [0.1, 3]
- 奖励 r = -(|e| + 0.1*|δ|) 其中 δ 为舵角
"""
metadata = {"render.modes": ["human"]}
def __init__(self,
dt=0.2,
target_psi=0.0,
max_rudder=np.deg2rad(30),
episode_len=400):
super().__init__()
self.dt = dt
self.target_psi = target_psi
self.max_rudder = max_rudder
self.ep_len = episode_len
# Nomoto 模型参数 (示例)
self.K = 0.8 # 增益
self.T = 15.0 # 时间常数
# 状态空间 [e, de, ie]
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(3,))
# 动作空间 PID 缩放因子
self.action_space = spaces.Box(low=0.1, high=3.0, shape=(3,))
# 基准 PID 值(可由经验给出)
self.base_kp, self.base_ki, self.base_kd = 2.0, 0.5, 8.0
self.reset()
def _nomoto(self, psi, r, delta):
"""一阶 Nomoto 模型: T * r_dot + r = K * delta """
r_dot = (self.K * delta - r) / self.T
psi_dot = r
return psi + psi_dot * self.dt, r + r_dot * self.dt
def reset(self):
self.t = 0
self.psi = 0.0
self.r = 0.0
self.e_prev = 0.0
self.ie = 0.0
return self._get_state()
def _get_state(self):
e = self.target_psi - self.psi
de = (e - self.e_prev) / self.dt
return np.array([e, de, self.ie], dtype=np.float32)
def step(self, action):
self.t += 1
kp_scale, ki_scale, kd_scale = action
kp = kp_scale * self.base_kp
ki = ki_scale * self.base_ki
kd = kd_scale * self.base_kd
# PID 计算舵角
e, de, ie = self._get_state()
delta = kp * e + ki * ie + kd * de
delta = np.clip(delta, -self.max_rudder, self.max_rudder)
# 积分抗饱和
if abs(delta) < self.max_rudder - 0.01:
self.ie += e * self.dt
# 动力学更新
self.psi, self.r = self._nomoto(self.psi, self.r, delta)
# 奖励
reward = -(abs(e) + 0.1 * abs(delta))
self.e_prev = e
done = self.t >= self.ep_len
return self._get_state(), reward, done, {"delta": delta}
def render(self, mode="human", close=False):
pass
from stable_baselines3 import SAC
from stable_baselines3.common.vec_env import DummyVecEnv
import torch
env = DummyVecEnv([lambda: HeadingEnv()])
model = SAC("MlpPolicy",
env,
learning_rate=3e-4,
buffer_size=100_000,
batch_size=256,
tau=0.005,
gamma=0.99,
policy_kwargs=dict(net_arch=[256, 256]),
verbose=1)
model.learn(total_timesteps=100_000)
model.save("sac_pid_heading")
import matplotlib.pyplot as plt
env = HeadingEnv()
model = SAC.load("sac_pid_heading")
obs = env.reset()
psi_hist = []
delta_hist = []
rewards = []
for _ in range(400):
action, _ = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
psi_hist.append(env.psi)
delta_hist.append(info["delta"])
rewards.append(reward)
plt.subplot(3, 1, 1)
plt.title("Heading (rad)")
plt.plot(psi_hist, label='psi')
plt.axhline(env.target_psi, color='r', linestyle='--')
plt.legend()
plt.subplot(3, 1, 2)
plt.title("Rudder (rad)")
plt.plot(delta_hist)
plt.subplot(3, 1, 3)
plt.title("Reward per step")
plt.plot(rewards)
plt.tight_layout()
plt.show()
一句话总结
“把 SAC 的 动作=PID 参数整定因子,把 PID 的 输出=舵角,然后让 SAC 在船舶航向保持任务里学会如何自动调参”。