Python-Point Cloud 系列(二)——PointNet论文复现

阳光穿透心脏的1/2处 2023-06-30 08:59 118阅读 0赞

文章目录

  • Point Net 介绍
    • 网络结构
      • 旋转变换矩阵
  • 源码解析(详细)
    • 源码下载
    • 源码目录结构
    • PointNet.py
    • ModelNetDataLoader.py
  • 复现程序
    • 原文数据集下载
  • 其他
  • 未完待完善和补充…
  • 迁移自己的数据库
    • 数据集载入

前言:
参考博客1 参考内容:
参考博客2 参考内容:

Point Net 介绍

原文参考
python-tensorflow
python-pytroch

网络结构

旋转变换矩阵

在这里插入图片描述
文章中采用了一个3 × \times × 3 矩阵对单个点云数据进行旋转变换。矩阵的参数是在网络训练的过程中自动调整的,也就是说,在训练过程中,网络会自动旋转待分类的点云对象,以达到一个合适的效果。其实也可以看成是一个特征转换层,只是从物理层面上是一个几何旋转变化。
其实就是将原始数据流分为两股,一股去训练一个网络,这个网络的输出是3 × \times × 3的矩阵,另外一股则
直接与这个矩阵相乘,起到旋转的结果。
论文主体中没有过多的解释,而是放到了补充材料中 补充材料

另外,作者从低维(3dim)得到启发,对转换后的高维特征也做了旋转标定,如64维的特征乘一个64 × \times × 64 的矩阵即可达到高维空间旋转的效果。 为了使变换矩阵靠近正交矩阵,作者对矩阵参数添加了正则项。

这部分还有部分不理解
1、为什么在数据输入的时候已经过旋转标定,在转换后的高维特征空间中还需要旋转标定(实验试错而得?)
2、看到很多文章说是转到正面,但是我还是不理解这个正面是从哪里看出来的(可能跟正交矩阵有关系?这部分可能还需要线性相关的知识。
3、为什么后面的64维的变换矩阵需要添加正则项以达到正交矩阵?
参考博客 参考其他博客,了解到旋转矩阵是正交矩阵的一种。
4、为什么对3维旋转变换矩阵不添加正则项以保证其正交性?和高维一样?

欧式空间V中的正交变换只包含:
(1)旋转
(2)反射
(3)旋转+反射的组合(即瑕旋转)

源码解析(详细)

源码下载

下载链接
把下载后的zip文件解压缩。

源码目录结构

解压后的文件目录如下图。其中PointNet主体的代码在“models”这个文件夹中。
【】在这里插入图片描述
“models”文件夹里面包括了原始pointnet的代码以及针对分类、分割、检测的代码,其中pointnet.py里面包含了最基础的模型代码。
如下图
在这里插入图片描述

PointNet.py

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.parallel
  4. import torch.utils.data
  5. from torch.autograd import Variable
  6. import numpy as np
  7. import torch.nn.functional as F
  8. """ STN: Spatial Transformer Networks 空间转换网络 """
  9. # 这边实现的是三维空间转换网络。
  10. class STN3d(nn.Module):
  11. def __init__(self, channel):
  12. super(STN3d, self).__init__()
  13. self.conv1 = torch.nn.Conv1d(channel, 64, 1)
  14. self.conv2 = torch.nn.Conv1d(64, 128, 1)
  15. self.conv3 = torch.nn.Conv1d(128, 1024, 1)
  16. self.fc1 = nn.Linear(1024, 512)
  17. self.fc2 = nn.Linear(512, 256)
  18. self.fc3 = nn.Linear(256, 9)
  19. self.relu = nn.ReLU()
  20. self.bn1 = nn.BatchNorm1d(64)
  21. self.bn2 = nn.BatchNorm1d(128)
  22. self.bn3 = nn.BatchNorm1d(1024)
  23. self.bn4 = nn.BatchNorm1d(512)
  24. self.bn5 = nn.BatchNorm1d(256)
  25. def forward(self, x):
  26. batchsize = x.size()[0] # 第一个维度是batch的数量
  27. x = F.relu(self.bn1(self.conv1(x)))
  28. x = F.relu(self.bn2(self.conv2(x)))
  29. x = F.relu(self.bn3(self.conv3(x)))
  30. x = torch.max(x, 2, keepdim=True)[0]
  31. x = x.view(-1, 1024) # 转换为列为1024但行不定的数据
  32. x = F.relu(self.bn4(self.fc1(x)))
  33. x = F.relu(self.bn5(self.fc2(x)))
  34. x = self.fc3(x)
  35. iden = Variable(torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32))).view(1, 9).repeat(
  36. batchsize, 1) # 生成3x3的单位矩阵,但是是一行的形式方便计算
  37. if x.is_cuda:
  38. iden = iden.cuda()
  39. x = x + iden # 这边加起来是什么意思?为什么要加单位矩阵,这边是对应论文说初始化为对角单位阵
  40. x = x.view(-1, 3, 3) # 转换为3x3的矩阵
  41. return x
  42. class STNkd(nn.Module):
  43. def __init__(self, k=64):
  44. super(STNkd, self).__init__()
  45. self.conv1 = torch.nn.Conv1d(k, 64, 1)
  46. self.conv2 = torch.nn.Conv1d(64, 128, 1)
  47. self.conv3 = torch.nn.Conv1d(128, 1024, 1)
  48. self.fc1 = nn.Linear(1024, 512)
  49. self.fc2 = nn.Linear(512, 256)
  50. self.fc3 = nn.Linear(256, k * k)
  51. self.relu = nn.ReLU()
  52. self.bn1 = nn.BatchNorm1d(64)
  53. self.bn2 = nn.BatchNorm1d(128)
  54. self.bn3 = nn.BatchNorm1d(1024)
  55. self.bn4 = nn.BatchNorm1d(512)
  56. self.bn5 = nn.BatchNorm1d(256)
  57. self.k = k
  58. def forward(self, x):
  59. batchsize = x.size()[0]
  60. x = F.relu(self.bn1(self.conv1(x)))
  61. x = F.relu(self.bn2(self.conv2(x)))
  62. x = F.relu(self.bn3(self.conv3(x)))
  63. x = torch.max(x, 2, keepdim=True)[0]
  64. x = x.view(-1, 1024)
  65. x = F.relu(self.bn4(self.fc1(x)))
  66. x = F.relu(self.bn5(self.fc2(x)))
  67. x = self.fc3(x)
  68. iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1, self.k * self.k).repeat(
  69. batchsize, 1)
  70. if x.is_cuda:
  71. iden = iden.cuda()
  72. x = x + iden
  73. x = x.view(-1, self.k, self.k)
  74. return x
  75. """高维映射网络,即将单个点云点映射到多维空间的网络,以避免后续的最大池化过度地损失信息"""
  76. class PointNetEncoder(nn.Module):
  77. def __init__(self, global_feat=True, feature_transform=False, channel=3):
  78. super(PointNetEncoder, self).__init__()
  79. self.stn = STN3d(channel) # 3维空间转换矩阵
  80. self.conv1 = torch.nn.Conv1d(channel, 64, 1)
  81. self.conv2 = torch.nn.Conv1d(64, 128, 1)
  82. self.conv3 = torch.nn.Conv1d(128, 1024, 1)
  83. self.bn1 = nn.BatchNorm1d(64)
  84. self.bn2 = nn.BatchNorm1d(128)
  85. self.bn3 = nn.BatchNorm1d(1024)
  86. self.global_feat = global_feat # 全局特侦标志
  87. self.feature_transform = feature_transform # 是否对高维特征进行旋转变换标定
  88. if self.feature_transform:
  89. self.fstn = STNkd(k=64) # 高维空间变换矩阵
  90. def forward(self, x):
  91. # B:样本的一个批次大小,batch;D:点的维度 3 (x,y,z) dim ; N:点的数量 (1024) number
  92. # 即这边一次输入24个样本,一个样本含有1024个点云点, 一个点云点为3维(x,y,z)
  93. B, D, N = x.size() # [24, 3, 1024]
  94. trans = self.stn(x) # 得到3维旋转转换矩阵
  95. x = x.transpose(2, 1) # 将2轴和1轴对调, 相当于[24,1024,3]
  96. if D > 3: # 这边是是特征点的话,不只有3维(x,y,z),可能为多维
  97. x, feature = x.split(3, dim=2) # 从维度2上按照3块分开。就是将高维特征按照3份分开
  98. x = torch.bmm(x, trans) # 将3维点云数据进行旋转变换
  99. if D > 3:
  100. x = torch.cat([x, feature], dim=2)
  101. x = x.transpose(2, 1) # 将2轴和1轴再对调,??
  102. x = F.relu(self.bn1(self.conv1(x))) # 进行第一次卷积、标准化、激活、得到64维的数据
  103. """————————————————(2020/1/18)——————————————————"""
  104. # 下面是第二层卷积层处理
  105. if self.feature_transform: # 如果需要对中间的特征进行旋转标定的话
  106. trans_feat = self.fstn(x) # 得到特征空间的旋转矩阵
  107. x = x.transpose(2, 1) # 将1轴和2轴对调
  108. x = torch.bmm(x, trans_feat) # 将特征数据进行旋转转换
  109. x = x.transpose(2, 1) # 将2轴再次和1轴对调
  110. else:
  111. trans_feat = None
  112. pointfeat = x # 旋转矫正过后的特征
  113. x = F.relu(self.bn2(self.conv2(x))) # 第二次卷积 输出维128
  114. x = self.bn3(self.conv3(x)) # 第三次卷积 输出维1024
  115. x = torch.max(x, 2, keepdim=True)[0] # 进行最大池化处理,只返回最大的数,不返回索引([0]是数值,[1]是索引)
  116. x = x.view(-1, 1024) # 把x reshape为 1024列的行数不定矩阵,这边的-1指的就是行数不定。
  117. if self.global_feat: # 是否为全局特征
  118. return x, trans, trans_feat # 返回特征数据x,3维旋转矩阵,多维旋转矩阵
  119. else:
  120. x = x.view(-1, 1024, 1).repeat(1, 1, N) # 多扩展了一个维度是为了和局部特征统一维度,方便后面的连接,然后复制成与局部特征一样的数量
  121. return torch.cat([x, pointfeat], 1), trans, trans_feat # 这边对应点云分割算法中,将全局特征与局部特征连接。
  122. """这边是高维特征空间转换举证的正则项,大致的意思的把这个转换矩阵乘上其转置阵再减去单位阵,取剩下差值的均值为损失函数"""
  123. def feature_transform_reguliarzer(trans):
  124. d = trans.size()[1] # 矩阵维度
  125. I = torch.eye(d)[None, :, :] # 生成同维度的对角单位阵
  126. if trans.is_cuda: # 是否采用Cuda加速
  127. I = I.cuda()
  128. # 损失函数,将变换矩阵乘自身转置然后减单位阵,取结果的元素均值为损失函数,因为正交阵乘其转置为单位阵。 这边不需要取绝对值或者L2吗?
  129. # A*(A'-I) = A*A'- A*I = I - A*I | A’: 矩阵A的转置
  130. loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2, 1) - I), dim=(1, 2)))
  131. return loss

ModelNetDataLoader.py

  1. import numpy as np
  2. import warnings
  3. import os
  4. from torch.utils.data import Dataset
  5. ASTYPE = np.array([cls]).astype(np.int32)
  6. INDEX_ = self.classes[self.datapath[index][0]]
  7. warnings.filterwarnings('ignore')
  8. def pc_normalize(pc): # 简单的标准化
  9. centroid = np.mean(pc, axis=0) # 形心,设微元体积为单位体积。
  10. pc = pc - centroid # 去偏移,将坐标系原点转换到形心位置,坐标系只平移不旋转
  11. m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) # 将同一行的元素取平方相加,再开方,取最大。 sqrt(x^2+y^2+z^2)
  12. pc = pc / m # 归一化,归一化操作似乎会丢失物品的尺寸大小信息? 因为每个样本的m不一样。
  13. return pc
  14. def farthest_point_sample(point, npoint): # 最远点的提取
  15. """ Input: xyz: pointcloud data, [N, D] npoint: number of samples Return: centroids: sampled pointcloud index, [npoint, D] """
  16. N, D = point.shape
  17. xyz = point[:, :3]
  18. centroids = np.zeros((npoint,)) # 重心
  19. distance = np.ones((N,)) * 1e10
  20. farthest = np.random.randint(0, N)
  21. for i in range(npoint):
  22. centroids[i] = farthest
  23. centroid = xyz[farthest, :]
  24. dist = np.sum((xyz - centroid) ** 2, -1)
  25. mask = dist < distance
  26. distance[mask] = dist[mask]
  27. farthest = np.argmax(distance, -1)
  28. point = point[centroids.astype(np.int32)]
  29. return point
  30. # 制作这个类的重点在于生成一个列表,这个列表的元素为(path_sample x,lable x)的形式,重要的是生成路径与标签的列表
  31. # 也不一定需要制作路径的列表,可能制作路径的列表会比较不占内存,每次只把需要的数据加载进来而已。
  32. # 可以直接制作数据与标签的列表,可以按索引进行连接。
  33. class ModelNetDataLoader(Dataset): # 自己的数据集类子类,需要集成父类 Dataset
  34. # 要求override __len__和__getitem__,前面提供数据集大小、后者支持整数索引(?)
  35. def __init__(self, root, npoint=1024, split='train', uniform=False, normal_channel=True, cache_size=15000):
  36. self.root = root # 根目录
  37. self.npoints = npoint # 每个实例的点云数
  38. self.uniform = uniform # 统一
  39. self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt') # cat 数据库类别的路径 (40类)
  40. self.cat = [line.rstrip() for line in open(self.catfile)] # 打开txt文件,读取每一行,并用rstrip()删除每行末尾的空格
  41. self.classes = dict(zip(self.cat, range(len(self.cat)))) # 生成类别字典{类别1:0,类别2:1,...类别40:39}
  42. self.normal_channel = normal_channel # 是否为标准的通道?标准通道指的是什么
  43. shape_ids = { }
  44. shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))] # 加载训练集为一个列表,列表放在一个字典里面
  45. shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))] # 加载测试集....
  46. assert (split == 'train' or split == 'test')
  47. shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]] # a = 'vase_0513', a.split('_') = ['vase', '0513'] 为什么要用'_'/join 呢?
  48. # list of (shape_name, shape_txt_file_path) tuple
  49. self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i
  50. in range(len(shape_ids[split]))] # 生成数据路径列表 [(数据标签,数据路径1),(数据标签,数据路径2)] 重点在这一步,生成的是路径与标签的列表
  51. print('The size of %s data is %d' % (split, len(self.datapath)))
  52. self.cache_size = cache_size # how many data points to cache in memory
  53. self.cache = { } # from index to (point_set, cls) tuple
  54. def __len__(self): # 必须要有重载
  55. return len(self.datapath) # 路径个数代表样本个数
  56. def _get_item(self, index): # 这边的任务主要是写好读取一个样本例子的示范代码,包括数据的初步预处理,如对齐什么的,返回一个(样本,label),
  57. # 这边通过开辟了一个缓存区,使用的数据的时候先判断在不在缓存区里面,如果在则直接使用,不再的话再初始化载入,按理说也可以直接把所有的数据一次性加载进来,然后按照索引读取。
  58. if index in self.cache:
  59. point_set, cls = self.cache[index]
  60. else:
  61. fn = self.datapath[index]
  62. cls = INDEX_ # 不大理解这边的意思
  63. cls = ASTYPE
  64. point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32) # 样本矩阵为(10000,6)格式表示每个样本具有10000个点云点,每一列的意义为x,y,z,r,g,b
  65. if self.uniform:
  66. point_set = farthest_point_sample(point_set, self.npoints)
  67. else:
  68. point_set = point_set[0:self.npoints, :]
  69. point_set[:, 0:3] = pc_normalize(point_set[:, 0:3]) # 将x,y,z 标准化
  70. if not self.normal_channel: # 应该是说不需要彩色信息的时候,只取前面的3列数据
  71. point_set = point_set[:, 0:3]
  72. if len(self.cache) < self.cache_size:
  73. self.cache[index] = (point_set, cls)
  74. return point_set, cls
  75. def __getitem__(self, index): # 必须要有重载,否则会报错
  76. return self._get_item(index)
  77. if __name__ == '__main__':
  78. import torch
  79. DATA_PATH = 'E:\\0onedrive2\\OneDrive - hit.edu.cn\\1图片数据库\\modelnet40_normal_resampled\\'
  80. data = ModelNetDataLoader(DATA_PATH, split='train', uniform=False, normal_channel=True, )
  81. DataLoader = torch.utils.data.DataLoader(data, batch_size=12, shuffle=True)
  82. for point, label in DataLoader:
  83. print(point.shape)
  84. print(label.shape)

复现程序

原文数据集下载

其他

越学习源码和论文,越发现pointnet的基本单元是通过将(x,y,z)编码为1024高维冗余数据,再解码为256数据,最后再根据目标需求,连接相应的全连接层,例如T-net想要得到一个3x3的矩阵,他就连接节点数为9的全连接层,要分类的话就连接对应分类数的全连接层.


未完待完善和补充…

迁移自己的数据库

我的数据库下载

数据集载入

发表评论

表情:
评论列表 (有 0 条评论,118人围观)

还没有评论,来说两句吧...

相关阅读

    相关 PointNet 论文阅读笔记

    摘要 点云是几何数据结构的一种重要类型。 由于格式不规则,大多数研究人员将此类数据转换为规则的3D voxel网格或图像集合。 但是,这使数据变得不必要地庞大,并导致了问