save_result.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506
  1. import numpy as np
  2. import open3d as o3d
  3. import os
  4. import time
  5. import json
  6. import cv2
  7. from itertools import groupby
  8. import argparse
  9. CLASS_MAPPING = {
  10. 'refrigerator': {'id': '0', 'name': '冰箱'},
  11. 'desk': {'id': '1', 'name': '书桌'},
  12. 'curtain': {'id': '2', 'name': '窗帘'},
  13. 'sofa': {'id': '3', 'name': '沙发'},
  14. 'bookshelf': {'id': '4', 'name': '书架'},
  15. 'bed': {'id': '5', 'name': '床'},
  16. 'table': {'id': '6', 'name': '桌子'},
  17. 'window': {'id': '7', 'name': '窗户'},
  18. 'cabinet': {'id': '8', 'name': '橱柜'},
  19. 'door': {'id': '9', 'name': '门'},
  20. 'chair': {'id': '10', 'name': '椅子'},
  21. 'floor': {'id': '11', 'name': '地板'},
  22. 'wall': {'id': '12', 'name': '墙'},
  23. 'sink': {'id': '13', 'name': '水槽'},
  24. 'toilet': {'id': '14', 'name': '马桶'},
  25. 'bathtub': {'id': '15', 'name': '浴缸'},
  26. 'shower curtain': {'id': '16', 'name': '浴帘'},
  27. 'picture': {'id': '17', 'name': '画'},
  28. 'counter': {'id': '18', 'name': '柜台'},
  29. }
  30. # ==============================================================================
  31. # 优化算法辅助函数
  32. # ==============================================================================
  33. def get_class_specific_dbscan_params(class_name):
  34. """为不同类别返回定制化的DBSCAN超参数。"""
  35. default_eps = 0.25
  36. default_min_points = 150
  37. params = {
  38. 'bed': {'eps': 0.23, 'min_points': 100},
  39. 'sofa': {'eps': 0.3, 'min_points': 300},
  40. 'table': {'eps': 0.3, 'min_points': 300},
  41. 'desk': {'eps': 0.3, 'min_points': 300},
  42. 'bookshelf': {'eps': 0.3, 'min_points': 300},
  43. 'chair': {'eps': 0.2, 'min_points': 100},
  44. 'refrigerator': {'eps': 0.25, 'min_points': 200},
  45. 'cabinet': {'eps': 0.3, 'min_points': 200},
  46. 'door': {'eps': 0.2, 'min_points': 100}
  47. }
  48. config = params.get(class_name, {'eps': default_eps, 'min_points': default_min_points})
  49. return config['eps'], config['min_points']
  50. # ==============================================================================
  51. # 2D投影和绘图函数
  52. # ==============================================================================
  53. def is_box_dimension_plausible_2d(box_extent_2d, class_name):
  54. """检查2D包围盒的尺寸是否在合理范围内(单位:米)。"""
  55. plausible_ranges_2d = {
  56. 'bed': ([1.2, 0.7], [3.0, 2.8]), # Significantly relaxed range for beds
  57. 'sofa': ([1.0, 0.7], [4.0, 1.8]),
  58. 'table': ([0.5, 0.5], [3.0, 1.5]),
  59. 'desk': ([0.8, 0.5], [2.5, 1.2]),
  60. 'bookshelf': ([0.5, 0.2], [2.5, 0.8]),
  61. 'chair': ([0.3, 0.3], [1.2, 1.2]),
  62. 'refrigerator': ([0.5, 0.5], [1.2, 1.2]),
  63. 'cabinet': ([0.4, 0.3], [3.0, 1.0]),
  64. 'door': ([0.6, 0.05], [1.2, 0.3]),
  65. 'window': ([0.4, 0.05], [3.0, 0.4])
  66. }
  67. if class_name not in plausible_ranges_2d:
  68. return True
  69. min_dims, max_dims = plausible_ranges_2d[class_name]
  70. sorted_extent = sorted(box_extent_2d)
  71. sorted_min = sorted(min_dims)
  72. sorted_max = sorted(max_dims)
  73. for i in range(2):
  74. if not (sorted_min[i] <= sorted_extent[i] <= sorted_max[i]):
  75. return False
  76. return True
  77. def calculate_2d_iou(box1, box2):
  78. """计算两个2D包围盒的IoU。盒子格式为[min_x, min_y, max_x, max_y]"""
  79. b1 = box1['bbox_2d_pixels']
  80. b2 = box2['bbox_2d_pixels']
  81. xA = max(b1[0], b2[0])
  82. yA = max(b1[1], b2[1])
  83. xB = min(b1[2], b2[2])
  84. yB = min(b1[3], b2[3])
  85. interArea = max(0, xB - xA) * max(0, yB - yA)
  86. if interArea == 0:
  87. return 0.0
  88. box1Area = (b1[2] - b1[0]) * (b1[3] - b1[1])
  89. box2Area = (b2[2] - b2[0]) * (b2[3] - b2[1])
  90. unionArea = float(box1Area + box2Area - interArea)
  91. if unionArea == 0:
  92. return 0.0
  93. return interArea / unionArea
  94. def post_process_in_2d(instances_with_pixel_boxes, x_m_per_px, y_m_per_px, iou_threshold=0.5):
  95. """在2D像素空间中对实例进行尺寸过滤和非极大值抑制(NMS)。"""
  96. # 1. 尺寸过滤
  97. plausible_instances = []
  98. for inst in instances_with_pixel_boxes:
  99. px_box = inst['bbox_2d_pixels'] # [min_x, min_y, max_x, max_y]
  100. px_width = px_box[2] - px_box[0]
  101. px_height = px_box[3] - px_box[1]
  102. metric_width = px_width * x_m_per_px
  103. metric_height = px_height * y_m_per_px
  104. extent_2d = [metric_width, metric_height]
  105. if is_box_dimension_plausible_2d(extent_2d, inst['category']):
  106. plausible_instances.append(inst)
  107. else:
  108. print(f" - 过滤掉一个2D尺寸异常的 '{inst['category']}' 实例,尺寸: {[f'{x:.2f}' for x in extent_2d]}")
  109. if not plausible_instances:
  110. return []
  111. # 2. 按类别分组进行后处理
  112. final_instances = []
  113. plausible_instances.sort(key=lambda x: x['category'])
  114. for class_name, group in groupby(plausible_instances, key=lambda x: x['category']):
  115. class_instances = list(group)
  116. # --- SPECIAL MERGING LOGIC FOR BEDS ---
  117. if class_name == 'bed':
  118. if not class_instances:
  119. continue
  120. # Build adjacency matrix for overlapping beds
  121. num_instances = len(class_instances)
  122. adj_matrix = np.zeros((num_instances, num_instances))
  123. for i in range(num_instances):
  124. for j in range(i, num_instances):
  125. # Use a low threshold to merge any overlap
  126. if calculate_2d_iou(class_instances[i], class_instances[j]) > 0.05:
  127. adj_matrix[i, j] = 1
  128. adj_matrix[j, i] = 1
  129. # Find connected components (groups of overlapping boxes)
  130. visited = [False] * num_instances
  131. groups = []
  132. for i in range(num_instances):
  133. if not visited[i]:
  134. component = []
  135. q = [i]
  136. visited[i] = True
  137. while q:
  138. u = q.pop(0)
  139. component.append(u)
  140. for v in range(num_instances):
  141. if adj_matrix[u, v] == 1 and not visited[v]:
  142. visited[v] = True
  143. q.append(v)
  144. groups.append(component)
  145. # Merge each group into a single instance
  146. merged_instances = []
  147. for group_indices in groups:
  148. instances_in_group = [class_instances[i] for i in group_indices]
  149. # Create the merged 2D bounding box
  150. min_x_2d = min(inst['bbox_2d_pixels'][0] for inst in instances_in_group)
  151. min_y_2d = min(inst['bbox_2d_pixels'][1] for inst in instances_in_group)
  152. max_x_2d = max(inst['bbox_2d_pixels'][2] for inst in instances_in_group)
  153. max_y_2d = max(inst['bbox_2d_pixels'][3] for inst in instances_in_group)
  154. # --- Create the merged 3D bounding box ---
  155. all_3d_corners = np.vstack([inst['bbox'] for inst in instances_in_group])
  156. min_3d = np.min(all_3d_corners, axis=0)
  157. max_3d = np.max(all_3d_corners, axis=0)
  158. merged_3d_bbox = [
  159. [min_3d[0], min_3d[1], min_3d[2]],
  160. [max_3d[0], min_3d[1], min_3d[2]],
  161. [min_3d[0], max_3d[1], min_3d[2]],
  162. [max_3d[0], max_3d[1], min_3d[2]],
  163. [min_3d[0], min_3d[1], max_3d[2]],
  164. [max_3d[0], min_3d[1], max_3d[2]],
  165. [min_3d[0], max_3d[1], max_3d[2]],
  166. [max_3d[0], max_3d[1], max_3d[2]],
  167. ]
  168. # Aggregate score and find a representative instance for metadata
  169. total_score = sum(inst['score'] for inst in instances_in_group)
  170. representative_instance = max(instances_in_group, key=lambda x: x['score'])
  171. new_instance = representative_instance.copy()
  172. new_instance['bbox_2d_pixels'] = [min_x_2d, min_y_2d, max_x_2d, max_y_2d]
  173. new_instance['bbox'] = merged_3d_bbox # Assign the new merged 3D bbox
  174. new_instance['score'] = total_score
  175. merged_instances.append(new_instance)
  176. final_instances.extend(merged_instances)
  177. print(f" - 类别 'bed': Merged {len(class_instances)} candidates into {len(merged_instances)} final instances.")
  178. # --- STANDARD NMS FOR OTHER CLASSES ---
  179. else:
  180. class_instances.sort(key=lambda x: x['score'], reverse=True)
  181. kept_instances = []
  182. while class_instances:
  183. best_inst = class_instances.pop(0)
  184. kept_instances.append(best_inst)
  185. remaining_instances = []
  186. for other_inst in class_instances:
  187. iou = calculate_2d_iou(best_inst, other_inst)
  188. if iou < iou_threshold:
  189. remaining_instances.append(other_inst)
  190. else:
  191. print(f" - 2D NMS: 抑制一个与更佳实例IoU为 {iou:.2f} 的 '{class_name}' 实例。")
  192. class_instances = remaining_instances
  193. final_instances.extend(kept_instances)
  194. print(f" - 类别 '{class_name}': 经过2D过滤和NMS后,剩余 {len(kept_instances)} 个有效实例。")
  195. return final_instances
  196. def build_floor_transform_matrix(j_info: dict, floor_id: int):
  197. tab = [[0.0] * 3 for _ in range(3)]
  198. res_width = None
  199. res_height = None
  200. for in_json in j_info.get("floors", []):
  201. if in_json.get("id") != floor_id:
  202. continue
  203. res_width = in_json.get("resolution", {}).get("width")
  204. res_height = in_json.get("resolution", {}).get("height")
  205. bound = in_json.get("bound", {})
  206. x_min, x_max = bound.get("x_min"), bound.get("x_max")
  207. y_min, y_max = bound.get("y_min"), bound.get("y_max")
  208. tab[0][0] = x_max - x_min
  209. tab[0][2] = x_min
  210. tab[1][1] = y_min - y_max
  211. tab[1][2] = y_max
  212. tab[2][2] = 1.0
  213. break
  214. if res_width is None: return np.identity(3).tolist(), None, None
  215. tab_array = np.array(tab, dtype=np.float64)
  216. if np.linalg.det(tab_array) == 0: raise ValueError("矩阵是奇异的,无法求逆。")
  217. return np.linalg.inv(tab_array).tolist(), res_width, res_height
  218. def process_and_draw_bboxes(picture_name, floor_path, raw_bbox_data, floor_id, output_image_path, output_json_path, output_3d_json_path):
  219. try:
  220. img = cv2.imread(picture_name)
  221. if img is None: raise FileNotFoundError(f"无法加载背景图片: {picture_name}")
  222. with open(floor_path, 'r', encoding='utf-8') as f: j_info = json.load(f)
  223. if not raw_bbox_data:
  224. print("警告: 未提供任何原始3D包围盒数据。")
  225. return None, None
  226. matrix, res_w, res_h = build_floor_transform_matrix(j_info, floor_id)
  227. if res_w is None: raise ValueError(f"未在 {floor_path} 中找到 ID 为 {floor_id} 的楼层信息。")
  228. M = np.array(matrix, dtype=np.float64)
  229. floor_info = next((f for f in j_info.get("floors", []) if f.get("id") == floor_id), None)
  230. bound = floor_info.get("bound", {})
  231. x_m_per_px = (bound.get("x_max") - bound.get("x_min")) / res_w
  232. y_m_per_px = abs(bound.get("y_max") - bound.get("y_min")) / res_h
  233. instances_with_pixel_boxes = []
  234. for item in raw_bbox_data:
  235. corners = item.get("bbox", [])
  236. if len(corners) < 8: continue
  237. points_2d = []
  238. for i in range(8):
  239. norm_pt = M @ np.array([corners[i][0], corners[i][1], 1.0])
  240. points_2d.append([int(norm_pt[0] * res_w), int(norm_pt[1] * res_h)])
  241. x_coords, y_coords = [p[0] for p in points_2d], [p[1] for p in points_2d]
  242. new_item = item.copy()
  243. new_item['bbox_2d_pixels'] = [min(x_coords), min(y_coords), max(x_coords), max(y_coords)]
  244. instances_with_pixel_boxes.append(new_item)
  245. print("\n开始在2D空间进行后处理...")
  246. filtered_bbox_data = post_process_in_2d(instances_with_pixel_boxes, x_m_per_px, y_m_per_px)
  247. print("2D后处理完成。")
  248. img_height, img_width, _ = img.shape
  249. shapes_2d = []
  250. for item in filtered_bbox_data:
  251. min_x, min_y, max_x, max_y = item['bbox_2d_pixels']
  252. category = item["category"]
  253. color_rgb = item["color"]
  254. color_bgr = (color_rgb[2], color_rgb[1], color_rgb[0])
  255. cv2.rectangle(img, (min_x, min_y), (max_x, max_y), color_bgr, 2)
  256. font = cv2.FONT_HERSHEY_SIMPLEX
  257. (text_w, text_h), _ = cv2.getTextSize(category, font, 0.5, 1)
  258. label_y = min_y - 10 if min_y - 10 > text_h else min_y + text_h + 10
  259. cv2.rectangle(img, (min_x, label_y - text_h - 5), (min_x + text_w, label_y + 5), color_bgr, -1)
  260. cv2.putText(img, category, (min_x, label_y), font, 0.5, (255, 255, 255), 1, cv2.LINE_AA)
  261. bbox_poly = [min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y]
  262. class_info = CLASS_MAPPING.get(category, {'id': '-1', 'name': '未知'})
  263. shapes_2d.append({
  264. "bbox": bbox_poly,
  265. "category": category,
  266. "color": color_rgb,
  267. "label": class_info['id'],
  268. "name": class_info['name']
  269. })
  270. # --- Save 2D Results ---
  271. output_2d_json_data = {
  272. "shapes": shapes_2d,
  273. "imageHeight": img_height,
  274. "imagePath": os.path.basename(picture_name),
  275. "imageWidth": img_width,
  276. "version": "4Dage_Furniture_Detection_0.0.1"
  277. }
  278. os.makedirs(os.path.dirname(output_image_path), exist_ok=True)
  279. os.makedirs(os.path.dirname(output_json_path), exist_ok=True)
  280. cv2.imwrite(output_image_path, img)
  281. with open(output_json_path, 'w', encoding='utf-8') as f:
  282. json.dump(output_2d_json_data, f, ensure_ascii=False, indent=4)
  283. print(f"\n处理完成!2D结果已保存到: {output_image_path} 和 {output_json_path}")
  284. # --- Save Final 3D Results ---
  285. shapes_3d = []
  286. for item in filtered_bbox_data:
  287. shapes_3d.append({
  288. "bbox": item['bbox'],
  289. "category": item['category'],
  290. "color": item['color'],
  291. "label": item['label'],
  292. "name": item['name'],
  293. })
  294. output_3d_json_data = {
  295. "shapes": shapes_3d,
  296. "version": "4Dage_Furniture_Detection_0.0.1_3D_final"
  297. }
  298. os.makedirs(os.path.dirname(output_3d_json_path), exist_ok=True)
  299. with open(output_3d_json_path, 'w', encoding='utf-8') as f:
  300. json.dump(output_3d_json_data, f, ensure_ascii=False, indent=4)
  301. print(f"对应的3D结果已保存到: {output_3d_json_path}")
  302. return output_json_path, output_image_path
  303. except Exception as e:
  304. print(f"发生错误: {e}")
  305. return None, None
  306. # ==============================================================================
  307. # 主函数
  308. # ==============================================================================
  309. def visualize_point_cloud_segmentation(coords_file, preds_file, classes_to_show='all',
  310. classes_to_ignore=None,
  311. save_pcd_path=None,
  312. if_save_ply=False,
  313. if_save_vision=False):
  314. CLASS_NAMES = [
  315. 'refrigerator', 'desk', 'curtain', 'sofa', 'bookshelf', 'bed',
  316. 'table', 'window', 'cabinet', 'door', 'chair', 'floor', 'wall',
  317. 'sink', 'toilet', 'bathtub', 'shower curtain', 'picture', 'counter'
  318. ]
  319. COLOR_MAP = np.array([
  320. [174, 199, 232], [255, 127, 14], [44, 160, 44], [214, 39, 40],
  321. [148, 103, 189], [255, 187, 120], [140, 86, 75], [152, 223, 138],
  322. [23, 190, 207], [247, 182, 210], [196, 156, 148], [127, 127, 127],
  323. [199, 199, 199], [188, 189, 34], [219, 219, 141], [227, 119, 194],
  324. [31, 119, 180], [255, 152, 150], [82, 84, 163]
  325. ])
  326. try:
  327. coords = np.load(coords_file)
  328. predictions = np.load(preds_file)
  329. except FileNotFoundError as e:
  330. print(f"错误: 找不到文件 {e.filename}。")
  331. return None
  332. if len(coords) != len(predictions):
  333. print("警告: 坐标点数和预测标签数不匹配!")
  334. return None
  335. default_ignore_classes = {'floor', 'wall', 'picture'}
  336. ignore_set = default_ignore_classes.union(set(classes_to_ignore or []))
  337. show_set = set(classes_to_show) if isinstance(classes_to_show, (list, set)) else None
  338. final_instances_data, all_instance_points, all_instance_colors = [], [], []
  339. print(f"\n通过DBSCAN寻找原始实例...")
  340. for pred_idx in np.unique(predictions):
  341. class_name = CLASS_NAMES[pred_idx]
  342. if class_name in ignore_set or (show_set and class_name not in show_set):
  343. continue
  344. dbscan_eps, dbscan_min_points = get_class_specific_dbscan_params(class_name)
  345. class_points_indices = np.where(predictions == pred_idx)[0]
  346. if len(class_points_indices) < dbscan_min_points: continue
  347. class_pcd_temp = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(coords[class_points_indices]))
  348. instance_labels = np.array(class_pcd_temp.cluster_dbscan(eps=dbscan_eps, min_points=dbscan_min_points, print_progress=False))
  349. unique_instances = np.unique(instance_labels[instance_labels != -1])
  350. if len(unique_instances) > 0: print(f"- 类别 '{class_name}': 找到 {len(unique_instances)} 个原始候选实例")
  351. for instance_id in unique_instances:
  352. instance_point_indices = np.where(instance_labels == instance_id)[0]
  353. if len(instance_point_indices) < dbscan_min_points / 2: continue
  354. instance_pcd = class_pcd_temp.select_by_index(instance_point_indices)
  355. try:
  356. aabb = instance_pcd.get_axis_aligned_bounding_box()
  357. points_np = np.asarray(instance_pcd.points)
  358. class_info = CLASS_MAPPING.get(class_name, {'id': '-1', 'name': '未知'})
  359. final_instances_data.append({
  360. "category": class_name,
  361. "label": class_info['id'],
  362. "name": class_info['name'],
  363. "color": COLOR_MAP[pred_idx].tolist(),
  364. "bbox": np.asarray(aabb.get_box_points()).tolist(),
  365. "score": len(points_np)
  366. })
  367. all_instance_points.append(points_np)
  368. all_instance_colors.append(np.tile(COLOR_MAP[pred_idx] / 255.0, (len(points_np), 1)))
  369. except RuntimeError: continue
  370. print("\n所有原始实例处理完毕。")
  371. if if_save_ply and save_pcd_path and all_instance_points:
  372. instance_pcd = o3d.geometry.PointCloud()
  373. instance_pcd.points = o3d.utility.Vector3dVector(np.vstack(all_instance_points))
  374. instance_pcd.colors = o3d.utility.Vector3dVector(np.vstack(all_instance_colors))
  375. o3d.io.write_point_cloud(save_pcd_path, instance_pcd)
  376. print(f"所有检测到的实例点云已保存至: {save_pcd_path}")
  377. if if_save_vision:
  378. pcd = o3d.geometry.PointCloud(points=o3d.utility.Vector3dVector(coords),
  379. colors=o3d.utility.Vector3dVector(COLOR_MAP[predictions] / 255.0))
  380. o3d.visualization.draw_geometries([pcd], window_name="原始点云", width=1280, height=720)
  381. return final_instances_data
  382. if __name__ == "__main__":
  383. parser = argparse.ArgumentParser(
  384. description="输入单个场景文件夹用于结果保存",
  385. formatter_class=argparse.RawTextHelpFormatter
  386. )
  387. parser.add_argument(
  388. '-i',
  389. '--input_folder',
  390. type=str,
  391. required=True,
  392. help='指定输入场景的文件夹路径。'
  393. )
  394. args = parser.parse_args()
  395. scene_folder = args.input_folder
  396. scenece = os.path.basename(scene_folder)
  397. coords_file = os.path.join(scene_folder, 'scene/val/process_data/coord.npy')
  398. preds_file = os.path.join(scene_folder, "output/pred.npy")
  399. floor_plan_image = os.path.join(scene_folder, f"{scenece}.png")
  400. scene_info_json = os.path.join(scene_folder, f"{scenece}.json")
  401. output_dir = os.path.join(scene_folder, 'result_2d_filtered')
  402. os.makedirs(output_dir, exist_ok=True)
  403. final_instances2d_json_path = os.path.join(output_dir, 'instances2d_final.json')
  404. final_instances3d_json_path = os.path.join(output_dir, 'instances3d_final.json')
  405. instances_ply_path = os.path.join(output_dir, 'instances_raw.ply')
  406. segment_onfloor_png_path = os.path.join(output_dir, 'segment_onfloor_final.png')
  407. raw_3d_instances = visualize_point_cloud_segmentation(
  408. coords_file=coords_file, preds_file=preds_file,
  409. classes_to_ignore=['curtain', 'bookshelf', 'floor', 'wall', 'sink', 'toilet', 'bathtub', 'shower curtain', 'picture'],
  410. save_pcd_path=instances_ply_path, if_save_ply=False
  411. )
  412. if raw_3d_instances and all(os.path.exists(f) for f in [floor_plan_image, scene_info_json]):
  413. print("\n--- 开始进行2D投影和后处理 ---")
  414. process_and_draw_bboxes(
  415. picture_name=floor_plan_image,
  416. floor_path=scene_info_json,
  417. raw_bbox_data=raw_3d_instances,
  418. floor_id=0,
  419. output_image_path=segment_onfloor_png_path,
  420. output_json_path=final_instances2d_json_path,
  421. output_3d_json_path=final_instances3d_json_path
  422. )
  423. else:
  424. print("\nSkipping 2D projection due to missing files or no raw instances detected.")