当前位置:网站首页>Pointnet++ code explanation (V): Sample_ and_ Group function and samle_ and_ group_ All function
Pointnet++ code explanation (V): Sample_ and_ Group function and samle_ and_ group_ All function
2022-07-19 05:44:00 【weixin_ forty-two million seven hundred and seven thousand and 】
Sampling + Grouping It is mainly used to disperse the whole point cloud into local group, For each group Both can be used. PointNet Extract local and global features separately .Sampling + Grouping You need to use the functions defined in the previous analysis , Divided into sample_and_group and sample_and_group_all Two functions , The difference is sample_and_group_all Directly treat all points as one group.
sample_and_group The implementation steps of :
- First use farthest_point_sample Function to realize farthest point sampling FPS Get the index of the sampling point , Re pass index_points Pick out the of these points from the original points , As new_xyz
- utilize query_ball_point and index_points Pass the original point cloud new_xyz As the center, it is divided into npoint There are spherical areas, each of which has nsample A sampling point
- Subtract the center value of the area from the point of each area
- If each point has a new feature dimension , The new features are spliced with the old features , Otherwise, return the old feature directly
sample_and_group_all Directly treat all points as one group, That is, add a length of 1 It's just a dimension of , Of course, there is also a process of splicing new features .
def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
"""
Input:
npoint: number of samples,FPS The number of sampling points
radius: local region radius, Radius defined by spherical area
nsample: max sample number in local region, The maximum number of points that can be surrounded by a spherical area
xyz: input points position data, [B, N, 3]
points: input points data, [B, N, D],D Coordinate data not included x,y,z
Return:
new_xyz: sampled points position data, [B, npoint, 3]
new_points: sampled points data, [B, npoint, nsample, 3+D]
"""
B, N, C = xyz.shape
S = npoint
# The farthest sampling point selected from the origin cloud is new_xyz
fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint]
torch.cuda.empty_cache()
# adopt index_points take FPS The sampling point is selected from the original point
#new_xyz Represents the center point , At this time, the dimension is [B, S, 3]
new_xyz = index_points(xyz, fps_idx)
torch.cuda.empty_cache()
# idx:[B, npoint, nsample] representative npoint Of each of the spherical regions nsample Index of sampling points
idx = query_ball_point(radius, nsample, xyz, new_xyz)
torch.cuda.empty_cache()
# grouped_xyz:[B, npoint, nsample, C],
# adopt index_points Will all group Internal nsample A sampling point is selected from the original point
grouped_xyz = index_points(xyz, idx)
torch.cuda.empty_cache()
# grouped_xyz Minus the center point : Subtract the center value of the area from the point of each area
grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
torch.cuda.empty_cache()
# If each point has a new feature dimension , The new features are spliced with the old features , Otherwise, return the old feature directly
if points is not None:
# adopt index_points Will all group Internal nsample A sampling point is selected from the original point , obtain group Data of other dimensions except coordinate dimension of inner point
grouped_points = index_points(points, idx)
#dim=-1 Represents splicing according to the last dimension , That is equivalent to dim=3
new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
else:
new_points = grouped_xyz_norm
if returnfps:
return new_xyz, new_points, grouped_xyz, fps_idx
else:
return new_xyz, new_points
def samle_and_group_all(xyz, points):
'''
Input:
xyz: input points position data, [B, N, 3]
points: input points data, [B, N, D]
Return:
new_xyz: sampled points position data, [B, 1, 3]
new_points: sampled points data, [B, 1, N, 3+D]
Directly treat all points as one group, That is, add a length of 1 It's just a dimension of
'''
device = xyz.device
B, N, C =xyz.shape
#new_xyz Represents the center point , Use the origin to represent
new_xyz = torch.zeros(B, 1, C).to(device)
# grouped_xyz Minus the center point : Subtract the center value of the area from the point of each area , Since the center point is the origin , So the result is still grouped_xyz
grouped_xyz = xyz.view(B, 1, N, C)
# If each point has a new feature dimension , The new features are spliced with the old features , Otherwise, return the old feature directly
if points is not None:
#view(B, 1, N, -1),-1 Stands for automatic calculation , That is, the result is equal to view(B, 1, N, D)
new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
else:
new_points = grouped_xyz
return new_xyz, new_points
torch.cuda.empty_cache():
Pytorch It can automatically recycle us left-off memory , Be similar to python Reference mechanism of , When the data in a certain memory no longer has any variable reference , This part of memory will be released . But there's one thing to note , When some of our memory is no longer in use , This part of the released memory passes through Nvidia-smi Orders are invisible , for instance :
device = torch.device('cuda:0')
# Define two tensor
dummy_tensor_4 = torch.randn(120, 3, 512, 512).float().to(device) # 120*3*512*512*4/1000/1000 = 377.48M
dummy_tensor_5 = torch.randn(80, 3, 512, 512).float().to(device) # 80*3*512*512*4/1000/1000 = 251.64M
# Then release
dummy_tensor_4 = dummy_tensor_4.cpu()
dummy_tensor_2 = dummy_tensor_2.cpu()
# Although the above memory is released here , But we go through Nvidia-smi Command to see that the memory is still occupied
torch.cuda.empty_cache()
# Only to finish the above sentence , It's only when you're visible Nvidia-smi Middle release
Pytorch This is also explained by the developers of , This part of the released memory can be used , It's just not in Nvidia-smi It's just .
边栏推荐
- static 关键字对作用域和生命周期的影响
- 2. Technology selection of Neusoft cross border e-commerce data warehouse project
- gradle
- Simple application of COAP in andorid
- 1.東軟跨境電商數倉需求規格說明文檔
- Composants communs des applets Wechat
- MySQL transactions
- 5. Spark core programming (1)
- 5.1 business data acquisition channel construction of data acquisition channel construction
- PyTorch学习笔记【2】:学习的机制
猜你喜欢
10 question 10 answer: do you really know thread pools?
网吧管理系统数据库设计
PyTorch学习笔记【5】:使用卷积进行泛化
软件过程与管理复习(十)
The future of data Lakehouse - Open
Hanoi Tower problem -- > recursive implementation
5.1 business data acquisition channel construction of data acquisition channel construction
利用IDE打jar包
In VS, error c4996: 'scanf': this function or variable may be unsafe Solutions.
7. Data warehouse environment preparation for data warehouse construction
随机推荐
跨域和处理跨域
In VS, error c4996: 'scanf': this function or variable may be unsafe Solutions.
5.1 business data acquisition channel construction of data acquisition channel construction
JNA loading DLL and its application in jar
安卓实现真正安全的退出app
Gradle custom plug-in
1.東軟跨境電商數倉需求規格說明文檔
throttle/debounce应用及原理
CV-Model【2】:Alexnet
C language implementation of iteration and binary search
QuizCardGame
PyTorch学习笔记【3】:使用神经网络拟合数据
Custom components of wechat applet
Ambari 2.7.5 integrated installation hue 4.6
USB转TTL CH340模块安装(WIN10)
使用OpenCV、ONNXRuntime部署YOLOV7目标检测——记录贴
Pointnet++代码详解(四):index_points函数
模型时间复杂度和空间复杂度
static 关键字对作用域和生命周期的影响
回顾我的第一份工作求职之旅