pipeline.py 90 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469
  1. import argparse
  2. import cv2
  3. import numpy as np
  4. import os
  5. import json
  6. import math
  7. import shutil
  8. from collections import defaultdict
  9. from datetime import datetime
  10. try:
  11. from tqdm import tqdm
  12. except ImportError:
  13. def tqdm(x, **kw):
  14. return x
  15. from ultralytics import YOLO
  16. # ===========================================================================
  17. # PART 1: merge_all
  18. # ===========================================================================
  19. def remove_pure_red(img):
  20. if img is None: return
  21. # 在 OpenCV (BGR) 中,纯红色是 [0, 0, 255]
  22. # 找到所有三个通道完全匹配该值的像素
  23. red_pixels = (img[:, :, 0] == 0) & (img[:, :, 1] == 0) & (img[:, :, 2] == 255)
  24. # 直接将这些位置涂黑
  25. img[red_pixels] = [0, 0, 0]
  26. return img
  27. def remove_edge_regions_image(img):
  28. """移除边缘黑边/红边,返回清洗后的单张 RGB 图。"""
  29. if img is None:
  30. raise ValueError("输入图像为空")
  31. result = img.copy()
  32. img_2 = np.zeros_like(result)
  33. mask = (result[:, :, 0] == 0) & (result[:, :, 1] == 0) & (result[:, :, 2] == 0)
  34. img_2[~mask] = (255, 255, 255)
  35. edges = cv2.Canny(img_2, 50, 150)
  36. kernel = np.ones((9, 9), np.uint8)
  37. edges = cv2.dilate(edges, kernel, 1)
  38. result[edges > 0] = (0, 0, 0)
  39. return remove_pure_red(result)
  40. def extract_gaps_from_mask(mask):
  41. """从单通道 mask 中提取裂隙,并返回 gaps 与 gap_add_mask。"""
  42. if mask is None:
  43. raise ValueError("mask 为空")
  44. _, binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
  45. kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (25, 25))
  46. stitched_mask = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
  47. gaps = cv2.subtract(stitched_mask, binary)
  48. refine_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
  49. gaps_dilated = cv2.dilate(gaps, refine_kernel, iterations=1)
  50. merged = cv2.add(mask, gaps_dilated)
  51. return gaps, merged
  52. def extract_mask_region_from_arrays(rgb_img, ori_mask, full_mask):
  53. """在内存中提取 full_mask 相比 ori_mask 新增的 RGB 区域。"""
  54. if rgb_img is None or ori_mask is None or full_mask is None:
  55. raise ValueError("rgb_img / ori_mask / full_mask 不能为空")
  56. _, binary_ori_mask = cv2.threshold(ori_mask, 127, 255, cv2.THRESH_BINARY)
  57. _, binary_full_mask = cv2.threshold(full_mask, 127, 255, cv2.THRESH_BINARY)
  58. result_rgb_ori_mask = cv2.bitwise_and(rgb_img, rgb_img, mask=binary_ori_mask)
  59. result_rgb_full_mask = cv2.bitwise_and(rgb_img, rgb_img, mask=binary_full_mask)
  60. return cv2.subtract(result_rgb_full_mask, result_rgb_ori_mask)
  61. def _to_gray_mask(mask):
  62. if mask is None:
  63. raise ValueError("mask 为空")
  64. if len(mask.shape) == 2:
  65. return mask
  66. return cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
  67. def _expand_rect(x, y, w, h, expand_pixel, width, height):
  68. if expand_pixel <= 0:
  69. return int(x), int(y), int(w), int(h)
  70. x1 = max(0, int(x) - int(expand_pixel))
  71. y1 = max(0, int(y) - int(expand_pixel))
  72. x2 = min(width, int(x) + int(w) + int(expand_pixel))
  73. y2 = min(height, int(y) + int(h) + int(expand_pixel))
  74. return x1, y1, max(1, x2 - x1), max(1, y2 - y1)
  75. def _load_yolo_model(model_or_path):
  76. if isinstance(model_or_path, (str, os.PathLike)):
  77. return YOLO(model_or_path)
  78. return model_or_path
  79. def extract_gaps(mask_path, gap_path, gap_add_mask_path):
  80. # 1. 读取原始 Mask 图
  81. mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
  82. if mask is None:
  83. print("错误:无法读取图像,请检查路径。")
  84. return
  85. # 确保图像是纯粹的二值图 (0 和 255)
  86. _, binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
  87. # 2. 核心步骤:使用形态学闭运算“抹平”裂隙
  88. # 核 (Kernel) 的大小决定了算法能填补多宽的裂缝。
  89. # 这里的 (25, 25) 是一个经验值,刚好略大于你图中裂隙的像素宽度。
  90. # 如果裂隙更宽,可以增大这个值,比如 (35, 35)。
  91. kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (25, 25))
  92. # 闭运算:先膨胀(挤满裂缝)再腐蚀(恢复外部原有边界)
  93. stitched_mask = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
  94. # 3. 提取裂隙:缝合后的整体 - 原始的断裂块
  95. gaps = cv2.subtract(stitched_mask, binary)
  96. cv2.imwrite(gap_path, gaps)
  97. refine_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
  98. gaps = cv2.dilate(gaps, refine_kernel, iterations=1)
  99. merged = cv2.add(mask, gaps)
  100. cv2.imwrite(gap_add_mask_path, merged)
  101. def extract_mask_region(rgb_path, ori_mask_path, full_mask_path, output_path):
  102. rgb_img = cv2.imread(rgb_path)
  103. ori_mask = cv2.imread(ori_mask_path, cv2.IMREAD_GRAYSCALE)
  104. full_mask = cv2.imread(full_mask_path, cv2.IMREAD_GRAYSCALE)
  105. if rgb_img is None or ori_mask is None or full_mask is None:
  106. print(f"无法读取文件: {rgb_path} 或 {ori_mask_path} 或 {full_mask_path}")
  107. return
  108. _, binary_ori_mask = cv2.threshold(ori_mask, 127, 255, cv2.THRESH_BINARY)
  109. # birefNet生成的mask和原图有些位移,在生成连通区域时会有噪声,边缘长细条
  110. # kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
  111. # binary_ori_mask = cv2.dilate(binary_ori_mask, kernel, iterations=1)
  112. # birefNet生成的mask和原图有些位移,在生成连通区域时会有噪声,边缘长细条
  113. _, binary_full_mask = cv2.threshold(full_mask, 127, 255, cv2.THRESH_BINARY)
  114. result_rgb_ori_mask = cv2.bitwise_and(rgb_img, rgb_img, mask=binary_ori_mask)
  115. result_rgb_full_mask = cv2.bitwise_and(rgb_img, rgb_img, mask=binary_full_mask)
  116. result = result_rgb_full_mask - result_rgb_ori_mask
  117. # 5. 保存结果
  118. cv2.imwrite(output_path, result)
  119. print(f"✅ 提取成功:{os.path.basename(output_path)}")
  120. def identify_boundary_connections(block_mask_path, contour_mask_path,
  121. output_path='final.png',
  122. mapping_output_path='final_mapping.png', # 新增:关系映射图保存路径
  123. max_dist=120, max_thickness=45):
  124. """
  125. 识别边界连接,合并近似共线且相距不远的轮廓,并生成两张可视化图:
  126. 1. final_mapping.png: 标记房间ID与轮廓连通关系的映射图
  127. 2. final.png: 在底图上绘制合并后红色矩形的最终结果图
  128. """
  129. # 1. 加载数据
  130. blocks = cv2.imread(block_mask_path, cv2.IMREAD_GRAYSCALE)
  131. contours = cv2.imread(contour_mask_path, cv2.IMREAD_GRAYSCALE)
  132. target_mask = cv2.imread(block_mask_path)
  133. if target_mask is None:
  134. target_mask = np.zeros((blocks.shape[0], blocks.shape[1], 3), dtype=np.uint8)
  135. # 预处理轮廓
  136. kernel = np.ones((5, 5), np.uint8)
  137. contours_dilated = cv2.dilate(contours, kernel, iterations=1)
  138. _, blocks_bin = cv2.threshold(blocks, 127, 255, cv2.THRESH_BINARY)
  139. _, contours_bin = cv2.threshold(contours_dilated, 0, 255, cv2.THRESH_BINARY)
  140. # 2. 连通域标记 (注意这里接收了 stats 和 centroids 用于后续标点)
  141. num_blocks, block_labels = cv2.connectedComponents(blocks_bin, connectivity=8)
  142. num_contours, contour_labels, stats, centroids = cv2.connectedComponentsWithStats(contours_bin, connectivity=8)
  143. print(f"检测到房间数: {num_blocks - 1}")
  144. print(f"检测到轮廓段: {num_contours - 1}")
  145. # 创建彩色画布用于映射关系结果展示 (final_mapping.png)
  146. mapping_img = cv2.cvtColor(contours_bin, cv2.COLOR_GRAY2BGR)
  147. # 3. 收集轮廓并绘制映射关系
  148. pair_to_contours = defaultdict(list)
  149. connections = []
  150. for c_id in range(1, num_contours):
  151. single_contour = (contour_labels == c_id).astype(np.uint8) * 255
  152. # 加厚以探测邻居
  153. kernel_detect = np.ones((9, 9), np.uint8)
  154. dilated_c = cv2.dilate(single_contour, kernel_detect, iterations=1)
  155. neighboring_blocks = np.unique(block_labels[dilated_c > 0])
  156. neighboring_blocks = neighboring_blocks[neighboring_blocks > 0]
  157. block_ids = sorted(neighboring_blocks.tolist())
  158. connections.append({"contour_id": c_id, "connects": block_ids})
  159. # ==================== 新增:绘制映射图元素 ====================
  160. cx, cy = int(centroids[c_id][0]), int(centroids[c_id][1])
  161. relation_text = "-".join(map(str, block_ids))
  162. # 绘制该段轮廓(设为绿色)
  163. mapping_img[contour_labels == c_id] = [0, 255, 0]
  164. # 标注连接关系 (例如 "1-2")
  165. if len(block_ids) >= 2:
  166. cv2.putText(mapping_img, relation_text, (cx, cy),
  167. cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 255), 1)
  168. # =============================================================
  169. # 收集独立轮廓用于后续合并逻辑
  170. if len(block_ids) == 2:
  171. pair_key = f"{block_ids[0]}-{block_ids[1]}"
  172. y_coords, x_coords = np.where(contour_labels == c_id)
  173. if len(x_coords) > 0:
  174. points = np.column_stack((x_coords, y_coords))
  175. pair_to_contours[pair_key].append(points)
  176. # ==================== 新增:在映射图上标出房间 ID ====================
  177. for b_id in range(1, num_blocks):
  178. b_mask = (block_labels == b_id).astype(np.uint8)
  179. M = cv2.moments(b_mask)
  180. if M["m00"] != 0:
  181. bx, by = int(M["m10"] / M["m00"]), int(M["m01"] / M["m00"])
  182. cv2.putText(mapping_img, f"B{b_id}", (bx, by),
  183. cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
  184. # 保存映射关系图
  185. if mapping_output_path:
  186. cv2.imwrite(mapping_output_path, mapping_img)
  187. print(f"✅ 映射关系图已保存至: {mapping_output_path}")
  188. # ===================================================================
  189. # 4. 轮廓聚类与绘制组合后的外接矩形 (原有的高级合并逻辑)
  190. def should_merge(pts1, pts2):
  191. rect1 = cv2.minAreaRect(pts1)
  192. rect2 = cv2.minAreaRect(pts2)
  193. c1, c2 = np.array(rect1[0]), np.array(rect2[0])
  194. if np.linalg.norm(c1 - c2) > max_dist:
  195. return False
  196. combined_pts = np.vstack((pts1, pts2))
  197. comb_rect = cv2.minAreaRect(combined_pts)
  198. thickness = min(comb_rect[1][0], comb_rect[1][1])
  199. if thickness > max_thickness:
  200. return False
  201. return True
  202. merged_rects_info = []
  203. for pair_key, contour_list in pair_to_contours.items():
  204. n = len(contour_list)
  205. if n == 0: continue
  206. adj = {i: [] for i in range(n)}
  207. for i in range(n):
  208. for j in range(i + 1, n):
  209. if should_merge(contour_list[i], contour_list[j]):
  210. adj[i].append(j)
  211. adj[j].append(i)
  212. visited = set()
  213. groups = []
  214. for i in range(n):
  215. if i not in visited:
  216. comp = []
  217. queue = [i]
  218. visited.add(i)
  219. while queue:
  220. curr = queue.pop(0)
  221. comp.append(curr)
  222. for neighbor in adj[curr]:
  223. if neighbor not in visited:
  224. visited.add(neighbor)
  225. queue.append(neighbor)
  226. groups.append(comp)
  227. for group_indices in groups:
  228. combined_group_points = np.vstack([contour_list[idx] for idx in group_indices])
  229. points_np = np.array(combined_group_points, dtype=np.int32)
  230. # A. 计算最小外接矩形
  231. rect = cv2.minAreaRect(points_np)
  232. (cx, cy), (w, h), angle = rect
  233. thickness = min(w, h)
  234. length = max(w, h)
  235. # B. 过滤器 1: 宽度阈值处理
  236. # 如果厚度超过阈值,说明不是“缝隙”,可能是错误的对角线连接,直接跳过
  237. if thickness > max_thickness:
  238. continue
  239. # C. 过滤器 2: 遮罩约束 (确保矩形在缝隙内)
  240. # 创建一个临时黑色画布,只画这个生成的矩形
  241. temp_rect_mask = np.zeros_like(contours_bin)
  242. box = cv2.boxPoints(rect)
  243. box = np.int32(box)
  244. cv2.fillPoly(temp_rect_mask, [box], 255)
  245. # 计算该矩形区域内,有多少像素真正属于原始缝隙 contours_bin
  246. # 使用逻辑与运算
  247. overlap = cv2.bitwise_and(temp_rect_mask, contours_bin)
  248. overlap_score = np.sum(overlap > 0) / np.sum(temp_rect_mask > 0)
  249. # 如果重叠率太低(例如低于 50%),说明该矩形跨越了大量的非缝隙区域,舍弃
  250. if overlap_score < 0.6:
  251. continue
  252. # --- 只有通过以上两层过滤的才会被绘制和记录 ---
  253. cv2.fillPoly(target_mask, [box], (0, 0, 255))
  254. merged_rects_info.append({
  255. "pair": pair_key,
  256. "box": box.tolist(),
  257. "thickness": thickness,
  258. "overlap_score": float(overlap_score)
  259. })
  260. # 5. 保存带有红色合并框的最终底图
  261. if output_path:
  262. cv2.imwrite(output_path, target_mask)
  263. print(f"✅ 合并后的轮廓底图已保存至: {output_path}")
  264. return connections, merged_rects_info
  265. def mask_add_conncet(mask_path, conncet_path, add_path):
  266. img1 = cv2.imread(mask_path)
  267. img2 = cv2.imread(conncet_path)
  268. img2 = remove_pure_red(img2)
  269. # 1. 创建掩码:找出 img2 中所有非黑色的像素点
  270. img2_gray = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
  271. _, mask = cv2.threshold(img2_gray, 1, 255, cv2.THRESH_BINARY)
  272. # 2. 直接覆盖
  273. result = img1.copy()
  274. result[mask > 0] = img2[mask > 0]
  275. cv2.imwrite(add_path, result)
  276. def overlay_images(base_img_path, overlay_img_path, output_path):
  277. """
  278. 将 overlay_img 叠加到 base_img 上。
  279. 在重叠部分,只显示 overlay_img 的像素。
  280. """
  281. img1 = cv2.imread(base_img_path) # 底图
  282. img2 = cv2.imread(overlay_img_path) # 要叠加的图
  283. if img1 is None or img2 is None:
  284. print("错误:无法读取图片,请检查路径。")
  285. return
  286. if img1.shape != img2.shape:
  287. img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0]))
  288. print("注意:已将 img2 的尺寸调整为与 img1 一致。")
  289. img2_gray = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
  290. _, mask = cv2.threshold(img2_gray, 1, 255, cv2.THRESH_BINARY)
  291. result = img1.copy()
  292. result[mask > 0] = img2[mask > 0]
  293. cv2.imwrite(output_path, result)
  294. print(f"✅ 图片叠加完成,结果已保存至: {output_path}")
  295. def remove_pure_red(img):
  296. if img is None: return
  297. # 在 OpenCV (BGR) 中,纯红色是 [0, 0, 255]
  298. # 找到所有三个通道完全匹配该值的像素
  299. red_pixels = (img[:, :, 0] == 0) & (img[:, :, 1] == 0) & (img[:, :, 2] == 255)
  300. # 直接将这些位置涂黑
  301. img[red_pixels] = [0, 0, 0]
  302. return img
  303. def remove_edge_regions(ori_rgb_folder, no_rededge_mask):
  304. test_names = os.listdir(ori_rgb_folder)
  305. for name in test_names:
  306. img_name = os.path.join(ori_rgb_folder, name)
  307. save_path = os.path.join(no_rededge_mask, name)
  308. img = cv2.imread(img_name)
  309. img_2 = np.zeros_like(img)
  310. mask = (img[:, :, 0] == 0) & (img[:, :, 1] == 0) & (img[:, :, 2] == 0)
  311. img_2[~mask] = (255, 255, 255)
  312. edges = cv2.Canny(img_2, 50, 150)
  313. kernel = np.ones((9, 9), np.uint8)
  314. edges = cv2.dilate(edges, kernel, 1)
  315. img[edges > 0] = (0, 0, 0)
  316. img = remove_pure_red(img)
  317. cv2.imwrite(save_path, img)
  318. def merge_gap_fillers_from_arrays(mask1, mask2, image_path="",
  319. dilation_kernel_size=5, center_threshold=50,
  320. expand_pixel=10, min_rect_short_side=30):
  321. """在内存中合并 mask1 和 mask2,返回 connect_area JSON、预览图和统计信息。"""
  322. m1 = _to_gray_mask(mask1)
  323. m2 = _to_gray_mask(mask2)
  324. if m1.shape != m2.shape:
  325. m2 = cv2.resize(m2, (m1.shape[1], m1.shape[0]))
  326. _, m1_bin = cv2.threshold(m1, 127, 255, cv2.THRESH_BINARY)
  327. _, m2_bin = cv2.threshold(m2, 0, 255, cv2.THRESH_BINARY)
  328. num1, block_labels = cv2.connectedComponents(m1_bin)
  329. num2, frag_labels, _, _ = cv2.connectedComponentsWithStats(m2_bin)
  330. result = cv2.cvtColor(m1_bin, cv2.COLOR_GRAY2BGR)
  331. bridge_frags = []
  332. for i in range(1, num2):
  333. single_frag_mask = (frag_labels == i).astype(np.uint8) * 255
  334. kernel = np.ones((dilation_kernel_size, dilation_kernel_size), np.uint8)
  335. dilated_frag = cv2.dilate(single_frag_mask, kernel, iterations=1)
  336. touched_labels = np.unique(block_labels[dilated_frag > 0])
  337. neighbors = sorted(int(n) for n in touched_labels if n > 0)
  338. if len(neighbors) == 2:
  339. pts = np.column_stack(np.where(single_frag_mask > 0))
  340. if len(pts) > 0:
  341. cy, cx = np.mean(pts, axis=0)
  342. bridge_frags.append({
  343. 'id': int(i),
  344. 'cx': float(cx),
  345. 'cy': float(cy),
  346. 'block_pair': tuple(int(n - 1) for n in neighbors),
  347. 'mask': single_frag_mask
  348. })
  349. groups = []
  350. for frag in bridge_frags:
  351. assigned = False
  352. for group in groups:
  353. if group[0]['block_pair'] != frag['block_pair']:
  354. continue
  355. for existing in group:
  356. dx = abs(frag['cx'] - existing['cx'])
  357. dy = abs(frag['cy'] - existing['cy'])
  358. if dx < center_threshold or dy < center_threshold:
  359. group.append(frag)
  360. assigned = True
  361. break
  362. if assigned:
  363. break
  364. if not assigned:
  365. groups.append([frag])
  366. connect_areas = []
  367. rect_id = 0
  368. height, width = m1.shape[:2]
  369. for group in groups:
  370. points_list = []
  371. frag_rects = []
  372. for frag in group:
  373. pts = np.column_stack(np.where(frag['mask'] > 0))
  374. if pts.size > 0:
  375. points_xy = pts[:, ::-1]
  376. points_list.append(points_xy)
  377. raw_x, raw_y, raw_w, raw_h = cv2.boundingRect(points_xy)
  378. frag_rects.append((int(raw_x), int(raw_y), int(raw_w), int(raw_h), frag))
  379. if len(points_list) < 1:
  380. continue
  381. all_points = np.vstack(points_list)
  382. if len(all_points) < 3:
  383. continue
  384. raw_x, raw_y, raw_w, raw_h = cv2.boundingRect(all_points)
  385. merged_short_side = min(raw_w, raw_h)
  386. if len(group) > 1 and merged_short_side > min_rect_short_side:
  387. for fx, fy, fw, fh, frag in frag_rects:
  388. single_short_side = min(fw, fh)
  389. if single_short_side <= min_rect_short_side:
  390. fx, fy, fw, fh = _expand_rect(fx, fy, fw, fh, expand_pixel, width, height)
  391. cv2.rectangle(result, (fx, fy), (fx + fw, fy + fh), (0, 255, 0), -1)
  392. connect_areas.append({
  393. 'id': int(rect_id),
  394. 'x': int(fx),
  395. 'y': int(fy),
  396. 'w': int(fw),
  397. 'h': int(fh),
  398. 'block_pair': [int(n) for n in frag['block_pair']],
  399. 'label': 'door'
  400. })
  401. rect_id += 1
  402. continue
  403. if merged_short_side <= min_rect_short_side:
  404. x, y, w, h = _expand_rect(raw_x, raw_y, raw_w, raw_h, expand_pixel, width, height)
  405. cv2.rectangle(result, (x, y), (x + w, y + h), (0, 255, 0), -1)
  406. connect_areas.append({
  407. 'id': int(rect_id),
  408. 'x': int(x),
  409. 'y': int(y),
  410. 'w': int(w),
  411. 'h': int(h),
  412. 'block_pair': [int(n) for n in group[0]['block_pair']],
  413. 'label': 'door'
  414. })
  415. rect_id += 1
  416. json_data = {
  417. 'image_path': str(image_path),
  418. 'image_size': {
  419. 'width': int(m1.shape[1]),
  420. 'height': int(m1.shape[0])
  421. },
  422. 'connect_area': connect_areas
  423. }
  424. stats = {
  425. 'mask1_blocks': int(num1 - 1),
  426. 'mask2_fragments': int(num2 - 1),
  427. 'bridge_fragments': int(len(bridge_frags)),
  428. 'group_count': int(len(groups)),
  429. 'connect_area_count': int(len(connect_areas))
  430. }
  431. return json_data, result, stats
  432. def merge_gap_fillers(mask1_path, mask2_path, output_image_path, output_json_path,
  433. dilation_kernel_size=5, center_threshold=50,
  434. expand_pixel=10, min_rect_short_side=30):
  435. """
  436. 合并 mask1 和 mask2,利用 mask2 填充 mask1 块之间的裂隙。
  437. """
  438. m1 = cv2.imread(mask1_path, cv2.IMREAD_GRAYSCALE)
  439. m2 = cv2.imread(mask2_path, cv2.IMREAD_GRAYSCALE)
  440. if m1 is None or m2 is None:
  441. raise FileNotFoundError("无法加载图像,请检查路径")
  442. json_data, result, stats = merge_gap_fillers_from_arrays(
  443. m1,
  444. m2,
  445. image_path=output_image_path,
  446. dilation_kernel_size=dilation_kernel_size,
  447. center_threshold=center_threshold,
  448. expand_pixel=expand_pixel,
  449. min_rect_short_side=min_rect_short_side,
  450. )
  451. os.makedirs(os.path.dirname(os.path.abspath(output_image_path)), exist_ok=True)
  452. os.makedirs(os.path.dirname(os.path.abspath(output_json_path)), exist_ok=True)
  453. cv2.imwrite(output_image_path, result)
  454. with open(output_json_path, 'w', encoding='utf-8') as f:
  455. json.dump(json_data, f, indent=2, ensure_ascii=False)
  456. print(f"✅ 处理完成!")
  457. print(f" - mask1 块数:{stats['mask1_blocks']} (白色)")
  458. print(f" - mask2 碎片总数:{stats['mask2_fragments']}")
  459. print(f" - 连接 2 块的碎片:{stats['bridge_fragments']} 个")
  460. print(f" - 聚类后的组数:{stats['group_count']} 组")
  461. print(f" - 总连接区域数:{stats['connect_area_count']}")
  462. print(f" - 结果图片保存至:{output_image_path}")
  463. print(f" - JSON 坐标保存至:{output_json_path}")
  464. def verify_json_coordinates(json_path, output_verify_path, mask1_path=None):
  465. """
  466. 读取 JSON 中的坐标数据,在图片上绘制绿色矩形进行验证
  467. 参数:
  468. json_path: JSON 文件路径
  469. output_verify_path: 验证结果图片保存路径
  470. mask1_path: 可选,mask1 图片路径(如果提供,会叠加显示 mask1 白色块)
  471. """
  472. # 1. 读取 JSON 数据
  473. with open(json_path, 'r', encoding='utf-8') as f:
  474. data = json.load(f)
  475. print(f"📊 读取 JSON 文件:{json_path}")
  476. print(f" - 图片路径:{data['image_path']}")
  477. print(f" - 图片尺寸:{data['image_size']['width']} x {data['image_size']['height']}")
  478. print(f" - 连接区域数:{len(data['connect_area'])}")
  479. # 2. 创建画布
  480. width = data['image_size']['width']
  481. height = data['image_size']['height']
  482. if mask1_path:
  483. # 如果有 mask1,加载并显示白色块
  484. m1 = cv2.imread(mask1_path, cv2.IMREAD_GRAYSCALE)
  485. if m1 is not None:
  486. if m1.shape[0] != height or m1.shape[1] != width:
  487. m1 = cv2.resize(m1, (width, height))
  488. _, m1_bin = cv2.threshold(m1, 127, 255, cv2.THRESH_BINARY)
  489. result = cv2.cvtColor(m1_bin, cv2.COLOR_GRAY2BGR)
  490. print(f" - 已加载 mask1,显示白色块")
  491. else:
  492. result = np.zeros((height, width, 3), dtype=np.uint8)
  493. print(f" - 未找到 mask1,使用黑色背景")
  494. else:
  495. # 没有 mask1,使用黑色背景
  496. result = np.zeros((height, width, 3), dtype=np.uint8)
  497. print(f" - 未提供 mask1,使用黑色背景")
  498. # 3. 根据 JSON 坐标绘制绿色矩形
  499. for area in data['connect_area']:
  500. x = area['x']
  501. y = area['y']
  502. w = area['w']
  503. h = area['h']
  504. block_pair = area['block_pair']
  505. label = area.get('label', 'door')
  506. # 绘制绿色矩形(填充)
  507. cv2.rectangle(result, (x, y), (x + w, y + h), (0, 255, 0), -1)
  508. # 绘制矩形边框(红色,方便看清边界)
  509. cv2.rectangle(result, (x, y), (x + w, y + h), (0, 0, 255), 2)
  510. # 在矩形中心标记 ID 和 block_pair
  511. cx, cy = x + w // 2, y + h // 2
  512. text = f"ID:{area['id']} [{block_pair[0]},{block_pair[1]}]"
  513. cv2.putText(result, text, (cx - 80, cy),
  514. cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)
  515. # 4. 保存验证图片
  516. cv2.imwrite(output_verify_path, result)
  517. # 5. 打印详细信息
  518. print(f"\n📋 连接区域详情:")
  519. print(f" {'ID':<6} {'X':<8} {'Y':<8} {'W':<6} {'H':<6} {'Block_Pair':<15} {'Label':<10}")
  520. print(f" {'-' * 70}")
  521. for area in data['connect_area']:
  522. print(
  523. f" {area['id']:<6} {area['x']:<8} {area['y']:<8} {area['w']:<6} {area['h']:<6} {str(area['block_pair']):<15} {area['label']:<10}")
  524. print(f"\n✅ 验证图片保存至:{output_verify_path}")
  525. return data
  526. # def build_block_data(rgb_img, block_mask, model_path="room_seg.pt"):
  527. # """基于单张 RGB 和 block mask 生成 block 列表。"""
  528. # model = _load_yolo_model(model_path)
  529. # blocks = _to_gray_mask(block_mask)
  530. #
  531. # if rgb_img is None or blocks is None:
  532. # raise FileNotFoundError("图片路径错误,请检查文件是否存在")
  533. #
  534. # if rgb_img.shape[:2] != blocks.shape[:2]:
  535. # rgb_img = cv2.resize(rgb_img, (blocks.shape[1], blocks.shape[0]))
  536. #
  537. # _, blocks_bin = cv2.threshold(blocks, 127, 255, cv2.THRESH_BINARY)
  538. # num_blocks, block_labels = cv2.connectedComponents(blocks_bin, connectivity=8)
  539. # actual_room_count = num_blocks - 1
  540. # print(f"🔍 检测到 {actual_room_count} 个独立的房间块")
  541. #
  542. # block_list = []
  543. #
  544. # for b_id in range(1, num_blocks):
  545. # print(f"\n📍 处理 Block {b_id - 1}...")
  546. # mask_single = (block_labels == b_id).astype(np.uint8)
  547. # contours, _ = cv2.findContours(mask_single, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  548. #
  549. # points = []
  550. # cx = cy = 0
  551. # if contours:
  552. # largest_contour = max(contours, key=cv2.contourArea)
  553. # epsilon = 2.0
  554. # simplified = cv2.approxPolyDP(largest_contour, epsilon, True)
  555. # for pt in simplified:
  556. # points.append(int(pt[0][0]))
  557. # points.append(int(pt[0][1]))
  558. #
  559. # M = cv2.moments(simplified)
  560. # if M['m00'] != 0:
  561. # cx = int(M['m10'] / M['m00'])
  562. # cy = int(M['m01'] / M['m00'])
  563. # else:
  564. # cx = int(np.mean(simplified[:, 0, 0]))
  565. # cy = int(np.mean(simplified[:, 0, 1]))
  566. #
  567. # x, y, w, h = cv2.boundingRect(mask_single)
  568. # if not contours:
  569. # cx = int(x + w / 2)
  570. # cy = int(y + h / 2)
  571. #
  572. # pad = 20
  573. # y1, y2 = max(0, y - pad), min(rgb_img.shape[0], y + h + pad)
  574. # x1, x2 = max(0, x - pad), min(rgb_img.shape[1], x + w + pad)
  575. # roi_rgb = rgb_img[y1:y2, x1:x2]
  576. #
  577. # valid_boxes = []
  578. # if roi_rgb.shape[0] > 10 and roi_rgb.shape[1] > 10:
  579. # results = model(roi_rgb, conf=0.15, verbose=False)[0]
  580. #
  581. # if len(results.boxes) > 0:
  582. # boxes_xyxy = results.boxes.xyxy.cpu().numpy()
  583. # boxes_conf = results.boxes.conf.cpu().numpy()
  584. # boxes_cls = results.boxes.cls.cpu().numpy()
  585. #
  586. # for i in range(len(boxes_xyxy)):
  587. # bx1, by1, bx2, by2 = boxes_xyxy[i]
  588. # center_x = (bx1 + bx2) / 2
  589. # center_y = (by1 + by2) / 2
  590. # global_cx = x1 + center_x
  591. # global_cy = y1 + center_y
  592. #
  593. # if 0 <= global_cy < block_labels.shape[0] and 0 <= global_cx < block_labels.shape[1]:
  594. # if block_labels[int(global_cy), int(global_cx)] == b_id:
  595. # valid_boxes.append({
  596. # "cls": int(boxes_cls[i]),
  597. # "conf": float(boxes_conf[i]),
  598. # "name": model.names[int(boxes_cls[i])]
  599. # })
  600. #
  601. # final_label = "other_room"
  602. # final_conf = 0.0
  603. # if valid_boxes:
  604. # best_box = max(valid_boxes, key=lambda k: k['conf'])
  605. # final_label = best_box['name']
  606. # final_conf = best_box['conf']
  607. # print(f" ✅ 检测到:{final_label} (conf: {final_conf:.2f})")
  608. # else:
  609. # print(f" ⚠️ 未检测到有效物体,标记为 other_room")
  610. #
  611. # block_list.append({
  612. # "id": int(b_id - 1),
  613. # "points": points,
  614. # "label": final_label,
  615. # "center": [int(cx), int(cy)]
  616. # })
  617. #
  618. # return block_list
  619. def build_block_data(rgb_img, block_mask, model_path="room_cls.pt"):
  620. """基于单张 RGB 和 block mask 生成 block 列表(使用 YOLO 分类模型)。"""
  621. model = _load_yolo_model(model_path)
  622. blocks = _to_gray_mask(block_mask)
  623. if rgb_img is None or blocks is None:
  624. raise FileNotFoundError("图片路径错误,请检查文件是否存在")
  625. if rgb_img.shape[:2] != blocks.shape[:2]:
  626. rgb_img = cv2.resize(rgb_img, (blocks.shape[1], blocks.shape[0]))
  627. _, blocks_bin = cv2.threshold(blocks, 127, 255, cv2.THRESH_BINARY)
  628. num_blocks, block_labels = cv2.connectedComponents(blocks_bin, connectivity=8)
  629. actual_room_count = num_blocks - 1
  630. print(f"🔍 检测到 {actual_room_count} 个独立的房间块")
  631. block_list = []
  632. for b_id in range(1, num_blocks):
  633. print(f"\n📍 处理 Block {b_id - 1}...")
  634. mask_single = (block_labels == b_id).astype(np.uint8)
  635. contours, _ = cv2.findContours(mask_single, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  636. points = []
  637. cx = cy = 0
  638. if contours:
  639. largest_contour = max(contours, key=cv2.contourArea)
  640. epsilon = 2.0
  641. simplified = cv2.approxPolyDP(largest_contour, epsilon, True)
  642. for pt in simplified:
  643. points.append(int(pt[0][0]))
  644. points.append(int(pt[0][1]))
  645. M = cv2.moments(simplified)
  646. if M['m00'] != 0:
  647. cx = int(M['m10'] / M['m00'])
  648. cy = int(M['m01'] / M['m00'])
  649. else:
  650. cx = int(np.mean(simplified[:, 0, 0]))
  651. cy = int(np.mean(simplified[:, 0, 1]))
  652. x, y, w, h = cv2.boundingRect(mask_single)
  653. if not contours:
  654. cx = int(x + w / 2)
  655. cy = int(y + h / 2)
  656. # ---------------------------------------------------------
  657. # 修改点 1:提取纯净的 ROI(剔除背景噪声)
  658. # ---------------------------------------------------------
  659. pad = 20
  660. y1, y2 = max(0, y - pad), min(rgb_img.shape[0], y + h + pad)
  661. x1, x2 = max(0, x - pad), min(rgb_img.shape[1], x + w + pad)
  662. # 使用 .copy() 防止修改原图
  663. roi_rgb = rgb_img[y1:y2, x1:x2].copy()
  664. # 提取相同位置的掩码切片
  665. roi_mask_patch = mask_single[y1:y2, x1:x2]
  666. # 核心逻辑:将掩码为 0(非当前房间)的像素,全部置为黑色 [0, 0, 0]
  667. roi_rgb[roi_mask_patch == 0] = [0, 0, 0]
  668. # ---------------------------------------------------------
  669. # 修改点 2:使用 YOLO 分类模型的输出解析逻辑
  670. # ---------------------------------------------------------
  671. final_label = "other_room"
  672. final_conf = 0.0
  673. if roi_rgb.shape[0] > 10 and roi_rgb.shape[1] > 10:
  674. # 执行分类推理
  675. results = model(roi_rgb, verbose=False)[0]
  676. # 分类模型的输出在 results.probs 中,不再是 results.boxes
  677. if hasattr(results, 'probs') and results.probs is not None:
  678. # 获取置信度最高的类别索引
  679. top1_idx = int(results.probs.top1)
  680. # 获取最高置信度数值
  681. top1_conf = float(results.probs.top1conf.cpu().numpy())
  682. # 获取类别名称
  683. top1_label = model.names[top1_idx]
  684. # 设定分类置信度阈值 (可根据需求调整,比如 0.4)
  685. if top1_conf >= 0.15:
  686. final_label = top1_label
  687. final_conf = top1_conf
  688. print(f" ✅ 分类检测到:{final_label} (conf: {final_conf:.2f})")
  689. else:
  690. print(f" ⚠️ 最高分类置信度偏低 ({top1_label}: {top1_conf:.2f}),标记为 other_room")
  691. else:
  692. print(" ⚠️ 模型输出中没有分类概率,请检查是否加载了正确的 -cls 分类模型!")
  693. block_list.append({
  694. "id": int(b_id - 1),
  695. "points": points,
  696. "label": final_label,
  697. "center": [int(cx), int(cy)]
  698. })
  699. return block_list
  700. def add_block_labels_to_json(json_path, rgb_path, block_mask_path, model_path="room_seg.pt", output_json_path=None):
  701. """
  702. 读取现有 JSON 文件,追加 block 信息(轮廓坐标 + YOLO 检测标签)
  703. """
  704. with open(json_path, 'r', encoding='utf-8') as f:
  705. json_data = json.load(f)
  706. print(f"📊 读取原有 JSON: {json_path}")
  707. print(f" - 连接区域数:{len(json_data.get('connect_area', []))}")
  708. img_rgb = cv2.imread(rgb_path)
  709. blocks = cv2.imread(block_mask_path, cv2.IMREAD_GRAYSCALE)
  710. if img_rgb is None or blocks is None:
  711. raise FileNotFoundError("图片路径错误,请检查文件是否存在")
  712. json_data['block'] = build_block_data(img_rgb, blocks, model_path=model_path)
  713. if output_json_path is None:
  714. output_json_path = json_path
  715. with open(output_json_path, 'w', encoding='utf-8') as f:
  716. json.dump(json_data, f, indent=2, ensure_ascii=False)
  717. print(f"\n✅ 处理完成!")
  718. print(f" - 总房间块数:{len(json_data['block'])}")
  719. print(f" - JSON 保存至:{output_json_path}")
  720. print(f"\n📋 Block 详情:")
  721. print(f" {'ID':<6} {'Points':<10} {'Label':<20}")
  722. print(f" {'-' * 40}")
  723. for block in json_data['block']:
  724. points_info = f"{len(block['points']) // 2}个点"
  725. print(f" {block['id']:<6} {points_info:<10} {block['label']:<20}")
  726. return json_data
  727. def detect_furniture_in_image(rgb_img, model_path='furniture_detect.onnx'):
  728. """在内存中执行家具检测,返回家具列表。"""
  729. model = _load_yolo_model(model_path)
  730. if rgb_img is None:
  731. raise FileNotFoundError("无法读取 RGB 图像")
  732. results = model(rgb_img, conf=0.25, verbose=False)[0]
  733. allowed_labels = {'sofa', 'chair', 'desk', 'bed', 'window'}
  734. furniture_list = []
  735. if len(results.boxes) > 0:
  736. boxes_xyxy = results.boxes.xyxy.cpu().numpy()
  737. boxes_cls = results.boxes.cls.cpu().numpy()
  738. for i in range(len(boxes_xyxy)):
  739. bx1, by1, bx2, by2 = [int(v) for v in boxes_xyxy[i]]
  740. label = model.names[int(boxes_cls[i])]
  741. if label not in allowed_labels:
  742. continue
  743. furniture_list.append({
  744. 'id': len(furniture_list),
  745. 'label': label,
  746. 'center': [(bx1 + bx2) // 2, (by1 + by2) // 2],
  747. 'points': {
  748. 'x1': bx1, 'y1': by1,
  749. 'x2': bx2, 'y2': by1,
  750. 'x3': bx2, 'y3': by2,
  751. 'x4': bx1, 'y4': by2
  752. }
  753. })
  754. return furniture_list
  755. def detect_furniture(json_path, rgb_path, model_path='furniture_detect.onnx', output_json_path=None):
  756. """
  757. 使用 YOLO 模型检测家具,并将结果写入 JSON。
  758. """
  759. with open(json_path, 'r', encoding='utf-8') as f:
  760. json_data = json.load(f)
  761. img_rgb = cv2.imread(rgb_path)
  762. if img_rgb is None:
  763. raise FileNotFoundError(f"无法读取图片:{rgb_path}")
  764. json_data['furniture'] = detect_furniture_in_image(img_rgb, model_path=model_path)
  765. if output_json_path is None:
  766. output_json_path = json_path
  767. with open(output_json_path, 'w', encoding='utf-8') as f:
  768. json.dump(json_data, f, indent=2, ensure_ascii=False)
  769. print(f"✅ 家具检测完成,共检测到 {len(json_data['furniture'])} 个家具,JSON 保存至:{output_json_path}")
  770. return json_data
  771. def verify_final_json(json_path, output_verify_path):
  772. """
  773. 验证最终 JSON 文件,可视化所有数据
  774. 显示规则:
  775. - block: 白色多边形,中心显示 label 文字
  776. - connect_area: 绿色矩形
  777. 参数:
  778. json_path: JSON 文件路径
  779. output_verify_path: 验证结果图片保存路径
  780. """
  781. # 1. 读取 JSON 数据
  782. with open(json_path, 'r', encoding='utf-8') as f:
  783. data = json.load(f)
  784. print(f"📊 读取 JSON 文件:{json_path}")
  785. print(f" - 图片尺寸:{data['image_size']['width']} x {data['image_size']['height']}")
  786. print(f" - 连接区域数:{len(data.get('connect_area', []))}")
  787. print(f" - 房间块数:{len(data.get('block', []))}")
  788. # 2. 创建画布(黑色背景)
  789. width = data['image_size']['width']
  790. height = data['image_size']['height']
  791. result = np.zeros((height, width, 3), dtype=np.uint8)
  792. # 3. 绘制 block(白色多边形 + 中心 label)
  793. print(f"\n📋 绘制 {len(data.get('block', []))} 个房间块:")
  794. for block in data.get('block', []):
  795. block_id = block['id']
  796. points = block['points']
  797. label = block['label']
  798. # 转换为多边形格式 [[x1,y1], [x2,y2], ...]
  799. if len(points) >= 6: # 至少 3 个点
  800. pts = np.array([[points[j], points[j + 1]] for j in range(0, len(points), 2)], dtype=np.int32)
  801. # 绘制白色填充多边形
  802. cv2.fillPoly(result, [pts], (255, 255, 255))
  803. # 绘制白色边框(方便看清边界)
  804. cv2.polylines(result, [pts], True, (200, 200, 200), 2)
  805. # 计算中心点
  806. M = cv2.moments(pts)
  807. if M['m00'] != 0:
  808. cx = int(M['m10'] / M['m00'])
  809. cy = int(M['m01'] / M['m00'])
  810. else:
  811. cx, cy = pts.mean(axis=0).astype(int)
  812. # 在中心位置绘制 label 文字(黑色文字 + 白色背景)
  813. text = f"{label}"
  814. (text_w, text_h), baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)
  815. # 绘制白色背景矩形
  816. cv2.rectangle(result,
  817. (cx - text_w // 2 - 5, cy - text_h // 2 - 5),
  818. (cx + text_w // 2 + 5, cy + text_h // 2 + 5),
  819. (255, 255, 255), -1)
  820. # 绘制黑色文字
  821. cv2.putText(result, text, (cx - text_w // 2, cy + text_h // 2),
  822. cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)
  823. print(f" ✅ Block {block_id}: {label} ({len(points) // 2}个点)")
  824. else:
  825. print(f" ⚠️ Block {block_id}: 点数不足,跳过")
  826. # 4. 绘制 connect_area(绿色矩形)
  827. print(f"\n📋 绘制 {len(data.get('connect_area', []))} 个连接区域:")
  828. for area in data.get('connect_area', []):
  829. area_id = area['id']
  830. x = area['x']
  831. y = area['y']
  832. w = area['w']
  833. h = area['h']
  834. block_pair = area['block_pair']
  835. # 绘制绿色填充矩形
  836. cv2.rectangle(result, (x, y), (x + w, y + h), (0, 255, 0), -1)
  837. # 绘制红色边框(方便看清边界)
  838. cv2.rectangle(result, (x, y), (x + w, y + h), (0, 0, 255), 2)
  839. print(f" ✅ Area {area_id}: ({x}, {y}) {w}x{h} 连接块 [{block_pair[0]}, {block_pair[1]}]")
  840. # 5. 添加图例说明
  841. legend_y = 30
  842. cv2.putText(result, "Legend:", (10, legend_y), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
  843. # 白色方块 - block
  844. cv2.rectangle(result, (10, legend_y + 10), (30, legend_y + 30), (255, 255, 255), -1)
  845. cv2.putText(result, "Room Block", (40, legend_y + 27), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
  846. # 绿色方块 - connect_area
  847. cv2.rectangle(result, (150, legend_y + 10), (170, legend_y + 30), (0, 255, 0), -1)
  848. cv2.putText(result, "Connect Area", (180, legend_y + 27), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
  849. # 6. 保存验证图片
  850. cv2.imwrite(output_verify_path, result)
  851. print(f"\n✅ 验证图片保存至:{output_verify_path}")
  852. return data
  853. # ===========================================================================
  854. # PART 2: 深光 refine (内联 txt2json)
  855. # ===========================================================================
  856. # ================= 1. 工具函数 =================
  857. def to_key(p, precision=1):
  858. return (round(float(p[0]), precision), round(float(p[1]), precision))
  859. def is_orthogonal(seg, thresh=1e-1):
  860. dx, dy = abs(seg[2] - seg[0]), abs(seg[3] - seg[1])
  861. return dx < thresh or dy < thresh
  862. def get_line_intersection(s1, s2_point, s2_dir):
  863. x1, y1, x2, y2 = s1
  864. dx1, dy1 = x2 - x1, y2 - y1
  865. x3, y3 = s2_point
  866. vx, vy = s2_dir
  867. det = vx * dy1 - vy * dx1
  868. if abs(det) < 1e-6:
  869. return [float(x2), float(y2)]
  870. t1 = (vx * (y3 - y1) - vy * (x3 - x1)) / det
  871. return [float(x1 + t1 * dx1), float(y1 + t1 * dy1)]
  872. def is_loop_inside(small_loop, large_loop):
  873. poly = np.array([[s[0], s[1]] for s in large_loop], dtype=np.float32)
  874. for s in small_loop:
  875. pt = (float(s[0]), float(s[1]))
  876. dist = cv2.pointPolygonTest(poly, pt, False)
  877. if dist < 0:
  878. return False
  879. return True
  880. # ================= 2. 拓扑基础逻辑 =================
  881. def group_segments_by_connectivity(segments, stitch_dist=5.0):
  882. num_segments = len(segments)
  883. parent = list(range(num_segments))
  884. def find(i):
  885. if parent[i] == i: return i
  886. parent[i] = find(parent[i])
  887. return parent[i]
  888. def union(i, j):
  889. root_i, root_j = find(i), find(j)
  890. if root_i != root_j: parent[root_i] = root_j
  891. for i in range(num_segments):
  892. for j in range(i + 1, num_segments):
  893. p1_set = [segments[i][:2], segments[i][2:]]
  894. p2_set = [segments[j][:2], segments[j][2:]]
  895. if any(np.linalg.norm(np.array(a) - np.array(b)) < stitch_dist for a in p1_set for b in p2_set):
  896. union(i, j)
  897. groups = {}
  898. for i in range(num_segments):
  899. groups.setdefault(find(i), []).append(segments[i])
  900. return list(groups.values())
  901. def orthogonalize_and_move_nodes(segment_groups, angle_thresh_deg=15):
  902. refined_groups = []
  903. for group in segment_groups:
  904. node_pool = {}
  905. def get_shared_node(p):
  906. pk = to_key(p, precision=1)
  907. for ex_pk, obj in node_pool.items():
  908. if np.linalg.norm(np.array(pk) - np.array(ex_pk)) < 2.5: return obj
  909. new_node = np.array(p, dtype=np.float32)
  910. node_pool[pk] = new_node
  911. return new_node
  912. graph_segs = [(get_shared_node(s[:2]), get_shared_node(s[2:])) for s in group]
  913. for _ in range(3):
  914. for p1, p2 in graph_segs:
  915. dx, dy = abs(p2[0] - p1[0]), abs(p2[1] - p1[1])
  916. angle = math.degrees(math.atan2(dy, dx))
  917. if angle < angle_thresh_deg or angle > (180 - angle_thresh_deg):
  918. y_avg = (p1[1] + p2[1]) / 2
  919. p1[1] = p2[1] = y_avg
  920. elif abs(angle - 90) < angle_thresh_deg:
  921. x_avg = (p1[0] + p2[0]) / 2
  922. p1[0] = p2[0] = x_avg
  923. refined_groups.append([[p1[0], p1[1], p2[0], p2[1]] for p1, p2 in graph_segs])
  924. return refined_groups
  925. def bridge_isolated_endpoints(all_segments, snap_thresh=25.0):
  926. degree_map = {}
  927. for s in all_segments:
  928. for p in [tuple(s[:2]), tuple(s[2:])]:
  929. k = to_key(p, precision=1)
  930. degree_map.setdefault(k, []).append(p)
  931. isolated_pts = [nodes[0] for k, nodes in degree_map.items() if len(nodes) == 1]
  932. new_lines = []
  933. used = set()
  934. for i in range(len(isolated_pts)):
  935. if i in used: continue
  936. best_j, min_d = -1, snap_thresh
  937. for j in range(i + 1, len(isolated_pts)):
  938. if j in used: continue
  939. d = np.linalg.norm(np.array(isolated_pts[i]) - np.array(isolated_pts[j]))
  940. if d < min_d: min_d, best_j = d, j
  941. if best_j != -1:
  942. p1, p2 = isolated_pts[i], isolated_pts[best_j]
  943. new_lines.append([p1[0], p1[1], p2[0], p2[1]])
  944. used.add(i)
  945. used.add(best_j)
  946. return new_lines
  947. def find_max_length_loop(group, dist_thresh=2.0):
  948. if not group: return []
  949. nodes = []
  950. def get_node_idx(p):
  951. for i, target in enumerate(nodes):
  952. if np.linalg.norm(np.array(p) - target) < dist_thresh: return i
  953. nodes.append(np.array(p))
  954. return len(nodes) - 1
  955. edges = []
  956. for s in group:
  957. u, v = get_node_idx(s[:2]), get_node_idx(s[2:])
  958. if u == v: continue
  959. edges.append({'u': u, 'v': v, 'len': np.linalg.norm(np.array(s[:2]) - np.array(s[2:])), 'data': s})
  960. graph = {}
  961. for i, e in enumerate(edges):
  962. graph.setdefault(e['u'], []).append(i)
  963. graph.setdefault(e['v'], []).append(i)
  964. memo = {"best": [], "max": 0}
  965. def dfs(curr, start, path, length, visited):
  966. for idx in graph.get(curr, []):
  967. nxt = edges[idx]['v'] if edges[idx]['u'] == curr else edges[idx]['u']
  968. if nxt == start and len(path) >= 2:
  969. if length + edges[idx]['len'] > memo["max"]:
  970. memo["max"] = length + edges[idx]['len']
  971. memo["best"] = path + [idx]
  972. continue
  973. if nxt not in visited:
  974. dfs(nxt, start, path + [idx], length + edges[idx]['len'], visited | {nxt})
  975. starts = [n for n in graph if len(graph[n]) >= 2]
  976. for s_node in starts[:4]:
  977. dfs(s_node, s_node, [], 0, {s_node})
  978. if not memo["best"]: return []
  979. res = []
  980. e0, e1 = edges[memo["best"][0]], edges[memo["best"][1]]
  981. common = e1['u'] if e1['u'] in (e0['u'], e0['v']) else e1['v']
  982. curr_n = e0['v'] if e0['u'] == common else e0['u']
  983. for idx in memo["best"]:
  984. e = edges[idx]
  985. s = e['data']
  986. if get_node_idx(s[:2]) == curr_n:
  987. res.append(list(s))
  988. curr_n = get_node_idx(s[2:])
  989. else:
  990. res.append([s[2], s[3], s[0], s[1]])
  991. curr_n = get_node_idx(s[:2])
  992. return res
  993. # ================= 3. 整流逻辑 =================
  994. def apply_user_refinement_with_moving(group):
  995. if len(group) < 2: return group
  996. start_idx = -1
  997. for i, s in enumerate(group):
  998. if is_orthogonal(s):
  999. start_idx = i
  1000. break
  1001. if start_idx == -1: return group
  1002. w_list = [list(s) for s in (group[start_idx:] + group[:start_idx])]
  1003. refined = []
  1004. i = 0
  1005. while i < len(w_list):
  1006. curr_seg = w_list[i]
  1007. refined.append(curr_seg)
  1008. nxt_idx = i + 1
  1009. if nxt_idx >= len(w_list): break
  1010. if not is_orthogonal(w_list[nxt_idx]):
  1011. p_idx = nxt_idx
  1012. found_lp = False
  1013. while p_idx < len(w_list):
  1014. if is_orthogonal(w_list[p_idx]):
  1015. found_lp = True
  1016. break
  1017. p_idx += 1
  1018. if found_lp:
  1019. lp = w_list[p_idx]
  1020. p2_end, lp_start = [curr_seg[2], curr_seg[3]], [lp[0], lp[1]]
  1021. is_h1, is_h2 = abs(curr_seg[3] - curr_seg[1]) < 1e-1, abs(lp[3] - lp[1]) < 1e-1
  1022. if is_h1 == is_h2:
  1023. p_mid = [p2_end[0], lp_start[1]] if is_h1 else [lp_start[0], p2_end[1]]
  1024. else:
  1025. p_mid = [lp_start[0], p2_end[1]] if is_h1 else [p2_end[0], lp_start[1]]
  1026. lp[0], lp[1] = p_mid[0], p_mid[1]
  1027. refined.append([p2_end[0], p2_end[1], lp[0], lp[1]])
  1028. i = p_idx
  1029. else:
  1030. refined.extend(w_list[nxt_idx:])
  1031. break
  1032. else:
  1033. i += 1
  1034. return refined
  1035. def merge_collinear_segments(ordered_group, dist_thresh=0.5):
  1036. if len(ordered_group) < 2: return ordered_group
  1037. merged, curr_seg = [], list(ordered_group[0])
  1038. for i in range(1, len(ordered_group)):
  1039. next_seg = ordered_group[i]
  1040. is_h = abs(curr_seg[1] - curr_seg[3]) < dist_thresh and abs(next_seg[1] - next_seg[3]) < dist_thresh and abs(
  1041. curr_seg[3] - next_seg[1]) < dist_thresh
  1042. is_v = abs(curr_seg[0] - curr_seg[2]) < dist_thresh and abs(next_seg[0] - next_seg[2]) < dist_thresh and abs(
  1043. curr_seg[2] - next_seg[0]) < dist_thresh
  1044. if is_h or is_v:
  1045. curr_seg[2], curr_seg[3] = next_seg[2], next_seg[3]
  1046. else:
  1047. merged.append(curr_seg)
  1048. curr_seg = list(next_seg)
  1049. merged.append(curr_seg)
  1050. return merged
  1051. def merge_parallel_lines_topology(group, dist_thresh=15.0):
  1052. if not group: return group
  1053. segs = [list(s) for s in group]
  1054. def get_overlap(min1, max1, min2, max2):
  1055. overlap_min = max(min1, min2)
  1056. overlap_max = min(max1, max2)
  1057. return overlap_max - overlap_min if overlap_max > overlap_min else 0
  1058. iteration = 0
  1059. while True:
  1060. merged_in_this_round = False
  1061. merged_indices = set()
  1062. for i in range(len(segs)):
  1063. if i in merged_indices: continue
  1064. for j in range(i + 1, len(segs)):
  1065. if j in merged_indices: continue
  1066. s1, s2 = segs[i], segs[j]
  1067. is_h1, is_h2 = abs(s1[1] - s1[3]) < 1e-1, abs(s2[1] - s2[3]) < 1e-1
  1068. is_v1, is_v2 = abs(s1[0] - s1[2]) < 1e-1, abs(s2[0] - s2[2]) < 1e-1
  1069. should_merge = False
  1070. new_pos = 0
  1071. if is_h1 and is_h2:
  1072. dist = abs(s1[1] - s2[1])
  1073. overlap = get_overlap(min(s1[0], s1[2]), max(s1[0], s1[2]), min(s2[0], s2[2]), max(s2[0], s2[2]))
  1074. if dist < dist_thresh and overlap > 0:
  1075. should_merge = True
  1076. new_pos = (s1[1] + s2[1]) / 2
  1077. old_y1, old_y2 = s1[1], s2[1]
  1078. for s in segs:
  1079. if abs(s[1] - old_y1) < 1e-1 or abs(s[1] - old_y2) < 1e-1: s[1] = new_pos
  1080. if abs(s[3] - old_y1) < 1e-1 or abs(s[3] - old_y2) < 1e-1: s[3] = new_pos
  1081. elif is_v1 and is_v2:
  1082. dist = abs(s1[0] - s2[0])
  1083. overlap = get_overlap(min(s1[1], s1[3]), max(s1[1], s1[3]), min(s2[1], s2[3]), max(s2[1], s2[3]))
  1084. if dist < dist_thresh and overlap > 0:
  1085. should_merge = True
  1086. new_pos = (s1[0] + s2[0]) / 2
  1087. old_x1, old_x2 = s1[0], s2[0]
  1088. for s in segs:
  1089. if abs(s[0] - old_x1) < 1e-1 or abs(s[0] - old_x2) < 1e-1: s[0] = new_pos
  1090. if abs(s[2] - old_x1) < 1e-1 or abs(s[2] - old_x2) < 1e-1: s[2] = new_pos
  1091. if should_merge:
  1092. s1[0], s1[2] = min(s1[0], s1[2], s2[0], s2[2]), max(s1[0], s1[2], s2[0], s2[2])
  1093. s1[1], s1[3] = min(s1[1], s1[3], s2[1], s2[3]), max(s1[1], s1[3], s2[1], s2[3])
  1094. merged_indices.add(j)
  1095. merged_in_this_round = True
  1096. break
  1097. if merged_in_this_round: break
  1098. if merged_indices:
  1099. segs = [s for idx, s in enumerate(segs) if idx not in merged_indices]
  1100. if not merged_in_this_round: break
  1101. iteration += 1
  1102. if iteration > 100: break
  1103. return segs
  1104. def get_line_angle_0_90(s1, s2):
  1105. v1 = (s1[2] - s1[0], s1[3] - s1[1])
  1106. v2 = (s2[2] - s2[0], s2[3] - s2[1])
  1107. mag1 = math.sqrt(v1[0] ** 2 + v1[1] ** 2)
  1108. mag2 = math.sqrt(v2[0] ** 2 + v2[1] ** 2)
  1109. if mag1 < 1e-6 or mag2 < 1e-6: return 0
  1110. dot_product = abs(v1[0] * v2[0] + v1[1] * v2[1])
  1111. cos_theta = dot_product / (mag1 * mag2)
  1112. cos_theta = max(-1.0, min(1.0, cos_theta))
  1113. return math.acos(cos_theta)
  1114. def apply_user_refinement_with_moving_refine_distance_angle(group, length_threshold=30, angle_threshold_deg=30):
  1115. if len(group) < 2: return group
  1116. angle_thresh_rad = math.radians(angle_threshold_deg)
  1117. start_idx = -1
  1118. for i, s in enumerate(group):
  1119. if is_orthogonal(s):
  1120. start_idx = i
  1121. break
  1122. if start_idx == -1: return group
  1123. w_list = [list(s) for s in (group[start_idx:] + group[:start_idx])]
  1124. refined = []
  1125. i = 0
  1126. while i < len(w_list):
  1127. curr_seg = w_list[i]
  1128. refined.append(curr_seg)
  1129. nxt_idx = i + 1
  1130. if nxt_idx >= len(w_list): break
  1131. if not is_orthogonal(w_list[nxt_idx]):
  1132. p_idx = nxt_idx
  1133. found_lp = False
  1134. total_gap_len = 0
  1135. gap_segs = []
  1136. while p_idx < len(w_list):
  1137. if is_orthogonal(w_list[p_idx]):
  1138. found_lp = True
  1139. break
  1140. seg_tmp = w_list[p_idx]
  1141. total_gap_len += math.sqrt((seg_tmp[2] - seg_tmp[0]) ** 2 + (seg_tmp[3] - seg_tmp[1]) ** 2)
  1142. gap_segs.append(seg_tmp)
  1143. p_idx += 1
  1144. if found_lp:
  1145. lp = w_list[p_idx]
  1146. if total_gap_len < length_threshold:
  1147. p2_end, lp_start = [curr_seg[2], curr_seg[3]], [lp[0], lp[1]]
  1148. is_h1, is_h2 = abs(curr_seg[3] - curr_seg[1]) < 1e-1, abs(lp[3] - lp[1]) < 1e-1
  1149. if is_h1 == is_h2:
  1150. p_mid = [p2_end[0], lp_start[1]] if is_h1 else [lp_start[0], p2_end[1]]
  1151. else:
  1152. p_mid = [lp_start[0], p2_end[1]] if is_h1 else [p2_end[0], lp_start[1]]
  1153. lp[0], lp[1] = p_mid[0], p_mid[1]
  1154. refined.append([p2_end[0], p2_end[1], lp[0], lp[1]])
  1155. i = p_idx
  1156. else:
  1157. gap_segs[0][0], gap_segs[0][1] = curr_seg[2], curr_seg[3]
  1158. merged_gap = []
  1159. temp_s = list(gap_segs[0])
  1160. for k in range(1, len(gap_segs)):
  1161. next_s = gap_segs[k]
  1162. if get_line_angle_0_90(temp_s, next_s) < angle_thresh_rad:
  1163. temp_s[2], temp_s[3] = next_s[2], next_s[3]
  1164. else:
  1165. merged_gap.append(temp_s)
  1166. temp_s = list(next_s)
  1167. temp_s[0], temp_s[1] = merged_gap[-1][2], merged_gap[-1][3]
  1168. merged_gap.append(temp_s)
  1169. merged_gap[-1][2], merged_gap[-1][3] = lp[0], lp[1]
  1170. refined.extend(merged_gap)
  1171. i = p_idx
  1172. else:
  1173. refined.extend(w_list[nxt_idx:])
  1174. break
  1175. else:
  1176. i += 1
  1177. return refined
  1178. # ================= 4. 针对单个 Block 的处理流程 =================
  1179. def refine_single_block_segments(segments):
  1180. if not segments: return []
  1181. temp_groups = [segments]
  1182. temp_groups = orthogonalize_and_move_nodes(temp_groups, angle_thresh_deg=15)
  1183. segments = temp_groups[0]
  1184. rect_loop = apply_user_refinement_with_moving_refine_distance_angle(segments, length_threshold=30,
  1185. angle_threshold_deg=30)
  1186. rect_loop = merge_collinear_segments(rect_loop, dist_thresh=2)
  1187. rect_loop = merge_parallel_lines_topology(rect_loop, dist_thresh=12)
  1188. return rect_loop
  1189. # ================= 5. 主处理流程 =================
  1190. def refine_blocks_in_data(data):
  1191. """将 block.points 从扁平点集转换为整流后的线段格式。"""
  1192. blocks = data.get("block", data.get("blocks", []))
  1193. if not blocks:
  1194. if isinstance(data, list):
  1195. blocks = data
  1196. else:
  1197. print("⚠️ JSON 中未找到 'block' 或 'blocks' 字段")
  1198. return data
  1199. print(f"🚀 开始处理 {len(blocks)} 个块...")
  1200. total_segments = 0
  1201. for idx, block in enumerate(tqdm(blocks, desc="Refining Blocks")):
  1202. block_id = block.get('id', idx)
  1203. points_data = block.get("points", [])
  1204. if not points_data:
  1205. print(f"⚠️ Block {block_id} 无 points,跳过")
  1206. continue
  1207. if isinstance(points_data[0], list) and len(points_data[0]) == 4:
  1208. source_segments = [[float(v) for v in seg] for seg in points_data]
  1209. else:
  1210. if len(points_data) < 4:
  1211. print(f"⚠️ Block {block_id} 点数不足,跳过")
  1212. continue
  1213. points = []
  1214. for i in range(0, len(points_data), 2):
  1215. if i + 1 < len(points_data):
  1216. points.append([float(points_data[i]), float(points_data[i + 1])])
  1217. if len(points) < 2:
  1218. continue
  1219. source_segments = []
  1220. for i in range(len(points)):
  1221. p1 = points[i]
  1222. p2 = points[(i + 1) % len(points)]
  1223. source_segments.append([p1[0], p1[1], p2[0], p2[1]])
  1224. refined_segments = refine_single_block_segments(source_segments)
  1225. if not refined_segments:
  1226. refined_segments = source_segments
  1227. segments_int = []
  1228. for seg in refined_segments:
  1229. segments_int.append([
  1230. int(round(seg[0])),
  1231. int(round(seg[1])),
  1232. int(round(seg[2])),
  1233. int(round(seg[3]))
  1234. ])
  1235. block["points"] = segments_int
  1236. block["refined"] = True
  1237. block["format"] = "segments"
  1238. block["segment_count"] = len(segments_int)
  1239. total_segments += len(segments_int)
  1240. data["format_version"] = "segments_v1"
  1241. data["total_segments"] = total_segments
  1242. return data
  1243. def process_json_blocks(json_path, rgb_path, save_path, txt_folder, save_json=False, json_backup=True):
  1244. """
  1245. 直接保存绘制的 p1, p2 坐标到.txt 文件
  1246. """
  1247. # 1. 加载 JSON
  1248. try:
  1249. with open(json_path, 'r', encoding='utf-8') as f:
  1250. data = json.load(f)
  1251. except Exception as e:
  1252. print(f"❌ 读取 JSON 失败:{e}")
  1253. return
  1254. # 2. 加载背景图
  1255. output_img = cv2.imread(rgb_path)
  1256. if output_img is None:
  1257. print(f"❌ 无法读取背景图:{rgb_path}")
  1258. return
  1259. # 3. 获取 Blocks 列表
  1260. blocks = data.get("block", data.get("blocks", []))
  1261. if not blocks:
  1262. if isinstance(data, list):
  1263. blocks = data
  1264. else:
  1265. print("⚠️ JSON 中未找到 'block' 或 'blocks' 字段")
  1266. return
  1267. print(f"🚀 开始处理 {len(blocks)} 个块...")
  1268. # 4. 创建 txt 文件夹(txt_folder=None 时跳过 txt 保存)
  1269. txt_path = None
  1270. if txt_folder is not None:
  1271. if not os.path.exists(txt_folder):
  1272. os.makedirs(txt_folder)
  1273. txt_filename = os.path.basename(save_path).replace('.png', '.txt')
  1274. txt_path = os.path.join(txt_folder, txt_filename)
  1275. # 6. 遍历每个块,独立处理
  1276. all_lines_data = [] # 存储所有线段数据用于保存 txt
  1277. for idx, block in enumerate(tqdm(blocks, desc="Refining Blocks")):
  1278. block_id = block.get('id', idx)
  1279. label = block.get('label', 'door')
  1280. points_flat = block.get("points", [])
  1281. if len(points_flat) < 4:
  1282. print(f"⚠️ Block {block_id} 点数不足,跳过")
  1283. continue
  1284. # 将扁平数组转换为点列表
  1285. points = []
  1286. for i in range(0, len(points_flat), 2):
  1287. if i + 1 < len(points_flat):
  1288. points.append([float(points_flat[i]), float(points_flat[i + 1])])
  1289. if len(points) < 2:
  1290. continue
  1291. # A. 将多边形点转换为线段
  1292. segments = []
  1293. for i in range(len(points)):
  1294. p1 = points[i]
  1295. p2 = points[(i + 1) % len(points)]
  1296. segments.append([p1[0], p1[1], p2[0], p2[1]])
  1297. # B. 对该块的线段进行完整整流优化
  1298. refined_segments = refine_single_block_segments(segments)
  1299. # C. 保存 JSON(可选)
  1300. if save_json and refined_segments:
  1301. segments_int = []
  1302. for seg in refined_segments:
  1303. segments_int.append([
  1304. int(round(seg[0])),
  1305. int(round(seg[1])),
  1306. int(round(seg[2])),
  1307. int(round(seg[3]))
  1308. ])
  1309. block["points"] = segments_int
  1310. block["refined"] = True
  1311. block["format"] = "segments"
  1312. # D. 【关键】绘制并记录坐标(与可视化完全一致)
  1313. color = np.random.randint(100, 255, (3,)).tolist()
  1314. for s in refined_segments:
  1315. p1 = (int(s[0]), int(s[1]))
  1316. p2 = (int(s[2]), int(s[3]))
  1317. # 绘制到图像
  1318. cv2.line(output_img, p1, p2, color, 2, cv2.LINE_AA)
  1319. cv2.circle(output_img, p1, 3, (255, 255, 255), -1)
  1320. # 记录到 txt 数据
  1321. # 格式:block_id,label,x1,y1,x2,y2
  1322. all_lines_data.append(f"{block_id},{label},{p1[0]},{p1[1]},{p2[0]},{p2[1]}")
  1323. # 7. 保存 txt 文件(可选)
  1324. if txt_path is not None:
  1325. with open(txt_path, 'w', encoding='utf-8') as f:
  1326. f.write(f"# {os.path.basename(txt_path)}\n")
  1327. f.write(f"# Format: block_id,label,x1,y1,x2,y2\n")
  1328. f.write(f"# Total lines: {len(all_lines_data)}\n")
  1329. f.write(f"#\n")
  1330. for line in all_lines_data:
  1331. f.write(line + "\n")
  1332. print(f"📄 线段数据已保存至:{txt_path}")
  1333. print(f" 共 {len(all_lines_data)} 条线段")
  1334. # 8. 保存图像
  1335. cv2.imwrite(save_path, output_img)
  1336. print(f"✨ 图像已保存至:{save_path}")
  1337. # 9. 保存 JSON(可选)
  1338. if save_json:
  1339. try:
  1340. if json_backup:
  1341. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  1342. backup_path = json_path.replace(".json", f"_backup_{timestamp}.json")
  1343. shutil.copy2(json_path, backup_path)
  1344. print(f"💾 原 JSON 已备份至:{backup_path}")
  1345. save_json_path = json_path
  1346. with open(save_json_path, 'w', encoding='utf-8') as f:
  1347. json.dump(data, f, indent=2, ensure_ascii=False)
  1348. print(f"💾 优化后的 JSON 已保存至:{save_json_path}")
  1349. except Exception as e:
  1350. print(f"❌ 保存 JSON 失败:{e}")
  1351. print(f"✨ 全部处理完成!")
  1352. # ================= 6. 主运行流程 =================
  1353. # ===========================================================================
  1354. # PART 3: txt2json — merge_txt_to_json
  1355. # ===========================================================================
  1356. def merge_txt_to_json(txt_path, json_path, output_path=None):
  1357. """
  1358. 将 txt 文件中的线段数据整合到 JSON 文件中
  1359. txt 格式:block_id,label,x1,y1,x2,y2
  1360. JSON 格式:points: [(x1,y1,x2,y2), (x1,y1,x2,y2), ...]
  1361. """
  1362. print(f"\n{'=' * 60}")
  1363. print(f"🔄 合并 txt 数据到 JSON")
  1364. print('=' * 60)
  1365. # 1. 读取 txt 文件
  1366. print(f"📄 读取 txt 文件:{txt_path}")
  1367. lines_data = []
  1368. with open(txt_path, 'r', encoding='utf-8') as f:
  1369. for line in f:
  1370. line = line.strip()
  1371. if line.startswith('#') or not line:
  1372. continue
  1373. parts = line.split(',')
  1374. if len(parts) == 6:
  1375. try:
  1376. block_id = int(parts[0])
  1377. label = parts[1]
  1378. x1, y1, x2, y2 = int(parts[2]), int(parts[3]), int(parts[4]), int(parts[5])
  1379. lines_data.append({
  1380. 'block_id': block_id,
  1381. 'label': label,
  1382. 'segment': (x1, y1, x2, y2)
  1383. })
  1384. except Exception as e:
  1385. print(f" ⚠️ 解析失败:{line} - {e}")
  1386. print(f" 共读取 {len(lines_data)} 条线段")
  1387. # 2. 按 block_id 分组线段
  1388. blocks_segments = defaultdict(list)
  1389. blocks_labels = {}
  1390. for item in lines_data:
  1391. block_id = item['block_id']
  1392. blocks_segments[block_id].append(item['segment'])
  1393. blocks_labels[block_id] = item['label']
  1394. print(f" 共 {len(blocks_segments)} 个 block")
  1395. # 3. 读取原始 JSON
  1396. print(f"\n📄 读取 JSON 文件:{json_path}")
  1397. with open(json_path, 'r', encoding='utf-8') as f:
  1398. data = json.load(f)
  1399. # 4. 更新 blocks 数据
  1400. blocks = data.get('block', data.get('blocks', []))
  1401. updated_count = 0
  1402. print(f"\n🔄 更新 {len(blocks)} 个 block...")
  1403. for block in blocks:
  1404. block_id = block.get('id', 0)
  1405. if block_id in blocks_segments:
  1406. # 将线段列表转换为 JSON 格式
  1407. # 使用列表而不是元组(JSON 不支持元组)
  1408. segments_list = []
  1409. for seg in blocks_segments[block_id]:
  1410. segments_list.append(list(seg)) # (x1,y1,x2,y2) → [x1,y1,x2,y2]
  1411. # 更新 points 字段
  1412. block['points'] = segments_list
  1413. block['refined'] = True
  1414. block['format'] = 'segments' # 标记格式
  1415. block['segment_count'] = len(segments_list)
  1416. # 更新 label(如果 txt 中有)
  1417. if block_id in blocks_labels:
  1418. block['label'] = blocks_labels[block_id]
  1419. updated_count += 1
  1420. print(f" ✅ Block {block_id}: {len(segments_list)} 线段")
  1421. else:
  1422. print(f" ⚠️ Block {block_id}: 未找到对应线段数据")
  1423. # 5. 添加合并标记
  1424. data['txt_merged'] = True
  1425. data['total_segments'] = len(lines_data)
  1426. data['format_version'] = 'segments_v1'
  1427. # 6. 保存结果
  1428. if output_path is None:
  1429. output_path = json_path
  1430. print(f"\n💾 保存至:{output_path}")
  1431. with open(output_path, 'w', encoding='utf-8') as f:
  1432. json.dump(data, f, indent=2, ensure_ascii=False)
  1433. print(f"\n✨ 合并完成!")
  1434. print(f" - 更新 block 数量:{updated_count}/{len(blocks)}")
  1435. print(f" - 总线段数:{len(lines_data)}")
  1436. return data
  1437. def verify_json_segments(json_path):
  1438. """
  1439. 验证 JSON 中的线段格式是否正确
  1440. """
  1441. print(f"\n{'=' * 60}")
  1442. print(f"🔍 验证 JSON 线段格式")
  1443. print('=' * 60)
  1444. with open(json_path, 'r', encoding='utf-8') as f:
  1445. data = json.load(f)
  1446. blocks = data.get('block', data.get('blocks', []))
  1447. for block in blocks:
  1448. block_id = block.get('id', 0)
  1449. label = block.get('label', 'door')
  1450. points = block.get('points', [])
  1451. # 检查格式
  1452. if isinstance(points, list) and len(points) > 0:
  1453. if isinstance(points[0], list) and len(points[0]) == 4:
  1454. # 正确的线段格式
  1455. print(f" ✅ Block {block_id} ({label}): {len(points)} 线段 [x1,y1,x2,y2]")
  1456. elif isinstance(points[0], (int, float)):
  1457. # 旧的扁平数组格式
  1458. print(f" ⚠️ Block {block_id} ({label}): 扁平数组格式 ({len(points) // 2} 点)")
  1459. else:
  1460. print(f" ❌ Block {block_id} ({label}): 未知格式")
  1461. else:
  1462. print(f" ❌ Block {block_id} ({label}): 无数据")
  1463. print(f"\n✨ 验证完成!")
  1464. # ================= 主运行流程 =================
  1465. # ===========================================================================
  1466. # PART 4: merge_segments
  1467. # ===========================================================================
  1468. from tqdm import tqdm
  1469. def tqdm(x, **kw):
  1470. return x
  1471. # ---------------------------------------------------------------------------
  1472. # Step 1: Angle normalization
  1473. # ---------------------------------------------------------------------------
  1474. def normalize_segment(seg, angle=10):
  1475. """
  1476. If a segment is nearly horizontal or nearly vertical, snap it.
  1477. angle: tolerance in degrees
  1478. - |theta| <= angle or |theta| >= 180-angle -> horizontal (fix y to midpoint)
  1479. - |90 - |theta|| <= angle -> vertical (fix x to midpoint)
  1480. seg: [x1, y1, x2, y2]
  1481. Returns new [x1, y1, x2, y2].
  1482. """
  1483. x1, y1, x2, y2 = seg
  1484. dx, dy = x2 - x1, y2 - y1
  1485. if dx == 0 and dy == 0:
  1486. return seg
  1487. theta = abs(math.degrees(math.atan2(abs(dy), abs(dx)))) # 0..90
  1488. if theta <= angle: # near horizontal
  1489. mid_y = round((y1 + y2) / 2)
  1490. return [x1, mid_y, x2, mid_y]
  1491. if theta >= 90 - angle: # near vertical
  1492. mid_x = round((x1 + x2) / 2)
  1493. return [mid_x, y1, mid_x, y2]
  1494. return seg
  1495. def normalize_all(data, angle=10):
  1496. for block in data.get('block', []):
  1497. block['points'] = [normalize_segment(s, angle) for s in block['points']]
  1498. # ---------------------------------------------------------------------------
  1499. # Step 2: Cluster-based parallel merge (mean coordinate)
  1500. # ---------------------------------------------------------------------------
  1501. def segment_orientation(seg):
  1502. x1, y1, x2, y2 = seg
  1503. if y1 == y2:
  1504. return 'H'
  1505. if x1 == x2:
  1506. return 'V'
  1507. return None
  1508. def snap_coord_global(all_segs, old_vals, new_val, ori):
  1509. """
  1510. Replace all occurrences of any value in old_vals with new_val
  1511. for the relevant axis across all segments.
  1512. ori='H': affects y-coords; ori='V': affects x-coords.
  1513. """
  1514. old_set = set(old_vals)
  1515. for k, s in enumerate(all_segs):
  1516. if ori == 'H':
  1517. if segment_orientation(s) == 'H':
  1518. if s[1] in old_set:
  1519. all_segs[k][1] = new_val
  1520. all_segs[k][3] = new_val
  1521. else: # V seg: update y endpoints
  1522. if s[1] in old_set:
  1523. all_segs[k][1] = new_val
  1524. if s[3] in old_set:
  1525. all_segs[k][3] = new_val
  1526. else: # ori == 'V'
  1527. if segment_orientation(s) == 'V':
  1528. if s[0] in old_set:
  1529. all_segs[k][0] = new_val
  1530. all_segs[k][2] = new_val
  1531. else: # H seg: update x endpoints
  1532. if s[0] in old_set:
  1533. all_segs[k][0] = new_val
  1534. if s[2] in old_set:
  1535. all_segs[k][2] = new_val
  1536. def segments_overlap(sa, sb, ori):
  1537. """
  1538. Check if two parallel segments have overlapping projection on the perpendicular axis.
  1539. ori='H': compare x ranges; ori='V': compare y ranges.
  1540. """
  1541. if ori == 'H':
  1542. a1, a2 = min(sa[0], sa[2]), max(sa[0], sa[2])
  1543. b1, b2 = min(sb[0], sb[2]), max(sb[0], sb[2])
  1544. else:
  1545. a1, a2 = min(sa[1], sa[3]), max(sa[1], sa[3])
  1546. b1, b2 = min(sb[1], sb[3]), max(sb[1], sb[3])
  1547. return a1 <= b2 and b1 <= a2
  1548. def cluster_and_merge(all_segs, ori, threshold):
  1549. """
  1550. Overlap-aware parallel merge:
  1551. Two parallel segments are merged only if:
  1552. 1. Their axis-coordinate distance < threshold, AND
  1553. 2. Their projections on the perpendicular axis overlap.
  1554. Uses union-find to group segments into clusters, then snaps each
  1555. cluster to the mean coordinate.
  1556. Returns True if any snapping occurred.
  1557. """
  1558. # Indices of segments with this orientation
  1559. idxs = [i for i, s in enumerate(all_segs) if segment_orientation(s) == ori]
  1560. if not idxs:
  1561. return False
  1562. # Union-Find
  1563. parent = {i: i for i in idxs}
  1564. def find(x):
  1565. while parent[x] != x:
  1566. parent[x] = parent[parent[x]]
  1567. x = parent[x]
  1568. return x
  1569. def union(x, y):
  1570. parent[find(x)] = find(y)
  1571. for ii in range(len(idxs)):
  1572. for jj in range(ii + 1, len(idxs)):
  1573. i, j = idxs[ii], idxs[jj]
  1574. si, sj = all_segs[i], all_segs[j]
  1575. coord_i = si[1] if ori == 'H' else si[0]
  1576. coord_j = sj[1] if ori == 'H' else sj[0]
  1577. if abs(coord_i - coord_j) <= threshold and segments_overlap(si, sj, ori):
  1578. union(i, j)
  1579. # Group by cluster root
  1580. from collections import defaultdict
  1581. clusters = defaultdict(list)
  1582. for i in idxs:
  1583. clusters[find(i)].append(i)
  1584. changed = False
  1585. for members in clusters.values():
  1586. if len(members) < 2:
  1587. continue
  1588. coords = [all_segs[i][1] if ori == 'H' else all_segs[i][0] for i in members]
  1589. old_vals = list(set(coords))
  1590. if len(old_vals) == 1: # 已全部对齐,无需操作
  1591. continue
  1592. new_val = round(sum(coords) / len(coords))
  1593. snap_coord_global(all_segs, old_vals, new_val, ori)
  1594. changed = True
  1595. return changed
  1596. def seg_key(s):
  1597. p1 = (s[0], s[1])
  1598. p2 = (s[2], s[3])
  1599. return (min(p1, p2), max(p1, p2))
  1600. def remove_degenerate_and_duplicates(segs):
  1601. seen = set()
  1602. result = []
  1603. for s in segs:
  1604. if s[0] == s[2] and s[1] == s[3]:
  1605. continue
  1606. k = seg_key(s)
  1607. if k in seen:
  1608. continue
  1609. seen.add(k)
  1610. result.append(s)
  1611. return result
  1612. def merge_all_blocks(data, threshold=6):
  1613. """
  1614. Global cross-block parallel merge using cluster mean strategy.
  1615. Iterates H and V directions until no more clusters can be merged.
  1616. """
  1617. blocks = data.get('block', [])
  1618. all_segs = []
  1619. block_counts = []
  1620. for block in blocks:
  1621. segs = [list(s) for s in block['points']]
  1622. all_segs.extend(segs)
  1623. block_counts.append(len(segs))
  1624. # Iterate until stable
  1625. changed = True
  1626. while changed:
  1627. ch_h = cluster_and_merge(all_segs, 'H', threshold)
  1628. ch_v = cluster_and_merge(all_segs, 'V', threshold)
  1629. changed = ch_h or ch_v
  1630. # Write back to blocks
  1631. idx = 0
  1632. for block, count in zip(blocks, block_counts):
  1633. block_segs = remove_degenerate_and_duplicates(all_segs[idx:idx + count])
  1634. block['points'] = block_segs
  1635. block['segment_count'] = len(block_segs)
  1636. idx += count
  1637. data['total_segments'] = sum(len(b['points']) for b in blocks)
  1638. # ---------------------------------------------------------------------------
  1639. # Pipeline
  1640. # ---------------------------------------------------------------------------
  1641. def process_json(input_path, output_path, threshold=6, angle=10):
  1642. with open(input_path, 'r') as f:
  1643. data = json.load(f)
  1644. normalize_all(data, angle=angle)
  1645. merge_all_blocks(data, threshold=threshold)
  1646. with open(output_path, 'w') as f:
  1647. json.dump(data, f, indent=2)
  1648. # ===========================================================================
  1649. # PART 5: vis
  1650. # ===========================================================================
  1651. # ================= 1. 颜色映射 =================
  1652. COLOR_MAP = {
  1653. "living_room": (255, 180, 100), # BGR: 橙色
  1654. "bed_room": (100, 255, 100), # BGR: 绿色
  1655. "bath_room": (255, 100, 255), # BGR: 紫色
  1656. "kitchen_room": (100, 255, 255), # BGR: 黄色
  1657. "other_room": (180, 180, 180), # BGR: 灰色
  1658. "balcony": (200, 150, 100), # BGR: 棕色
  1659. }
  1660. DEFAULT_COLOR = (200, 200, 200)
  1661. # ================= 2. 工具函数 =================
  1662. def get_image_size(data):
  1663. """从 JSON 数据获取图像尺寸"""
  1664. if 'image_size' in data:
  1665. return data['image_size']['width'], data['image_size']['height']
  1666. # 从线段数据推断
  1667. max_x, max_y = 0, 0
  1668. for block in data.get('block', []):
  1669. segments = block.get('points', [])
  1670. if isinstance(segments, list) and len(segments) > 0:
  1671. if isinstance(segments[0], list) and len(segments[0]) == 4:
  1672. for seg in segments:
  1673. max_x = max(max_x, seg[0], seg[2])
  1674. max_y = max(max_y, seg[1], seg[3])
  1675. for area in data.get('connect_area', []):
  1676. max_x = max(max_x, area.get('x', 0) + area.get('w', 0))
  1677. max_y = max(max_y, area.get('y', 0) + area.get('h', 0))
  1678. return int(max_x) + 100, int(max_y) + 100
  1679. def get_points_from_segments(segments):
  1680. """从线段列表提取所有唯一点"""
  1681. if not segments:
  1682. return []
  1683. points_set = set()
  1684. for seg in segments:
  1685. if len(seg) == 4:
  1686. points_set.add((int(seg[0]), int(seg[1])))
  1687. points_set.add((int(seg[2]), int(seg[3])))
  1688. return list(points_set)
  1689. # ================= 3. 主可视化函数 =================
  1690. def visualize_final_json(json_path, output_path, rgb_path=None, vis_door=False):
  1691. """
  1692. 最终可视化:线段格式 JSON
  1693. 修改点:
  1694. 1. 房间端点:白色点 (255, 255, 255)
  1695. 2. 标签位置:使用 JSON 中的 center 字段
  1696. """
  1697. print(f"\n{'=' * 60}")
  1698. print(f"🎨 最终可视化:{os.path.basename(json_path)}")
  1699. print('=' * 60)
  1700. # 1. 读取 JSON
  1701. with open(json_path, 'r', encoding='utf-8') as f:
  1702. data = json.load(f)
  1703. # 2. 获取尺寸并创建画布
  1704. width, height = get_image_size(data)
  1705. if rgb_path and os.path.exists(rgb_path):
  1706. canvas = cv2.imread(rgb_path)
  1707. print(f"🖼️ 使用背景图:{rgb_path}")
  1708. else:
  1709. # canvas = np.ones((height, width, 3), dtype=np.uint8) * 255
  1710. canvas = np.zeros((height, width, 3), dtype=np.uint8)
  1711. print(f"📐 创建白色画布:{width}x{height}")
  1712. # 3. 绘制房间块 (Blocks)
  1713. blocks = data.get('block', [])
  1714. print(f"\n🏠 绘制 {len(blocks)} 个房间块...")
  1715. for block_idx, block in enumerate(blocks):
  1716. block_id = block.get('id', block_idx)
  1717. label = block.get('label', 'door')
  1718. segments = block.get('points', [])
  1719. center = block.get('center', None) # 【修改】使用 JSON 中的 center 字段
  1720. # 获取颜色
  1721. color = COLOR_MAP.get(label.lower(), DEFAULT_COLOR)
  1722. # 验证线段格式
  1723. if not (isinstance(segments, list) and len(segments) > 0):
  1724. print(f" ⚠️ Block {block_id}: 无数据,跳过")
  1725. continue
  1726. if not (isinstance(segments[0], list) and len(segments[0]) == 4):
  1727. print(f" ⚠️ Block {block_id}: 格式不正确,跳过")
  1728. continue
  1729. # 提取所有点
  1730. points = get_points_from_segments(segments)
  1731. if len(points) < 2:
  1732. print(f" ⚠️ Block {block_id}: 点数不足,跳过")
  1733. continue
  1734. # 【1】绘制线段(按元素连线)
  1735. for seg in segments:
  1736. p1 = (int(seg[0]), int(seg[1]))
  1737. p2 = (int(seg[2]), int(seg[3]))
  1738. cv2.line(canvas, p1, p2, color, 2, cv2.LINE_AA)
  1739. # 【2】绘制点(所有顶点标记)【修改】白色点
  1740. for pt in points:
  1741. cv2.circle(canvas, pt, 4, (255, 255, 255), -1) # 白色点
  1742. # 【3】绘制标签【修改】使用 center 字段
  1743. if center and len(center) == 2:
  1744. cx, cy = int(center[0]), int(center[1])
  1745. else:
  1746. # 如果没有 center,回退到计算中心
  1747. pts = np.array(points, dtype=np.int32)
  1748. M = cv2.moments(pts)
  1749. if M["m00"] != 0:
  1750. cx = int(M["m10"] / M["m00"])
  1751. cy = int(M["m01"] / M["m00"])
  1752. else:
  1753. cx = int(np.mean(pts[:, 0]))
  1754. cy = int(np.mean(pts[:, 1]))
  1755. print(f" ⚠️ Block {block_id}: 无 center 字段,使用计算中心")
  1756. # 确保在画布内
  1757. font = cv2.FONT_HERSHEY_SIMPLEX
  1758. font_scale = 0.6
  1759. thickness = 1
  1760. (tw, th), _ = cv2.getTextSize(label, font, font_scale, thickness)
  1761. cx = max(tw // 2 + 5, min(cx, width - tw // 2 - 5))
  1762. cy = max(th // 2 + 5, min(cy, height - th // 2 - 5))
  1763. # 标签背景(白色矩形)
  1764. cv2.rectangle(canvas, (cx - tw // 2 - 3, cy - th - 3),
  1765. (cx + tw // 2 + 3, cy + 3), (255, 255, 255), -1)
  1766. # 标签文字(黑色)
  1767. cv2.putText(canvas, label, (cx - tw // 2, cy),
  1768. font, font_scale, (0, 0, 0), thickness)
  1769. # Block ID(灰色小字)
  1770. id_text = f"ID:{block_id}"
  1771. (iw, ih), _ = cv2.getTextSize(id_text, font, 0.4, 1)
  1772. cv2.putText(canvas, id_text, (cx - iw // 2, cy + 15),
  1773. font, 0.4, (100, 100, 100), 1)
  1774. print(f" ✅ Block {block_id} ({label}): {len(segments)} 线段,{len(points)} 点,center=({cx},{cy})")
  1775. # 4. 绘制连接区域 (Connect Area)
  1776. connect_areas = data.get('connect_area', [])
  1777. print(f"\n🔗 绘制 {len(connect_areas)} 个连接区域...")
  1778. for area_idx, area in enumerate(connect_areas):
  1779. area_id = area.get('id', area_idx)
  1780. x = area.get('x', 0)
  1781. y = area.get('y', 0)
  1782. w = area.get('w', 0)
  1783. h = area.get('h', 0)
  1784. label = area.get('label', 'door').lower()
  1785. block_pair = area.get('block_pair', [])
  1786. # 确保在画布内
  1787. x = max(0, min(x, width - 1))
  1788. y = max(0, min(y, height - 1))
  1789. w = max(1, min(w, width - x))
  1790. h = max(1, min(h, height - y))
  1791. # 【unknown/unknow】红色实心矩形
  1792. if vis_door and ("door" in label):
  1793. cv2.rectangle(canvas, (x, y), (x + w, y + h), (0, 0, 255), -1) # 红色填充
  1794. cv2.rectangle(canvas, (x, y), (x + w, y + h), (255, 255, 255), 1) # 白色边框
  1795. print(f" ✅ Connect {area_id}: 🔴 红色实心矩形")
  1796. # else:
  1797. # # 其他:绿色边框
  1798. # cv2.rectangle(canvas, (x, y), (x + w, y + h), (0, 255, 0), 2)
  1799. # if label:
  1800. # cv2.putText(canvas, label, (x + 3, y + 15),
  1801. # cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1)
  1802. # if block_pair:
  1803. # pair_text = f"{block_pair[0]}-{block_pair[1]}"
  1804. # cv2.putText(canvas, pair_text, (x + 3, y + h - 3),
  1805. # cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1)
  1806. # print(f" ✅ Connect {area_id}: 🟢 {label} ({block_pair})")
  1807. # 5. 绘制家具 (Furniture)
  1808. furniture_list = data.get('furniture', [])
  1809. print(f"\n🛋️ 绘制 {len(furniture_list)} 个家具...")
  1810. FURNITURE_COLOR = (0, 165, 255) # BGR: 橙色
  1811. for item in furniture_list:
  1812. label = item.get('label', '')
  1813. center = item.get('center', None)
  1814. pts = item.get('points', {})
  1815. if pts:
  1816. bx1, by1 = pts['x1'], pts['y1']
  1817. bx2, by2 = pts['x3'], pts['y3']
  1818. cv2.rectangle(canvas, (bx1, by1), (bx2, by2), FURNITURE_COLOR, 2)
  1819. if center and len(center) == 2:
  1820. cx, cy = int(center[0]), int(center[1])
  1821. font = cv2.FONT_HERSHEY_SIMPLEX
  1822. (tw, th), _ = cv2.getTextSize(label, font, 0.5, 1)
  1823. cv2.rectangle(canvas, (cx - tw // 2 - 3, cy - th - 3),
  1824. (cx + tw // 2 + 3, cy + 3), FURNITURE_COLOR, -1)
  1825. cv2.putText(canvas, label, (cx - tw // 2, cy),
  1826. font, 0.5, (255, 255, 255), 1)
  1827. print(f" ✅ {label} center=({center})")
  1828. # 6. 添加图例
  1829. print(f"\n📋 添加图例...")
  1830. legend_y = 30
  1831. legend_x = 20
  1832. cv2.putText(canvas, "Legend:", (legend_x, legend_y - 5),
  1833. cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
  1834. legend_y += 5
  1835. for room_type, color in COLOR_MAP.items():
  1836. if room_type in ["door"]:
  1837. continue
  1838. cv2.rectangle(canvas, (legend_x, legend_y),
  1839. (legend_x + 20, legend_y + 15), color, -1)
  1840. cv2.rectangle(canvas, (legend_x, legend_y),
  1841. (legend_x + 20, legend_y + 15), (0, 0, 0), 1)
  1842. cv2.putText(canvas, room_type, (legend_x + 25, legend_y + 12),
  1843. cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 1)
  1844. legend_y += 20
  1845. # 添加 unknown 图例
  1846. cv2.rectangle(canvas, (legend_x, legend_y),
  1847. (legend_x + 20, legend_y + 15), (0, 0, 255), -1)
  1848. cv2.rectangle(canvas, (legend_x, legend_y),
  1849. (legend_x + 20, legend_y + 15), (0, 0, 0), 1)
  1850. cv2.putText(canvas, "door", (legend_x + 25, legend_y + 12),
  1851. cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 1)
  1852. # 6. 保存结果
  1853. cv2.imwrite(output_path, canvas)
  1854. print(f"\n{'=' * 60}")
  1855. print(f"✨ 可视化完成!")
  1856. print(f" 📁 输出路径:{output_path}")
  1857. print(f" 📐 画布尺寸:{width}x{height}")
  1858. print(f" 🏠 Block 数量:{len(blocks)}")
  1859. print(f" 🔗 Connect Area 数量:{len(connect_areas)}")
  1860. print('=' * 60)
  1861. return canvas
  1862. # ================= 4. 批量处理 =================
  1863. def visualize_batch(json_folder, output_folder, rgb_folder=None, vis_door=False):
  1864. """批量可视化文件夹中的所有 JSON"""
  1865. if not os.path.exists(output_folder):
  1866. os.makedirs(output_folder)
  1867. json_files = [f for f in os.listdir(json_folder) if f.endswith('.json')]
  1868. print(f"\n📁 发现 {len(json_files)} 个 JSON 文件\n")
  1869. for json_file in json_files:
  1870. json_path = os.path.join(json_folder, json_file)
  1871. output_path = os.path.join(output_folder, json_file.replace('.json', '.png'))
  1872. rgb_path = None
  1873. if rgb_folder:
  1874. rgb_path = os.path.join(rgb_folder, json_file.replace('.json', '.png'))
  1875. if not os.path.exists(rgb_path):
  1876. rgb_path = None
  1877. try:
  1878. visualize_final_json(json_path, output_path, rgb_path, vis_door)
  1879. except Exception as e:
  1880. print(f"❌ 处理 {json_file} 失败:{e}")
  1881. # ================= 5. 主运行流程 =================
  1882. # ===========================================================================
  1883. # MAIN PIPELINE
  1884. # ===========================================================================
  1885. def run_single_image_pipeline(rgb_path, mask_path, output_json_path=None,
  1886. room_model_path="room_seg.pt",
  1887. furniture_model_path="furniture_detect.onnx",
  1888. merge_threshold=15, angle=10,
  1889. dilation_kernel_size=5, center_threshold=15,
  1890. expand_pixel=2, min_rect_short_side=30):
  1891. """单张图流水线:输入 RGB 与 mask,直接输出最终 JSON。"""
  1892. rgb_img = cv2.imread(rgb_path)
  1893. block_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
  1894. if rgb_img is None:
  1895. raise FileNotFoundError(f"无法读取 RGB 图像:{rgb_path}")
  1896. if block_mask is None:
  1897. raise FileNotFoundError(f"无法读取 mask 图像:{mask_path}")
  1898. if output_json_path is None:
  1899. output_json_path = os.path.splitext(rgb_path)[0] + '.json'
  1900. print("===== Single Image Pipeline =====")
  1901. print(f"RGB : {rgb_path}")
  1902. print(f"Mask: {mask_path}")
  1903. print(f"Out : {output_json_path}")
  1904. clean_rgb = remove_edge_regions_image(rgb_img)
  1905. _, gap_add_mask = extract_gaps_from_mask(block_mask)
  1906. connect_mask = extract_mask_region_from_arrays(clean_rgb, block_mask, gap_add_mask)
  1907. room_model = _load_yolo_model(room_model_path)
  1908. furniture_model = _load_yolo_model(furniture_model_path)
  1909. json_data, _, stats = merge_gap_fillers_from_arrays(
  1910. block_mask,
  1911. connect_mask,
  1912. image_path=rgb_path,
  1913. dilation_kernel_size=dilation_kernel_size,
  1914. center_threshold=center_threshold,
  1915. expand_pixel=expand_pixel,
  1916. min_rect_short_side=min_rect_short_side,
  1917. )
  1918. json_data['source_mask_path'] = mask_path
  1919. json_data['connect_area_stats'] = stats
  1920. print("===== Block Classification =====")
  1921. json_data['block'] = build_block_data(rgb_img, block_mask, model_path=room_model)
  1922. print("===== Furniture Detection =====")
  1923. json_data['furniture'] = detect_furniture_in_image(rgb_img, model_path=furniture_model)
  1924. print("===== Block Refinement =====")
  1925. refine_blocks_in_data(json_data)
  1926. print("===== Segment Merge =====")
  1927. normalize_all(json_data, angle=angle)
  1928. merge_all_blocks(json_data, threshold=merge_threshold)
  1929. os.makedirs(os.path.dirname(os.path.abspath(output_json_path)), exist_ok=True)
  1930. with open(output_json_path, 'w', encoding='utf-8') as f:
  1931. json.dump(json_data, f, indent=2, ensure_ascii=False)
  1932. print("===== Pipeline 完成 =====")
  1933. print(f" - JSON 保存至:{output_json_path}")
  1934. print(f" - Connect Area 数量:{len(json_data.get('connect_area', []))}")
  1935. print(f" - Block 数量:{len(json_data.get('block', []))}")
  1936. print(f" - Furniture 数量:{len(json_data.get('furniture', []))}")
  1937. print(f" - Total Segments:{json_data.get('total_segments', 0)}")
  1938. return json_data
  1939. def parse_args():
  1940. parser = argparse.ArgumentParser(description="Single-image floorplan pipeline")
  1941. parser.add_argument('rgb', help='RGB 图路径')
  1942. parser.add_argument('mask', help='Mask 图路径')
  1943. parser.add_argument('-o', '--output-json', help='输出 JSON 路径,默认与 RGB 同名')
  1944. parser.add_argument('--room-model', default='room_cls.pt', help='房间分类模型路径')
  1945. parser.add_argument('--furniture-model', default='furniture_detect.onnx', help='家具检测模型路径')
  1946. parser.add_argument('--merge-threshold', type=int, default=15, help='跨 block 线段合并阈值')
  1947. parser.add_argument('--angle', type=int, default=10, help='水平/垂直吸附角度阈值')
  1948. parser.add_argument('--center-threshold', type=int, default=15, help='连接区域聚类阈值')
  1949. parser.add_argument('--dilation-kernel-size', type=int, default=5, help='连接区域膨胀核大小')
  1950. parser.add_argument('--expand-pixel', type=int, default=2, help='door 矩形扩展像素')
  1951. parser.add_argument('--min-rect-short-side', type=int, default=30, help='door 矩形短边上限')
  1952. return parser.parse_args()
  1953. if __name__ == "__main__":
  1954. import sys
  1955. if len(sys.argv) < 2:
  1956. print("用法:python pipeline.py <文件夹名称>")
  1957. print("示例:python pipeline.py SG-n6nV8B2oW95")
  1958. sys.exit(1)
  1959. folder_name = sys.argv[1]
  1960. img_folder = "temp_data"
  1961. folder_path = os.path.join(img_folder, folder_name)
  1962. if not os.path.isdir(folder_path):
  1963. print(f"错误:文件夹不存在 {folder_path}")
  1964. sys.exit(1)
  1965. # 获取 RGB 图片(文件名与文件夹同名)
  1966. rgb_imgs = [f for f in os.listdir(folder_path)
  1967. if f.startswith(folder_name) and not f.startswith('initial_mask_') and not f.startswith('refine_mask_')
  1968. and f.endswith(('.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'))]
  1969. if not rgb_imgs:
  1970. print(f"错误:未找到 RGB 图片")
  1971. sys.exit(1)
  1972. rgb_name = rgb_imgs[0]
  1973. rgb_path = os.path.join(folder_path, rgb_name)
  1974. # 查找对应的 refine_mask 文件
  1975. refine_mask_name = f"refine_mask_{rgb_name}"
  1976. refine_mask_path = os.path.join(folder_path, refine_mask_name)
  1977. if not os.path.exists(refine_mask_path):
  1978. print(f"错误:未找到 {refine_mask_name}")
  1979. sys.exit(1)
  1980. # 输出 JSON 路径(保存到子文件夹)
  1981. output_json = os.path.join(folder_path, rgb_name.rsplit('.', 1)[0] + '.json')
  1982. print(f"处理文件夹:{folder_name}")
  1983. print(f" RGB: {rgb_path}")
  1984. print(f" Mask: {refine_mask_path}")
  1985. print(f" JSON: {output_json}")
  1986. try:
  1987. run_single_image_pipeline(
  1988. rgb_path=rgb_path,
  1989. mask_path=refine_mask_path,
  1990. output_json_path=output_json,
  1991. room_model_path="room_cls.pt",
  1992. furniture_model_path="furniture_detect.onnx",
  1993. merge_threshold=15,
  1994. angle=10,
  1995. dilation_kernel_size=5,
  1996. center_threshold=15,
  1997. expand_pixel=2,
  1998. min_rect_short_side=30,
  1999. )
  2000. print(f"✅ 完成 -> {os.path.basename(output_json)}")
  2001. except Exception as e:
  2002. print(f"❌ 错误:{e}")
  2003. sys.exit(1)