
YOPO代码运行及功能解析
简介:
项目地址:https://github.com/TJU-Aerial-Robotics/YOPO
YOPO提出了一个基于学习的规划器,用于在障碍密集环境中进行自主导航,将(i)感知和映射,(ii)前端路径搜索和(iii)经典方法的后端优化集成到单个网络中。
Learning-based Planner:基于学习的规划师:考虑到导航问题的多模态性质,并避免围绕初始值的局部最小值,我们的方法采用一组运动原语作为锚,以覆盖搜索空间,并预测原始变量的偏移量和分数以进一步改进(如单级对象检测器YOLO)。
Training Strategy:训练策略:与在强化学习中通过试错作为模型演示或通过试错进行探索相比,我们直接将轨迹成本的梯度(例如从ESDF)反向传播到网络的权重,这简单,直接,准确且独立于序列(无在线模拟器交互或渲染)。
测试方法:
您可以使用我们提供的预训练权重测试策略。YOPO/saved/YOPO_1/epoch50.pth
. .
1.启动控制器和动态模拟器
有关控制器的详细介绍,请参阅 Controller_Introduction
cd Controller
source devel/setup.bash
roslaunch so3_quadrotor_simulator simulator_attitude_control.launch
2.启动环境和传感器模拟器
有关模拟器的详细介绍,请参阅 Simulator_Introduction。随机森林的示例可以在 random_forest.png 中找到
cd Simulator
source devel/setup.bash
rosrun sensor_simulator sensor_simulator_cuda
您可以参考config.yaml来修改传感器(例如,相机和激光雷达参数)和环境(例如,mazide_type和障碍物密度)。对于泛化,在森林(类型5)中训练的策略可以零射转移到3D Perlin(1型)。
3.启动 YOPO Planner
您可以参考 traj_opt.yaml 来修改飞行速度(给定的权重在 6 米/秒时经过预训练,在 0 - 6 米/秒之间的速度平稳执行)。
cd YOPO
conda activate yopo
python test_yopo_ros.py --trial=1 --epoch=50
4.可视化
启动 RVIZ 以可视化图像和轨迹。
cd YOPO
rviz -d yopo.rviz
可以点击2D Nav Goal
在RVIZ作为目标(地图是无限的,所以目标是自由的),就像下面的GIF(Flightmare模拟器)。
整体流程
YOPO 采用端到端的架构设计,核心由三部分组成:
感知模块:处理深度图像,提取环境特征(如障碍物位置、距离)。
规划模块:结合当前状态(速度、加速度、目标位置)和环境特征,生成多条候选轨迹。
评估模块:为每条轨迹打分,选择最优轨迹执行。
这种设计将传统的 “感知 - 规划 - 控制” 管道简化为 “感知 - 规划” 两阶段,通过神经网络直接学习从感知到规划的映射,提高了实时性和适应性。
二、数据流程:从输入到输出的完整链路
plaintext
输入数据 → 预处理 → 特征提取 → 轨迹预测 → 轨迹评估 → 输出最优轨迹
1. 输入数据
深度图像:无人机相机采集的环境深度信息(预处理为 [0,1] 范围的浮点数)。
观测状态:机体坐标系下的速度(v_xyz)、加速度(a_xyz)和目标位置(goal_xyz)。
2. 预处理(YOPODataset
)
数据加载:读取深度图像和对应的位置 / 姿态数据,按训练 / 验证集划分。
图像预处理:调整尺寸、归一化(除以 65535),扩展通道维度(符合 [C,H,W] 格式)。
状态生成:随机生成符合物理约束的速度、加速度和目标位置,增强数据多样性。
3. 特征提取(YopoNetwork
)
图像特征:通过
YopoBackbone
提取深度图像的空间特征(如障碍物分布)。状态特征:当前状态(速度、加速度、目标)直接作为特征(可通过
state_backbone
扩展)。特征融合:拼接图像特征和状态特征,输入到
YopoHead
。
4. 轨迹预测
候选轨迹生成:网络预测多个候选轨迹的终点状态(位置、速度、加速度)。
轨迹分数:为每个候选轨迹分配分数,表示其优劣程度(基于平滑度、安全性和目标接近度)。
5. 轨迹评估与选择
损失计算:通过
YOPOLoss
评估轨迹的平滑度、安全性和目标接近度。最优轨迹选择:训练时通过分数损失引导网络学习评估轨迹;推理时选择分数最高的轨迹执行。
三、核心组件及其作用
1. 数据集类(YOPODataset
)
作用:管理训练 / 验证数据,生成多样化的训练样本。
关键功能:
深度图像预处理(尺寸调整、归一化)。
随机状态生成(速度、加速度符合物理约束,目标位置符合任务分布)。
坐标转换(世界坐标系↔机体坐标系)。
2. 神经网络(YopoNetwork
)
作用:实现从深度图像和状态到轨迹的映射。
关键组件:
YopoBackbone
:提取深度图像特征。YopoHead
:融合特征并预测轨迹终点状态和分数。StateTransform
:状态归一化和坐标转换,确保输入输出一致性。
3. 损失函数(YOPOLoss
)
作用:评估轨迹质量,引导网络学习。
损失组成:
平滑度损失:惩罚剧烈的加速度变化,确保轨迹平滑。
安全损失:基于深度图像预测碰撞风险,惩罚不安全轨迹。
目标损失:衡量轨迹终点与目标的距离,引导朝向目标。
4. 训练器(YopoTrainer
)
作用:管理模型训练流程,包括数据加载、前向传播、反向传播、参数更新和评估。
关键功能:
批量数据加载与并行处理。
多轮训练与验证。
梯度计算与参数优化(AdamW 优化器)。
日志记录(TensorBoard 可视化)和模型保存。
四、训练策略与优化
监督学习:通过
YOPOLoss
计算预测轨迹与期望轨迹的差异,优化模型参数。损失权重平衡:
轨迹损失:平滑度、安全性和目标损失的加权和,确保生成的轨迹既安全又高效。
分数损失:训练网络准确评估轨迹质量,使预测分数与实际损失一致。
正则化:
使用 AdamW 优化器(带权重衰减)减少过拟合。
梯度裁剪(虽未显式使用,但配置了
max_grad_norm
)防止梯度爆炸。
训练流程:
按批次加载数据,执行前向传播和反向传播。
定期在验证集上评估模型性能,监控泛化能力。
保存检查点,支持中断后恢复训练。
五、应用场景与优势
应用场景:无人机自主导航、避障、目标跟踪等实时决策任务。
核心优势:
端到端学习:直接从原始感知数据(深度图像)生成控制指令,减少人工设计特征的工作量。
实时性:通过单次前向传播生成多条候选轨迹,避免传统规划算法的高计算开销。
适应性:通过大量训练数据学习复杂环境下的导航策略,泛化能力强。
物理约束感知:通过精心设计的状态采样和损失函数,确保生成的轨迹符合无人机物理特性
代码解析:
YOPO部分
config/config.py
这段代码实现了一个全局配置管理类,用于加载和管理 YAML 配置文件。通过封装配置数据,提供了统一的访问接口,使参数在整个项目中保持一致性:
1. 配置管理类 Config
class Config:
def __init__(self):
# 获取当前脚本所在目录
base_dir = os.path.dirname(os.path.abspath(__file__))
# 加载YAML配置文件
self._data = YAML().load(open(os.path.join(base_dir, "traj_opt.yaml"), 'r'))
# 设置和计算派生参数
self._data["train"] = True
self._data["goal_length"] = 2.0 * self._data['radio_range']
self._data["sgm_time"] = 2 * self._data["radio_range"] / self._data["vel_max_train"]
self._data["traj_num"] = self._data['horizon_num'] * self._data['vertical_num'] * self._data["radio_num"]
功能:初始化配置对象,加载 YAML 文件并计算派生参数。
关键步骤:
通过
os.path
获取当前脚本的绝对路径,定位配置文件traj_opt.yaml
。使用
ruamel.yaml
加载 YAML 文件内容到self._data
字典。设置训练模式为开启(
train=True
)。计算派生参数:
goal_length
:基于radio_range
的两倍。sgm_time
:基于radio_range
和最大训练速度vel_max_train
。traj_num
:基于水平、垂直和径向轨迹数量的乘积。
2. 魔法方法实现
def __getitem__(self, key):
return self._data[key]
def __setitem__(self, key, value):
self._data[key] = value
功能:使配置对象支持字典式的键值访问。
__getitem__
:允许通过cfg[key]
获取配置值。__setitem__
:允许通过cfg[key] = value
修改配置值。
3. 全局配置实例
cfg = Config()
功能:创建全局配置对象,使配置在整个项目中可访问。
使用示例
# 导入全局配置
from config_module import cfg
# 获取参数
radio_range = cfg["radio_range"] # 从YAML读取
goal_length = cfg["goal_length"] # 计算得到的值(2*radio_range)
# 修改参数
cfg["train"] = False # 切换到测试模式
配置文件结构(核心配置文件)
用于无人机轨迹规划与控制算法的训练和测试参数配置,涵盖速度限制、损失权重、数据集路径、轨迹采样等核心参数。以下是按功能分类的详细解析:
一、速度与加速度限制(核心运动参数)
# 测试时的速度(可修改)
velocity: 6.0
# 训练时的最大速度和加速度
# 注意:测试时的实际值可能不同,但需与训练保持一致性(参考primitive.py)
vel_max_train: 6.0
acc_max_train: 6.0
作用:定义无人机的运动学约束,确保训练与测试的一致性。
细节:
velocity
:测试阶段无人机的实际飞行速度(可根据场景调整)。vel_max_train
/acc_max_train
:训练时的最大速度和加速度,用于生成符合物理约束的训练数据,直接影响轨迹规划的可行性(如YOPOLoss
中速度缩放系数的计算)。一致性要求:训练与测试的速度量级需匹配,否则可能导致模型在测试时失效(如训练时用低速数据,测试时高速飞行会超出模型预期)。
二、损失函数权重(控制优化目标)
# 单位速度下的损失权重(可在tensorboard中可视化)
wg: 0.12 # 引导损失(guidance)权重
ws: 10.0 # 平滑度损失(smoothness)权重
wc: 0.1 # 碰撞损失(collision)权重
作用:平衡三个核心损失函数的优先级,直接影响轨迹优化方向。
与代码关联:
在
YOPOLoss
类中,这些权重会被加载并用于计算加权损失(self.smoothness_weight = cfg["ws"]
等)。权重含义:
wg
:引导损失权重(值越大,轨迹越偏向快速朝向目标)。ws
:平滑度损失权重(值越大,轨迹越平滑,减少急加速 / 急转向)。wc
:碰撞损失权重(值越大,避障优先级越高,可能牺牲部分效率)。
调优原则:根据任务需求调整(如室内避障场景可增大
wc
,室外高速场景可增大ws
)。
三、数据集配置
# 数据集路径与图像尺寸
dataset_path: "../dataset"
image_height: 96
image_width: 160
作用:指定训练数据的存储路径和输入图像的尺寸。
与代码关联:
dataset_path
:对应数据集生成代码(如之前的main()
函数)中保存图像和标签的路径。image_height
/image_width
:定义输入神经网络的图像尺寸(需与数据集图像一致,否则会导致维度不匹配)。
四、轨迹采样与规划参数
# 轨迹与primitive参数(轨迹采样相关)
horizon_num: 5 # 水平方向轨迹数量
vertical_num: 3 # 垂直方向轨迹数量
horizon_camera_fov: 90.0 # 水平相机视场角(度)
vertical_camera_fov: 60.0 # 垂直相机视场角(度)
horizon_anchor_fov: 30.0 # 水平锚点视场角(度)
vertical_anchor_fov: 30.0 # 垂直锚点视场角(度)
radio_range: 5.0 # 径向范围(规划 horizon = 2 * radio_range)
radio_num: 1 # 径向轨迹数量(当前仅支持1)
作用:定义轨迹采样的维度和范围,影响候选轨迹的覆盖度。
关键参数解析:
horizon_num
/vertical_num
/radio_num
:水平、垂直、径向的轨迹采样数量,总轨迹数为三者乘积(如5×3×1=15
条候选轨迹)。radio_range
:径向基础范围,规划的轨迹总长度为2×radio_range
(如 5m→轨迹长度 10m)。视场角参数(
horizon_camera_fov
等):限制轨迹采样的角度范围,确保轨迹在相机可见范围内(避免规划超出感知范围的轨迹)。
五、安全惩罚参数
# 安全惩罚参数(碰撞损失计算)
d0: 1.2 # 安全距离基准
r: 0.6 # 惩罚系数
作用:定义碰撞损失(
SafetyLoss
)中的安全距离和惩罚强度。与安全损失关联:
d0
:无人机与障碍物的最小安全距离(小于该距离时触发碰撞惩罚)。r
:惩罚系数(值越大,对接近障碍物的轨迹惩罚越重)。
示例:若无人机与障碍物距离为
d
,当d < d0
时,损失随(d0 - d)×r
增大而增加。
六、状态采样分布(训练数据生成)
# 单位状态采样的分布参数(用于生成训练数据)
vx_mean_unit: 0.4 # x方向速度均值
vy_mean_unit: 0.0 # y方向速度均值
vz_mean_unit: 0.0 # z方向速度均值
vx_std_unit: 2.0 # x方向速度标准差
vy_std_unit: 0.45 # y方向速度标准差
vz_std_unit: 0.3 # z方向速度标准差
ax_mean_unit: 0.0 # x方向加速度均值
ay_mean_unit: 0.0 # y方向加速度均值
az_mean_unit: 0.0 # z方向加速度均值
ax_std_unit: 0.5 # x方向加速度标准差
ay_std_unit: 0.5 # y方向加速度标准差
az_std_unit: 0.3 # z方向加速度标准差
goal_pitch_std: 10.0 # 目标俯仰角标准差(度)
goal_yaw_std: 20.0 # 目标偏航角标准差(度)
作用:生成训练数据时,随机采样无人机的速度、加速度和目标角度,确保数据覆盖多样化的运动状态。
细节:
均值(
mean
):采样的中心值(如vx_mean_unit=0.4
→x 方向速度集中在 0.4m/s 附近)。标准差(
std
):采样的离散程度(值越大,数据覆盖范围越广)。目标角度标准差:控制目标位置的随机性(如
goal_yaw_std=20
→目标在 ±20 度范围内随机分布)。
七、地图扩展参数
# 地图扩展参数(避免轨迹超出边界)
map_expand_min: [0, 0, 0.2] # 地图最小扩展量(x,y,z)
map_expand_max: [0, 0, 6.0] # 地图最大扩展量(x,y,z)
作用:扩展地图边界,避免轨迹规划时超出原始地图范围(尤其 z 轴,防止将天空误判为障碍物)。
细节:
主要针对 z 轴(高度)扩展:
map_expand_min[2]=0.2
→地图底部向下扩展 0.2m;map_expand_max[2]=6.0
→地图顶部向上扩展 6.0m。目的:解决原始地图可能未覆盖全部可行区域的问题(如无人机需要爬升至高于地图上限的高度避障)。
总结
这份配置文件是无人机轨迹规划算法的核心参数中枢,通过调整这些参数可实现:
平衡轨迹的平滑性、安全性和目标导向性(损失权重)。
适配不同速度场景(速度与加速度限制)。
控制训练数据的分布和范围(采样参数、数据集路径)。
确保轨迹规划的安全性和可行性(安全距离、地图扩展)。
实际使用中需重点关注IMPORTANT
标记的参数,确保训练与测试的一致性(如vel_max_train
与测试velocity
的量级匹配),否则可能导致模型性能下降。
control_msg/_PositionCommand.py
这段代码是 ROS(机器人操作系统)中自动生成的消息类型定义文件,对应quadrotor_msgs/PositionCommand
消息,用于向无人机发送位置、速度、姿态等控制指令
功能:定义了一个继承自genpy.Message
的消息类,用于标准化无人机控制指令的格式。
核心作用:在 ROS 节点间传递结构化的控制指令,确保发送方和接收方对数据格式达成一致
轨迹状态常量
代码中定义了一组轨迹状态的枚举常量,用于标识轨迹的执行状态:
# 伪常量定义
TRAJECTORY_STATUS_EMPTY = 0 # 轨迹为空(未初始化)
TRAJECTORY_STATUS_READY = 1 # 轨迹就绪(可执行)
TRAJECTORY_STATUS_COMPLETED = 3 # 轨迹已完成
TRAJECTROY_STATUS_ABORT = 4 # 轨迹中止(如遇障碍物)
TRAJECTORY_STATUS_ILLEGAL_START = 5 # 起始点非法(如在障碍物内)
TRAJECTORY_STATUS_ILLEGAL_FINAL = 6 # 终点非法
TRAJECTORY_STATUS_IMPOSSIBLE = 7 # 轨迹无法执行(如超出物理限制)
作用:通过
trajectory_flag
字段传递轨迹的状态信息,方便上层控制器判断轨迹执行情况。
loss/loss_function.py
代码定义了一个综合损失函数类YOPOLoss
,用于无人机轨迹优化任务。该类整合了平滑度损失(Smoothness Loss)、安全损失(Safety Loss) 和目标损失(Goal Loss),并通过矩阵运算实现轨迹参数与损失的映射,最终输出加权的综合损失。
1. 类初始化与核心参数(__init__
方法)
def __init__(self):
super(YOPOLoss, self).__init__()
self.sgm_time = cfg["sgm_time"] # 轨迹段时间(从配置文件读取)
self.device = th.device("cuda" if th.cuda.is_available() else "cpu") # 设备选择(GPU/CPU)
self._C, self._B, self._L, self._R = self.qp_generation() # 生成优化相关矩阵
self._R = self._R.to(self.device) # 转移到目标设备
self._L = self._L.to(self.device)
vel_scale = cfg["vel_max_train"] / 1.0 # 速度尺度(用于权重归一化)
# 损失权重(从配置文件读取)
self.smoothness_weight = cfg["ws"]
self.safety_weight = cfg["wc"]
self.goal_weight = cfg["wg"]
self.denormalize_weight(vel_scale) # 权重归一化
# 初始化子损失函数
self.smoothness_loss = SmoothnessLoss(self._R)
self.safety_loss = SafetyLoss(self._L)
self.goal_loss = GuidanceLoss()
# 打印实际损失权重
...
核心功能:初始化损失函数的基础参数、矩阵和子损失函数,完成权重的物理意义归一化。
关键设计:通过矩阵运算(
qp_generation
)建立轨迹参数与损失的映射关系,为后续损失计算提供数学基础。
2. 二次规划矩阵生成(qp_generation
方法)
def qp_generation(self):
# 论文中的映射矩阵A:将多项式系数映射到位置、速度、加速度
A = th.zeros((6, 6))
for i in range(3):
A[2 * i, i] = math.factorial(i) # 位置相关系数(0,2,4行)
for j in range(i, 6):
# 速度相关系数(1,3,5行):导数公式
A[2 * i + 1, j] = math.factorial(j) / math.factorial(j - i) * (self.sgm_time ** (j - i))
# 海森矩阵H(论文中的Q):用于平滑度损失(加加速度jerk的积分)
H = th.zeros((6, 6))
for i in range(3, 6): # 关注3-5阶系数(对应加加速度)
for j in range(3, 6):
# 加加速度平方的时间积分,用于计算平滑度损失
H[i, j] = i*(i-1)*(i-2)*j*(j-1)*(j-2)/(i+j-5) * (self.sgm_time ** (i+j-5))
return self.stack_opt_dep(A, H)
矩阵 A 的作用:
建立多项式轨迹系数与轨迹状态(位置、速度、加速度)的映射。例如,通过多项式系数计算任意时刻的位置、速度、加速度。矩阵 H 的作用:
描述轨迹平滑度的代价(加加速度的平方积分),是二次规划中的目标函数矩阵,用于后续平滑度损失计算。
3. 矩阵转换与堆叠(stack_opt_dep
方法)
def stack_opt_dep(self, A, Q):
Ct = th.zeros((6, 6))
Ct[[0, 2, 4, 1, 3, 5], [0, 1, 2, 3, 4, 5]] = 1 # 状态重排矩阵
_C = th.transpose(Ct, 0, 1) # 转置重排矩阵
B = th.inverse(A) # A的逆矩阵:从状态反推多项式系数
B_T = th.transpose(B, 0, 1) # B的转置
_L = B @ Ct # 状态到多项式系数的映射(用于安全损失)
_R = _C @ B_T @ Q @ B @ Ct # 平滑度损失的二次项矩阵(用于平滑度损失)
return _C, B, _L, _R
功能:通过矩阵运算将原始多项式系数与轨迹状态(位置、速度、加速度)关联,生成后续损失计算所需的映射矩阵
_L
(安全损失)和_R
(平滑度损失)。数学意义:将轨迹优化问题转化为二次规划形式,通过矩阵乘法快速计算损失,提高效率。
4. 权重归一化(denormalize_weight
方法)
def denormalize_weight(self, vel_scale):
# 平滑度损失:与加加速度jerk的5次方相关,速度缩放n倍时,jerk缩放n^6,时间缩放1/n → 总缩放n^5
self.smoothness_weight = self.smoothness_weight / vel_scale ** 5
# 安全损失:与时间积分相关,速度缩放n倍时,时间缩放1/n → 总缩放1/n,需乘以n抵消
self.safety_weight = self.safety_weight * vel_scale
# 目标损失:与距离相关,不受速度缩放影响
self.goal_weight = self.goal_weight
核心逻辑:根据物理量纲调整权重,确保不同速度场景下损失权重的物理意义一致,简化参数调优。
举例:当无人机速度提高时,平滑度损失的权重自动降低(避免过度约束),安全损失权重自动提高(确保避障时间足够)。
5. 损失计算主入口(forward
方法)
def forward(self, state, prediction, goal, map_id):
# 转换状态和预测的维度顺序:适配子损失函数的输入格式
Df = state.permute(0, 2, 1) # 固定参数(当前状态)
Dp = prediction.permute(0, 2, 1) # 决策参数(预测轨迹)
# 初始化各损失
smoothness_cost = th.tensor(0.0, device=self.device, requires_grad=True)
safety_cost = th.tensor(0.0, device=self.device, requires_grad=True)
goal_cost = th.tensor(0.0, device=self.device, requires_grad=True)
# 计算各子损失(根据权重是否为0决定是否计算)
if self.smoothness_weight > 0:
smoothness_cost = self.smoothness_loss(Df, Dp)
if self.safety_weight > 0:
safety_cost = self.safety_loss(Df, Dp, map_id)
if self.goal_weight > 0:
goal_cost = self.goal_loss(Df, Dp, goal)
# 返回加权后的各损失
return (self.smoothness_weight * smoothness_cost,
self.safety_weight * safety_cost,
self.goal_weight * goal_cost)
功能:整合三个子损失函数的计算结果,返回加权后的综合损失。
输入输出:
输入:当前状态(
state
)、预测轨迹(prediction
)、目标位置(goal
)、地图 ID(map_id
)。输出:三个加权损失(平滑度、安全、目标),可用于反向传播优化轨迹参数。
policy/models/poly_solver.py(底层规划)
包含无人机轨迹规划的核心组件:五阶多项式轨迹生成器和偏航角(yaw)计算函数,用于生成平滑轨迹并控制无人机朝向。以下是详细分析:
一、五阶多项式轨迹生成器(Poly5Solver
类)
五阶多项式是无人机轨迹规划中常用的工具,因其可满足位置、速度、加速度的初始和末端约束,生成平滑轨迹(加加速度 jerk 连续)。
1. 类初始化(__init__
方法)
def __init__(self, pos0, vel0, acc0, pos1, vel1, acc1, Tf):
# 初始与末端状态矩阵(位置、速度、加速度约束)
State_Mat = np.array([pos0, vel0, acc0, pos1, vel1, acc1])
t = Tf # 轨迹总时间
# 多项式系数求解矩阵(预定义的逆矩阵,用于从状态约束计算系数)
Coef_inv = np.array([
[1, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0],
[0, 0, 1/2, 0, 0, 0],
[-10/t**3, -6/t**2, -3/(2*t), 10/t**3, -4/t**2, 1/(2*t)],
[15/t**4, 8/t**3, 3/(2*t**2), -15/t**4, 7/t**3, -1/t**2],
[-6/t**5, -3/t**4, -1/(2*t**3), 6/t**5, -3/t**4, 1/(2*t**3)]
])
# 计算多项式系数A(A0-A5,对应t^0到t^5的系数)
self.A = np.dot(Coef_inv, State_Mat)
核心功能:根据初始状态(
pos0
,vel0
,acc0
)、末端状态(pos1
,vel1
,acc1
)和轨迹总时间Tf
,求解五阶多项式的系数。数学基础:
五阶多项式形式为:
p(t)=A0+A1t+A2t2+A3t3+A4t4+A5t5
其导数(速度、加速度)需满足初始和末端约束:初始时刻(
t=0
):p(0)=pos0, p˙(0)=vel0, p¨(0)=acc0末端时刻(
t=Tf
):p(Tf)=pos1, p˙(Tf)=vel1, p¨(Tf)=acc1
这些约束构成线性方程组,通过预定义的逆矩阵Coef_inv
快速求解系数self.A
(避免重复求逆,提高效率)。
2. 轨迹状态计算方法
def get_position(self, t):
# 位置:p(t) = A0 + A1 t + A2 t² + A3 t³ + A4 t⁴ + A5 t⁵
return self.A[0] + self.A[1]*t + self.A[2]*t**2 + self.A[3]*t**3 + self.A[4]*t**4 + self.A[5]*t**5
def get_velocity(self, t):
# 速度:v(t) = p’(t) = A1 + 2A2 t + 3A3 t² + 4A4 t³ + 5A5 t⁴
return self.A[1] + 2*self.A[2]*t + 3*self.A[3]*t**2 + 4*self.A[4]*t**3 + 5*self.A[5]*t**4
def get_acceleration(self, t):
# 加速度:a(t) = p''(t) = 2A2 + 6A3 t + 12A4 t² + 20A5 t³
return 2*self.A[2] + 6*self.A[3]*t + 12*self.A[4]*t**2 + 20*self.A[5]*t**3
def get_jerk(self, t):
# 加加速度(jerk):j(t) = p'''(t) = 6A3 + 24A4 t + 60A5 t²
return 6*self.A[3] + 24*self.A[4]*t + 60*self.A[5]*t**2
def get_snap(self, t):
# 急动(snap):s(t) = p''''(t) = 24A4 + 120A5 t
return 24*self.A[4] + 120*self.A[5]*t
功能:根据时间
t
计算轨迹在该时刻的位置、速度、加速度、加加速度(jerk)和急动(snap)。意义:
五阶多项式的高阶导数(jerk、snap)连续,可生成平滑轨迹,减少无人机机械损耗(如电机振动),这也是YOPOLoss
中平滑度损失的物理基础。
二、多轨迹五阶多项式生成器(Polys5Solver
类)
class Polys5Solver:
def __init__(self, pos0, vel0, acc0, pos1, vel1, acc1, Tf):
N = len(pos1) # 轨迹数量
# 状态矩阵适配多轨迹(N条轨迹,每行对应一个约束,列对应不同轨迹)
State_Mat = np.array([[pos0]*N, [vel0]*N, [acc0]*N, pos1, vel1, acc1])
t = Tf
# 与Poly5Solver相同的系数矩阵(复用约束求解逻辑)
Coef_inv = np.array([...]) # 同Poly5Solver
# 计算N条轨迹的系数(self.A形状为(6, N),每行对应A0-A5,每列对应一条轨迹)
self.A = np.dot(Coef_inv, State_Mat)
def get_position(self, t):
t = np.atleast_1d(t) # 确保t是数组(支持批量计算)
# 计算所有轨迹在时间t的位置(利用广播机制)
result = (self.A[0][:, np.newaxis] + self.A[1][:, np.newaxis]*t +
self.A[2][:, np.newaxis]*t**2 + self.A[3][:, np.newaxis]*t**3 +
self.A[4][:, np.newaxis]*t**4 + self.A[5][:, np.newaxis]*t**5)
return result.flatten() # 展平为一维数组
功能:生成多条五阶多项式轨迹(用于批量计算或可视化),与
Poly5Solver
的区别在于支持同时处理N
条轨迹。应用场景:在轨迹规划中需要生成多条候选轨迹(如之前配置文件中的
horizon_num×vertical_num
条)时,可高效计算所有轨迹的位置,用于后续损失评估或选择最优轨迹。
三、偏航角计算函数(calculate_yaw
)
偏航角(yaw)控制无人机的朝向,需兼顾速度方向(减少空气阻力)和目标方向(便于感知目标),同时受最大偏航率约束(物理限制)。
1. 核心逻辑
def calculate_yaw(vel_dir, goal_dir, last_yaw, dt, max_yaw_rate=0.3):
YAW_DOT_MAX_PER_SEC = max_yaw_rate * np.pi # 最大偏航角速度(弧度/秒)
# 归一化速度方向和目标方向
vel_dir = vel_dir / (np.linalg.norm(vel_dir) + 1e-5)
goal_dist = np.linalg.norm(goal_dir)
goal_dir = goal_dir / (goal_dist + 1e-5)
# 动态计算目标方向的权重(角度差越大,目标权重越高)
goal_yaw = np.arctan2(goal_dir[1], goal_dir[0]) # 目标方向的偏航角
delta_yaw = goal_yaw - last_yaw
delta_yaw = (delta_yaw + np.pi) % (2*np.pi) - np.pi # 归一化到[-π, π]
weight = 6 * abs(delta_yaw) / np.pi # 权重随角度差增大而增大(0到6)
# 期望方向 = 速度方向 + 权重×目标方向(平衡两者)
dir_des = vel_dir + weight * goal_dir
# 计算期望偏航角
yaw_temp = np.arctan2(dir_des[1], dir_des[0]) if goal_dist > 0.2 else last_yaw
max_yaw_change = YAW_DOT_MAX_PER_SEC * dt # 时间步内的最大偏航角变化
# 确保偏航角变化不超过最大限制(处理角度环绕问题)
# ... 角度调整逻辑(见下方解析) ...
return yaw, yawdot # 当前偏航角和偏航角速度
2. 关键步骤解析
方向归一化:将
vel_dir
(速度方向)和goal_dir
(目标方向)归一化为单位向量,确保权重计算公平。动态权重:
权重weight
与当前偏航角和目标偏航角的差值delta_yaw
成正比(最大为 6),实现:当无人机朝向与目标方向偏差小时,主要跟随速度方向(减少阻力)。
当偏差大时,目标方向权重增加(快速转向目标,便于感知)。
最大偏航率约束:
无人机的偏航角变化量不能超过max_yaw_change
(最大角速度×时间步
),避免物理上不可行的急转。若期望偏航角
yaw_temp
与上一时刻偏航角last_yaw
的差值超过限制,则将变化量钳位到最大值。处理角度环绕问题(如从
π
到-π
的过渡,避免计算错误)。
这段代码是无人机轨迹规划的底层核心工具:
Poly5Solver
和Polys5Solver
通过五阶多项式生成平滑轨迹,满足运动学约束;calculate_yaw
计算合理的偏航角,平衡物理限制与任务需求。
三者共同支撑起从轨迹生成到姿态控制的完整流程,是连接高层规划(如目标点设置)与底层执行(如电机控制)的关键环节。
policy/models/primitive(轨迹生成)
定义了无人机轨迹规划中用于生成候选轨迹晶格(Lattice) 的核心类,包括轨迹参数配置(LatticeParam
)和晶格候选轨迹生成(LatticePrimitive
)。晶格轨迹是预先定义的一系列候选轨迹方向,用于快速生成符合物理约束的可行轨迹,是高效轨迹规划的基础。
一、LatticeParam
类:轨迹参数配置
该类用于从配置文件加载并计算轨迹规划的核心参数,为晶格轨迹生成提供基础约束。
1. 核心参数初始化
(class LatticeParam:
def __init__(self):
# 计算缩放比例(测试时根据实际速度调整,保持与训练的一致性)
ratio = 1.0 if cfg["train"] else cfg["velocity"] / cfg["vel_max_train"]
# 物理约束参数(根据比例调整)
self.vel_max = ratio * cfg["vel_max_train"] # 最大速度
self.acc_max = ratio * ratio * cfg["acc_max_train"] # 最大加速度(与速度平方成正比)
self.segment_time = cfg["sgm_time"] / ratio # 轨迹段时间(速度越高,时间越短)
# 晶格分布参数
self.horizon_num = cfg["horizon_num"] # 水平方向轨迹数量
self.vertical_num = cfg["vertical_num"] # 垂直方向轨迹数量
self.radio_num = cfg["radio_num"] # 径向轨迹数量
self.traj_num = cfg["traj_num"] # 总轨迹数量(三者乘积)
# 视场角参数(限制轨迹方向在相机可见范围内)
self.horizon_fov = cfg["horizon_camera_fov"] # 水平视场角(度)
self.vertical_fov = cfg["vertical_camera_fov"] # 垂直视场角(度)
self.horizon_anchor_fov = cfg["horizon_anchor_fov"] # 水平锚点视场角
self.vertical_anchor_fov = cfg["vertical_anchor_fov"] # 垂直锚点视场角
# 径向范围
self.radio_range = cfg["radio_range"] # 单段径向范围(总范围为2×radio_range)
# 打印关键参数
print("---------- Param --------")
print(f"| {'max speed':<12} = {round(self.vel_max, 1):>6} |")
print(f"| {'max accel':<12} = {round(self.acc_max, 1):>6} |")
print(f"| {'traj time':<12} = {round(self.segment_time, 1):>6} |")
print(f"| {'max radio':<12} = {round(2 * self.radio_range, 1):>6} |")
print("-------------------------")
核心功能:从配置文件加载参数,并根据训练 / 测试模式动态调整物理约束参数,为晶格轨迹生成提供基础配置。
关键参数解析:
比例调整(
ratio
):
当处于测试模式时,根据实际速度(cfg["velocity"]
)与训练时的最大速度(cfg["vel_max_train"]
)的比例,调整最大速度、加速度和轨迹段时间,确保物理约束的一致性(如高速时轨迹段时间缩短,避免轨迹过长)。物理约束:
vel_max
(最大速度)、acc_max
(最大加速度)确保轨迹符合无人机的运动学极限;segment_time
(轨迹段时间)定义单段轨迹的持续时间。晶格分布:
horizon_num
(水平方向数量)、vertical_num
(垂直方向数量)、radio_num
(径向数量)决定候选轨迹的分布密度,总数量为traj_num = horizon_num × vertical_num × radio_num
。
二、LatticePrimitive
类:晶格候选轨迹生成
该类继承自LatticeParam
,生成具体的晶格候选轨迹(Primitives
),即一系列预设方向的候选轨迹终点,用于轨迹规划中的候选轨迹采样。
1. 单例模式设计
class LatticePrimitive(LatticeParam):
_instance = None # 单例实例
@classmethod
def get_instance(self):
if self._instance is None:
self._instance = self() # 首次调用时初始化
return self._instance
功能:通过单例模式(
_instance
)确保全局仅存在一个LatticePrimitive
实例,避免重复生成晶格轨迹(节省计算资源)。
2. 晶格轨迹生成逻辑
def __init__(self):
super().__init__()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 设备选择
# 计算水平/垂直方向的角度间隔(弧度)
if self.horizon_num == 1:
direction_diff = 0 # 仅1条水平轨迹时,角度差为0
else:
# 水平视场角(度转弧度)平均分配到horizon_num个间隔
direction_diff = (self.horizon_fov / 180.0 * torch.pi) / self.horizon_num
if self.vertical_num == 1:
altitude_diff = 0 # 仅1条垂直轨迹时,角度差为0
else:
# 垂直视场角平均分配到vertical_num个间隔
altitude_diff = (self.vertical_fov / 180.0 * torch.pi) / self.vertical_num
radio_diff = self.radio_range / self.radio_num # 径向间隔(单位距离)
# 存储晶格节点的位置、角度、旋转矩阵
lattice_pos_list = [] # 位置:[N, 3](N=traj_num)
lattice_angle_list = [] # 角度:[N, 2](水平角、垂直角)
lattice_Rbp_list = [] # 旋转矩阵:[N, 3, 3](从机体坐标系到世界坐标系)
# 生成晶格节点(径向→垂直→水平的嵌套循环)
for h in range(0, self.radio_num): # 径向(距离)
for i in range(0, self.vertical_num): # 垂直方向(高度)
for j in range(0, self.horizon_num): # 水平方向(方位)
# 1. 计算径向距离
search_radio = (h + 1) * radio_diff # 第h个径向的距离(从近到远)
# 2. 计算水平角(alpha)和垂直角(beta)
# 水平角:居中分布(如视场角90度,3个间隔→-45°, 0°, 45°)
alpha = torch.tensor(-direction_diff * (self.horizon_num - 1) / 2 + j * direction_diff)
# 垂直角:居中分布(如视场角60度,3个间隔→-30°, 0°, 30°)
beta = torch.tensor(-altitude_diff * (self.vertical_num - 1) / 2 + i * altitude_diff)
# 3. 计算该角度下的位置坐标(球坐标→直角坐标)
pos_node = torch.tensor([
torch.cos(beta) * torch.cos(alpha) * search_radio, # x
torch.cos(beta) * torch.sin(alpha) * search_radio, # y
torch.sin(beta) * search_radio # z
])
# 4. 计算旋转矩阵(从机体坐标系到世界坐标系)
# 欧拉角(ZYX顺序):水平角(alpha)→垂直角(-beta)→滚动角(0)
Rotation = R.from_euler('ZYX', [alpha, -beta, 0.0], degrees=False)
R_matrix = torch.tensor(Rotation.as_matrix()) # 转换为旋转矩阵
# 5. 存储参数
lattice_pos_list.append(pos_node)
lattice_angle_list.append(torch.tensor([alpha, beta]))
lattice_Rbp_list.append(R_matrix)
# 转换为Tensor并移动到目标设备(GPU/CPU)
self.lattice_pos_node = torch.stack(lattice_pos_list).to(dtype=torch.float32, device=device)
self.lattice_angle_node = torch.stack(lattice_angle_list).to(dtype=torch.float32, device=device)
self.lattice_Rbp_node = torch.stack(lattice_Rbp_list).to(dtype=torch.float32, device=device)
# 锚点视场角的半角(用于后续角度约束)
self.yaw_diff = 0.5 * self.horizon_anchor_fov / 180.0 * torch.pi
self.pitch_diff = 0.5 * self.vertical_anchor_fov / 180.0 * torch.pi
核心功能:生成均匀分布在水平、垂直、径向三个维度的候选轨迹终点(晶格节点),每个节点包含位置、角度和旋转矩阵,用于轨迹规划时快速采样候选轨迹。
关键步骤解析:
角度间隔计算:
direction_diff
:水平方向的角度间隔(如 90 度视场角,3 条轨迹→30 度 / 间隔)。altitude_diff
:垂直方向的角度间隔(如 60 度视场角,3 条轨迹→20 度 / 间隔)。
角度分布采用居中对称(如 3 条轨迹对应-θ/2, 0, θ/2
),确保覆盖整个视场。
位置计算(球坐标→直角坐标):
每个晶格节点的位置通过球坐标转换得到:径向距离:
search_radio = (h+1)×radio_diff
(从近到远,共radio_num
级)。水平角
alpha
(方位角)和垂直角beta
(仰角 / 俯角)决定方向。直角坐标转换公式:
x=r⋅cos(β)⋅cos(α)
y=r⋅cos(β)⋅sin(α)
z=r⋅sin(β)
旋转矩阵:
通过欧拉角(ZYX 顺序)转换为旋转矩阵,用于将机体坐标系下的轨迹转换到世界坐标系,确保轨迹方向与无人机朝向一致。存储与设备适配:
生成的位置、角度、旋转矩阵转换为 PyTorch Tensor 并移动到 GPU/CPU,便于后续与神经网络(如轨迹预测模型)交互。
3. 晶格参数访问方法
def getStateLattice(self, id=None):
# 获取晶格节点的位置(id为None时返回所有节点)
if id is not None:
return self.lattice_pos_node[id, :]
else:
return self.lattice_pos_node
def getAngleLattice(self, id=None):
# 获取晶格节点的角度(水平角alpha、垂直角beta)
if id is not None:
return self.lattice_angle_node[id, 0], self.lattice_angle_node[id, 1]
else:
return self.lattice_angle_node[:, 0], self.lattice_angle_node[:, 1]
def getRotation(self, id=None):
# 获取晶格节点的旋转矩阵
if id is not None:
return self.lattice_Rbp_node[id]
else:
return self.lattice_Rbp_node
功能:提供接口供其他模块(如轨迹优化器)获取晶格节点的位置、角度和旋转矩阵,用于生成候选轨迹或评估轨迹可行性。
4. 图像网格与晶格 ID 转换
def convert_ImageGrid_LatticeID(self, id):
return self.traj_num - id - 1
功能:转换图像网格索引到晶格 ID。
注释中提到图像网格的索引布局为 “行优先,左下为原点”(如 3×3 网格的索引 0 在左下,8 在右上),而晶格 ID 的生成顺序可能相反,因此需要通过该方法反转索引,确保图像中的网格位置与晶格轨迹对应。
三、晶格轨迹的应用场景
晶格候选轨迹是基于采样的轨迹规划的核心,其作用包括:
候选轨迹采样:生成覆盖水平、垂直、径向的均匀分布候选轨迹,确保轨迹规划时能快速覆盖所有可能的运动方向。
避障与目标平衡:通过多方向的候选轨迹,为避障(选择无碰撞方向)和目标跟踪(选择朝向目标方向)提供选择空间。
高效评估:预先生成的晶格轨迹可通过损失函数(如
YOPOLoss
)快速评估每条轨迹的代价(平滑度、安全性、目标导向性),选择最优轨迹执行。
总结
LatticeParam
和LatticePrimitive
类共同构成了无人机轨迹规划的候选轨迹生成模块:
LatticeParam
负责参数配置,从配置文件加载并调整物理约束和分布参数。LatticePrimitive
通过单例模式生成具体的晶格候选轨迹,基于水平、垂直、径向三个维度的均匀分布,生成符合运动学约束和视场角限制的候选轨迹终点。
这些晶格轨迹为后续的轨迹优化(如损失评估、最优轨迹选择)提供了基础,是连接轨迹规划算法与无人机物理约束的关键组件。
policy/models/state_transform.py
实现了无人机轨迹规划中的坐标变换与状态转换功能,主要用于在不同坐标系(机体坐标系、世界坐标系、晶格坐标系)之间转换轨迹状态,并处理数据归一化。这些转换是轨迹规划算法的核心环节,确保生成的轨迹符合物理约束并正确表达。
一、StateTransform
类:状态转换核心类
该类负责无人机状态在不同坐标系间的转换,以及数据的归一化 / 反归一化处理。
1. 初始化与依赖
class StateTransform:
def __init__(self):
self.lattice_primitive = LatticePrimitive.get_instance() # 获取晶格轨迹实例
self.goal_length = cfg['goal_length'] # 目标距离阈值
功能:初始化时获取晶格轨迹配置(如方向、角度),并设置目标距离阈值,为后续转换提供基础参数。
2. 预测结果转终点状态(pred_to_endstate
)
def pred_to_endstate(self, endstate_pred: torch.Tensor) -> torch.Tensor:
"""
将预测结果转换为机体坐标系下的终点状态
输入: [batch; px py pz vx vy vz ax ay az; primitive_v; primitive_h]
输出: [batch; px py pz vx vy vz ax ay az; primitive_v; primitive_h] (机体坐标系)
"""
B, N = endstate_pred.shape[0], endstate_pred.shape[2] * endstate_pred.shape[3]
# 调整维度:[B, 9, 3, 5] -> [B, 15, 9]
endstate_pred = endstate_pred.permute(0, 2, 3, 1).reshape(B, N, 9)
# 获取晶格角度和旋转矩阵(处理顺序翻转,因lattice与grid顺序相反)
yaw, pitch = self.lattice_primitive.getAngleLattice() # [15]
yaw = yaw.flip(0)[None, :].expand(B, -1) # [B, 15]
pitch = pitch.flip(0)[None, :].expand(B, -1)
Rbp = self.lattice_primitive.getRotation().flip(0) # [15, 3, 3]
Rbp = Rbp[None, :, :, :].expand(B, -1, -1, -1) # [B, 15, 3, 3]
# 解析预测值:偏航角增量、俯仰角增量、径向距离
delta_yaw = endstate_pred[:, :, 0] * self.lattice_primitive.yaw_diff # [B, 15]
delta_pitch = endstate_pred[:, :, 1] * self.lattice_primitive.pitch_diff
radio = (endstate_pred[:, :, 2] + 1.0) * self.lattice_primitive.radio_range
# 计算终点位置(球坐标→直角坐标)
cos_pitch = torch.cos(pitch + delta_pitch)
endstate_x = cos_pitch * torch.cos(yaw + delta_yaw) * radio
endstate_y = cos_pitch * torch.sin(yaw + delta_yaw) * radio
endstate_z = torch.sin(pitch + delta_pitch) * radio
endstate_p = torch.stack([endstate_x, endstate_y, endstate_z], dim=-1) # [B, 15, 3]
# 解析速度和加速度(并应用最大限制)
endstate_vp = endstate_pred[:, :, 3:6] * self.lattice_primitive.vel_max # [B, 15, 3]
endstate_ap = endstate_pred[:, :, 6:9] * self.lattice_primitive.acc_max # [B, 15, 3]
# 将速度和加速度从晶格坐标系转换到机体坐标系(通过旋转矩阵)
endstate_vb = torch.matmul(Rbp, endstate_vp.unsqueeze(-1)).squeeze(-1) # [B, 15, 3]
endstate_ab = torch.matmul(Rbp, endstate_ap.unsqueeze(-1)).squeeze(-1)
# 合并位置、速度、加速度
endstate = torch.cat([endstate_p, endstate_vb, endstate_ab], dim=-1) # [B, 15, 9]
# 恢复原始维度:[B, 15, 9] -> [B, 9, 3, 5]
endstate = endstate.permute(0, 2, 1).reshape(B, 9, 3, 5)
return endstate
核心功能:将神经网络预测的轨迹参数(角度增量、径向距离、速度、加速度)转换为机体坐标系下的终点状态(位置、速度、加速度)。
关键步骤:
数据预处理:调整输入张量维度,适应批量处理。
晶格参数获取:从晶格轨迹配置中获取基础角度和旋转矩阵,并处理顺序翻转(因索引布局差异)。
状态解析:
从预测值中提取角度增量(
delta_yaw
,delta_pitch
)和径向距离(radio
)。计算终点位置(球坐标→直角坐标转换)。
解析速度和加速度,并应用物理约束(乘以最大速度 / 加速度)。
坐标转换:通过旋转矩阵将速度和加速度从晶格坐标系转换到机体坐标系。
3. CPU 版本的状态转换(pred_to_endstate_cpu
)
def pred_to_endstate_cpu(self, endstate_pred: np.ndarray, lattice_id: torch.Tensor) -> np.ndarray:
"""
测试阶段使用的CPU版本(比GPU版本快10倍)
输入: 预测值、晶格ID
输出: [B; px py pz vx vy vz ax ay az] (机体坐标系)
"""
# 解析预测值
delta_yaw = endstate_pred[:, 0] * self.lattice_primitive.yaw_diff
delta_pitch = endstate_pred[:, 1] * self.lattice_primitive.pitch_diff
radio = (endstate_pred[:, 2] + 1.0) * self.lattice_primitive.radio_range
# 获取晶格角度(转换为NumPy数组)
yaw, pitch = self.lattice_primitive.getAngleLattice(lattice_id)
yaw, pitch = yaw.cpu().numpy(), pitch.cpu().numpy()
# 计算终点位置(NumPy版本)
endstate_x = np.cos(pitch + delta_pitch) * np.cos(yaw + delta_yaw) * radio
endstate_y = np.cos(pitch + delta_pitch) * np.sin(yaw + delta_yaw) * radio
endstate_z = np.sin(pitch + delta_pitch) * radio
endstate_p = np.stack((endstate_x, endstate_y, endstate_z), axis=1)
# 解析速度和加速度
endstate_vp = endstate_pred[:, 3:6] * self.lattice_primitive.vel_max
endstate_ap = endstate_pred[:, 6:9] * self.lattice_primitive.acc_max
# 坐标转换(NumPy版本)
Rpb = self.lattice_primitive.getRotation(lattice_id).cpu().numpy()
endstate_vb = np.matmul(Rpb, endstate_vp[:, :, np.newaxis]).squeeze(-1)
endstate_ab = np.matmul(Rpb, endstate_ap[:, :, np.newaxis]).squeeze(-1)
return np.concatenate((endstate_p, endstate_vb, endstate_ab), axis=1)
功能:与
pred_to_endstate
相同,但使用 NumPy 在 CPU 上执行,适用于测试阶段(避免 GPU 数据传输开销,提高速度)。
4. 观测数据预处理(prepare_input
)
def prepare_input(self, obs):
"""
将观测数据转换为晶格坐标系(机体坐标系 → 晶格坐标系 → 机体坐标系)
输入: [batch; vx, vy, yz, ax, ay, az, gx, gy, gz] (机体坐标系)
输出: [batch; vx, vy, yz, ax, ay, az, gx, gy, gz; primitive_v; primitive_h] (晶格坐标系)
"""
B, N = obs.shape[0], self.lattice_primitive.traj_num
# 获取所有晶格旋转矩阵并调整顺序
Rbp_all = self.lattice_primitive.getRotation().flip(0) # [N, 3, 3]
# 调整观测数据维度:[B, 9] -> [B, 3, 3]
obs = obs.view(B, 3, 3)
# 扩展维度以支持批量处理
obs_exp = obs[:, None, :, :].expand(B, N, 3, 3) # [B, N, 3, 3]
Rbp_exp = Rbp_all[None, :, :, :].expand(B, N, 3, 3)
# 通过矩阵乘法执行坐标变换
transformed = torch.matmul(obs_exp, Rbp_exp) # [B, N, 3, 3]
# 调整输出维度
transformed_flat = transformed.view(B, N, 9) # [B, N, 9]
out = transformed_flat.permute(0, 2, 1).contiguous() # [B, 9, N]
out = out.view(B, 9, self.lattice_primitive.vertical_num, self.lattice_primitive.horizon_num)
return out
功能:将观测数据(如当前速度、加速度、目标方向)从机体坐标系转换到晶格坐标系,为神经网络输入做准备。
关键操作:通过批量矩阵乘法(
torch.matmul
)高效处理多个晶格方向的坐标变换。
5. 数据归一化与反归一化
def unnormalize_obs(self, vel_acc):
# 反归一化:乘以最大速度/加速度
vel_acc[:, 0:3] = vel_acc[:, 0:3] * self.lattice_primitive.vel_max
vel_acc[:, 3:6] = vel_acc[:, 3:6] * self.lattice_primitive.acc_max
return vel_acc
def normalize_obs(self, vel_acc_goal):
# 归一化:除以最大速度/加速度
vel_acc_goal[:, 0:3] = vel_acc_goal[:, 0:3] / self.lattice_primitive.vel_max
vel_acc_goal[:, 3:6] = vel_acc_goal[:, 3:6] / self.lattice_primitive.acc_max
# 目标方向归一化(长度至少为goal_length)
goal_norm = vel_acc_goal[:, 6:9].norm(dim=1, keepdim=True)
vel_acc_goal[:, 6:9] = vel_acc_goal[:, 6:9] / goal_norm.clamp(min=self.goal_length)
return vel_acc_goal
功能:
unnormalize_obs
:将网络输出的归一化速度 / 加速度恢复为实际物理值。normalize_obs
:将输入的速度 / 加速度归一化到 [0,1] 区间,并确保目标方向向量长度合理。
二、坐标系转换工具函数
这些独立函数用于在世界坐标系和机体坐标系之间转换状态。
1. 旋转变换(rotate_body2world
)
def rotate_body2world(rot_wb, pos_b):
"""
使用旋转矩阵rot_wb将pos_b从机体坐标系转换到世界坐标系
rot_wb: (..., 3, 3)
pos_b: (..., 3)
"""
pos_w = torch.matmul(rot_wb, pos_b.unsqueeze(-1)).squeeze(-1)
return pos_w
功能:通过矩阵乘法实现向量旋转(不考虑平移)。
2. 坐标变换(transform_body2world
)
def transform_body2world(rot_wb, t_w, pos_b):
"""
使用旋转矩阵rot_wb和位置向量t_w将pos_b从机体坐标系转换到世界坐标系
rot_wb: (..., 3, 3)
t_w: (..., 3)
pos_b: (..., 3)
"""
return rotate_body2world(rot_wb, pos_b) + t_w
功能:综合旋转和平移,实现完整坐标变换(
pos_w = R·pos_b + t_w
)。
3. 状态变换(state_body2world
)
def state_body2world(pos_w, rot_wb, pos_b, vel_b, acc_b):
"""
将机体坐标系下的状态(位置、速度、加速度)转换到世界坐标系
"""
pos_b = transform_body2world(rot_wb, pos_w, pos_b)
vel_b = rotate_body2world(rot_wb, vel_b)
acc_b = rotate_body2world(rot_wb, acc_b)
return pos_b, vel_b, acc_b
功能:同时转换位置、速度、加速度,确保状态一致性(速度和加速度仅需旋转,无需平移)。
三、坐标变换的应用场景
轨迹规划:
将神经网络预测的晶格参数转换为实际轨迹终点状态(
pred_to_endstate
)。在不同坐标系间转换状态,确保轨迹符合物理约束(如最大速度、加速度)。
数据预处理 / 后处理:
归一化输入数据(
normalize_obs
),提高神经网络训练稳定性。反归一化输出数据(
unnormalize_obs
),恢复物理单位。
多传感器融合:
在世界坐标系和机体坐标系间转换观测数据,统一数据表达,便于融合处理。
总结
StateTransform
类和相关工具函数构成了无人机轨迹规划中的坐标变换与状态处理核心模块:
实现了多坐标系(机体、世界、晶格)间的灵活转换,确保轨迹规划的物理一致性。
通过归一化 / 反归一化处理,优化神经网络训练和预测效果。
提供 CPU 和 GPU 版本的转换函数,平衡计算效率和灵活性,适用于训练和测试不同场景。
这些功能是连接高层规划算法(如轨迹预测)和底层控制(如电机指令)的关键桥梁,确保生成的轨迹既符合物理约束,又能高效执行。
policy/models/yopo_dataset.py
定义了一个用于无人机轨迹规划模型的数据集类YOPODataset
,继承自 PyTorch 的Dataset
类。该类用于加载、预处理和采样无人机的深度图像、位置、姿态及随机生成的运动状态与目标位置,为模型训练和验证提供标准化输入。以下是详细分析:
一、类初始化(__init__
方法)
该方法负责加载数据集配置、划分训练 / 验证集、读取数据文件并预处理,是数据集的核心初始化逻辑。
1. 配置参数加载
def __init__(self, mode='train', val_ratio=0.1):
super(YOPODataset, self).__init__()
# 图像参数
self.height = int(cfg["image_height"])
self.width = int(cfg["image_width"])
# 运动状态采样参数(速度、加速度的分布参数)
self.vel_max = cfg["vel_max_train"]
self.acc_max = cfg["acc_max_train"]
self.vx_lognorm_mean = np.log(1 - cfg["vx_mean_unit"]) # x方向速度的对数正态分布均值
self.vx_logmorm_sigma = np.log(cfg["vx_std_unit"]) # x方向速度的对数正态分布标准差
self.v_mean = np.array([cfg["vx_mean_unit"], cfg["vy_mean_unit"], cfg["vz_mean_unit"]]) # 速度均值
self.v_std = np.array([cfg["vx_std_unit"], cfg["vy_std_unit"], cfg["vz_std_unit"]]) # 速度标准差
self.a_mean = np.array([cfg["ax_mean_unit"], cfg["ay_mean_unit"], cfg["az_mean_unit"]]) # 加速度均值
self.a_std = np.array([cfg["ax_std_unit"], cfg["ay_std_unit"], cfg["az_std_unit"]]) # 加速度标准差
# 目标参数
self.goal_length = cfg['goal_length']
self.goal_pitch_std = cfg["goal_pitch_std"] # 目标俯仰角标准差(度)
self.goal_yaw_std = cfg["goal_yaw_std"] # 目标偏航角标准差(度)
...
功能:从配置文件加载数据集所需的核心参数,包括图像尺寸、运动学约束(最大速度 / 加速度)、状态采样分布参数(均值、标准差)和目标生成参数。
关键参数意义:
速度和加速度的分布参数(
v_mean
,v_std
,a_mean
,a_std
)决定了随机状态的采样范围,确保生成的状态符合无人机物理特性。目标角度标准差(
goal_pitch_std
,goal_yaw_std
)控制目标位置的随机性,影响模型对不同方向目标的泛化能力。
2. 数据集加载与划分
# 数据集路径与文件夹
base_dir = os.path.dirname(os.path.abspath(__file__))
data_dir = os.path.join(base_dir, "../", cfg["dataset_path"])
self.img_list, self.map_idx, self.positions, self.quaternions = [], [], np.empty((0, 3), dtype=np.float32), np.empty((0, 4), dtype=np.float32)
datafolders = [f.path for f in os.scandir(data_dir) if f.is_dir()]
datafolders.sort(key=lambda x: int(os.path.basename(x))) # 按文件夹名称排序
# 遍历每个数据文件夹
for data_idx in range(len(datafolders)):
datafolder = datafolders[data_idx]
# 读取图像文件名并按序号排序(确保与标签对齐)
image_file_names = [filename for filename in os.listdir(datafolder) if os.path.splitext(filename)[1] == '.png']
image_file_names.sort(key=lambda x: int(x.split('.')[0].split("_")[1]))
# 读取姿态数据(位置和四元数)
states = np.loadtxt(data_dir + f"/pose-{data_idx}.csv", delimiter=',', skiprows=1).astype(np.float32)
positions = states[:, 0:3] # 位置(x,y,z)
quaternions = states[:, 3:7] # 四元数(姿态)
# 划分训练集和验证集
file_names_train, file_names_val, positions_train, positions_val, quaternions_train, quaternions_val = train_test_split(
image_file_names, positions, quaternions, test_size=val_ratio, random_state=0)
# 根据模式加载对应的数据
if mode == 'train':
images = [cv2.imread(datafolder + "/" + filename, -1).astype(np.float32) for filename in file_names_train]
self.img_list.extend(images)
self.positions = np.vstack((self.positions, positions_train.astype(np.float32)))
self.quaternions = np.vstack((self.quaternions, quaternions_train.astype(np.float32)))
elif mode == 'valid':
... # 类似训练集加载逻辑
self.map_idx.extend([data_idx] * len(images)) # 记录每个样本对应的地图索引
# 图像预处理: resize、归一化、扩展维度
self.img_list = [np.expand_dims(
cv2.resize(img, (self.width, self.height), interpolation=cv2.INTER_NEAREST) / 65535.0,
axis=0)
for img in self.img_list]
功能:加载数据集文件夹中的深度图像、位置和姿态数据,按比例划分为训练集和验证集,并对图像进行预处理(尺寸调整、归一化)。
关键细节:
图像排序:通过文件名提取序号并排序,确保图像与对应的位置 / 姿态标签对齐(避免数据错乱)。
图像预处理:
cv2.resize
:将图像调整为配置文件指定的尺寸(image_height
,image_width
)。归一化:深度图像原始数据为
int16
(0-65535 对应 0-20 米),除以 65535.0 转换为 [0,1] 范围的float32
。扩展维度:添加通道维度(
np.expand_dims(..., axis=0)
),符合深度学习模型的输入格式([C, H, W])。
二、样本获取(__getitem__
方法)
该方法返回数据集中的单个样本,包含预处理后的深度图像、位置、旋转矩阵、观测数据和地图索引。
def __getitem__(self, item):
# 生成随机的速度和加速度状态(机体坐标系)
vel_b, acc_b = self._get_random_state()
# 生成随机目标位置(在机体坐标系中)
q_wxyz = self.quaternions[item, :] # 四元数(wxyz格式)
R_WB = R.from_quat([q_wxyz[1], q_wxyz[2], q_wxyz[3], q_wxyz[0]]) # 转换为旋转矩阵(世界→机体)
euler_angles = R_WB.as_euler('ZYX', degrees=False) # 欧拉角(yaw(z), pitch(y), roll(x))
R_wB = R.from_euler('ZYX', [0, euler_angles[1], euler_angles[2]], degrees=False) # 忽略偏航角的旋转矩阵
goal_w = self._get_random_goal() # 世界坐标系中的目标
goal_b = R_wB.inv().apply(goal_w) # 转换到机体坐标系
# 组合观测数据(速度、加速度、目标位置,机体坐标系)
random_obs = np.hstack((vel_b, acc_b, goal_b)).astype(np.float32)
rot_wb = R_WB.as_matrix().astype(np.float32) # 世界→机体的旋转矩阵
# 返回样本:深度图像、位置、旋转矩阵、观测数据、地图索引
return self.img_list[item], self.positions[item], rot_wb, random_obs, self.map_idx[item]
核心功能:为每个样本生成随机的运动状态(速度、加速度)和目标位置,结合原始数据(图像、位置、姿态)构建模型输入。
坐标变换细节:
四元数转旋转矩阵:使用
scipy.spatial.transform.Rotation
将四元数(q_wxyz
)转换为旋转矩阵(R_WB
),描述世界坐标系到机体坐标系的变换。目标位置转换:先在世界坐标系生成目标(
goal_w
),再通过旋转矩阵的逆(R_wB.inv()
)转换到机体坐标系(goal_b
),确保目标位置相对于无人机自身的方向正确。
三、随机状态生成(_get_random_state
方法)
生成符合物理约束的随机速度和加速度(机体坐标系),用于模拟无人机的运动状态。
def _get_random_state(self):
# 生成速度(x分量用对数正态分布,右偏;其他分量用正态分布)
while True:
vel = self.vel_max * (self.v_mean + self.v_std * np.random.randn(3)) # 基础正态分布
# x方向速度:对数正态分布(确保右偏,符合无人机多向前运动的特性)
right_skewed_vx = -1
while right_skewed_vx < 0: # 确保x方向速度为正(向前)
right_skewed_vx = self.vel_max * np.random.lognormal(mean=self.vx_lognorm_mean, sigma=self.vx_logmorm_sigma, size=None)
right_skewed_vx = -right_skewed_vx + 1.2 * self.vel_max # 调整偏移,确保覆盖最大速度
vel[0] = right_skewed_vx
if np.linalg.norm(vel) < 1.2 * self.vel_max: # 过滤速度过大的异常值
break
# 生成加速度(正态分布,限制在最大加速度范围内)
while True:
acc = self.acc_max * (self.a_mean + self.a_std * np.random.randn(3))
if np.linalg.norm(acc) < 1.2 * self.acc_max: # 过滤加速度过大的异常值
break
return vel, acc
设计逻辑:
速度分布:x 方向(前进方向)使用对数正态分布(右偏),确保大多数情况下速度为正(符合无人机向前运动的常见场景);y、z 方向使用正态分布,允许左右、上下运动。
物理约束:通过
while
循环和范数检查,确保速度和加速度不超过最大限制的 1.2 倍(允许小范围超调,增加数据多样性,但过滤极端异常值)。
四、随机目标生成(_get_random_goal
方法)
生成随机的目标位置(世界坐标系),用于模拟无人机需要到达的目标。
def _get_random_goal(self):
# 生成目标的俯仰角和偏航角(正态分布,度→弧度)
goal_pitch_angle = np.random.normal(0.0, self.goal_pitch_std) # 俯仰角(上下方向)
goal_yaw_angle = np.random.normal(0.0, self.goal_yaw_std) # 偏航角(左右方向)
goal_pitch_angle, goal_yaw_angle = np.radians(goal_pitch_angle), np.radians(goal_yaw_angle)
# 计算目标方向向量(球坐标→直角坐标)
goal_w_dir = np.array([
np.cos(goal_yaw_angle) * np.cos(goal_pitch_angle), # x
np.sin(goal_yaw_angle) * np.cos(goal_pitch_angle), # y
np.sin(goal_pitch_angle) # z
])
# 10%概率生成近距离目标(增加接近目标时的样本多样性)
random_near = np.random.rand()
if random_near < 0.1:
goal_w_dir = random_near * 10 * goal_w_dir # 近距离(原长度的0~1倍)
return self.goal_length * goal_w_dir # 缩放至目标长度
功能:生成随机方向的目标位置,模拟不同方向和距离的任务需求。
多样性设计:
角度分布:偏航角和俯仰角基于正态分布,确保目标主要分布在无人机前方(符合视觉感知范围)。
距离变化:10% 概率生成近距离目标(原长度的 0~1 倍),补充无人机接近目标时的场景数据,提高模型在终点附近的精度。
五、辅助方法:数据分布分析
1. print_data
方法
打印速度、加速度的 95% 分布范围和目标角度范围,帮助开发者了解数据采样的覆盖度。
2. plot_sample_distribution
方法
可视化目标方向角(偏航、俯仰)、速度和加速度的分布,验证采样是否符合预期(如是否覆盖主要运动状态和目标方向)。
六、数据集的作用与应用场景
YOPODataset
是无人机轨迹规划模型的数据输入接口,其核心作用是:
数据标准化:将原始深度图像、位置、姿态数据转换为模型可直接使用的格式(如归一化、维度调整)。
多样性增强:通过随机生成速度、加速度和目标位置,扩展训练数据的覆盖范围,提高模型的泛化能力(如应对不同运动状态和目标方向)。
坐标一致性:确保所有状态(速度、加速度、目标)在机体坐标系下统一,简化模型对相对运动的学习。
该数据集通常与DataLoader
结合使用,批量加载数据用于训练无人机的感知 - 规划模型(如基于深度图像预测最优轨迹)。
总结
YOPODataset
是一个功能完整的无人机轨迹规划数据集类,涵盖数据加载、预处理、随机状态生成和坐标转换等核心功能。其设计兼顾了物理约束(速度、加速度限制)和任务多样性(随机目标、运动状态),为模型训练提供了高质量、多样化的输入数据。通过该数据集,模型可以学习从深度图像和当前状态预测合理轨迹的能力,适用于自主导航、避障等无人机任务。
policy/models/yoppo_network.py
定义了一个用于无人机轨迹规划的神经网络模型YopoNetwork
,该模型结合深度图像(环境感知)和无人机状态信息(运动状态),预测轨迹的终点状态(位置、速度、加速度),是连接感知与规划的核心组件。以下是详细分析:
1. 类初始化(__init__
方法)
def __init__(
self,
observation_dim=9, # 观测维度:v_xyz(3) + a_xyz(3) + goal_xyz(3)
output_dim=10, # 输出维度:x_pva(3) + y_pva(3) + z_pva(3) + score(1)
hidden_state=64, # 隐藏层特征维度
):
super(YopoNetwork, self).__init__()
self.state_transform = StateTransform() # 状态转换工具(坐标变换、归一化)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 设备选择
# 网络组件
self.image_backbone = YopoBackbone(hidden_state) # 图像特征提取骨干网络
self.state_backbone = nn.Sequential() # 状态特征提取骨干网络(预留,当前为空)
self.yopo_head = YopoHead(hidden_state + observation_dim, output_dim) # 输出头(融合特征并预测)
核心功能:初始化网络的核心组件,定义输入输出维度,关联状态转换工具(
StateTransform
)。关键组件解析:
state_transform
:复用之前定义的StateTransform
类,负责状态的归一化、坐标转换(如机体坐标系↔晶格坐标系),确保输入输出的一致性。image_backbone
:YopoBackbone
(未展示代码)是用于提取深度图像特征的骨干网络(可能包含卷积层),输出维度为hidden_state
(如 64 维特征)。state_backbone
:状态特征提取网络,当前为空(nn.Sequential()
),预留扩展空间(未来可添加全连接层处理状态特征)。yopo_head
:输出头网络,接收图像特征与状态特征的融合结果,输出最终预测(终点状态 + 轨迹分数)。维度设计:
hidden_state + observation_dim
表示融合后的特征维度(图像特征 + 状态特征),output_dim=10
对应 9 维轨迹参数(位置、速度、加速度各 3 维)+1 维轨迹分数。
2. 前向传播(forward
方法)
def forward(self, depth: torch.Tensor, obs: torch.Tensor) -> torch.Tensor:
"""前向传播:输入深度图像和观测状态,输出预测结果"""
# 提取深度图像特征
depth_feature = self.image_backbone(depth) # 形状:[batch_size, hidden_state, ...]
# 提取状态特征(当前为空操作,直接返回原始状态)
obs_feature = self.state_backbone(obs) # 形状:[batch_size, observation_dim, ...]
# 融合图像特征和状态特征
input_tensor = torch.cat((obs_feature, depth_feature), 1) # 按维度1拼接(特征维度融合)
# 预测输出
output = self.yopo_head(input_tensor) # 形状:[batch_size, output_dim, ...]
# 处理输出:终点状态(tanh激活,范围[-1,1])和轨迹分数(softplus激活,确保非负)
endstate = torch.tanh(output[:, :9]) # 前9维:[batch, 9, vertical_num, horizon_num]
score = torch.nn.functional.softplus(output[:, 9]) # 第10维:[batch, vertical_num, horizon_num]
return endstate, score
核心功能:定义网络的前向计算流程,将深度图像和观测状态映射为轨迹终点状态和分数。
关键步骤解析:
特征提取:
image_backbone
处理深度图像(如通过卷积层提取环境特征,如障碍物位置、距离等)。state_backbone
当前未处理状态(obs_feature
即原始观测),未来可添加全连接层增强状态特征表达。
特征融合:通过
torch.cat
在特征维度(维度 1)拼接图像特征和状态特征,使网络同时利用环境信息和运动状态。输出处理:
endstate
:前 9 维用tanh
激活,将预测值限制在 [-1,1] 范围,后续通过StateTransform
反归一化到实际物理范围。score
:第 10 维用softplus
激活(softplus(x) = ln(1 + e^x)
),确保输出非负,用于评估轨迹的优劣(分数越高,轨迹越优)。
3. 推理流程(inference
方法)
def inference(self, depth: torch.Tensor, obs: torch.Tensor) -> torch.Tensor:
"""推理流程:用于训练时的完整处理链,从原始状态到机体坐标系下的终点状态"""
# 步骤1:归一化观测状态(速度、加速度、目标位置)
obs = self.state_transform.normalize_obs(obs) # 归一化到[0,1]范围,稳定训练
# 步骤2:将状态转换到晶格坐标系(适配晶格轨迹的分布)
obs = self.state_transform.prepare_input(obs) # 形状调整为晶格坐标系下的格式
# 步骤3:前向传播,得到预测结果
endstate_pred, score_pred = self.forward(depth, obs) # 预测值(晶格坐标系)
# 步骤4:将预测结果转换为机体坐标系下的终点状态(物理单位)
endstate = self.state_transform.pred_to_endstate(endstate_pred) # 反归一化+坐标转换
return endstate, score_pred
核心功能:封装训练时的完整数据处理流程,从原始输入到最终可用的轨迹终点状态,确保输入输出的坐标和单位一致性。
与
forward
的区别:forward
仅实现网络的前向计算,而inference
包含完整的预处理(归一化、坐标转换)和后处理(反转换),是实际训练 / 推理时调用的接口。关键作用:
通过state_transform
的方法,将原始观测状态转换为网络可处理的格式(归一化 + 晶格坐标系),再将网络输出的抽象预测值转换为物理意义明确的机体坐标系下的终点状态(位置、速度、加速度),实现 “原始输入→网络预测→物理输出” 的端到端流程。
4. 梯度打印辅助方法(print_grad
)
def print_grad(self, grad):
print("grad of hook: ", grad)
功能:调试用的钩子函数,用于打印梯度信息,帮助分析网络训练中的梯度流动(如是否存在梯度消失或爆炸)。
5. 网络的核心作用与数据流向
YopoNetwork
是无人机感知 - 规划一体化的核心模型,其数据流向如下:
输入:
- 深度图像(depth):环境的深度信息(如障碍物距离)
- 观测状态(obs):机体坐标系下的速度(v_xyz)、加速度(a_xyz)、目标位置(goal_xyz)
处理流程:
1. 观测状态 → 归一化 → 转换到晶格坐标系(state_transform)
2. 深度图像 → 提取特征(image_backbone)
3. 晶格坐标系状态 + 图像特征 → 融合 → 预测(yopo_head)
4. 预测结果 → 转换到机体坐标系 → 输出终点状态和轨迹分数
输出:
- endstate:机体坐标系下的轨迹终点状态(位置、速度、加速度)
- score_pred:各候选轨迹的分数(用于选择最优轨迹)
6. 应用场景
该网络主要用于无人机的在线轨迹规划:
输入实时采集的深度图像(感知环境)和无人机当前状态(速度、加速度、目标位置),输出多条候选轨迹的终点状态和分数。
结合损失函数(如
YOPOLoss
),通过分数选择最优轨迹(平滑度高、无碰撞、朝向目标),控制无人机执行。
总结
YopoNetwork
是一个融合环境感知(深度图像)和运动状态的轨迹预测网络,通过状态转换工具实现坐标和单位的一致性处理,最终输出物理意义明确的轨迹终点状态。其设计体现了 “感知 - 规划” 一体化的思路,是无人机自主导航系统的核心组件,负责将原始感知数据转化为可执行的轨迹指令。
policy/models/yopo_trainer.py
定义了无人机轨迹规划模型YopoNetwork
的训练策略类YopoTrainer
,实现了完整的模型训练流程,包括数据加载、前向传播、损失计算、参数优化、模型评估和日志记录,是将模型从理论设计转化为可用模型的核心组件。以下是详细分析:
1. 类初始化(__init__
方法)
python
def __init__(
self,
learning_rate=0.001,
batch_size=32,
loss_weight=[],
tensorboard_path=None,
checkpoint_path=None,
save_on_exit=False,
):
self.batch_size = batch_size # 批处理大小
self.max_grad_norm = 0.1 # 最大梯度范数(用于梯度裁剪,未直接使用)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 设备选择
self.loss_weight = loss_weight # 轨迹损失和分数损失的权重
if save_on_exit: # 注册程序退出时的模型保存函数
self._exit_func = atexit.register(self.save_model)
# 日志配置
self.progress_log = Progress() # 进度条显示
self.tensorboard_path = self.get_next_log_path(tensorboard_path) # 日志路径
self.tensorboard_log = SummaryWriter(log_dir=self.tensorboard_path) # TensorBoard日志
# 轨迹数量(从配置文件读取)
self.traj_num = cfg['traj_num']
# 初始化网络
print("Loading network...")
self.policy = YopoNetwork() # 实例化YopoNetwork
self.policy = self.policy.to(self.device) # 移动到目标设备
# 加载检查点(若存在)
try:
state_dict = torch.load(checkpoint_path, weights_only=True)
self.policy.load_state_dict(state_dict)
print("Checkpoint loaded successfully")
except FileNotFoundError:
print("Training from scratch")
# 初始化损失函数
self.yopo_loss = YOPOLoss() # 综合损失函数
# 初始化优化器(AdamW,带权重衰减的Adam)
self.optimizer = torch.optim.AdamW(self.policy.parameters(), lr=learning_rate, fused=True)
# 初始化数据集加载器
print("Loading Dataset...")
self.train_dataloader = DataLoader( # 训练集加载器
YOPODataset(mode='train'),
batch_size=self.batch_size,
shuffle=True,
num_workers=1,
pin_memory=True
)
self.val_dataloader = DataLoader( # 验证集加载器
YOPODataset(mode='valid'),
batch_size=self.batch_size,
shuffle=False,
num_workers=1,
pin_memory=True
)
print("Dataset Loaded!")
核心功能:初始化训练所需的所有组件,建立 “数据集→模型→损失函数→优化器” 的完整训练链路,并配置日志记录和模型保存机制。
关键组件解析:
设备配置:自动选择 GPU(优先)或 CPU,提高训练效率。
日志系统:使用
rich.progress
显示训练进度条,SummaryWriter
记录损失曲线(TensorBoard 可视化)。模型与检查点:加载
YopoNetwork
模型,支持从检查点恢复训练(避免从头训练)。数据加载:通过
DataLoader
批量加载YOPODataset
的训练集和验证集,pin_memory=True
加速 GPU 数据传输。优化器:使用
AdamW
(带权重衰减的 Adam),减少过拟合风险,fused=True
启用融合优化(提高 GPU 效率)。
2. 训练主循环(train
方法)
python
def train(self, epoch, save_interval=None):
with self.progress_log: # 启动进度条上下文
total_progress = self.progress_log.add_task("Training", total=epoch) # 总进度条
for self.epoch_i in range(epoch): # 遍历训练轮次
self.policy.train() # 切换到训练模式(启用dropout、批归一化更新)
self.train_one_epoch(self.epoch_i, total_progress) # 训练一个epoch
self.eval_one_epoch(self.epoch_i) # 评估一个epoch
# 按间隔保存模型
if save_interval is not None and (self.epoch_i + 1) % save_interval == 0:
self.progress_log.console.log("Saving model...")
policy_path = self.tensorboard_path + f"/epoch{self.epoch_i + 1}.pth"
torch.save(self.policy.state_dict(), policy_path)
self.progress_log.console.log("Train YOPO Finish!")
self.progress_log.remove_task(total_progress) # 移除总进度条
核心功能:控制训练的整体流程,循环执行 “训练一个 epoch→评估一个 epoch→按需保存模型”,直到完成指定轮次。
关键逻辑:
模式切换:通过
self.policy.train()
启用训练模式(确保 BatchNorm 等层正常更新)。进度跟踪:使用
rich.progress
的任务管理功能,实时显示总训练进度和每个 epoch 的进度。模型保存:按
save_interval
指定的间隔保存模型权重,便于后续分析或继续训练。
3. 单轮训练(train_one_epoch
方法)
python
def train_one_epoch(self, epoch: int, total_progress):
one_epoch_progress = self.progress_log.add_task(f"Epoch: {epoch}", total=len(self.train_dataloader)) # 单epoch进度条
inspect_interval = max(1, len(self.train_dataloader) // 16) # 日志打印间隔
# 初始化损失记录列表和计时
traj_losses, score_losses, smooth_losses, safety_losses, goal_losses, start_time = [], [], [], [], [], time.time()
for step, (depth, pos, rot, obs_b, map_id) in enumerate(self.train_dataloader):
if depth.shape[0] != self.batch_size: continue # 跳过不完整的批次
self.optimizer.zero_grad() # 清零梯度
# 前向传播并计算损失
trajectory_loss, score_loss, smooth_cost, safety_cost, goal_cost = self.forward_and_compute_loss(depth, pos, rot, obs_b, map_id)
# 总损失 = 轨迹损失权重×轨迹损失 + 分数损失权重×分数损失
loss = self.loss_weight[0] * trajectory_loss + self.loss_weight[1] * score_loss
# 反向传播与参数更新
loss.backward() # 计算梯度
self.optimizer.step() # 更新参数
# 记录损失
traj_losses.append(self.loss_weight[0] * trajectory_loss.item())
score_losses.append(self.loss_weight[1] * score_loss.item())
smooth_losses.append(self.loss_weight[0] * smooth_cost.item())
safety_losses.append(self.loss_weight[0] * safety_cost.item())
goal_losses.append(self.loss_weight[0] * goal_cost.item())
# 按间隔打印日志并写入TensorBoard
if step % inspect_interval == inspect_interval - 1:
batch_fps = inspect_interval / (time.time() - start_time) # 计算帧率
self.progress_log.console.log(
f"Epoch: {epoch}, Traj Loss: {np.mean(traj_losses):.3g}, "
f"Score Loss: {np.mean(score_losses):.3g} "
f"Batch FPS: {batch_fps:.3g}"
)
# 写入TensorBoard
global_step = epoch * len(self.train_dataloader) + step
self.tensorboard_log.add_scalar("Train/TrajLoss", np.mean(traj_losses), global_step)
self.tensorboard_log.add_scalar("Train/ScoreLoss", np.mean(score_losses), global_step)
self.tensorboard_log.add_scalar("Detail/SmoothLoss", np.mean(smooth_losses), global_step)
self.tensorboard_log.add_scalar("Detail/SafetyLoss", np.mean(safety_losses), global_step)
self.tensorboard_log.add_scalar("Detail/GoalLoss", np.mean(goal_losses), global_step)
# 重置损失记录和计时
traj_losses, score_losses, smooth_losses, safety_losses, goal_losses, start_time = [], [], [], [], [], time.time()
# 更新进度条
self.progress_log.update(one_epoch_progress, advance=1)
self.progress_log.update(total_progress, advance=1 / len(self.train_dataloader))
self.progress_log.remove_task(one_epoch_progress) # 移除单epoch进度条
核心功能:实现单个 epoch 的训练逻辑,包括数据加载、前向传播、损失计算、反向传播、参数更新和日志记录。
关键步骤:
梯度管理:
self.optimizer.zero_grad()
清零上一轮的梯度,避免梯度累积错误。损失计算:调用
forward_and_compute_loss
获取轨迹损失和分数损失,组合为总损失。参数更新:
loss.backward()
计算梯度,self.optimizer.step()
更新模型参数。日志记录:按
inspect_interval
间隔计算平均损失并写入 TensorBoard,同时打印帧率(监控训练效率)。
4. 单轮评估(eval_one_epoch
方法)
python
@torch.inference_mode() # 禁用梯度计算,加速评估并节省内存
def eval_one_epoch(self, epoch: int):
one_epoch_progress = self.progress_log.add_task(f"Eval: {epoch}", total=len(self.val_dataloader)) # 评估进度条
traj_losses, score_losses = [], [] # 记录评估损失
for step, (depth, pos, rot, obs_b, map_id) in enumerate(self.val_dataloader):
if depth.shape[0] != self.batch_size: continue # 跳过不完整批次
# 计算损失(不反向传播)
trajectory_loss, score_loss, _, _, _ = self.forward_and_compute_loss(depth, pos, rot, obs_b, map_id)
# 记录损失
traj_losses.append(self.loss_weight[0] * trajectory_loss.item())
score_losses.append(self.loss_weight[1] * score_loss.item())
self.progress_log.update(one_epoch_progress, advance=1)
# 打印评估结果并写入TensorBoard
self.progress_log.console.log(
f"Eval: {epoch}, Traj Loss: {np.mean(traj_losses):.3g}, Score Loss: {np.mean(score_losses):.3g} "
)
self.tensorboard_log.add_scalar("Eval/TrajLoss", np.mean(traj_losses), epoch)
self.tensorboard_log.add_scalar("Eval/ScoreLoss", np.mean(score_losses), epoch)
self.progress_log.remove_task(one_epoch_progress)
核心功能:在验证集上评估模型性能,计算损失但不更新参数,用于监控模型泛化能力(避免过拟合)。
关键优化:
@torch.inference_mode()
:禁用梯度计算,大幅提高评估速度并减少内存占用。损失记录:仅记录轨迹损失和分数损失,不计算中间损失(简化评估逻辑)。
泛化监控:通过对比训练损失和验证损失,判断模型是否过拟合(如验证损失持续上升而训练损失下降)。
5. 前向传播与损失计算(forward_and_compute_loss
方法)
python
def forward_and_compute_loss(self, depth, pos, rot, obs_b, map_id):
# 将数据移动到目标设备
depth, pos, rot, obs_b, map_id = [x.to(self.device) for x in [depth, pos, rot, obs_b, map_id]]
# 1. 预处理:将机体坐标系状态转换到世界坐标系
goal_w, start_vel_w, start_acc_w = state_body2world(pos, rot, obs_b[:, 6:9], obs_b[:, 0:3], obs_b[:, 3:6])
start_state_w = torch.stack([pos, start_vel_w, start_acc_w], dim=1) # [B, 3, 3]:位置、速度、加速度
# 2. 前向传播:模型预测终点状态和分数
endstate, score = self.policy.inference(depth, obs_b) # endstate: [B, 9, V, H]; score: [B, V, H]
# 3. 后处理:展平数据以批量计算损失
endstate_flat = endstate.permute(0, 2, 3, 1).reshape(self.batch_size * self.traj_num, 9) # [B*V*H, 9]
score_flat = score.reshape(self.batch_size * self.traj_num) # [B*V*H]
# 扩展初始状态和目标到每个候选轨迹
pos_expanded = pos.repeat_interleave(self.traj_num, dim=0) # [B*V*H, 3]
rot_expanded = rot.repeat_interleave(self.traj_num, dim=0) # [B*V*H, 3, 3]
start_state_w = start_state_w.repeat_interleave(self.traj_num, dim=0) # [B*V*H, 3, 3]
goal_w = goal_w.repeat_interleave(self.traj_num, dim=0) # [B*V*H, 3]
# 4. 将预测的终点状态转换到世界坐标系
end_pos_w, end_vel_w, end_acc_w = state_body2world(
pos_expanded, rot_expanded,
endstate_flat[:, 0:3], # 终点位置(机体坐标系)
endstate_flat[:, 3:6], # 终点速度(机体坐标系)
endstate_flat[:, 6:9] # 终点加速度(机体坐标系)
)
end_state_w = torch.stack([end_pos_w, end_vel_w, end_acc_w], dim=1) # [B*V*H, 3, 3]
# 5. 计算损失
smooth_cost, safety_cost, goal_cost = self.yopo_loss(start_state_w, end_state_w, goal_w, map_id) # 每个轨迹的损失
trajectory_loss = (smooth_cost + safety_cost + goal_cost).mean() # 轨迹损失(平均)
# 分数损失:预测分数与实际损失的平滑L1损失(引导分数反映轨迹优劣)
score_label = (smooth_cost + safety_cost + goal_cost).clone().detach() # 标签:实际损失(不参与梯度)
score_loss = F.smooth_l1_loss(score_flat, score_label) # 分数损失
return trajectory_loss, score_loss, smooth_cost.mean(), safety_cost.mean(), goal_cost.mean()
核心功能:实现 “数据预处理→模型预测→后处理→损失计算” 的完整链路,是连接模型输出与损失函数的关键。
关键步骤解析:
坐标转换:通过
state_body2world
将机体坐标系的状态(速度、加速度、目标)转换到世界坐标系,确保损失计算在统一坐标系下进行。模型推理:调用
self.policy.inference
,完成 “归一化→坐标转换→前向传播→反转换” 的完整推理流程,得到世界坐标系下的终点状态。维度调整:通过
permute
和reshape
将多维输出展平为批次形式,便于批量计算所有候选轨迹的损失(提高效率)。损失计算:
轨迹损失:调用
YOPOLoss
计算平滑度、安全性和目标损失,求和后取平均。分数损失:将轨迹的实际损失作为标签,通过平滑 L1 损失训练模型的分数预测(使分数能反映轨迹优劣,便于后续选择最优轨迹)。
6. 模型保存与日志路径管理
save_model
方法:训练结束或程序退出时保存模型权重,避免意外中断导致训练成果丢失。get_next_log_path
方法:自动创建新的日志目录(如YOPO_0
、YOPO_1
),避免覆盖已有日志,便于对比不同实验结果。
总结
YopoTrainer
是YopoNetwork
模型的完整训练框架,实现了从数据加载到模型评估的全流程自动化。其核心价值在于:
流程完整性:整合数据预处理、模型推理、损失计算、参数优化和日志记录,形成闭环训练系统。
工程实用性:通过进度条、TensorBoard 日志和模型 checkpoint,便于监控训练过程和调试模型。
泛化保障:通过训练集更新参数、验证集评估泛化能力,平衡模型的拟合能力和泛化能力。
该类是将无人机轨迹规划模型从理论设计转化为可用系统的关键,确保模型能够从数据中学习到 “感知(深度图像)→状态(速度 / 加速度)→轨迹(终点状态)” 的映射关系,最终实现自主导航功能。