| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469 |
- import argparse
- import cv2
- import numpy as np
- import os
- import json
- import math
- import shutil
- from collections import defaultdict
- from datetime import datetime
- try:
- from tqdm import tqdm
- except ImportError:
- def tqdm(x, **kw):
- return x
- from ultralytics import YOLO
- # ===========================================================================
- # PART 1: merge_all
- # ===========================================================================
- def remove_pure_red(img):
- if img is None: return
- # 在 OpenCV (BGR) 中,纯红色是 [0, 0, 255]
- # 找到所有三个通道完全匹配该值的像素
- red_pixels = (img[:, :, 0] == 0) & (img[:, :, 1] == 0) & (img[:, :, 2] == 255)
- # 直接将这些位置涂黑
- img[red_pixels] = [0, 0, 0]
- return img
- def remove_edge_regions_image(img):
- """移除边缘黑边/红边,返回清洗后的单张 RGB 图。"""
- if img is None:
- raise ValueError("输入图像为空")
- result = img.copy()
- img_2 = np.zeros_like(result)
- mask = (result[:, :, 0] == 0) & (result[:, :, 1] == 0) & (result[:, :, 2] == 0)
- img_2[~mask] = (255, 255, 255)
- edges = cv2.Canny(img_2, 50, 150)
- kernel = np.ones((9, 9), np.uint8)
- edges = cv2.dilate(edges, kernel, 1)
- result[edges > 0] = (0, 0, 0)
- return remove_pure_red(result)
- def extract_gaps_from_mask(mask):
- """从单通道 mask 中提取裂隙,并返回 gaps 与 gap_add_mask。"""
- if mask is None:
- raise ValueError("mask 为空")
- _, binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
- kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (25, 25))
- stitched_mask = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
- gaps = cv2.subtract(stitched_mask, binary)
- refine_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
- gaps_dilated = cv2.dilate(gaps, refine_kernel, iterations=1)
- merged = cv2.add(mask, gaps_dilated)
- return gaps, merged
- def extract_mask_region_from_arrays(rgb_img, ori_mask, full_mask):
- """在内存中提取 full_mask 相比 ori_mask 新增的 RGB 区域。"""
- if rgb_img is None or ori_mask is None or full_mask is None:
- raise ValueError("rgb_img / ori_mask / full_mask 不能为空")
- _, binary_ori_mask = cv2.threshold(ori_mask, 127, 255, cv2.THRESH_BINARY)
- _, binary_full_mask = cv2.threshold(full_mask, 127, 255, cv2.THRESH_BINARY)
- result_rgb_ori_mask = cv2.bitwise_and(rgb_img, rgb_img, mask=binary_ori_mask)
- result_rgb_full_mask = cv2.bitwise_and(rgb_img, rgb_img, mask=binary_full_mask)
- return cv2.subtract(result_rgb_full_mask, result_rgb_ori_mask)
- def _to_gray_mask(mask):
- if mask is None:
- raise ValueError("mask 为空")
- if len(mask.shape) == 2:
- return mask
- return cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
- def _expand_rect(x, y, w, h, expand_pixel, width, height):
- if expand_pixel <= 0:
- return int(x), int(y), int(w), int(h)
- x1 = max(0, int(x) - int(expand_pixel))
- y1 = max(0, int(y) - int(expand_pixel))
- x2 = min(width, int(x) + int(w) + int(expand_pixel))
- y2 = min(height, int(y) + int(h) + int(expand_pixel))
- return x1, y1, max(1, x2 - x1), max(1, y2 - y1)
- def _load_yolo_model(model_or_path):
- if isinstance(model_or_path, (str, os.PathLike)):
- return YOLO(model_or_path)
- return model_or_path
- def extract_gaps(mask_path, gap_path, gap_add_mask_path):
- # 1. 读取原始 Mask 图
- mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
- if mask is None:
- print("错误:无法读取图像,请检查路径。")
- return
- # 确保图像是纯粹的二值图 (0 和 255)
- _, binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
- # 2. 核心步骤:使用形态学闭运算“抹平”裂隙
- # 核 (Kernel) 的大小决定了算法能填补多宽的裂缝。
- # 这里的 (25, 25) 是一个经验值,刚好略大于你图中裂隙的像素宽度。
- # 如果裂隙更宽,可以增大这个值,比如 (35, 35)。
- kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (25, 25))
- # 闭运算:先膨胀(挤满裂缝)再腐蚀(恢复外部原有边界)
- stitched_mask = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
- # 3. 提取裂隙:缝合后的整体 - 原始的断裂块
- gaps = cv2.subtract(stitched_mask, binary)
- cv2.imwrite(gap_path, gaps)
- refine_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
- gaps = cv2.dilate(gaps, refine_kernel, iterations=1)
- merged = cv2.add(mask, gaps)
- cv2.imwrite(gap_add_mask_path, merged)
- def extract_mask_region(rgb_path, ori_mask_path, full_mask_path, output_path):
- rgb_img = cv2.imread(rgb_path)
- ori_mask = cv2.imread(ori_mask_path, cv2.IMREAD_GRAYSCALE)
- full_mask = cv2.imread(full_mask_path, cv2.IMREAD_GRAYSCALE)
- if rgb_img is None or ori_mask is None or full_mask is None:
- print(f"无法读取文件: {rgb_path} 或 {ori_mask_path} 或 {full_mask_path}")
- return
- _, binary_ori_mask = cv2.threshold(ori_mask, 127, 255, cv2.THRESH_BINARY)
- # birefNet生成的mask和原图有些位移,在生成连通区域时会有噪声,边缘长细条
- # kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
- # binary_ori_mask = cv2.dilate(binary_ori_mask, kernel, iterations=1)
- # birefNet生成的mask和原图有些位移,在生成连通区域时会有噪声,边缘长细条
- _, binary_full_mask = cv2.threshold(full_mask, 127, 255, cv2.THRESH_BINARY)
- result_rgb_ori_mask = cv2.bitwise_and(rgb_img, rgb_img, mask=binary_ori_mask)
- result_rgb_full_mask = cv2.bitwise_and(rgb_img, rgb_img, mask=binary_full_mask)
- result = result_rgb_full_mask - result_rgb_ori_mask
- # 5. 保存结果
- cv2.imwrite(output_path, result)
- print(f"✅ 提取成功:{os.path.basename(output_path)}")
- def identify_boundary_connections(block_mask_path, contour_mask_path,
- output_path='final.png',
- mapping_output_path='final_mapping.png', # 新增:关系映射图保存路径
- max_dist=120, max_thickness=45):
- """
- 识别边界连接,合并近似共线且相距不远的轮廓,并生成两张可视化图:
- 1. final_mapping.png: 标记房间ID与轮廓连通关系的映射图
- 2. final.png: 在底图上绘制合并后红色矩形的最终结果图
- """
- # 1. 加载数据
- blocks = cv2.imread(block_mask_path, cv2.IMREAD_GRAYSCALE)
- contours = cv2.imread(contour_mask_path, cv2.IMREAD_GRAYSCALE)
- target_mask = cv2.imread(block_mask_path)
- if target_mask is None:
- target_mask = np.zeros((blocks.shape[0], blocks.shape[1], 3), dtype=np.uint8)
- # 预处理轮廓
- kernel = np.ones((5, 5), np.uint8)
- contours_dilated = cv2.dilate(contours, kernel, iterations=1)
- _, blocks_bin = cv2.threshold(blocks, 127, 255, cv2.THRESH_BINARY)
- _, contours_bin = cv2.threshold(contours_dilated, 0, 255, cv2.THRESH_BINARY)
- # 2. 连通域标记 (注意这里接收了 stats 和 centroids 用于后续标点)
- num_blocks, block_labels = cv2.connectedComponents(blocks_bin, connectivity=8)
- num_contours, contour_labels, stats, centroids = cv2.connectedComponentsWithStats(contours_bin, connectivity=8)
- print(f"检测到房间数: {num_blocks - 1}")
- print(f"检测到轮廓段: {num_contours - 1}")
- # 创建彩色画布用于映射关系结果展示 (final_mapping.png)
- mapping_img = cv2.cvtColor(contours_bin, cv2.COLOR_GRAY2BGR)
- # 3. 收集轮廓并绘制映射关系
- pair_to_contours = defaultdict(list)
- connections = []
- for c_id in range(1, num_contours):
- single_contour = (contour_labels == c_id).astype(np.uint8) * 255
- # 加厚以探测邻居
- kernel_detect = np.ones((9, 9), np.uint8)
- dilated_c = cv2.dilate(single_contour, kernel_detect, iterations=1)
- neighboring_blocks = np.unique(block_labels[dilated_c > 0])
- neighboring_blocks = neighboring_blocks[neighboring_blocks > 0]
- block_ids = sorted(neighboring_blocks.tolist())
- connections.append({"contour_id": c_id, "connects": block_ids})
- # ==================== 新增:绘制映射图元素 ====================
- cx, cy = int(centroids[c_id][0]), int(centroids[c_id][1])
- relation_text = "-".join(map(str, block_ids))
- # 绘制该段轮廓(设为绿色)
- mapping_img[contour_labels == c_id] = [0, 255, 0]
- # 标注连接关系 (例如 "1-2")
- if len(block_ids) >= 2:
- cv2.putText(mapping_img, relation_text, (cx, cy),
- cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 255), 1)
- # =============================================================
- # 收集独立轮廓用于后续合并逻辑
- if len(block_ids) == 2:
- pair_key = f"{block_ids[0]}-{block_ids[1]}"
- y_coords, x_coords = np.where(contour_labels == c_id)
- if len(x_coords) > 0:
- points = np.column_stack((x_coords, y_coords))
- pair_to_contours[pair_key].append(points)
- # ==================== 新增:在映射图上标出房间 ID ====================
- for b_id in range(1, num_blocks):
- b_mask = (block_labels == b_id).astype(np.uint8)
- M = cv2.moments(b_mask)
- if M["m00"] != 0:
- bx, by = int(M["m10"] / M["m00"]), int(M["m01"] / M["m00"])
- cv2.putText(mapping_img, f"B{b_id}", (bx, by),
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
- # 保存映射关系图
- if mapping_output_path:
- cv2.imwrite(mapping_output_path, mapping_img)
- print(f"✅ 映射关系图已保存至: {mapping_output_path}")
- # ===================================================================
- # 4. 轮廓聚类与绘制组合后的外接矩形 (原有的高级合并逻辑)
- def should_merge(pts1, pts2):
- rect1 = cv2.minAreaRect(pts1)
- rect2 = cv2.minAreaRect(pts2)
- c1, c2 = np.array(rect1[0]), np.array(rect2[0])
- if np.linalg.norm(c1 - c2) > max_dist:
- return False
- combined_pts = np.vstack((pts1, pts2))
- comb_rect = cv2.minAreaRect(combined_pts)
- thickness = min(comb_rect[1][0], comb_rect[1][1])
- if thickness > max_thickness:
- return False
- return True
- merged_rects_info = []
- for pair_key, contour_list in pair_to_contours.items():
- n = len(contour_list)
- if n == 0: continue
- adj = {i: [] for i in range(n)}
- for i in range(n):
- for j in range(i + 1, n):
- if should_merge(contour_list[i], contour_list[j]):
- adj[i].append(j)
- adj[j].append(i)
- visited = set()
- groups = []
- for i in range(n):
- if i not in visited:
- comp = []
- queue = [i]
- visited.add(i)
- while queue:
- curr = queue.pop(0)
- comp.append(curr)
- for neighbor in adj[curr]:
- if neighbor not in visited:
- visited.add(neighbor)
- queue.append(neighbor)
- groups.append(comp)
- for group_indices in groups:
- combined_group_points = np.vstack([contour_list[idx] for idx in group_indices])
- points_np = np.array(combined_group_points, dtype=np.int32)
- # A. 计算最小外接矩形
- rect = cv2.minAreaRect(points_np)
- (cx, cy), (w, h), angle = rect
- thickness = min(w, h)
- length = max(w, h)
- # B. 过滤器 1: 宽度阈值处理
- # 如果厚度超过阈值,说明不是“缝隙”,可能是错误的对角线连接,直接跳过
- if thickness > max_thickness:
- continue
- # C. 过滤器 2: 遮罩约束 (确保矩形在缝隙内)
- # 创建一个临时黑色画布,只画这个生成的矩形
- temp_rect_mask = np.zeros_like(contours_bin)
- box = cv2.boxPoints(rect)
- box = np.int32(box)
- cv2.fillPoly(temp_rect_mask, [box], 255)
- # 计算该矩形区域内,有多少像素真正属于原始缝隙 contours_bin
- # 使用逻辑与运算
- overlap = cv2.bitwise_and(temp_rect_mask, contours_bin)
- overlap_score = np.sum(overlap > 0) / np.sum(temp_rect_mask > 0)
- # 如果重叠率太低(例如低于 50%),说明该矩形跨越了大量的非缝隙区域,舍弃
- if overlap_score < 0.6:
- continue
- # --- 只有通过以上两层过滤的才会被绘制和记录 ---
- cv2.fillPoly(target_mask, [box], (0, 0, 255))
- merged_rects_info.append({
- "pair": pair_key,
- "box": box.tolist(),
- "thickness": thickness,
- "overlap_score": float(overlap_score)
- })
- # 5. 保存带有红色合并框的最终底图
- if output_path:
- cv2.imwrite(output_path, target_mask)
- print(f"✅ 合并后的轮廓底图已保存至: {output_path}")
- return connections, merged_rects_info
- def mask_add_conncet(mask_path, conncet_path, add_path):
- img1 = cv2.imread(mask_path)
- img2 = cv2.imread(conncet_path)
- img2 = remove_pure_red(img2)
- # 1. 创建掩码:找出 img2 中所有非黑色的像素点
- img2_gray = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
- _, mask = cv2.threshold(img2_gray, 1, 255, cv2.THRESH_BINARY)
- # 2. 直接覆盖
- result = img1.copy()
- result[mask > 0] = img2[mask > 0]
- cv2.imwrite(add_path, result)
- def overlay_images(base_img_path, overlay_img_path, output_path):
- """
- 将 overlay_img 叠加到 base_img 上。
- 在重叠部分,只显示 overlay_img 的像素。
- """
- img1 = cv2.imread(base_img_path) # 底图
- img2 = cv2.imread(overlay_img_path) # 要叠加的图
- if img1 is None or img2 is None:
- print("错误:无法读取图片,请检查路径。")
- return
- if img1.shape != img2.shape:
- img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0]))
- print("注意:已将 img2 的尺寸调整为与 img1 一致。")
- img2_gray = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
- _, mask = cv2.threshold(img2_gray, 1, 255, cv2.THRESH_BINARY)
- result = img1.copy()
- result[mask > 0] = img2[mask > 0]
- cv2.imwrite(output_path, result)
- print(f"✅ 图片叠加完成,结果已保存至: {output_path}")
- def remove_pure_red(img):
- if img is None: return
- # 在 OpenCV (BGR) 中,纯红色是 [0, 0, 255]
- # 找到所有三个通道完全匹配该值的像素
- red_pixels = (img[:, :, 0] == 0) & (img[:, :, 1] == 0) & (img[:, :, 2] == 255)
- # 直接将这些位置涂黑
- img[red_pixels] = [0, 0, 0]
- return img
- def remove_edge_regions(ori_rgb_folder, no_rededge_mask):
- test_names = os.listdir(ori_rgb_folder)
- for name in test_names:
- img_name = os.path.join(ori_rgb_folder, name)
- save_path = os.path.join(no_rededge_mask, name)
- img = cv2.imread(img_name)
- img_2 = np.zeros_like(img)
- mask = (img[:, :, 0] == 0) & (img[:, :, 1] == 0) & (img[:, :, 2] == 0)
- img_2[~mask] = (255, 255, 255)
- edges = cv2.Canny(img_2, 50, 150)
- kernel = np.ones((9, 9), np.uint8)
- edges = cv2.dilate(edges, kernel, 1)
- img[edges > 0] = (0, 0, 0)
- img = remove_pure_red(img)
- cv2.imwrite(save_path, img)
- def merge_gap_fillers_from_arrays(mask1, mask2, image_path="",
- dilation_kernel_size=5, center_threshold=50,
- expand_pixel=10, min_rect_short_side=30):
- """在内存中合并 mask1 和 mask2,返回 connect_area JSON、预览图和统计信息。"""
- m1 = _to_gray_mask(mask1)
- m2 = _to_gray_mask(mask2)
- if m1.shape != m2.shape:
- m2 = cv2.resize(m2, (m1.shape[1], m1.shape[0]))
- _, m1_bin = cv2.threshold(m1, 127, 255, cv2.THRESH_BINARY)
- _, m2_bin = cv2.threshold(m2, 0, 255, cv2.THRESH_BINARY)
- num1, block_labels = cv2.connectedComponents(m1_bin)
- num2, frag_labels, _, _ = cv2.connectedComponentsWithStats(m2_bin)
- result = cv2.cvtColor(m1_bin, cv2.COLOR_GRAY2BGR)
- bridge_frags = []
- for i in range(1, num2):
- single_frag_mask = (frag_labels == i).astype(np.uint8) * 255
- kernel = np.ones((dilation_kernel_size, dilation_kernel_size), np.uint8)
- dilated_frag = cv2.dilate(single_frag_mask, kernel, iterations=1)
- touched_labels = np.unique(block_labels[dilated_frag > 0])
- neighbors = sorted(int(n) for n in touched_labels if n > 0)
- if len(neighbors) == 2:
- pts = np.column_stack(np.where(single_frag_mask > 0))
- if len(pts) > 0:
- cy, cx = np.mean(pts, axis=0)
- bridge_frags.append({
- 'id': int(i),
- 'cx': float(cx),
- 'cy': float(cy),
- 'block_pair': tuple(int(n - 1) for n in neighbors),
- 'mask': single_frag_mask
- })
- groups = []
- for frag in bridge_frags:
- assigned = False
- for group in groups:
- if group[0]['block_pair'] != frag['block_pair']:
- continue
- for existing in group:
- dx = abs(frag['cx'] - existing['cx'])
- dy = abs(frag['cy'] - existing['cy'])
- if dx < center_threshold or dy < center_threshold:
- group.append(frag)
- assigned = True
- break
- if assigned:
- break
- if not assigned:
- groups.append([frag])
- connect_areas = []
- rect_id = 0
- height, width = m1.shape[:2]
- for group in groups:
- points_list = []
- frag_rects = []
- for frag in group:
- pts = np.column_stack(np.where(frag['mask'] > 0))
- if pts.size > 0:
- points_xy = pts[:, ::-1]
- points_list.append(points_xy)
- raw_x, raw_y, raw_w, raw_h = cv2.boundingRect(points_xy)
- frag_rects.append((int(raw_x), int(raw_y), int(raw_w), int(raw_h), frag))
- if len(points_list) < 1:
- continue
- all_points = np.vstack(points_list)
- if len(all_points) < 3:
- continue
- raw_x, raw_y, raw_w, raw_h = cv2.boundingRect(all_points)
- merged_short_side = min(raw_w, raw_h)
- if len(group) > 1 and merged_short_side > min_rect_short_side:
- for fx, fy, fw, fh, frag in frag_rects:
- single_short_side = min(fw, fh)
- if single_short_side <= min_rect_short_side:
- fx, fy, fw, fh = _expand_rect(fx, fy, fw, fh, expand_pixel, width, height)
- cv2.rectangle(result, (fx, fy), (fx + fw, fy + fh), (0, 255, 0), -1)
- connect_areas.append({
- 'id': int(rect_id),
- 'x': int(fx),
- 'y': int(fy),
- 'w': int(fw),
- 'h': int(fh),
- 'block_pair': [int(n) for n in frag['block_pair']],
- 'label': 'door'
- })
- rect_id += 1
- continue
- if merged_short_side <= min_rect_short_side:
- x, y, w, h = _expand_rect(raw_x, raw_y, raw_w, raw_h, expand_pixel, width, height)
- cv2.rectangle(result, (x, y), (x + w, y + h), (0, 255, 0), -1)
- connect_areas.append({
- 'id': int(rect_id),
- 'x': int(x),
- 'y': int(y),
- 'w': int(w),
- 'h': int(h),
- 'block_pair': [int(n) for n in group[0]['block_pair']],
- 'label': 'door'
- })
- rect_id += 1
- json_data = {
- 'image_path': str(image_path),
- 'image_size': {
- 'width': int(m1.shape[1]),
- 'height': int(m1.shape[0])
- },
- 'connect_area': connect_areas
- }
- stats = {
- 'mask1_blocks': int(num1 - 1),
- 'mask2_fragments': int(num2 - 1),
- 'bridge_fragments': int(len(bridge_frags)),
- 'group_count': int(len(groups)),
- 'connect_area_count': int(len(connect_areas))
- }
- return json_data, result, stats
- def merge_gap_fillers(mask1_path, mask2_path, output_image_path, output_json_path,
- dilation_kernel_size=5, center_threshold=50,
- expand_pixel=10, min_rect_short_side=30):
- """
- 合并 mask1 和 mask2,利用 mask2 填充 mask1 块之间的裂隙。
- """
- m1 = cv2.imread(mask1_path, cv2.IMREAD_GRAYSCALE)
- m2 = cv2.imread(mask2_path, cv2.IMREAD_GRAYSCALE)
- if m1 is None or m2 is None:
- raise FileNotFoundError("无法加载图像,请检查路径")
- json_data, result, stats = merge_gap_fillers_from_arrays(
- m1,
- m2,
- image_path=output_image_path,
- dilation_kernel_size=dilation_kernel_size,
- center_threshold=center_threshold,
- expand_pixel=expand_pixel,
- min_rect_short_side=min_rect_short_side,
- )
- os.makedirs(os.path.dirname(os.path.abspath(output_image_path)), exist_ok=True)
- os.makedirs(os.path.dirname(os.path.abspath(output_json_path)), exist_ok=True)
- cv2.imwrite(output_image_path, result)
- with open(output_json_path, 'w', encoding='utf-8') as f:
- json.dump(json_data, f, indent=2, ensure_ascii=False)
- print(f"✅ 处理完成!")
- print(f" - mask1 块数:{stats['mask1_blocks']} (白色)")
- print(f" - mask2 碎片总数:{stats['mask2_fragments']}")
- print(f" - 连接 2 块的碎片:{stats['bridge_fragments']} 个")
- print(f" - 聚类后的组数:{stats['group_count']} 组")
- print(f" - 总连接区域数:{stats['connect_area_count']}")
- print(f" - 结果图片保存至:{output_image_path}")
- print(f" - JSON 坐标保存至:{output_json_path}")
- def verify_json_coordinates(json_path, output_verify_path, mask1_path=None):
- """
- 读取 JSON 中的坐标数据,在图片上绘制绿色矩形进行验证
- 参数:
- json_path: JSON 文件路径
- output_verify_path: 验证结果图片保存路径
- mask1_path: 可选,mask1 图片路径(如果提供,会叠加显示 mask1 白色块)
- """
- # 1. 读取 JSON 数据
- with open(json_path, 'r', encoding='utf-8') as f:
- data = json.load(f)
- print(f"📊 读取 JSON 文件:{json_path}")
- print(f" - 图片路径:{data['image_path']}")
- print(f" - 图片尺寸:{data['image_size']['width']} x {data['image_size']['height']}")
- print(f" - 连接区域数:{len(data['connect_area'])}")
- # 2. 创建画布
- width = data['image_size']['width']
- height = data['image_size']['height']
- if mask1_path:
- # 如果有 mask1,加载并显示白色块
- m1 = cv2.imread(mask1_path, cv2.IMREAD_GRAYSCALE)
- if m1 is not None:
- if m1.shape[0] != height or m1.shape[1] != width:
- m1 = cv2.resize(m1, (width, height))
- _, m1_bin = cv2.threshold(m1, 127, 255, cv2.THRESH_BINARY)
- result = cv2.cvtColor(m1_bin, cv2.COLOR_GRAY2BGR)
- print(f" - 已加载 mask1,显示白色块")
- else:
- result = np.zeros((height, width, 3), dtype=np.uint8)
- print(f" - 未找到 mask1,使用黑色背景")
- else:
- # 没有 mask1,使用黑色背景
- result = np.zeros((height, width, 3), dtype=np.uint8)
- print(f" - 未提供 mask1,使用黑色背景")
- # 3. 根据 JSON 坐标绘制绿色矩形
- for area in data['connect_area']:
- x = area['x']
- y = area['y']
- w = area['w']
- h = area['h']
- block_pair = area['block_pair']
- label = area.get('label', 'door')
- # 绘制绿色矩形(填充)
- cv2.rectangle(result, (x, y), (x + w, y + h), (0, 255, 0), -1)
- # 绘制矩形边框(红色,方便看清边界)
- cv2.rectangle(result, (x, y), (x + w, y + h), (0, 0, 255), 2)
- # 在矩形中心标记 ID 和 block_pair
- cx, cy = x + w // 2, y + h // 2
- text = f"ID:{area['id']} [{block_pair[0]},{block_pair[1]}]"
- cv2.putText(result, text, (cx - 80, cy),
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)
- # 4. 保存验证图片
- cv2.imwrite(output_verify_path, result)
- # 5. 打印详细信息
- print(f"\n📋 连接区域详情:")
- print(f" {'ID':<6} {'X':<8} {'Y':<8} {'W':<6} {'H':<6} {'Block_Pair':<15} {'Label':<10}")
- print(f" {'-' * 70}")
- for area in data['connect_area']:
- print(
- f" {area['id']:<6} {area['x']:<8} {area['y']:<8} {area['w']:<6} {area['h']:<6} {str(area['block_pair']):<15} {area['label']:<10}")
- print(f"\n✅ 验证图片保存至:{output_verify_path}")
- return data
- # def build_block_data(rgb_img, block_mask, model_path="room_seg.pt"):
- # """基于单张 RGB 和 block mask 生成 block 列表。"""
- # model = _load_yolo_model(model_path)
- # blocks = _to_gray_mask(block_mask)
- #
- # if rgb_img is None or blocks is None:
- # raise FileNotFoundError("图片路径错误,请检查文件是否存在")
- #
- # if rgb_img.shape[:2] != blocks.shape[:2]:
- # rgb_img = cv2.resize(rgb_img, (blocks.shape[1], blocks.shape[0]))
- #
- # _, blocks_bin = cv2.threshold(blocks, 127, 255, cv2.THRESH_BINARY)
- # num_blocks, block_labels = cv2.connectedComponents(blocks_bin, connectivity=8)
- # actual_room_count = num_blocks - 1
- # print(f"🔍 检测到 {actual_room_count} 个独立的房间块")
- #
- # block_list = []
- #
- # for b_id in range(1, num_blocks):
- # print(f"\n📍 处理 Block {b_id - 1}...")
- # mask_single = (block_labels == b_id).astype(np.uint8)
- # contours, _ = cv2.findContours(mask_single, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
- #
- # points = []
- # cx = cy = 0
- # if contours:
- # largest_contour = max(contours, key=cv2.contourArea)
- # epsilon = 2.0
- # simplified = cv2.approxPolyDP(largest_contour, epsilon, True)
- # for pt in simplified:
- # points.append(int(pt[0][0]))
- # points.append(int(pt[0][1]))
- #
- # M = cv2.moments(simplified)
- # if M['m00'] != 0:
- # cx = int(M['m10'] / M['m00'])
- # cy = int(M['m01'] / M['m00'])
- # else:
- # cx = int(np.mean(simplified[:, 0, 0]))
- # cy = int(np.mean(simplified[:, 0, 1]))
- #
- # x, y, w, h = cv2.boundingRect(mask_single)
- # if not contours:
- # cx = int(x + w / 2)
- # cy = int(y + h / 2)
- #
- # pad = 20
- # y1, y2 = max(0, y - pad), min(rgb_img.shape[0], y + h + pad)
- # x1, x2 = max(0, x - pad), min(rgb_img.shape[1], x + w + pad)
- # roi_rgb = rgb_img[y1:y2, x1:x2]
- #
- # valid_boxes = []
- # if roi_rgb.shape[0] > 10 and roi_rgb.shape[1] > 10:
- # results = model(roi_rgb, conf=0.15, verbose=False)[0]
- #
- # if len(results.boxes) > 0:
- # boxes_xyxy = results.boxes.xyxy.cpu().numpy()
- # boxes_conf = results.boxes.conf.cpu().numpy()
- # boxes_cls = results.boxes.cls.cpu().numpy()
- #
- # for i in range(len(boxes_xyxy)):
- # bx1, by1, bx2, by2 = boxes_xyxy[i]
- # center_x = (bx1 + bx2) / 2
- # center_y = (by1 + by2) / 2
- # global_cx = x1 + center_x
- # global_cy = y1 + center_y
- #
- # if 0 <= global_cy < block_labels.shape[0] and 0 <= global_cx < block_labels.shape[1]:
- # if block_labels[int(global_cy), int(global_cx)] == b_id:
- # valid_boxes.append({
- # "cls": int(boxes_cls[i]),
- # "conf": float(boxes_conf[i]),
- # "name": model.names[int(boxes_cls[i])]
- # })
- #
- # final_label = "other_room"
- # final_conf = 0.0
- # if valid_boxes:
- # best_box = max(valid_boxes, key=lambda k: k['conf'])
- # final_label = best_box['name']
- # final_conf = best_box['conf']
- # print(f" ✅ 检测到:{final_label} (conf: {final_conf:.2f})")
- # else:
- # print(f" ⚠️ 未检测到有效物体,标记为 other_room")
- #
- # block_list.append({
- # "id": int(b_id - 1),
- # "points": points,
- # "label": final_label,
- # "center": [int(cx), int(cy)]
- # })
- #
- # return block_list
- def build_block_data(rgb_img, block_mask, model_path="room_cls.pt"):
- """基于单张 RGB 和 block mask 生成 block 列表(使用 YOLO 分类模型)。"""
- model = _load_yolo_model(model_path)
- blocks = _to_gray_mask(block_mask)
- if rgb_img is None or blocks is None:
- raise FileNotFoundError("图片路径错误,请检查文件是否存在")
- if rgb_img.shape[:2] != blocks.shape[:2]:
- rgb_img = cv2.resize(rgb_img, (blocks.shape[1], blocks.shape[0]))
- _, blocks_bin = cv2.threshold(blocks, 127, 255, cv2.THRESH_BINARY)
- num_blocks, block_labels = cv2.connectedComponents(blocks_bin, connectivity=8)
- actual_room_count = num_blocks - 1
- print(f"🔍 检测到 {actual_room_count} 个独立的房间块")
- block_list = []
- for b_id in range(1, num_blocks):
- print(f"\n📍 处理 Block {b_id - 1}...")
- mask_single = (block_labels == b_id).astype(np.uint8)
- contours, _ = cv2.findContours(mask_single, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
- points = []
- cx = cy = 0
- if contours:
- largest_contour = max(contours, key=cv2.contourArea)
- epsilon = 2.0
- simplified = cv2.approxPolyDP(largest_contour, epsilon, True)
- for pt in simplified:
- points.append(int(pt[0][0]))
- points.append(int(pt[0][1]))
- M = cv2.moments(simplified)
- if M['m00'] != 0:
- cx = int(M['m10'] / M['m00'])
- cy = int(M['m01'] / M['m00'])
- else:
- cx = int(np.mean(simplified[:, 0, 0]))
- cy = int(np.mean(simplified[:, 0, 1]))
- x, y, w, h = cv2.boundingRect(mask_single)
- if not contours:
- cx = int(x + w / 2)
- cy = int(y + h / 2)
- # ---------------------------------------------------------
- # 修改点 1:提取纯净的 ROI(剔除背景噪声)
- # ---------------------------------------------------------
- pad = 20
- y1, y2 = max(0, y - pad), min(rgb_img.shape[0], y + h + pad)
- x1, x2 = max(0, x - pad), min(rgb_img.shape[1], x + w + pad)
- # 使用 .copy() 防止修改原图
- roi_rgb = rgb_img[y1:y2, x1:x2].copy()
- # 提取相同位置的掩码切片
- roi_mask_patch = mask_single[y1:y2, x1:x2]
- # 核心逻辑:将掩码为 0(非当前房间)的像素,全部置为黑色 [0, 0, 0]
- roi_rgb[roi_mask_patch == 0] = [0, 0, 0]
- # ---------------------------------------------------------
- # 修改点 2:使用 YOLO 分类模型的输出解析逻辑
- # ---------------------------------------------------------
- final_label = "other_room"
- final_conf = 0.0
- if roi_rgb.shape[0] > 10 and roi_rgb.shape[1] > 10:
- # 执行分类推理
- results = model(roi_rgb, verbose=False)[0]
- # 分类模型的输出在 results.probs 中,不再是 results.boxes
- if hasattr(results, 'probs') and results.probs is not None:
- # 获取置信度最高的类别索引
- top1_idx = int(results.probs.top1)
- # 获取最高置信度数值
- top1_conf = float(results.probs.top1conf.cpu().numpy())
- # 获取类别名称
- top1_label = model.names[top1_idx]
- # 设定分类置信度阈值 (可根据需求调整,比如 0.4)
- if top1_conf >= 0.15:
- final_label = top1_label
- final_conf = top1_conf
- print(f" ✅ 分类检测到:{final_label} (conf: {final_conf:.2f})")
- else:
- print(f" ⚠️ 最高分类置信度偏低 ({top1_label}: {top1_conf:.2f}),标记为 other_room")
- else:
- print(" ⚠️ 模型输出中没有分类概率,请检查是否加载了正确的 -cls 分类模型!")
- block_list.append({
- "id": int(b_id - 1),
- "points": points,
- "label": final_label,
- "center": [int(cx), int(cy)]
- })
- return block_list
- def add_block_labels_to_json(json_path, rgb_path, block_mask_path, model_path="room_seg.pt", output_json_path=None):
- """
- 读取现有 JSON 文件,追加 block 信息(轮廓坐标 + YOLO 检测标签)
- """
- with open(json_path, 'r', encoding='utf-8') as f:
- json_data = json.load(f)
- print(f"📊 读取原有 JSON: {json_path}")
- print(f" - 连接区域数:{len(json_data.get('connect_area', []))}")
- img_rgb = cv2.imread(rgb_path)
- blocks = cv2.imread(block_mask_path, cv2.IMREAD_GRAYSCALE)
- if img_rgb is None or blocks is None:
- raise FileNotFoundError("图片路径错误,请检查文件是否存在")
- json_data['block'] = build_block_data(img_rgb, blocks, model_path=model_path)
- if output_json_path is None:
- output_json_path = json_path
- with open(output_json_path, 'w', encoding='utf-8') as f:
- json.dump(json_data, f, indent=2, ensure_ascii=False)
- print(f"\n✅ 处理完成!")
- print(f" - 总房间块数:{len(json_data['block'])}")
- print(f" - JSON 保存至:{output_json_path}")
- print(f"\n📋 Block 详情:")
- print(f" {'ID':<6} {'Points':<10} {'Label':<20}")
- print(f" {'-' * 40}")
- for block in json_data['block']:
- points_info = f"{len(block['points']) // 2}个点"
- print(f" {block['id']:<6} {points_info:<10} {block['label']:<20}")
- return json_data
- def detect_furniture_in_image(rgb_img, model_path='furniture_detect.onnx'):
- """在内存中执行家具检测,返回家具列表。"""
- model = _load_yolo_model(model_path)
- if rgb_img is None:
- raise FileNotFoundError("无法读取 RGB 图像")
- results = model(rgb_img, conf=0.25, verbose=False)[0]
- allowed_labels = {'sofa', 'chair', 'desk', 'bed', 'window'}
- furniture_list = []
- if len(results.boxes) > 0:
- boxes_xyxy = results.boxes.xyxy.cpu().numpy()
- boxes_cls = results.boxes.cls.cpu().numpy()
- for i in range(len(boxes_xyxy)):
- bx1, by1, bx2, by2 = [int(v) for v in boxes_xyxy[i]]
- label = model.names[int(boxes_cls[i])]
- if label not in allowed_labels:
- continue
- furniture_list.append({
- 'id': len(furniture_list),
- 'label': label,
- 'center': [(bx1 + bx2) // 2, (by1 + by2) // 2],
- 'points': {
- 'x1': bx1, 'y1': by1,
- 'x2': bx2, 'y2': by1,
- 'x3': bx2, 'y3': by2,
- 'x4': bx1, 'y4': by2
- }
- })
- return furniture_list
- def detect_furniture(json_path, rgb_path, model_path='furniture_detect.onnx', output_json_path=None):
- """
- 使用 YOLO 模型检测家具,并将结果写入 JSON。
- """
- with open(json_path, 'r', encoding='utf-8') as f:
- json_data = json.load(f)
- img_rgb = cv2.imread(rgb_path)
- if img_rgb is None:
- raise FileNotFoundError(f"无法读取图片:{rgb_path}")
- json_data['furniture'] = detect_furniture_in_image(img_rgb, model_path=model_path)
- if output_json_path is None:
- output_json_path = json_path
- with open(output_json_path, 'w', encoding='utf-8') as f:
- json.dump(json_data, f, indent=2, ensure_ascii=False)
- print(f"✅ 家具检测完成,共检测到 {len(json_data['furniture'])} 个家具,JSON 保存至:{output_json_path}")
- return json_data
- def verify_final_json(json_path, output_verify_path):
- """
- 验证最终 JSON 文件,可视化所有数据
- 显示规则:
- - block: 白色多边形,中心显示 label 文字
- - connect_area: 绿色矩形
- 参数:
- json_path: JSON 文件路径
- output_verify_path: 验证结果图片保存路径
- """
- # 1. 读取 JSON 数据
- with open(json_path, 'r', encoding='utf-8') as f:
- data = json.load(f)
- print(f"📊 读取 JSON 文件:{json_path}")
- print(f" - 图片尺寸:{data['image_size']['width']} x {data['image_size']['height']}")
- print(f" - 连接区域数:{len(data.get('connect_area', []))}")
- print(f" - 房间块数:{len(data.get('block', []))}")
- # 2. 创建画布(黑色背景)
- width = data['image_size']['width']
- height = data['image_size']['height']
- result = np.zeros((height, width, 3), dtype=np.uint8)
- # 3. 绘制 block(白色多边形 + 中心 label)
- print(f"\n📋 绘制 {len(data.get('block', []))} 个房间块:")
- for block in data.get('block', []):
- block_id = block['id']
- points = block['points']
- label = block['label']
- # 转换为多边形格式 [[x1,y1], [x2,y2], ...]
- if len(points) >= 6: # 至少 3 个点
- pts = np.array([[points[j], points[j + 1]] for j in range(0, len(points), 2)], dtype=np.int32)
- # 绘制白色填充多边形
- cv2.fillPoly(result, [pts], (255, 255, 255))
- # 绘制白色边框(方便看清边界)
- cv2.polylines(result, [pts], True, (200, 200, 200), 2)
- # 计算中心点
- M = cv2.moments(pts)
- if M['m00'] != 0:
- cx = int(M['m10'] / M['m00'])
- cy = int(M['m01'] / M['m00'])
- else:
- cx, cy = pts.mean(axis=0).astype(int)
- # 在中心位置绘制 label 文字(黑色文字 + 白色背景)
- text = f"{label}"
- (text_w, text_h), baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)
- # 绘制白色背景矩形
- cv2.rectangle(result,
- (cx - text_w // 2 - 5, cy - text_h // 2 - 5),
- (cx + text_w // 2 + 5, cy + text_h // 2 + 5),
- (255, 255, 255), -1)
- # 绘制黑色文字
- cv2.putText(result, text, (cx - text_w // 2, cy + text_h // 2),
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)
- print(f" ✅ Block {block_id}: {label} ({len(points) // 2}个点)")
- else:
- print(f" ⚠️ Block {block_id}: 点数不足,跳过")
- # 4. 绘制 connect_area(绿色矩形)
- print(f"\n📋 绘制 {len(data.get('connect_area', []))} 个连接区域:")
- for area in data.get('connect_area', []):
- area_id = area['id']
- x = area['x']
- y = area['y']
- w = area['w']
- h = area['h']
- block_pair = area['block_pair']
- # 绘制绿色填充矩形
- cv2.rectangle(result, (x, y), (x + w, y + h), (0, 255, 0), -1)
- # 绘制红色边框(方便看清边界)
- cv2.rectangle(result, (x, y), (x + w, y + h), (0, 0, 255), 2)
- print(f" ✅ Area {area_id}: ({x}, {y}) {w}x{h} 连接块 [{block_pair[0]}, {block_pair[1]}]")
- # 5. 添加图例说明
- legend_y = 30
- cv2.putText(result, "Legend:", (10, legend_y), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
- # 白色方块 - block
- cv2.rectangle(result, (10, legend_y + 10), (30, legend_y + 30), (255, 255, 255), -1)
- cv2.putText(result, "Room Block", (40, legend_y + 27), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
- # 绿色方块 - connect_area
- cv2.rectangle(result, (150, legend_y + 10), (170, legend_y + 30), (0, 255, 0), -1)
- cv2.putText(result, "Connect Area", (180, legend_y + 27), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
- # 6. 保存验证图片
- cv2.imwrite(output_verify_path, result)
- print(f"\n✅ 验证图片保存至:{output_verify_path}")
- return data
- # ===========================================================================
- # PART 2: 深光 refine (内联 txt2json)
- # ===========================================================================
- # ================= 1. 工具函数 =================
- def to_key(p, precision=1):
- return (round(float(p[0]), precision), round(float(p[1]), precision))
- def is_orthogonal(seg, thresh=1e-1):
- dx, dy = abs(seg[2] - seg[0]), abs(seg[3] - seg[1])
- return dx < thresh or dy < thresh
- def get_line_intersection(s1, s2_point, s2_dir):
- x1, y1, x2, y2 = s1
- dx1, dy1 = x2 - x1, y2 - y1
- x3, y3 = s2_point
- vx, vy = s2_dir
- det = vx * dy1 - vy * dx1
- if abs(det) < 1e-6:
- return [float(x2), float(y2)]
- t1 = (vx * (y3 - y1) - vy * (x3 - x1)) / det
- return [float(x1 + t1 * dx1), float(y1 + t1 * dy1)]
- def is_loop_inside(small_loop, large_loop):
- poly = np.array([[s[0], s[1]] for s in large_loop], dtype=np.float32)
- for s in small_loop:
- pt = (float(s[0]), float(s[1]))
- dist = cv2.pointPolygonTest(poly, pt, False)
- if dist < 0:
- return False
- return True
- # ================= 2. 拓扑基础逻辑 =================
- def group_segments_by_connectivity(segments, stitch_dist=5.0):
- num_segments = len(segments)
- parent = list(range(num_segments))
- def find(i):
- if parent[i] == i: return i
- parent[i] = find(parent[i])
- return parent[i]
- def union(i, j):
- root_i, root_j = find(i), find(j)
- if root_i != root_j: parent[root_i] = root_j
- for i in range(num_segments):
- for j in range(i + 1, num_segments):
- p1_set = [segments[i][:2], segments[i][2:]]
- p2_set = [segments[j][:2], segments[j][2:]]
- if any(np.linalg.norm(np.array(a) - np.array(b)) < stitch_dist for a in p1_set for b in p2_set):
- union(i, j)
- groups = {}
- for i in range(num_segments):
- groups.setdefault(find(i), []).append(segments[i])
- return list(groups.values())
- def orthogonalize_and_move_nodes(segment_groups, angle_thresh_deg=15):
- refined_groups = []
- for group in segment_groups:
- node_pool = {}
- def get_shared_node(p):
- pk = to_key(p, precision=1)
- for ex_pk, obj in node_pool.items():
- if np.linalg.norm(np.array(pk) - np.array(ex_pk)) < 2.5: return obj
- new_node = np.array(p, dtype=np.float32)
- node_pool[pk] = new_node
- return new_node
- graph_segs = [(get_shared_node(s[:2]), get_shared_node(s[2:])) for s in group]
- for _ in range(3):
- for p1, p2 in graph_segs:
- dx, dy = abs(p2[0] - p1[0]), abs(p2[1] - p1[1])
- angle = math.degrees(math.atan2(dy, dx))
- if angle < angle_thresh_deg or angle > (180 - angle_thresh_deg):
- y_avg = (p1[1] + p2[1]) / 2
- p1[1] = p2[1] = y_avg
- elif abs(angle - 90) < angle_thresh_deg:
- x_avg = (p1[0] + p2[0]) / 2
- p1[0] = p2[0] = x_avg
- refined_groups.append([[p1[0], p1[1], p2[0], p2[1]] for p1, p2 in graph_segs])
- return refined_groups
- def bridge_isolated_endpoints(all_segments, snap_thresh=25.0):
- degree_map = {}
- for s in all_segments:
- for p in [tuple(s[:2]), tuple(s[2:])]:
- k = to_key(p, precision=1)
- degree_map.setdefault(k, []).append(p)
- isolated_pts = [nodes[0] for k, nodes in degree_map.items() if len(nodes) == 1]
- new_lines = []
- used = set()
- for i in range(len(isolated_pts)):
- if i in used: continue
- best_j, min_d = -1, snap_thresh
- for j in range(i + 1, len(isolated_pts)):
- if j in used: continue
- d = np.linalg.norm(np.array(isolated_pts[i]) - np.array(isolated_pts[j]))
- if d < min_d: min_d, best_j = d, j
- if best_j != -1:
- p1, p2 = isolated_pts[i], isolated_pts[best_j]
- new_lines.append([p1[0], p1[1], p2[0], p2[1]])
- used.add(i)
- used.add(best_j)
- return new_lines
- def find_max_length_loop(group, dist_thresh=2.0):
- if not group: return []
- nodes = []
- def get_node_idx(p):
- for i, target in enumerate(nodes):
- if np.linalg.norm(np.array(p) - target) < dist_thresh: return i
- nodes.append(np.array(p))
- return len(nodes) - 1
- edges = []
- for s in group:
- u, v = get_node_idx(s[:2]), get_node_idx(s[2:])
- if u == v: continue
- edges.append({'u': u, 'v': v, 'len': np.linalg.norm(np.array(s[:2]) - np.array(s[2:])), 'data': s})
- graph = {}
- for i, e in enumerate(edges):
- graph.setdefault(e['u'], []).append(i)
- graph.setdefault(e['v'], []).append(i)
- memo = {"best": [], "max": 0}
- def dfs(curr, start, path, length, visited):
- for idx in graph.get(curr, []):
- nxt = edges[idx]['v'] if edges[idx]['u'] == curr else edges[idx]['u']
- if nxt == start and len(path) >= 2:
- if length + edges[idx]['len'] > memo["max"]:
- memo["max"] = length + edges[idx]['len']
- memo["best"] = path + [idx]
- continue
- if nxt not in visited:
- dfs(nxt, start, path + [idx], length + edges[idx]['len'], visited | {nxt})
- starts = [n for n in graph if len(graph[n]) >= 2]
- for s_node in starts[:4]:
- dfs(s_node, s_node, [], 0, {s_node})
- if not memo["best"]: return []
- res = []
- e0, e1 = edges[memo["best"][0]], edges[memo["best"][1]]
- common = e1['u'] if e1['u'] in (e0['u'], e0['v']) else e1['v']
- curr_n = e0['v'] if e0['u'] == common else e0['u']
- for idx in memo["best"]:
- e = edges[idx]
- s = e['data']
- if get_node_idx(s[:2]) == curr_n:
- res.append(list(s))
- curr_n = get_node_idx(s[2:])
- else:
- res.append([s[2], s[3], s[0], s[1]])
- curr_n = get_node_idx(s[:2])
- return res
- # ================= 3. 整流逻辑 =================
- def apply_user_refinement_with_moving(group):
- if len(group) < 2: return group
- start_idx = -1
- for i, s in enumerate(group):
- if is_orthogonal(s):
- start_idx = i
- break
- if start_idx == -1: return group
- w_list = [list(s) for s in (group[start_idx:] + group[:start_idx])]
- refined = []
- i = 0
- while i < len(w_list):
- curr_seg = w_list[i]
- refined.append(curr_seg)
- nxt_idx = i + 1
- if nxt_idx >= len(w_list): break
- if not is_orthogonal(w_list[nxt_idx]):
- p_idx = nxt_idx
- found_lp = False
- while p_idx < len(w_list):
- if is_orthogonal(w_list[p_idx]):
- found_lp = True
- break
- p_idx += 1
- if found_lp:
- lp = w_list[p_idx]
- p2_end, lp_start = [curr_seg[2], curr_seg[3]], [lp[0], lp[1]]
- is_h1, is_h2 = abs(curr_seg[3] - curr_seg[1]) < 1e-1, abs(lp[3] - lp[1]) < 1e-1
- if is_h1 == is_h2:
- p_mid = [p2_end[0], lp_start[1]] if is_h1 else [lp_start[0], p2_end[1]]
- else:
- p_mid = [lp_start[0], p2_end[1]] if is_h1 else [p2_end[0], lp_start[1]]
- lp[0], lp[1] = p_mid[0], p_mid[1]
- refined.append([p2_end[0], p2_end[1], lp[0], lp[1]])
- i = p_idx
- else:
- refined.extend(w_list[nxt_idx:])
- break
- else:
- i += 1
- return refined
- def merge_collinear_segments(ordered_group, dist_thresh=0.5):
- if len(ordered_group) < 2: return ordered_group
- merged, curr_seg = [], list(ordered_group[0])
- for i in range(1, len(ordered_group)):
- next_seg = ordered_group[i]
- is_h = abs(curr_seg[1] - curr_seg[3]) < dist_thresh and abs(next_seg[1] - next_seg[3]) < dist_thresh and abs(
- curr_seg[3] - next_seg[1]) < dist_thresh
- is_v = abs(curr_seg[0] - curr_seg[2]) < dist_thresh and abs(next_seg[0] - next_seg[2]) < dist_thresh and abs(
- curr_seg[2] - next_seg[0]) < dist_thresh
- if is_h or is_v:
- curr_seg[2], curr_seg[3] = next_seg[2], next_seg[3]
- else:
- merged.append(curr_seg)
- curr_seg = list(next_seg)
- merged.append(curr_seg)
- return merged
- def merge_parallel_lines_topology(group, dist_thresh=15.0):
- if not group: return group
- segs = [list(s) for s in group]
- def get_overlap(min1, max1, min2, max2):
- overlap_min = max(min1, min2)
- overlap_max = min(max1, max2)
- return overlap_max - overlap_min if overlap_max > overlap_min else 0
- iteration = 0
- while True:
- merged_in_this_round = False
- merged_indices = set()
- for i in range(len(segs)):
- if i in merged_indices: continue
- for j in range(i + 1, len(segs)):
- if j in merged_indices: continue
- s1, s2 = segs[i], segs[j]
- is_h1, is_h2 = abs(s1[1] - s1[3]) < 1e-1, abs(s2[1] - s2[3]) < 1e-1
- is_v1, is_v2 = abs(s1[0] - s1[2]) < 1e-1, abs(s2[0] - s2[2]) < 1e-1
- should_merge = False
- new_pos = 0
- if is_h1 and is_h2:
- dist = abs(s1[1] - s2[1])
- overlap = get_overlap(min(s1[0], s1[2]), max(s1[0], s1[2]), min(s2[0], s2[2]), max(s2[0], s2[2]))
- if dist < dist_thresh and overlap > 0:
- should_merge = True
- new_pos = (s1[1] + s2[1]) / 2
- old_y1, old_y2 = s1[1], s2[1]
- for s in segs:
- if abs(s[1] - old_y1) < 1e-1 or abs(s[1] - old_y2) < 1e-1: s[1] = new_pos
- if abs(s[3] - old_y1) < 1e-1 or abs(s[3] - old_y2) < 1e-1: s[3] = new_pos
- elif is_v1 and is_v2:
- dist = abs(s1[0] - s2[0])
- overlap = get_overlap(min(s1[1], s1[3]), max(s1[1], s1[3]), min(s2[1], s2[3]), max(s2[1], s2[3]))
- if dist < dist_thresh and overlap > 0:
- should_merge = True
- new_pos = (s1[0] + s2[0]) / 2
- old_x1, old_x2 = s1[0], s2[0]
- for s in segs:
- if abs(s[0] - old_x1) < 1e-1 or abs(s[0] - old_x2) < 1e-1: s[0] = new_pos
- if abs(s[2] - old_x1) < 1e-1 or abs(s[2] - old_x2) < 1e-1: s[2] = new_pos
- if should_merge:
- s1[0], s1[2] = min(s1[0], s1[2], s2[0], s2[2]), max(s1[0], s1[2], s2[0], s2[2])
- s1[1], s1[3] = min(s1[1], s1[3], s2[1], s2[3]), max(s1[1], s1[3], s2[1], s2[3])
- merged_indices.add(j)
- merged_in_this_round = True
- break
- if merged_in_this_round: break
- if merged_indices:
- segs = [s for idx, s in enumerate(segs) if idx not in merged_indices]
- if not merged_in_this_round: break
- iteration += 1
- if iteration > 100: break
- return segs
- def get_line_angle_0_90(s1, s2):
- v1 = (s1[2] - s1[0], s1[3] - s1[1])
- v2 = (s2[2] - s2[0], s2[3] - s2[1])
- mag1 = math.sqrt(v1[0] ** 2 + v1[1] ** 2)
- mag2 = math.sqrt(v2[0] ** 2 + v2[1] ** 2)
- if mag1 < 1e-6 or mag2 < 1e-6: return 0
- dot_product = abs(v1[0] * v2[0] + v1[1] * v2[1])
- cos_theta = dot_product / (mag1 * mag2)
- cos_theta = max(-1.0, min(1.0, cos_theta))
- return math.acos(cos_theta)
- def apply_user_refinement_with_moving_refine_distance_angle(group, length_threshold=30, angle_threshold_deg=30):
- if len(group) < 2: return group
- angle_thresh_rad = math.radians(angle_threshold_deg)
- start_idx = -1
- for i, s in enumerate(group):
- if is_orthogonal(s):
- start_idx = i
- break
- if start_idx == -1: return group
- w_list = [list(s) for s in (group[start_idx:] + group[:start_idx])]
- refined = []
- i = 0
- while i < len(w_list):
- curr_seg = w_list[i]
- refined.append(curr_seg)
- nxt_idx = i + 1
- if nxt_idx >= len(w_list): break
- if not is_orthogonal(w_list[nxt_idx]):
- p_idx = nxt_idx
- found_lp = False
- total_gap_len = 0
- gap_segs = []
- while p_idx < len(w_list):
- if is_orthogonal(w_list[p_idx]):
- found_lp = True
- break
- seg_tmp = w_list[p_idx]
- total_gap_len += math.sqrt((seg_tmp[2] - seg_tmp[0]) ** 2 + (seg_tmp[3] - seg_tmp[1]) ** 2)
- gap_segs.append(seg_tmp)
- p_idx += 1
- if found_lp:
- lp = w_list[p_idx]
- if total_gap_len < length_threshold:
- p2_end, lp_start = [curr_seg[2], curr_seg[3]], [lp[0], lp[1]]
- is_h1, is_h2 = abs(curr_seg[3] - curr_seg[1]) < 1e-1, abs(lp[3] - lp[1]) < 1e-1
- if is_h1 == is_h2:
- p_mid = [p2_end[0], lp_start[1]] if is_h1 else [lp_start[0], p2_end[1]]
- else:
- p_mid = [lp_start[0], p2_end[1]] if is_h1 else [p2_end[0], lp_start[1]]
- lp[0], lp[1] = p_mid[0], p_mid[1]
- refined.append([p2_end[0], p2_end[1], lp[0], lp[1]])
- i = p_idx
- else:
- gap_segs[0][0], gap_segs[0][1] = curr_seg[2], curr_seg[3]
- merged_gap = []
- temp_s = list(gap_segs[0])
- for k in range(1, len(gap_segs)):
- next_s = gap_segs[k]
- if get_line_angle_0_90(temp_s, next_s) < angle_thresh_rad:
- temp_s[2], temp_s[3] = next_s[2], next_s[3]
- else:
- merged_gap.append(temp_s)
- temp_s = list(next_s)
- temp_s[0], temp_s[1] = merged_gap[-1][2], merged_gap[-1][3]
- merged_gap.append(temp_s)
- merged_gap[-1][2], merged_gap[-1][3] = lp[0], lp[1]
- refined.extend(merged_gap)
- i = p_idx
- else:
- refined.extend(w_list[nxt_idx:])
- break
- else:
- i += 1
- return refined
- # ================= 4. 针对单个 Block 的处理流程 =================
- def refine_single_block_segments(segments):
- if not segments: return []
- temp_groups = [segments]
- temp_groups = orthogonalize_and_move_nodes(temp_groups, angle_thresh_deg=15)
- segments = temp_groups[0]
- rect_loop = apply_user_refinement_with_moving_refine_distance_angle(segments, length_threshold=30,
- angle_threshold_deg=30)
- rect_loop = merge_collinear_segments(rect_loop, dist_thresh=2)
- rect_loop = merge_parallel_lines_topology(rect_loop, dist_thresh=12)
- return rect_loop
- # ================= 5. 主处理流程 =================
- def refine_blocks_in_data(data):
- """将 block.points 从扁平点集转换为整流后的线段格式。"""
- blocks = data.get("block", data.get("blocks", []))
- if not blocks:
- if isinstance(data, list):
- blocks = data
- else:
- print("⚠️ JSON 中未找到 'block' 或 'blocks' 字段")
- return data
- print(f"🚀 开始处理 {len(blocks)} 个块...")
- total_segments = 0
- for idx, block in enumerate(tqdm(blocks, desc="Refining Blocks")):
- block_id = block.get('id', idx)
- points_data = block.get("points", [])
- if not points_data:
- print(f"⚠️ Block {block_id} 无 points,跳过")
- continue
- if isinstance(points_data[0], list) and len(points_data[0]) == 4:
- source_segments = [[float(v) for v in seg] for seg in points_data]
- else:
- if len(points_data) < 4:
- print(f"⚠️ Block {block_id} 点数不足,跳过")
- continue
- points = []
- for i in range(0, len(points_data), 2):
- if i + 1 < len(points_data):
- points.append([float(points_data[i]), float(points_data[i + 1])])
- if len(points) < 2:
- continue
- source_segments = []
- for i in range(len(points)):
- p1 = points[i]
- p2 = points[(i + 1) % len(points)]
- source_segments.append([p1[0], p1[1], p2[0], p2[1]])
- refined_segments = refine_single_block_segments(source_segments)
- if not refined_segments:
- refined_segments = source_segments
- segments_int = []
- for seg in refined_segments:
- segments_int.append([
- int(round(seg[0])),
- int(round(seg[1])),
- int(round(seg[2])),
- int(round(seg[3]))
- ])
- block["points"] = segments_int
- block["refined"] = True
- block["format"] = "segments"
- block["segment_count"] = len(segments_int)
- total_segments += len(segments_int)
- data["format_version"] = "segments_v1"
- data["total_segments"] = total_segments
- return data
- def process_json_blocks(json_path, rgb_path, save_path, txt_folder, save_json=False, json_backup=True):
- """
- 直接保存绘制的 p1, p2 坐标到.txt 文件
- """
- # 1. 加载 JSON
- try:
- with open(json_path, 'r', encoding='utf-8') as f:
- data = json.load(f)
- except Exception as e:
- print(f"❌ 读取 JSON 失败:{e}")
- return
- # 2. 加载背景图
- output_img = cv2.imread(rgb_path)
- if output_img is None:
- print(f"❌ 无法读取背景图:{rgb_path}")
- return
- # 3. 获取 Blocks 列表
- blocks = data.get("block", data.get("blocks", []))
- if not blocks:
- if isinstance(data, list):
- blocks = data
- else:
- print("⚠️ JSON 中未找到 'block' 或 'blocks' 字段")
- return
- print(f"🚀 开始处理 {len(blocks)} 个块...")
- # 4. 创建 txt 文件夹(txt_folder=None 时跳过 txt 保存)
- txt_path = None
- if txt_folder is not None:
- if not os.path.exists(txt_folder):
- os.makedirs(txt_folder)
- txt_filename = os.path.basename(save_path).replace('.png', '.txt')
- txt_path = os.path.join(txt_folder, txt_filename)
- # 6. 遍历每个块,独立处理
- all_lines_data = [] # 存储所有线段数据用于保存 txt
- for idx, block in enumerate(tqdm(blocks, desc="Refining Blocks")):
- block_id = block.get('id', idx)
- label = block.get('label', 'door')
- points_flat = block.get("points", [])
- if len(points_flat) < 4:
- print(f"⚠️ Block {block_id} 点数不足,跳过")
- continue
- # 将扁平数组转换为点列表
- points = []
- for i in range(0, len(points_flat), 2):
- if i + 1 < len(points_flat):
- points.append([float(points_flat[i]), float(points_flat[i + 1])])
- if len(points) < 2:
- continue
- # A. 将多边形点转换为线段
- segments = []
- for i in range(len(points)):
- p1 = points[i]
- p2 = points[(i + 1) % len(points)]
- segments.append([p1[0], p1[1], p2[0], p2[1]])
- # B. 对该块的线段进行完整整流优化
- refined_segments = refine_single_block_segments(segments)
- # C. 保存 JSON(可选)
- if save_json and refined_segments:
- segments_int = []
- for seg in refined_segments:
- segments_int.append([
- int(round(seg[0])),
- int(round(seg[1])),
- int(round(seg[2])),
- int(round(seg[3]))
- ])
- block["points"] = segments_int
- block["refined"] = True
- block["format"] = "segments"
- # D. 【关键】绘制并记录坐标(与可视化完全一致)
- color = np.random.randint(100, 255, (3,)).tolist()
- for s in refined_segments:
- p1 = (int(s[0]), int(s[1]))
- p2 = (int(s[2]), int(s[3]))
- # 绘制到图像
- cv2.line(output_img, p1, p2, color, 2, cv2.LINE_AA)
- cv2.circle(output_img, p1, 3, (255, 255, 255), -1)
- # 记录到 txt 数据
- # 格式:block_id,label,x1,y1,x2,y2
- all_lines_data.append(f"{block_id},{label},{p1[0]},{p1[1]},{p2[0]},{p2[1]}")
- # 7. 保存 txt 文件(可选)
- if txt_path is not None:
- with open(txt_path, 'w', encoding='utf-8') as f:
- f.write(f"# {os.path.basename(txt_path)}\n")
- f.write(f"# Format: block_id,label,x1,y1,x2,y2\n")
- f.write(f"# Total lines: {len(all_lines_data)}\n")
- f.write(f"#\n")
- for line in all_lines_data:
- f.write(line + "\n")
- print(f"📄 线段数据已保存至:{txt_path}")
- print(f" 共 {len(all_lines_data)} 条线段")
- # 8. 保存图像
- cv2.imwrite(save_path, output_img)
- print(f"✨ 图像已保存至:{save_path}")
- # 9. 保存 JSON(可选)
- if save_json:
- try:
- if json_backup:
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- backup_path = json_path.replace(".json", f"_backup_{timestamp}.json")
- shutil.copy2(json_path, backup_path)
- print(f"💾 原 JSON 已备份至:{backup_path}")
- save_json_path = json_path
- with open(save_json_path, 'w', encoding='utf-8') as f:
- json.dump(data, f, indent=2, ensure_ascii=False)
- print(f"💾 优化后的 JSON 已保存至:{save_json_path}")
- except Exception as e:
- print(f"❌ 保存 JSON 失败:{e}")
- print(f"✨ 全部处理完成!")
- # ================= 6. 主运行流程 =================
- # ===========================================================================
- # PART 3: txt2json — merge_txt_to_json
- # ===========================================================================
- def merge_txt_to_json(txt_path, json_path, output_path=None):
- """
- 将 txt 文件中的线段数据整合到 JSON 文件中
- txt 格式:block_id,label,x1,y1,x2,y2
- JSON 格式:points: [(x1,y1,x2,y2), (x1,y1,x2,y2), ...]
- """
- print(f"\n{'=' * 60}")
- print(f"🔄 合并 txt 数据到 JSON")
- print('=' * 60)
- # 1. 读取 txt 文件
- print(f"📄 读取 txt 文件:{txt_path}")
- lines_data = []
- with open(txt_path, 'r', encoding='utf-8') as f:
- for line in f:
- line = line.strip()
- if line.startswith('#') or not line:
- continue
- parts = line.split(',')
- if len(parts) == 6:
- try:
- block_id = int(parts[0])
- label = parts[1]
- x1, y1, x2, y2 = int(parts[2]), int(parts[3]), int(parts[4]), int(parts[5])
- lines_data.append({
- 'block_id': block_id,
- 'label': label,
- 'segment': (x1, y1, x2, y2)
- })
- except Exception as e:
- print(f" ⚠️ 解析失败:{line} - {e}")
- print(f" 共读取 {len(lines_data)} 条线段")
- # 2. 按 block_id 分组线段
- blocks_segments = defaultdict(list)
- blocks_labels = {}
- for item in lines_data:
- block_id = item['block_id']
- blocks_segments[block_id].append(item['segment'])
- blocks_labels[block_id] = item['label']
- print(f" 共 {len(blocks_segments)} 个 block")
- # 3. 读取原始 JSON
- print(f"\n📄 读取 JSON 文件:{json_path}")
- with open(json_path, 'r', encoding='utf-8') as f:
- data = json.load(f)
- # 4. 更新 blocks 数据
- blocks = data.get('block', data.get('blocks', []))
- updated_count = 0
- print(f"\n🔄 更新 {len(blocks)} 个 block...")
- for block in blocks:
- block_id = block.get('id', 0)
- if block_id in blocks_segments:
- # 将线段列表转换为 JSON 格式
- # 使用列表而不是元组(JSON 不支持元组)
- segments_list = []
- for seg in blocks_segments[block_id]:
- segments_list.append(list(seg)) # (x1,y1,x2,y2) → [x1,y1,x2,y2]
- # 更新 points 字段
- block['points'] = segments_list
- block['refined'] = True
- block['format'] = 'segments' # 标记格式
- block['segment_count'] = len(segments_list)
- # 更新 label(如果 txt 中有)
- if block_id in blocks_labels:
- block['label'] = blocks_labels[block_id]
- updated_count += 1
- print(f" ✅ Block {block_id}: {len(segments_list)} 线段")
- else:
- print(f" ⚠️ Block {block_id}: 未找到对应线段数据")
- # 5. 添加合并标记
- data['txt_merged'] = True
- data['total_segments'] = len(lines_data)
- data['format_version'] = 'segments_v1'
- # 6. 保存结果
- if output_path is None:
- output_path = json_path
- print(f"\n💾 保存至:{output_path}")
- with open(output_path, 'w', encoding='utf-8') as f:
- json.dump(data, f, indent=2, ensure_ascii=False)
- print(f"\n✨ 合并完成!")
- print(f" - 更新 block 数量:{updated_count}/{len(blocks)}")
- print(f" - 总线段数:{len(lines_data)}")
- return data
- def verify_json_segments(json_path):
- """
- 验证 JSON 中的线段格式是否正确
- """
- print(f"\n{'=' * 60}")
- print(f"🔍 验证 JSON 线段格式")
- print('=' * 60)
- with open(json_path, 'r', encoding='utf-8') as f:
- data = json.load(f)
- blocks = data.get('block', data.get('blocks', []))
- for block in blocks:
- block_id = block.get('id', 0)
- label = block.get('label', 'door')
- points = block.get('points', [])
- # 检查格式
- if isinstance(points, list) and len(points) > 0:
- if isinstance(points[0], list) and len(points[0]) == 4:
- # 正确的线段格式
- print(f" ✅ Block {block_id} ({label}): {len(points)} 线段 [x1,y1,x2,y2]")
- elif isinstance(points[0], (int, float)):
- # 旧的扁平数组格式
- print(f" ⚠️ Block {block_id} ({label}): 扁平数组格式 ({len(points) // 2} 点)")
- else:
- print(f" ❌ Block {block_id} ({label}): 未知格式")
- else:
- print(f" ❌ Block {block_id} ({label}): 无数据")
- print(f"\n✨ 验证完成!")
- # ================= 主运行流程 =================
- # ===========================================================================
- # PART 4: merge_segments
- # ===========================================================================
- from tqdm import tqdm
- def tqdm(x, **kw):
- return x
- # ---------------------------------------------------------------------------
- # Step 1: Angle normalization
- # ---------------------------------------------------------------------------
- def normalize_segment(seg, angle=10):
- """
- If a segment is nearly horizontal or nearly vertical, snap it.
- angle: tolerance in degrees
- - |theta| <= angle or |theta| >= 180-angle -> horizontal (fix y to midpoint)
- - |90 - |theta|| <= angle -> vertical (fix x to midpoint)
- seg: [x1, y1, x2, y2]
- Returns new [x1, y1, x2, y2].
- """
- x1, y1, x2, y2 = seg
- dx, dy = x2 - x1, y2 - y1
- if dx == 0 and dy == 0:
- return seg
- theta = abs(math.degrees(math.atan2(abs(dy), abs(dx)))) # 0..90
- if theta <= angle: # near horizontal
- mid_y = round((y1 + y2) / 2)
- return [x1, mid_y, x2, mid_y]
- if theta >= 90 - angle: # near vertical
- mid_x = round((x1 + x2) / 2)
- return [mid_x, y1, mid_x, y2]
- return seg
- def normalize_all(data, angle=10):
- for block in data.get('block', []):
- block['points'] = [normalize_segment(s, angle) for s in block['points']]
- # ---------------------------------------------------------------------------
- # Step 2: Cluster-based parallel merge (mean coordinate)
- # ---------------------------------------------------------------------------
- def segment_orientation(seg):
- x1, y1, x2, y2 = seg
- if y1 == y2:
- return 'H'
- if x1 == x2:
- return 'V'
- return None
- def snap_coord_global(all_segs, old_vals, new_val, ori):
- """
- Replace all occurrences of any value in old_vals with new_val
- for the relevant axis across all segments.
- ori='H': affects y-coords; ori='V': affects x-coords.
- """
- old_set = set(old_vals)
- for k, s in enumerate(all_segs):
- if ori == 'H':
- if segment_orientation(s) == 'H':
- if s[1] in old_set:
- all_segs[k][1] = new_val
- all_segs[k][3] = new_val
- else: # V seg: update y endpoints
- if s[1] in old_set:
- all_segs[k][1] = new_val
- if s[3] in old_set:
- all_segs[k][3] = new_val
- else: # ori == 'V'
- if segment_orientation(s) == 'V':
- if s[0] in old_set:
- all_segs[k][0] = new_val
- all_segs[k][2] = new_val
- else: # H seg: update x endpoints
- if s[0] in old_set:
- all_segs[k][0] = new_val
- if s[2] in old_set:
- all_segs[k][2] = new_val
- def segments_overlap(sa, sb, ori):
- """
- Check if two parallel segments have overlapping projection on the perpendicular axis.
- ori='H': compare x ranges; ori='V': compare y ranges.
- """
- if ori == 'H':
- a1, a2 = min(sa[0], sa[2]), max(sa[0], sa[2])
- b1, b2 = min(sb[0], sb[2]), max(sb[0], sb[2])
- else:
- a1, a2 = min(sa[1], sa[3]), max(sa[1], sa[3])
- b1, b2 = min(sb[1], sb[3]), max(sb[1], sb[3])
- return a1 <= b2 and b1 <= a2
- def cluster_and_merge(all_segs, ori, threshold):
- """
- Overlap-aware parallel merge:
- Two parallel segments are merged only if:
- 1. Their axis-coordinate distance < threshold, AND
- 2. Their projections on the perpendicular axis overlap.
- Uses union-find to group segments into clusters, then snaps each
- cluster to the mean coordinate.
- Returns True if any snapping occurred.
- """
- # Indices of segments with this orientation
- idxs = [i for i, s in enumerate(all_segs) if segment_orientation(s) == ori]
- if not idxs:
- return False
- # Union-Find
- parent = {i: i for i in idxs}
- def find(x):
- while parent[x] != x:
- parent[x] = parent[parent[x]]
- x = parent[x]
- return x
- def union(x, y):
- parent[find(x)] = find(y)
- for ii in range(len(idxs)):
- for jj in range(ii + 1, len(idxs)):
- i, j = idxs[ii], idxs[jj]
- si, sj = all_segs[i], all_segs[j]
- coord_i = si[1] if ori == 'H' else si[0]
- coord_j = sj[1] if ori == 'H' else sj[0]
- if abs(coord_i - coord_j) <= threshold and segments_overlap(si, sj, ori):
- union(i, j)
- # Group by cluster root
- from collections import defaultdict
- clusters = defaultdict(list)
- for i in idxs:
- clusters[find(i)].append(i)
- changed = False
- for members in clusters.values():
- if len(members) < 2:
- continue
- coords = [all_segs[i][1] if ori == 'H' else all_segs[i][0] for i in members]
- old_vals = list(set(coords))
- if len(old_vals) == 1: # 已全部对齐,无需操作
- continue
- new_val = round(sum(coords) / len(coords))
- snap_coord_global(all_segs, old_vals, new_val, ori)
- changed = True
- return changed
- def seg_key(s):
- p1 = (s[0], s[1])
- p2 = (s[2], s[3])
- return (min(p1, p2), max(p1, p2))
- def remove_degenerate_and_duplicates(segs):
- seen = set()
- result = []
- for s in segs:
- if s[0] == s[2] and s[1] == s[3]:
- continue
- k = seg_key(s)
- if k in seen:
- continue
- seen.add(k)
- result.append(s)
- return result
- def merge_all_blocks(data, threshold=6):
- """
- Global cross-block parallel merge using cluster mean strategy.
- Iterates H and V directions until no more clusters can be merged.
- """
- blocks = data.get('block', [])
- all_segs = []
- block_counts = []
- for block in blocks:
- segs = [list(s) for s in block['points']]
- all_segs.extend(segs)
- block_counts.append(len(segs))
- # Iterate until stable
- changed = True
- while changed:
- ch_h = cluster_and_merge(all_segs, 'H', threshold)
- ch_v = cluster_and_merge(all_segs, 'V', threshold)
- changed = ch_h or ch_v
- # Write back to blocks
- idx = 0
- for block, count in zip(blocks, block_counts):
- block_segs = remove_degenerate_and_duplicates(all_segs[idx:idx + count])
- block['points'] = block_segs
- block['segment_count'] = len(block_segs)
- idx += count
- data['total_segments'] = sum(len(b['points']) for b in blocks)
- # ---------------------------------------------------------------------------
- # Pipeline
- # ---------------------------------------------------------------------------
- def process_json(input_path, output_path, threshold=6, angle=10):
- with open(input_path, 'r') as f:
- data = json.load(f)
- normalize_all(data, angle=angle)
- merge_all_blocks(data, threshold=threshold)
- with open(output_path, 'w') as f:
- json.dump(data, f, indent=2)
- # ===========================================================================
- # PART 5: vis
- # ===========================================================================
- # ================= 1. 颜色映射 =================
- COLOR_MAP = {
- "living_room": (255, 180, 100), # BGR: 橙色
- "bed_room": (100, 255, 100), # BGR: 绿色
- "bath_room": (255, 100, 255), # BGR: 紫色
- "kitchen_room": (100, 255, 255), # BGR: 黄色
- "other_room": (180, 180, 180), # BGR: 灰色
- "balcony": (200, 150, 100), # BGR: 棕色
- }
- DEFAULT_COLOR = (200, 200, 200)
- # ================= 2. 工具函数 =================
- def get_image_size(data):
- """从 JSON 数据获取图像尺寸"""
- if 'image_size' in data:
- return data['image_size']['width'], data['image_size']['height']
- # 从线段数据推断
- max_x, max_y = 0, 0
- for block in data.get('block', []):
- segments = block.get('points', [])
- if isinstance(segments, list) and len(segments) > 0:
- if isinstance(segments[0], list) and len(segments[0]) == 4:
- for seg in segments:
- max_x = max(max_x, seg[0], seg[2])
- max_y = max(max_y, seg[1], seg[3])
- for area in data.get('connect_area', []):
- max_x = max(max_x, area.get('x', 0) + area.get('w', 0))
- max_y = max(max_y, area.get('y', 0) + area.get('h', 0))
- return int(max_x) + 100, int(max_y) + 100
- def get_points_from_segments(segments):
- """从线段列表提取所有唯一点"""
- if not segments:
- return []
- points_set = set()
- for seg in segments:
- if len(seg) == 4:
- points_set.add((int(seg[0]), int(seg[1])))
- points_set.add((int(seg[2]), int(seg[3])))
- return list(points_set)
- # ================= 3. 主可视化函数 =================
- def visualize_final_json(json_path, output_path, rgb_path=None, vis_door=False):
- """
- 最终可视化:线段格式 JSON
- 修改点:
- 1. 房间端点:白色点 (255, 255, 255)
- 2. 标签位置:使用 JSON 中的 center 字段
- """
- print(f"\n{'=' * 60}")
- print(f"🎨 最终可视化:{os.path.basename(json_path)}")
- print('=' * 60)
- # 1. 读取 JSON
- with open(json_path, 'r', encoding='utf-8') as f:
- data = json.load(f)
- # 2. 获取尺寸并创建画布
- width, height = get_image_size(data)
- if rgb_path and os.path.exists(rgb_path):
- canvas = cv2.imread(rgb_path)
- print(f"🖼️ 使用背景图:{rgb_path}")
- else:
- # canvas = np.ones((height, width, 3), dtype=np.uint8) * 255
- canvas = np.zeros((height, width, 3), dtype=np.uint8)
- print(f"📐 创建白色画布:{width}x{height}")
- # 3. 绘制房间块 (Blocks)
- blocks = data.get('block', [])
- print(f"\n🏠 绘制 {len(blocks)} 个房间块...")
- for block_idx, block in enumerate(blocks):
- block_id = block.get('id', block_idx)
- label = block.get('label', 'door')
- segments = block.get('points', [])
- center = block.get('center', None) # 【修改】使用 JSON 中的 center 字段
- # 获取颜色
- color = COLOR_MAP.get(label.lower(), DEFAULT_COLOR)
- # 验证线段格式
- if not (isinstance(segments, list) and len(segments) > 0):
- print(f" ⚠️ Block {block_id}: 无数据,跳过")
- continue
- if not (isinstance(segments[0], list) and len(segments[0]) == 4):
- print(f" ⚠️ Block {block_id}: 格式不正确,跳过")
- continue
- # 提取所有点
- points = get_points_from_segments(segments)
- if len(points) < 2:
- print(f" ⚠️ Block {block_id}: 点数不足,跳过")
- continue
- # 【1】绘制线段(按元素连线)
- for seg in segments:
- p1 = (int(seg[0]), int(seg[1]))
- p2 = (int(seg[2]), int(seg[3]))
- cv2.line(canvas, p1, p2, color, 2, cv2.LINE_AA)
- # 【2】绘制点(所有顶点标记)【修改】白色点
- for pt in points:
- cv2.circle(canvas, pt, 4, (255, 255, 255), -1) # 白色点
- # 【3】绘制标签【修改】使用 center 字段
- if center and len(center) == 2:
- cx, cy = int(center[0]), int(center[1])
- else:
- # 如果没有 center,回退到计算中心
- pts = np.array(points, dtype=np.int32)
- M = cv2.moments(pts)
- if M["m00"] != 0:
- cx = int(M["m10"] / M["m00"])
- cy = int(M["m01"] / M["m00"])
- else:
- cx = int(np.mean(pts[:, 0]))
- cy = int(np.mean(pts[:, 1]))
- print(f" ⚠️ Block {block_id}: 无 center 字段,使用计算中心")
- # 确保在画布内
- font = cv2.FONT_HERSHEY_SIMPLEX
- font_scale = 0.6
- thickness = 1
- (tw, th), _ = cv2.getTextSize(label, font, font_scale, thickness)
- cx = max(tw // 2 + 5, min(cx, width - tw // 2 - 5))
- cy = max(th // 2 + 5, min(cy, height - th // 2 - 5))
- # 标签背景(白色矩形)
- cv2.rectangle(canvas, (cx - tw // 2 - 3, cy - th - 3),
- (cx + tw // 2 + 3, cy + 3), (255, 255, 255), -1)
- # 标签文字(黑色)
- cv2.putText(canvas, label, (cx - tw // 2, cy),
- font, font_scale, (0, 0, 0), thickness)
- # Block ID(灰色小字)
- id_text = f"ID:{block_id}"
- (iw, ih), _ = cv2.getTextSize(id_text, font, 0.4, 1)
- cv2.putText(canvas, id_text, (cx - iw // 2, cy + 15),
- font, 0.4, (100, 100, 100), 1)
- print(f" ✅ Block {block_id} ({label}): {len(segments)} 线段,{len(points)} 点,center=({cx},{cy})")
- # 4. 绘制连接区域 (Connect Area)
- connect_areas = data.get('connect_area', [])
- print(f"\n🔗 绘制 {len(connect_areas)} 个连接区域...")
- for area_idx, area in enumerate(connect_areas):
- area_id = area.get('id', area_idx)
- x = area.get('x', 0)
- y = area.get('y', 0)
- w = area.get('w', 0)
- h = area.get('h', 0)
- label = area.get('label', 'door').lower()
- block_pair = area.get('block_pair', [])
- # 确保在画布内
- x = max(0, min(x, width - 1))
- y = max(0, min(y, height - 1))
- w = max(1, min(w, width - x))
- h = max(1, min(h, height - y))
- # 【unknown/unknow】红色实心矩形
- if vis_door and ("door" in label):
- cv2.rectangle(canvas, (x, y), (x + w, y + h), (0, 0, 255), -1) # 红色填充
- cv2.rectangle(canvas, (x, y), (x + w, y + h), (255, 255, 255), 1) # 白色边框
- print(f" ✅ Connect {area_id}: 🔴 红色实心矩形")
- # else:
- # # 其他:绿色边框
- # cv2.rectangle(canvas, (x, y), (x + w, y + h), (0, 255, 0), 2)
- # if label:
- # cv2.putText(canvas, label, (x + 3, y + 15),
- # cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1)
- # if block_pair:
- # pair_text = f"{block_pair[0]}-{block_pair[1]}"
- # cv2.putText(canvas, pair_text, (x + 3, y + h - 3),
- # cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1)
- # print(f" ✅ Connect {area_id}: 🟢 {label} ({block_pair})")
- # 5. 绘制家具 (Furniture)
- furniture_list = data.get('furniture', [])
- print(f"\n🛋️ 绘制 {len(furniture_list)} 个家具...")
- FURNITURE_COLOR = (0, 165, 255) # BGR: 橙色
- for item in furniture_list:
- label = item.get('label', '')
- center = item.get('center', None)
- pts = item.get('points', {})
- if pts:
- bx1, by1 = pts['x1'], pts['y1']
- bx2, by2 = pts['x3'], pts['y3']
- cv2.rectangle(canvas, (bx1, by1), (bx2, by2), FURNITURE_COLOR, 2)
- if center and len(center) == 2:
- cx, cy = int(center[0]), int(center[1])
- font = cv2.FONT_HERSHEY_SIMPLEX
- (tw, th), _ = cv2.getTextSize(label, font, 0.5, 1)
- cv2.rectangle(canvas, (cx - tw // 2 - 3, cy - th - 3),
- (cx + tw // 2 + 3, cy + 3), FURNITURE_COLOR, -1)
- cv2.putText(canvas, label, (cx - tw // 2, cy),
- font, 0.5, (255, 255, 255), 1)
- print(f" ✅ {label} center=({center})")
- # 6. 添加图例
- print(f"\n📋 添加图例...")
- legend_y = 30
- legend_x = 20
- cv2.putText(canvas, "Legend:", (legend_x, legend_y - 5),
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
- legend_y += 5
- for room_type, color in COLOR_MAP.items():
- if room_type in ["door"]:
- continue
- cv2.rectangle(canvas, (legend_x, legend_y),
- (legend_x + 20, legend_y + 15), color, -1)
- cv2.rectangle(canvas, (legend_x, legend_y),
- (legend_x + 20, legend_y + 15), (0, 0, 0), 1)
- cv2.putText(canvas, room_type, (legend_x + 25, legend_y + 12),
- cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 1)
- legend_y += 20
- # 添加 unknown 图例
- cv2.rectangle(canvas, (legend_x, legend_y),
- (legend_x + 20, legend_y + 15), (0, 0, 255), -1)
- cv2.rectangle(canvas, (legend_x, legend_y),
- (legend_x + 20, legend_y + 15), (0, 0, 0), 1)
- cv2.putText(canvas, "door", (legend_x + 25, legend_y + 12),
- cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 1)
- # 6. 保存结果
- cv2.imwrite(output_path, canvas)
- print(f"\n{'=' * 60}")
- print(f"✨ 可视化完成!")
- print(f" 📁 输出路径:{output_path}")
- print(f" 📐 画布尺寸:{width}x{height}")
- print(f" 🏠 Block 数量:{len(blocks)}")
- print(f" 🔗 Connect Area 数量:{len(connect_areas)}")
- print('=' * 60)
- return canvas
- # ================= 4. 批量处理 =================
- def visualize_batch(json_folder, output_folder, rgb_folder=None, vis_door=False):
- """批量可视化文件夹中的所有 JSON"""
- if not os.path.exists(output_folder):
- os.makedirs(output_folder)
- json_files = [f for f in os.listdir(json_folder) if f.endswith('.json')]
- print(f"\n📁 发现 {len(json_files)} 个 JSON 文件\n")
- for json_file in json_files:
- json_path = os.path.join(json_folder, json_file)
- output_path = os.path.join(output_folder, json_file.replace('.json', '.png'))
- rgb_path = None
- if rgb_folder:
- rgb_path = os.path.join(rgb_folder, json_file.replace('.json', '.png'))
- if not os.path.exists(rgb_path):
- rgb_path = None
- try:
- visualize_final_json(json_path, output_path, rgb_path, vis_door)
- except Exception as e:
- print(f"❌ 处理 {json_file} 失败:{e}")
- # ================= 5. 主运行流程 =================
- # ===========================================================================
- # MAIN PIPELINE
- # ===========================================================================
- def run_single_image_pipeline(rgb_path, mask_path, output_json_path=None,
- room_model_path="room_seg.pt",
- furniture_model_path="furniture_detect.onnx",
- merge_threshold=15, angle=10,
- dilation_kernel_size=5, center_threshold=15,
- expand_pixel=2, min_rect_short_side=30):
- """单张图流水线:输入 RGB 与 mask,直接输出最终 JSON。"""
- rgb_img = cv2.imread(rgb_path)
- block_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
- if rgb_img is None:
- raise FileNotFoundError(f"无法读取 RGB 图像:{rgb_path}")
- if block_mask is None:
- raise FileNotFoundError(f"无法读取 mask 图像:{mask_path}")
- if output_json_path is None:
- output_json_path = os.path.splitext(rgb_path)[0] + '.json'
- print("===== Single Image Pipeline =====")
- print(f"RGB : {rgb_path}")
- print(f"Mask: {mask_path}")
- print(f"Out : {output_json_path}")
- clean_rgb = remove_edge_regions_image(rgb_img)
- _, gap_add_mask = extract_gaps_from_mask(block_mask)
- connect_mask = extract_mask_region_from_arrays(clean_rgb, block_mask, gap_add_mask)
- room_model = _load_yolo_model(room_model_path)
- furniture_model = _load_yolo_model(furniture_model_path)
- json_data, _, stats = merge_gap_fillers_from_arrays(
- block_mask,
- connect_mask,
- image_path=rgb_path,
- dilation_kernel_size=dilation_kernel_size,
- center_threshold=center_threshold,
- expand_pixel=expand_pixel,
- min_rect_short_side=min_rect_short_side,
- )
- json_data['source_mask_path'] = mask_path
- json_data['connect_area_stats'] = stats
- print("===== Block Classification =====")
- json_data['block'] = build_block_data(rgb_img, block_mask, model_path=room_model)
- print("===== Furniture Detection =====")
- json_data['furniture'] = detect_furniture_in_image(rgb_img, model_path=furniture_model)
- print("===== Block Refinement =====")
- refine_blocks_in_data(json_data)
- print("===== Segment Merge =====")
- normalize_all(json_data, angle=angle)
- merge_all_blocks(json_data, threshold=merge_threshold)
- os.makedirs(os.path.dirname(os.path.abspath(output_json_path)), exist_ok=True)
- with open(output_json_path, 'w', encoding='utf-8') as f:
- json.dump(json_data, f, indent=2, ensure_ascii=False)
- print("===== Pipeline 完成 =====")
- print(f" - JSON 保存至:{output_json_path}")
- print(f" - Connect Area 数量:{len(json_data.get('connect_area', []))}")
- print(f" - Block 数量:{len(json_data.get('block', []))}")
- print(f" - Furniture 数量:{len(json_data.get('furniture', []))}")
- print(f" - Total Segments:{json_data.get('total_segments', 0)}")
- return json_data
- def parse_args():
- parser = argparse.ArgumentParser(description="Single-image floorplan pipeline")
- parser.add_argument('rgb', help='RGB 图路径')
- parser.add_argument('mask', help='Mask 图路径')
- parser.add_argument('-o', '--output-json', help='输出 JSON 路径,默认与 RGB 同名')
- parser.add_argument('--room-model', default='room_cls.pt', help='房间分类模型路径')
- parser.add_argument('--furniture-model', default='furniture_detect.onnx', help='家具检测模型路径')
- parser.add_argument('--merge-threshold', type=int, default=15, help='跨 block 线段合并阈值')
- parser.add_argument('--angle', type=int, default=10, help='水平/垂直吸附角度阈值')
- parser.add_argument('--center-threshold', type=int, default=15, help='连接区域聚类阈值')
- parser.add_argument('--dilation-kernel-size', type=int, default=5, help='连接区域膨胀核大小')
- parser.add_argument('--expand-pixel', type=int, default=2, help='door 矩形扩展像素')
- parser.add_argument('--min-rect-short-side', type=int, default=30, help='door 矩形短边上限')
- return parser.parse_args()
- if __name__ == "__main__":
- import sys
- if len(sys.argv) < 2:
- print("用法:python pipeline.py <文件夹名称>")
- print("示例:python pipeline.py SG-n6nV8B2oW95")
- sys.exit(1)
- folder_name = sys.argv[1]
- img_folder = "temp_data"
- folder_path = os.path.join(img_folder, folder_name)
- if not os.path.isdir(folder_path):
- print(f"错误:文件夹不存在 {folder_path}")
- sys.exit(1)
- # 获取 RGB 图片(文件名与文件夹同名)
- rgb_imgs = [f for f in os.listdir(folder_path)
- if f.startswith(folder_name) and not f.startswith('initial_mask_') and not f.startswith('refine_mask_')
- and f.endswith(('.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'))]
- if not rgb_imgs:
- print(f"错误:未找到 RGB 图片")
- sys.exit(1)
- rgb_name = rgb_imgs[0]
- rgb_path = os.path.join(folder_path, rgb_name)
- # 查找对应的 refine_mask 文件
- refine_mask_name = f"refine_mask_{rgb_name}"
- refine_mask_path = os.path.join(folder_path, refine_mask_name)
- if not os.path.exists(refine_mask_path):
- print(f"错误:未找到 {refine_mask_name}")
- sys.exit(1)
- # 输出 JSON 路径(保存到子文件夹)
- output_json = os.path.join(folder_path, rgb_name.rsplit('.', 1)[0] + '.json')
- print(f"处理文件夹:{folder_name}")
- print(f" RGB: {rgb_path}")
- print(f" Mask: {refine_mask_path}")
- print(f" JSON: {output_json}")
- try:
- run_single_image_pipeline(
- rgb_path=rgb_path,
- mask_path=refine_mask_path,
- output_json_path=output_json,
- room_model_path="room_cls.pt",
- furniture_model_path="furniture_detect.onnx",
- merge_threshold=15,
- angle=10,
- dilation_kernel_size=5,
- center_threshold=15,
- expand_pixel=2,
- min_rect_short_side=30,
- )
- print(f"✅ 完成 -> {os.path.basename(output_json)}")
- except Exception as e:
- print(f"❌ 错误:{e}")
- sys.exit(1)
|