vis.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. import json
  2. import cv2
  3. import numpy as np
  4. import os
  5. from collections import defaultdict
  6. # ================= 1. 颜色映射 =================
  7. COLOR_MAP = {
  8. "living_room": (255, 180, 100), # BGR: 橙色
  9. "bed_room": (100, 255, 100), # BGR: 绿色
  10. "bath_room": (255, 100, 255), # BGR: 紫色
  11. "kitchen_room": (100, 255, 255), # BGR: 黄色
  12. "other_room": (180, 180, 180), # BGR: 灰色
  13. "balcony": (200, 150, 100), # BGR: 棕色
  14. }
  15. DEFAULT_COLOR = (200, 200, 200)
  16. # ================= 2. 工具函数 =================
  17. def get_image_size(data):
  18. """从 JSON 数据获取图像尺寸"""
  19. if 'image_size' in data:
  20. return data['image_size']['width'], data['image_size']['height']
  21. # 从线段数据推断
  22. max_x, max_y = 0, 0
  23. for block in data.get('block', []):
  24. segments = block.get('points', [])
  25. if isinstance(segments, list) and len(segments) > 0:
  26. if isinstance(segments[0], list) and len(segments[0]) == 4:
  27. for seg in segments:
  28. max_x = max(max_x, seg[0], seg[2])
  29. max_y = max(max_y, seg[1], seg[3])
  30. for area in data.get('connect_area', []):
  31. max_x = max(max_x, area.get('x', 0) + area.get('w', 0))
  32. max_y = max(max_y, area.get('y', 0) + area.get('h', 0))
  33. return int(max_x) + 100, int(max_y) + 100
  34. def get_points_from_segments(segments):
  35. """从线段列表提取所有唯一点"""
  36. if not segments:
  37. return []
  38. points_set = set()
  39. for seg in segments:
  40. if len(seg) == 4:
  41. points_set.add((int(seg[0]), int(seg[1])))
  42. points_set.add((int(seg[2]), int(seg[3])))
  43. return list(points_set)
  44. # ================= 3. 主可视化函数 =================
  45. def visualize_final_json(json_path, output_path, rgb_path=None, vis_door=False):
  46. """
  47. 最终可视化:线段格式 JSON
  48. 修改点:
  49. 1. 房间端点:白色点 (255, 255, 255)
  50. 2. 标签位置:使用 JSON 中的 center 字段
  51. """
  52. print(f"\n{'=' * 60}")
  53. print(f"🎨 最终可视化:{os.path.basename(json_path)}")
  54. print('=' * 60)
  55. # 1. 读取 JSON
  56. with open(json_path, 'r', encoding='utf-8') as f:
  57. data = json.load(f)
  58. # 2. 获取尺寸并创建画布
  59. width, height = get_image_size(data)
  60. if rgb_path and os.path.exists(rgb_path):
  61. canvas = cv2.imread(rgb_path)
  62. print(f"🖼️ 使用背景图:{rgb_path}")
  63. else:
  64. # canvas = np.ones((height, width, 3), dtype=np.uint8) * 255
  65. canvas = np.zeros((height, width, 3), dtype=np.uint8)
  66. print(f"📐 创建白色画布:{width}x{height}")
  67. # 3. 绘制房间块 (Blocks)
  68. blocks = data.get('block', [])
  69. print(f"\n🏠 绘制 {len(blocks)} 个房间块...")
  70. for block_idx, block in enumerate(blocks):
  71. block_id = block.get('id', block_idx)
  72. label = block.get('label', 'door')
  73. segments = block.get('points', [])
  74. center = block.get('center', None) # 【修改】使用 JSON 中的 center 字段
  75. # 获取颜色
  76. color = COLOR_MAP.get(label.lower(), DEFAULT_COLOR)
  77. # 验证线段格式
  78. if not (isinstance(segments, list) and len(segments) > 0):
  79. print(f" ⚠️ Block {block_id}: 无数据,跳过")
  80. continue
  81. if not (isinstance(segments[0], list) and len(segments[0]) == 4):
  82. print(f" ⚠️ Block {block_id}: 格式不正确,跳过")
  83. continue
  84. # 提取所有点
  85. points = get_points_from_segments(segments)
  86. if len(points) < 2:
  87. print(f" ⚠️ Block {block_id}: 点数不足,跳过")
  88. continue
  89. # 【1】绘制线段(按元素连线)
  90. for seg in segments:
  91. p1 = (int(seg[0]), int(seg[1]))
  92. p2 = (int(seg[2]), int(seg[3]))
  93. cv2.line(canvas, p1, p2, color, 2, cv2.LINE_AA)
  94. # 【2】绘制点(所有顶点标记)【修改】白色点
  95. for pt in points:
  96. cv2.circle(canvas, pt, 4, (255, 255, 255), -1) # 白色点
  97. # 【3】绘制标签【修改】使用 center 字段
  98. if center and len(center) == 2:
  99. cx, cy = int(center[0]), int(center[1])
  100. else:
  101. # 如果没有 center,回退到计算中心
  102. pts = np.array(points, dtype=np.int32)
  103. M = cv2.moments(pts)
  104. if M["m00"] != 0:
  105. cx = int(M["m10"] / M["m00"])
  106. cy = int(M["m01"] / M["m00"])
  107. else:
  108. cx = int(np.mean(pts[:, 0]))
  109. cy = int(np.mean(pts[:, 1]))
  110. print(f" ⚠️ Block {block_id}: 无 center 字段,使用计算中心")
  111. # 确保在画布内
  112. font = cv2.FONT_HERSHEY_SIMPLEX
  113. font_scale = 0.6
  114. thickness = 1
  115. (tw, th), _ = cv2.getTextSize(label, font, font_scale, thickness)
  116. cx = max(tw // 2 + 5, min(cx, width - tw // 2 - 5))
  117. cy = max(th // 2 + 5, min(cy, height - th // 2 - 5))
  118. # 标签背景(白色矩形)
  119. cv2.rectangle(canvas, (cx - tw // 2 - 3, cy - th - 3),
  120. (cx + tw // 2 + 3, cy + 3), (255, 255, 255), -1)
  121. # 标签文字(黑色)
  122. cv2.putText(canvas, label, (cx - tw // 2, cy),
  123. font, font_scale, (0, 0, 0), thickness)
  124. # Block ID(灰色小字)
  125. id_text = f"ID:{block_id}"
  126. (iw, ih), _ = cv2.getTextSize(id_text, font, 0.4, 1)
  127. cv2.putText(canvas, id_text, (cx - iw // 2, cy + 15),
  128. font, 0.4, (100, 100, 100), 1)
  129. print(f" ✅ Block {block_id} ({label}): {len(segments)} 线段,{len(points)} 点,center=({cx},{cy})")
  130. # 4. 绘制连接区域 (Connect Area)
  131. connect_areas = data.get('connect_area', [])
  132. print(f"\n🔗 绘制 {len(connect_areas)} 个连接区域...")
  133. for area_idx, area in enumerate(connect_areas):
  134. area_id = area.get('id', area_idx)
  135. x = area.get('x', 0)
  136. y = area.get('y', 0)
  137. w = area.get('w', 0)
  138. h = area.get('h', 0)
  139. label = area.get('label', 'door').lower()
  140. block_pair = area.get('block_pair', [])
  141. # 确保在画布内
  142. x = max(0, min(x, width - 1))
  143. y = max(0, min(y, height - 1))
  144. w = max(1, min(w, width - x))
  145. h = max(1, min(h, height - y))
  146. # 【unknown/unknow】红色实心矩形
  147. if vis_door and ("door" in label):
  148. cv2.rectangle(canvas, (x, y), (x + w, y + h), (0, 0, 255), -1) # 红色填充
  149. cv2.rectangle(canvas, (x, y), (x + w, y + h), (255, 255, 255), 1) # 白色边框
  150. print(f" ✅ Connect {area_id}: 🔴 红色实心矩形")
  151. # 5. 绘制家具 (Furniture)
  152. furniture_list = data.get('furniture', [])
  153. print(f"\n🛋️ 绘制 {len(furniture_list)} 个家具...")
  154. FURNITURE_COLOR = (0, 165, 255) # BGR: 橙色
  155. for item in furniture_list:
  156. label = item.get('label', '')
  157. center = item.get('center', None)
  158. pts = item.get('points', {})
  159. if pts:
  160. bx1, by1 = pts['x1'], pts['y1']
  161. bx2, by2 = pts['x3'], pts['y3']
  162. cv2.rectangle(canvas, (bx1, by1), (bx2, by2), FURNITURE_COLOR, 2)
  163. if center and len(center) == 2:
  164. cx, cy = int(center[0]), int(center[1])
  165. font = cv2.FONT_HERSHEY_SIMPLEX
  166. (tw, th), _ = cv2.getTextSize(label, font, 0.5, 1)
  167. cv2.rectangle(canvas, (cx - tw // 2 - 3, cy - th - 3),
  168. (cx + tw // 2 + 3, cy + 3), FURNITURE_COLOR, -1)
  169. cv2.putText(canvas, label, (cx - tw // 2, cy),
  170. font, 0.5, (255, 255, 255), 1)
  171. print(f" ✅ {label} center=({center})")
  172. # 6. 添加图例
  173. print(f"\n📋 添加图例...")
  174. legend_y = 30
  175. legend_x = 20
  176. cv2.putText(canvas, "Legend:", (legend_x, legend_y - 5),
  177. cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
  178. legend_y += 5
  179. for room_type, color in COLOR_MAP.items():
  180. if room_type in ["door"]:
  181. continue
  182. cv2.rectangle(canvas, (legend_x, legend_y),
  183. (legend_x + 20, legend_y + 15), color, -1)
  184. cv2.rectangle(canvas, (legend_x, legend_y),
  185. (legend_x + 20, legend_y + 15), (0, 0, 0), 1)
  186. cv2.putText(canvas, room_type, (legend_x + 25, legend_y + 12),
  187. cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 1)
  188. legend_y += 20
  189. # 添加 unknown 图例
  190. cv2.rectangle(canvas, (legend_x, legend_y),
  191. (legend_x + 20, legend_y + 15), (0, 0, 255), -1)
  192. cv2.rectangle(canvas, (legend_x, legend_y),
  193. (legend_x + 20, legend_y + 15), (0, 0, 0), 1)
  194. cv2.putText(canvas, "door", (legend_x + 25, legend_y + 12),
  195. cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 1)
  196. # 6. 保存结果
  197. cv2.imwrite(output_path, canvas)
  198. print(f"\n{'=' * 60}")
  199. print(f"✨ 可视化完成!")
  200. print(f" 📁 输出路径:{output_path}")
  201. print(f" 📐 画布尺寸:{width}x{height}")
  202. print(f" 🏠 Block 数量:{len(blocks)}")
  203. print(f" 🔗 Connect Area 数量:{len(connect_areas)}")
  204. print('=' * 60)
  205. return canvas
  206. # ================= 4. 批量处理 =================
  207. def visualize_batch(json_folder, output_folder, rgb_folder=None, vis_door=False):
  208. """批量可视化文件夹中的所有 JSON"""
  209. if not os.path.exists(output_folder):
  210. os.makedirs(output_folder)
  211. json_files = [f for f in os.listdir(json_folder) if f.endswith('.json')]
  212. print(f"\n📁 发现 {len(json_files)} 个 JSON 文件\n")
  213. for json_file in json_files:
  214. json_path = os.path.join(json_folder, json_file)
  215. output_path = os.path.join(output_folder, json_file.replace('.json', '.png'))
  216. rgb_path = None
  217. if rgb_folder:
  218. rgb_path = os.path.join(rgb_folder, json_file.replace('.json', '.png'))
  219. if not os.path.exists(rgb_path):
  220. rgb_path = None
  221. try:
  222. visualize_final_json(json_path, output_path, rgb_path, vis_door)
  223. except Exception as e:
  224. print(f"❌ 处理 {json_file} 失败:{e}")
  225. # ================= 5. 主运行流程 =================
  226. if __name__ == '__main__':
  227. import sys
  228. if len(sys.argv) < 2:
  229. print("用法:python vis.py <文件夹名称>")
  230. print("示例:python vis.py SG-n6nV8B2oW95")
  231. sys.exit(1)
  232. folder_name = sys.argv[1]
  233. img_folder = "temp_data"
  234. folder_path = os.path.join(img_folder, folder_name)
  235. if not os.path.isdir(folder_path):
  236. print(f"错误:文件夹不存在 {folder_path}")
  237. sys.exit(1)
  238. # 查找对应的 JSON 文件(文件名与文件夹同名)
  239. json_files = [f for f in os.listdir(folder_path)
  240. if f.startswith(folder_name) and f.endswith('.json')]
  241. if not json_files:
  242. print(f"错误:未找到 JSON 文件")
  243. sys.exit(1)
  244. json_name = json_files[0]
  245. json_path = os.path.join(folder_path, json_name)
  246. # 查找对应的 RGB 图片
  247. rgb_name = json_name.replace('.json', '.png')
  248. rgb_path = os.path.join(folder_path, rgb_name)
  249. if not os.path.exists(rgb_path):
  250. print(f"错误:未找到 {rgb_name}")
  251. sys.exit(1)
  252. # 输出可视化路径:name_vis.png(保存到子文件夹)
  253. vis_name = json_name.replace('.json', '_vis.png')
  254. output_path = os.path.join(folder_path, vis_name)
  255. print(f"处理文件夹:{folder_name}")
  256. print(f" JSON: {json_path}")
  257. print(f" RGB: {rgb_path}")
  258. print(f" Vis: {output_path}")
  259. try:
  260. visualize_final_json(json_path, output_path, rgb_path, vis_door=True)
  261. print(f"✅ 完成 -> {os.path.basename(output_path)}")
  262. except Exception as e:
  263. print(f"❌ 错误:{e}")
  264. sys.exit(1)