当前位置:网站首页>Pointnet++ code explanation (VII): pointnetsetabstractionmsg layer
Pointnet++ code explanation (VII): pointnetsetabstractionmsg layer
2022-07-19 05:44:00 【weixin_ forty-two million seven hundred and seven thousand and 】
The method of capturing multi-scale patterns is to apply different scale packet layers , Then extract the features of each scale according to the points . Connect the features of different scales , Form multi-scale features . Use Multi-Scale Grouping(MSG) Methodical SA layer :
Most of the forms are similar to ordinary SA Layers are similar , But here radius_list The input is a list for example [0.1,0.2,0.4], For different radii ball query, Finally, the point cloud features under different radii are saved in new_points_list in , And finally put it together . The specific code is as follows :
class PointNetSetAbstractionMsg(nn.Module):
def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
super(PointNetSetAbstractionMsg, self).__init__()
'''
PointNet Set Abstraction (SA) module with Multi-Scale Grouping (MSG)
Input:
xyz: (batch_size, ndataset, 3) TF tensor
points: (batch_size, ndataset, channel) TF tensor
npoint: int32 -- #points sampled in farthest point sampling
radius_list: list of float32 -- search radius in local region
nsample_list: list of int32 -- how many points in each local region
mlp_list: list of list of int32 -- output size for MLP on each point
Return:
new_xyz: (batch_size, npoint, 3) TF tensor
new_points: (batch_size, npoint, sum_k{mlp[k][-1]}) TF tensor
'''
self.npoint = npoint
self.radius_list = radius_list
self.nsample_list = nsample_list
self.conv_blocks = nn.ModuleList()
self.bn_blocks = nn.ModuleList()
for i in range(len(mlp_list)):
convs = nn.ModuleList()
bns = nn.ModuleList()
last_channel = in_channel + 3
for out_channel in mlp_list[i]:
convs.append(nn.Conv2d(last_channel, out_channel, 1))
bns.append(nn.BatchNorm2d(out_channel))
last_channel = out_channel
self.conv_blocks.append(convs)
self.bn_blocks.append(bns)
def forward(self, xyz, points):
"""
Input:
xyz: input points position data, [B, C, N]
points: input points data, [B, D, N]
Return:
new_xyz: sampled points position data, [B, C, S]
new_points_concat: sample points feature data, [B, D', S]
"""
xyz = xyz.permute(0, 2, 1)
if points is not None:
points = points.permute(0, 2, 1)
B, N, C = xyz.shape
S = self.npoint
new_xyz = index_points(xyz, farthest_point_sample(xyz, S))
new_points_list = []
for i, radius in enumerate(self.radius_list):
K = self.nsample_list[i]
group_idx = query_ball_point(radius, K, xyz, new_xyz)
grouped_xyz = index_points(xyz, group_idx)
grouped_xyz -= new_xyz.view(B, S, 1, C)
if points is not None:
grouped_points = index_points(points, group_idx)
grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
else:
grouped_points = grouped_xyz
grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S]
for j in range(len(self.conv_blocks[i])):
conv = self.conv_blocks[i][j]
bn = self.bn_blocks[i][j]
grouped_points = F.relu(bn(conv(grouped_points)))
new_points = torch.max(grouped_points, 2)[0] # [B, D', S]
new_points_list.append(new_points)
new_xyz = new_xyz.permute(0, 2, 1)
new_points_concat = torch.cat(new_points_list, dim=1)
return new_xyz, new_points_concat
边栏推荐
- DEEP JOINT TRANSMISSION-RECOGNITION FOR POWER-CONSTRAINED IOT DEVICES
- 使用OpenCV、ONNXRuntime部署YOLOV7目标检测——记录贴
- Composants communs des applets Wechat
- Unable to determine Electron version. Please specify an Electron version
- INRIAPerson数据集转化为yolo训练格式并可视化
- Solve idea new module prompt module XXXX does exits
- 软件过程与管理复习(七)
- Android realizes truly safe exit from App
- 微信小程序代码的构成
- How can the thread pool be monitored to help developers quickly locate online errors?
猜你喜欢
CV-Model【3】:VGG16
VS 中 error C4996: ‘scanf‘: This function or variable may be unsafe. 的解决方法。
5.1 business data acquisition channel construction of data acquisition channel construction
CV-Model【1】:Mnist
微信小程序的自定义组件
Hanoi Tower problem -- > recursive implementation
CV-Model【2】:Alexnet
4. Neusoft cross border e-commerce data warehouse project - user behavior data acquisition channel construction of data acquisition channel construction (2022.6.1-2022.6.4)
USB转TTL CH340模块安装(WIN10)
1 sparksql overview
随机推荐
电商用户行为实时分析系统(Flink1.10.1)
微信小程序代码的构成
Configure tabbar and request network data requests
配置tabBar和request网络数据请求
Using Flink SQL to fluidize market data 2: intraday var
Regular replace group (n) content
Development progress of Neusoft cross border e-commerce warehouse
Review my first job search trip
replace限制文本框只能输入数字,数字和字母等的正则表达式
static 关键字对作用域和生命周期的影响
图片的大小限制显示
OpenCV读取中文路径下的图片,并对其格式转化不改变颜色
kotlin作用域函数
Use ide to make jar package
MySQL learning notes (4) - (basic crud) operate the data of tables in the database
Kotlin scope function
PyTorch学习笔记【4】:从图像学习
利用IDE打jar包
尝试解决YOLOv5推理rtsp有延迟的一些方法
7. Data warehouse environment preparation for data warehouse construction