save_result.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. import numpy as np
  2. import open3d as o3d
  3. import os
  4. import time
  5. import json
  6. import cv2
  7. from itertools import groupby
  8. import argparse
  9. # ==============================================================================
  10. # 优化算法辅助函数
  11. # ==============================================================================
  12. def get_class_specific_dbscan_params(class_name):
  13. """为不同类别返回定制化的DBSCAN超参数。"""
  14. default_eps = 0.25
  15. default_min_points = 150
  16. params = {
  17. 'bed': {'eps': 0.23, 'min_points': 100},
  18. 'sofa': {'eps': 0.3, 'min_points': 300},
  19. 'table': {'eps': 0.3, 'min_points': 300},
  20. 'desk': {'eps': 0.3, 'min_points': 300},
  21. 'bookshelf': {'eps': 0.3, 'min_points': 300},
  22. 'chair': {'eps': 0.2, 'min_points': 100},
  23. 'refrigerator': {'eps': 0.25, 'min_points': 200},
  24. 'cabinet': {'eps': 0.3, 'min_points': 200},
  25. 'door': {'eps': 0.2, 'min_points': 100}
  26. }
  27. config = params.get(class_name, {'eps': default_eps, 'min_points': default_min_points})
  28. return config['eps'], config['min_points']
  29. # ==============================================================================
  30. # 2D投影和绘图函数
  31. # ==============================================================================
  32. def is_box_dimension_plausible_2d(box_extent_2d, class_name):
  33. """检查2D包围盒的尺寸是否在合理范围内(单位:米)。"""
  34. plausible_ranges_2d = {
  35. 'bed': ([1.2, 0.7], [3.0, 2.8]), # Significantly relaxed range for beds
  36. 'sofa': ([1.0, 0.7], [4.0, 1.8]),
  37. 'table': ([0.5, 0.5], [3.0, 1.5]),
  38. 'desk': ([0.8, 0.5], [2.5, 1.2]),
  39. 'bookshelf': ([0.5, 0.2], [2.5, 0.8]),
  40. 'chair': ([0.3, 0.3], [1.2, 1.2]),
  41. 'refrigerator': ([0.5, 0.5], [1.2, 1.2]),
  42. 'cabinet': ([0.4, 0.3], [3.0, 1.0]),
  43. 'door': ([0.6, 0.05], [1.2, 0.3]),
  44. 'window': ([0.4, 0.05], [3.0, 0.4])
  45. }
  46. if class_name not in plausible_ranges_2d:
  47. return True
  48. min_dims, max_dims = plausible_ranges_2d[class_name]
  49. sorted_extent = sorted(box_extent_2d)
  50. sorted_min = sorted(min_dims)
  51. sorted_max = sorted(max_dims)
  52. for i in range(2):
  53. if not (sorted_min[i] <= sorted_extent[i] <= sorted_max[i]):
  54. return False
  55. return True
  56. def calculate_2d_iou(box1, box2):
  57. """计算两个2D包围盒的IoU。盒子格式为[min_x, min_y, max_x, max_y]"""
  58. b1 = box1['bbox_2d_pixels']
  59. b2 = box2['bbox_2d_pixels']
  60. xA = max(b1[0], b2[0])
  61. yA = max(b1[1], b2[1])
  62. xB = min(b1[2], b2[2])
  63. yB = min(b1[3], b2[3])
  64. interArea = max(0, xB - xA) * max(0, yB - yA)
  65. if interArea == 0:
  66. return 0.0
  67. box1Area = (b1[2] - b1[0]) * (b1[3] - b1[1])
  68. box2Area = (b2[2] - b2[0]) * (b2[3] - b2[1])
  69. unionArea = float(box1Area + box2Area - interArea)
  70. if unionArea == 0:
  71. return 0.0
  72. return interArea / unionArea
  73. def post_process_in_2d(instances_with_pixel_boxes, x_m_per_px, y_m_per_px, iou_threshold=0.5):
  74. """在2D像素空间中对实例进行尺寸过滤和非极大值抑制(NMS)。"""
  75. # 1. 尺寸过滤
  76. plausible_instances = []
  77. for inst in instances_with_pixel_boxes:
  78. px_box = inst['bbox_2d_pixels'] # [min_x, min_y, max_x, max_y]
  79. px_width = px_box[2] - px_box[0]
  80. px_height = px_box[3] - px_box[1]
  81. metric_width = px_width * x_m_per_px
  82. metric_height = px_height * y_m_per_px
  83. extent_2d = [metric_width, metric_height]
  84. if is_box_dimension_plausible_2d(extent_2d, inst['label']):
  85. plausible_instances.append(inst)
  86. else:
  87. print(f" - 过滤掉一个2D尺寸异常的 '{inst['label']}' 实例,尺寸: {[f'{x:.2f}' for x in extent_2d]}")
  88. if not plausible_instances:
  89. return []
  90. # 2. 按类别分组进行后处理
  91. final_instances = []
  92. plausible_instances.sort(key=lambda x: x['label'])
  93. for class_name, group in groupby(plausible_instances, key=lambda x: x['label']):
  94. class_instances = list(group)
  95. # --- SPECIAL MERGING LOGIC FOR BEDS ---
  96. if class_name == 'bed':
  97. if not class_instances:
  98. continue
  99. # Build adjacency matrix for overlapping beds
  100. num_instances = len(class_instances)
  101. adj_matrix = np.zeros((num_instances, num_instances))
  102. for i in range(num_instances):
  103. for j in range(i, num_instances):
  104. # Use a low threshold to merge any overlap
  105. if calculate_2d_iou(class_instances[i], class_instances[j]) > 0.05:
  106. adj_matrix[i, j] = 1
  107. adj_matrix[j, i] = 1
  108. # Find connected components (groups of overlapping boxes)
  109. visited = [False] * num_instances
  110. groups = []
  111. for i in range(num_instances):
  112. if not visited[i]:
  113. component = []
  114. q = [i]
  115. visited[i] = True
  116. while q:
  117. u = q.pop(0)
  118. component.append(u)
  119. for v in range(num_instances):
  120. if adj_matrix[u, v] == 1 and not visited[v]:
  121. visited[v] = True
  122. q.append(v)
  123. groups.append(component)
  124. # Merge each group into a single instance
  125. merged_instances = []
  126. for group_indices in groups:
  127. instances_in_group = [class_instances[i] for i in group_indices]
  128. # Create the merged bounding box
  129. min_x = min(inst['bbox_2d_pixels'][0] for inst in instances_in_group)
  130. min_y = min(inst['bbox_2d_pixels'][1] for inst in instances_in_group)
  131. max_x = max(inst['bbox_2d_pixels'][2] for inst in instances_in_group)
  132. max_y = max(inst['bbox_2d_pixels'][3] for inst in instances_in_group)
  133. # Aggregate score and find a representative instance for metadata
  134. total_score = sum(inst['score'] for inst in instances_in_group)
  135. representative_instance = max(instances_in_group, key=lambda x: x['score'])
  136. new_instance = representative_instance.copy()
  137. new_instance['bbox_2d_pixels'] = [min_x, min_y, max_x, max_y]
  138. new_instance['score'] = total_score
  139. merged_instances.append(new_instance)
  140. final_instances.extend(merged_instances)
  141. print(f" - 类别 'bed': Merged {len(class_instances)} candidates into {len(merged_instances)} final instances.")
  142. # --- STANDARD NMS FOR OTHER CLASSES ---
  143. else:
  144. class_instances.sort(key=lambda x: x['score'], reverse=True)
  145. kept_instances = []
  146. while class_instances:
  147. best_inst = class_instances.pop(0)
  148. kept_instances.append(best_inst)
  149. remaining_instances = []
  150. for other_inst in class_instances:
  151. iou = calculate_2d_iou(best_inst, other_inst)
  152. if iou < iou_threshold:
  153. remaining_instances.append(other_inst)
  154. else:
  155. print(f" - 2D NMS: 抑制一个与更佳实例IoU为 {iou:.2f} 的 '{class_name}' 实例。")
  156. class_instances = remaining_instances
  157. final_instances.extend(kept_instances)
  158. print(f" - 类别 '{class_name}': 经过2D过滤和NMS后,剩余 {len(kept_instances)} 个有效实例。")
  159. return final_instances
  160. def build_floor_transform_matrix(j_info: dict, floor_id: int):
  161. tab = [[0.0] * 3 for _ in range(3)]
  162. res_width = None
  163. res_height = None
  164. for in_json in j_info.get("floors", []):
  165. if in_json.get("id") != floor_id:
  166. continue
  167. res_width = in_json.get("resolution", {}).get("width")
  168. res_height = in_json.get("resolution", {}).get("height")
  169. bound = in_json.get("bound", {})
  170. x_min, x_max = bound.get("x_min"), bound.get("x_max")
  171. y_min, y_max = bound.get("y_min"), bound.get("y_max")
  172. tab[0][0] = x_max - x_min
  173. tab[0][2] = x_min
  174. tab[1][1] = y_min - y_max
  175. tab[1][2] = y_max
  176. tab[2][2] = 1.0
  177. break
  178. if res_width is None: return np.identity(3).tolist(), None, None
  179. tab_array = np.array(tab, dtype=np.float64)
  180. if np.linalg.det(tab_array) == 0: raise ValueError("矩阵是奇异的,无法求逆。")
  181. return np.linalg.inv(tab_array).tolist(), res_width, res_height
  182. def process_and_draw_bboxes(picture_name, floor_path, instances_path, floor_id, output_image_path, output_json_path):
  183. try:
  184. img = cv2.imread(picture_name)
  185. if img is None: raise FileNotFoundError(f"无法加载背景图片: {picture_name}")
  186. with open(floor_path, 'r', encoding='utf-8') as f: j_info = json.load(f)
  187. with open(instances_path, 'r', encoding='utf-8') as f: raw_bbox_data = json.load(f)
  188. matrix, res_w, res_h = build_floor_transform_matrix(j_info, floor_id)
  189. if res_w is None: raise ValueError(f"未在 {floor_path} 中找到 ID 为 {floor_id} 的楼层信息。")
  190. M = np.array(matrix, dtype=np.float64)
  191. floor_info = next((f for f in j_info.get("floors", []) if f.get("id") == floor_id), None)
  192. bound = floor_info.get("bound", {})
  193. x_m_per_px = (bound.get("x_max") - bound.get("x_min")) / res_w
  194. y_m_per_px = abs(bound.get("y_max") - bound.get("y_min")) / res_h
  195. instances_with_pixel_boxes = []
  196. for item in raw_bbox_data:
  197. corners = item.get("corners", [])
  198. if len(corners) < 4: continue
  199. points_2d = []
  200. for i in range(4):
  201. norm_pt = M @ np.array([corners[i][0], corners[i][1], 1.0])
  202. points_2d.append([int(norm_pt[0] * res_w), int(norm_pt[1] * res_h)])
  203. x_coords, y_coords = [p[0] for p in points_2d], [p[1] for p in points_2d]
  204. new_item = item.copy()
  205. new_item['bbox_2d_pixels'] = [min(x_coords), min(y_coords), max(x_coords), max(y_coords)]
  206. instances_with_pixel_boxes.append(new_item)
  207. print("\n开始在2D空间进行后处理...")
  208. filtered_bbox_data = post_process_in_2d(instances_with_pixel_boxes, x_m_per_px, y_m_per_px)
  209. print("2D后处理完成。")
  210. instances_2d_data = []
  211. for item in filtered_bbox_data:
  212. min_x, min_y, max_x, max_y = item['bbox_2d_pixels']
  213. label = item["label"]
  214. color_bgr = (item["color"][2], item["color"][1], item["color"][0])
  215. instances_2d_data.append({"label": label, "color": item["color"], "bbox_2d": item['bbox_2d_pixels']})
  216. cv2.rectangle(img, (min_x, min_y), (max_x, max_y), color_bgr, 2)
  217. font = cv2.FONT_HERSHEY_SIMPLEX
  218. (text_w, text_h), _ = cv2.getTextSize(label, font, 0.5, 1)
  219. label_y = min_y - 10 if min_y - 10 > text_h else min_y + text_h + 10
  220. cv2.rectangle(img, (min_x, label_y - text_h - 5), (min_x + text_w, label_y + 5), color_bgr, -1)
  221. cv2.putText(img, label, (min_x, label_y), font, 0.5, (255, 255, 255), 1, cv2.LINE_AA)
  222. os.makedirs(os.path.dirname(output_image_path), exist_ok=True)
  223. os.makedirs(os.path.dirname(output_json_path), exist_ok=True)
  224. cv2.imwrite(output_image_path, img)
  225. with open(output_json_path, 'w', encoding='utf-8') as f:
  226. json.dump(instances_2d_data, f, indent=4)
  227. print(f"\n处理完成!2D结果已保存到: {output_image_path} 和 {output_json_path}")
  228. return output_json_path, output_image_path
  229. except Exception as e:
  230. print(f"发生错误: {e}")
  231. return None, None
  232. # ==============================================================================
  233. # 主函数
  234. # ==============================================================================
  235. def visualize_point_cloud_segmentation(coords_file, preds_file, classes_to_show='all',
  236. classes_to_ignore=None,
  237. save_pcd_path=None,
  238. save_3d_json_path=None,
  239. if_save_ply=False,
  240. if_save_vision=False):
  241. CLASS_NAMES = [
  242. 'refrigerator', 'desk', 'curtain', 'sofa', 'bookshelf', 'bed',
  243. 'table', 'window', 'cabinet', 'door', 'chair', 'floor', 'wall',
  244. 'sink', 'toilet', 'bathtub', 'shower curtain', 'picture', 'counter'
  245. ]
  246. COLOR_MAP = np.array([
  247. [174, 199, 232], [255, 127, 14], [44, 160, 44], [214, 39, 40],
  248. [148, 103, 189], [255, 187, 120], [140, 86, 75], [152, 223, 138],
  249. [23, 190, 207], [247, 182, 210], [196, 156, 148], [127, 127, 127],
  250. [199, 199, 199], [188, 189, 34], [219, 219, 141], [227, 119, 194],
  251. [31, 119, 180], [255, 152, 150], [82, 84, 163]
  252. ])
  253. try:
  254. coords = np.load(coords_file)
  255. predictions = np.load(preds_file)
  256. except FileNotFoundError as e:
  257. print(f"错误: 找不到文件 {e.filename}。")
  258. return None
  259. if len(coords) != len(predictions):
  260. print("警告: 坐标点数和预测标签数不匹配!")
  261. return None
  262. default_ignore_classes = {'floor', 'wall', 'picture'}
  263. ignore_set = default_ignore_classes.union(set(classes_to_ignore or []))
  264. show_set = set(classes_to_show) if isinstance(classes_to_show, (list, set)) else None
  265. final_instances_data, all_instance_points, all_instance_colors = [], [], []
  266. print(f"\n通过DBSCAN寻找原始实例...")
  267. for pred_idx in np.unique(predictions):
  268. class_name = CLASS_NAMES[pred_idx]
  269. if class_name in ignore_set or (show_set and class_name not in show_set):
  270. continue
  271. dbscan_eps, dbscan_min_points = get_class_specific_dbscan_params(class_name)
  272. class_points_indices = np.where(predictions == pred_idx)[0]
  273. if len(class_points_indices) < dbscan_min_points: continue
  274. class_pcd_temp = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(coords[class_points_indices]))
  275. instance_labels = np.array(class_pcd_temp.cluster_dbscan(eps=dbscan_eps, min_points=dbscan_min_points, print_progress=False))
  276. unique_instances = np.unique(instance_labels[instance_labels != -1])
  277. if len(unique_instances) > 0: print(f"- 类别 '{class_name}': 找到 {len(unique_instances)} 个原始候选实例")
  278. for instance_id in unique_instances:
  279. instance_point_indices = np.where(instance_labels == instance_id)[0]
  280. if len(instance_point_indices) < dbscan_min_points / 2: continue
  281. instance_pcd = class_pcd_temp.select_by_index(instance_point_indices)
  282. try:
  283. aabb = instance_pcd.get_axis_aligned_bounding_box()
  284. points_np = np.asarray(instance_pcd.points)
  285. final_instances_data.append({
  286. "label": class_name, "color": COLOR_MAP[pred_idx].tolist(),
  287. "corners": np.asarray(aabb.get_box_points()).tolist(), "score": len(points_np)
  288. })
  289. all_instance_points.append(points_np)
  290. all_instance_colors.append(np.tile(COLOR_MAP[pred_idx] / 255.0, (len(points_np), 1)))
  291. except RuntimeError: continue
  292. print("\n所有原始实例处理完毕。")
  293. if save_3d_json_path:
  294. os.makedirs(os.path.dirname(save_3d_json_path), exist_ok=True)
  295. with open(save_3d_json_path, 'w', encoding='utf-8') as f:
  296. json.dump(final_instances_data, f, ensure_ascii=False, indent=4)
  297. print(f"原始3D实例JSON信息已保存至: {save_3d_json_path}")
  298. if if_save_ply and save_pcd_path and all_instance_points:
  299. instance_pcd = o3d.geometry.PointCloud()
  300. instance_pcd.points = o3d.utility.Vector3dVector(np.vstack(all_instance_points))
  301. instance_pcd.colors = o3d.utility.Vector3dVector(np.vstack(all_instance_colors))
  302. o3d.io.write_point_cloud(save_pcd_path, instance_pcd)
  303. print(f"所有检测到的实例点云已保存至: {save_pcd_path}")
  304. if if_save_vision:
  305. pcd = o3d.geometry.PointCloud(points=o3d.utility.Vector3dVector(coords),
  306. colors=o3d.utility.Vector3dVector(COLOR_MAP[predictions] / 255.0))
  307. o3d.visualization.draw_geometries([pcd], window_name="原始点云", width=1280, height=720)
  308. return save_3d_json_path
  309. if __name__ == "__main__":
  310. parser = argparse.ArgumentParser(
  311. description="输入单个场景文件夹用于结果保存",
  312. formatter_class=argparse.RawTextHelpFormatter
  313. )
  314. parser.add_argument(
  315. '-i',
  316. '--input_folder',
  317. type=str,
  318. required=True,
  319. help='指定输入场景的文件夹路径。'
  320. )
  321. args = parser.parse_args()
  322. scene_folder = args.input_folder
  323. scenece = os.path.basename(scene_folder)
  324. coords_file = os.path.join(scene_folder, 'scene/val/process_data/coord.npy')
  325. preds_file = os.path.join(scene_folder, "output/pred.npy")
  326. floor_plan_image = os.path.join(scene_folder, f"{scenece}.png")
  327. scene_info_json = os.path.join(scene_folder, f"{scenece}.json")
  328. output_dir = os.path.join(scene_folder, 'result_2d_filtered')
  329. os.makedirs(output_dir, exist_ok=True)
  330. raw_instances3d_json_path = os.path.join(output_dir, 'instances3d_raw.json')
  331. final_instances2d_json_path = os.path.join(output_dir, 'instances2d_final.json')
  332. instances_ply_path = os.path.join(output_dir, 'instances_raw.ply')
  333. segment_onfloor_png_path = os.path.join(output_dir, 'segment_onfloor_final.png')
  334. saved_3d_json = visualize_point_cloud_segmentation(
  335. coords_file=coords_file, preds_file=preds_file,
  336. classes_to_ignore=['curtain', 'bookshelf', 'floor', 'wall', 'sink', 'toilet', 'bathtub', 'shower curtain', 'picture'],
  337. save_3d_json_path=raw_instances3d_json_path, save_pcd_path=instances_ply_path, if_save_ply=False
  338. )
  339. if saved_3d_json and all(os.path.exists(f) for f in [floor_plan_image, scene_info_json]):
  340. print("\n--- 开始进行2D投影和后处理 ---")
  341. process_and_draw_bboxes(
  342. picture_name=floor_plan_image, floor_path=scene_info_json,
  343. instances_path=saved_3d_json, floor_id=0,
  344. output_image_path=segment_onfloor_png_path, output_json_path=final_instances2d_json_path
  345. )
  346. else:
  347. print("\nSkipping 2D projection due to missing files.")