save_ply.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import numpy as np
  2. import open3d as o3d
  3. import os
  4. import time
  5. import json
  6. import argparse
  7. def save_point_cloud_by_class(coords_file, preds_file, classes_to_show, save_pcd_path):
  8. """
  9. 加载点云坐标和预测类别,筛选指定类别,并保存为.ply文件。
  10. """
  11. # 1. 定义类别名称和颜色映射
  12. CLASS_NAMES = [
  13. 'refrigerator', 'desk', 'curtain', 'sofa', 'bookshelf', 'bed',
  14. 'table', 'window', 'cabinet', 'door', 'chair', 'floor', 'wall',
  15. 'sink', 'toilet', 'bathtub', 'shower curtain', 'picture', 'counter'
  16. ]
  17. COLOR_MAP = np.array([
  18. [174, 199, 232], [255, 127, 14], [152, 223, 138], [214, 39, 40],
  19. [148, 103, 189], [255, 187, 120], [140, 86, 75], [152, 223, 138],
  20. [152, 223, 138], [152, 223, 138], [196, 156, 148], [127, 127, 127],
  21. [152, 223, 138], [188, 189, 34], [219, 219, 141], [227, 119, 194],
  22. [31, 119, 180], [255, 152, 150], [82, 84, 163]
  23. ])
  24. # 2. 加载数据
  25. try:
  26. print(f"正在加载坐标文件: {coords_file}")
  27. coords = np.load(coords_file)
  28. print(f"正在加载预测文件: {preds_file}")
  29. predictions = np.load(preds_file)
  30. except FileNotFoundError as e:
  31. print(f"错误: 找不到文件 {e.filename}。请确保文件路径正确。")
  32. return
  33. print(f"原始点云数量: {len(coords)}")
  34. if len(coords) != len(predictions):
  35. print("警告: 坐标点数和预测标签数不匹配!")
  36. return
  37. # 3. 根据 'classes_to_show' 筛选点云
  38. if isinstance(classes_to_show, str):
  39. target_classes = [classes_to_show]
  40. elif isinstance(classes_to_show, list):
  41. target_classes = classes_to_show
  42. else:
  43. print("错误: 'classes_to_show' 参数必须是字符串或列表。")
  44. return
  45. target_indices = [CLASS_NAMES.index(cn) for cn in target_classes if cn in CLASS_NAMES]
  46. if not target_indices:
  47. print(f"错误: 类别 '{classes_to_show}' 无效或未在 CLASS_NAMES 中找到。")
  48. return
  49. print(f"正在筛选类别: {[CLASS_NAMES[i] for i in target_indices]}")
  50. mask = np.isin(predictions, target_indices)
  51. coords = coords[mask]
  52. predictions = predictions[mask]
  53. if coords.shape[0] == 0:
  54. print(f"警告: 在场景中没有找到属于类别 {target_classes} 的点。")
  55. return
  56. print(f"筛选后剩余点云数量: {len(coords)}")
  57. # 4. 创建Open3D点云对象
  58. pcd = o3d.geometry.PointCloud()
  59. pcd.points = o3d.utility.Vector3dVector(coords)
  60. pcd.colors = o3d.utility.Vector3dVector(COLOR_MAP[predictions] / 255.0)
  61. # 5. 保存筛选后的点云
  62. if not save_pcd_path:
  63. print("警告: 未提供 'save_pcd_path',不保存文件。")
  64. return
  65. print(f"正在保存筛选后的点云到: {save_pcd_path}")
  66. try:
  67. o3d.io.write_point_cloud(save_pcd_path, pcd, write_ascii=True)
  68. print("保存成功。")
  69. except Exception as e:
  70. print(f"错误: 保存点云文件失败: {e}")
  71. if __name__ == "__main__":
  72. parser = argparse.ArgumentParser(
  73. description="输入单个场景文件夹用于结果点云保存",
  74. formatter_class=argparse.RawTextHelpFormatter
  75. )
  76. parser.add_argument(
  77. '-i',
  78. '--input_folder',
  79. type=str,
  80. required=True,
  81. help='指定输入场景的文件夹路径。'
  82. )
  83. args = parser.parse_args()
  84. scene_folder = args.input_folder
  85. scenece = os.path.basename(scene_folder)
  86. coords_file = os.path.join(scene_folder, 'scene/val/process_data/coord.npy')
  87. preds_file = os.path.join(scene_folder, 'output/pred.npy')
  88. output_dir = os.path.join(scene_folder, 'wall_ply')
  89. os.makedirs(output_dir, exist_ok=True)
  90. print(f"\n--- 正在处理场景: {scenece} ---")
  91. save_point_cloud_by_class(
  92. coords_file=coords_file,
  93. preds_file=preds_file,
  94. classes_to_show=['wall', 'window', 'door', 'cabinet'],
  95. save_pcd_path=f'{output_dir}/wall_instances.ply'
  96. )