| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209 |
- import os
- import cv2
- import numpy as np
- import onnxruntime
- from PIL import Image
- import psutil
- import subprocess
- import sys
- import shutil
- def get_memory_usage():
- process = psutil.Process(os.getpid())
- return process.memory_info().rss / 1024 / 1024
- def red_edge_generate(img_path, save_path):
- img = cv2.imread(img_path)
- img_2 = np.zeros_like(img)
- mask = (img[:, :, 0] == 0) * (img[:, :, 1] == 0) * (img[:, :, 2] == 0)
- img_2[~mask] = (255, 255, 255)
- edges = cv2.Canny(img_2, 50, 150)
- kernel = np.ones((3, 3), np.uint8)
- edges = cv2.dilate(edges, kernel, 1)
- # vis
- mask = edges[:, :, None] / 255.0
- masks = np.concatenate([mask, mask, mask], axis=-1)
- img1 = (masks * (0.0, 0.0, 255.0)).clip(0, 255)
- alpha = 0.9
- img = img1 * alpha + img * (1 - masks * alpha)
- cv2.imwrite(save_path, img)
- def predict_birefnet_onnx(image_path, onnx_session, mask_dir, input_size=(1024, 1024)):
- """
- 使用 ONNX 模型进行 BiRefNet 推理并进行后处理
- 参数:
- - image_path: 输入图片路径
- - onnx_session: 已初始化的 onnxruntime.InferenceSession 对象
- - mask_dir: 结果保存目录
- - input_size: 模型要求的输入尺寸
- """
- # 1. 图像加载与预处理
- orig_img = Image.open(image_path).convert("RGB")
- w_orig, h_orig = orig_img.size
- # 缩放并转为 float32 Numpy 数组
- img_resized = orig_img.resize(input_size, resample=Image.BILINEAR)
- img_np = np.array(img_resized).astype(np.float32) / 255.0
- # 标准化 (使用 float32 避免类型提升)
- mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
- std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
- img_np = (img_np - mean) / std
- # HWC -> CHW 并增加 Batch 维度
- img_np = img_np.transpose(2, 0, 1)[np.newaxis, :]
- img_np = np.ascontiguousarray(img_np)
- # 2. ONNX 推理
- input_name = onnx_session.get_inputs()[0].name
- outputs = onnx_session.run(None, {input_name: img_np})
- raw_preds = outputs[-1] # 取最后一个输出
- # 3. 后处理
- # Sigmoid 激活
- pred_mask = 1 / (1 + np.exp(-raw_preds))
- pred_mask = pred_mask.squeeze()
- # 尺寸还原至原图大小 (Width, Height)
- mask_resized = cv2.resize(pred_mask, (w_orig, h_orig), interpolation=cv2.INTER_LINEAR)
- # 转为 8bit 灰度图
- mask_8bit = (mask_resized * 255).astype(np.uint8)
- # 腐蚀处理 (Erosion)
- kernel = np.ones((3, 3), np.uint8)
- mask_eroded = cv2.erode(mask_8bit, kernel, iterations=1)
- # Canny 边缘检测
- edges = cv2.Canny(mask_eroded, 150, 200)
- # 4. 保存与返回
- if not os.path.exists(mask_dir):
- os.makedirs(mask_dir)
- output_filename = os.path.basename(image_path)
- save_path = os.path.join(mask_dir, output_filename)
- cv2.imwrite(save_path, mask_eroded)
- # 如果需要,可以将 edges 也保存或返回
- # cv2.imwrite(os.path.join(mask_dir, "edge_" + output_filename), edges)
- return mask_eroded, edges
- # --- 使用示例 ---
- # if __name__ == "__main__":
- # # 初始化 Session (只需初始化一次)
- # weights = "/media/gu/d54b9541-2b55-4c75-b059-3006d51983d53/lqc/Wall-mask/BiRefNet/ckpts/own_dataset_finetune/shenguang_v1.onnx"
- # # session = onnxruntime.InferenceSession(weights, providers=['CPUExecutionProvider'])
- # session = onnxruntime.InferenceSession(weights, providers=['CUDAExecutionProvider'])
- #
- #
- # img_p = '00004-VqCaAuuoeWk.png'
- # out_d = './output_results'
- # print(f"推理前内存: {get_memory_usage():.2f} MB")
- # mask, edge = predict_birefnet_onnx(img_p, session, out_d)
- # print(f"推理后内存: {get_memory_usage():.2f} MB")
- # print("处理完成,结果已保存。")
- if __name__ == "__main__":
- import gc
- import sys
- if len(sys.argv) < 2:
- print("用法:python extract_initial_mask.py <图片名称>")
- print("示例:python extract_initial_mask.py SG-n6nV8B2oW95")
- sys.exit(1)
- img_name = sys.argv[1]
- weights = "initial_mask.onnx"
- img_folder = "temp_data"
- # 检查文件是否存在(优先在 img_folder 下查找)
- rgb_path = os.path.join(img_folder, img_name)
- if not os.path.exists(rgb_path):
- # 尝试带扩展名
- for ext in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG']:
- test_path = os.path.join(img_folder, img_name + ext)
- if os.path.exists(test_path):
- rgb_path = test_path
- img_name = img_name + ext
- break
- else:
- print(f"错误:找不到图片 {img_name}")
- sys.exit(1)
- # 创建子文件夹:temp_data/name/
- base_name = os.path.splitext(img_name)[0]
- out_d = os.path.join(img_folder, base_name)
- os.makedirs(out_d, exist_ok=True)
- # 复制原始 RGB 图片到子文件夹
- rgb_path_in_folder = os.path.join(out_d, img_name)
- shutil.copy2(rgb_path, rgb_path_in_folder)
- rgb_path = rgb_path_in_folder # 更新路径为子文件夹中的路径
- output_name = f"initial_mask_{img_name}" # initial_mask_xxx.png
- output_path = os.path.join(out_d, output_name)
- print(f"处理:{img_name}")
- print(f"输出目录:{out_d}")
- # 在子进程中推理,进程退出时显存 100% 释放
- result = subprocess.run([
- sys.executable, "-c", f'''
- import sys
- import gc
- import torch
- import onnxruntime
- from extract_initial_mask import predict_birefnet_onnx, red_edge_generate
- # 1. 给 RGB 图像添加红色轮廓(覆盖原图)
- red_edge_generate("{rgb_path}", "{rgb_path}")
- # 2. 使用带红边的图生成 mask
- session = onnxruntime.InferenceSession("{weights}", providers=[("CUDAExecutionProvider", {{"device_id": 0}})])
- temp_mask_dir = "{out_d}/_temp_masks"
- import os
- os.makedirs(temp_mask_dir, exist_ok=True)
- predict_birefnet_onnx("{rgb_path}", session, temp_mask_dir)
- # 3. 将生成的 mask 重命名并移动到目标目录,添加前缀
- import shutil
- temp_mask_path = os.path.join(temp_mask_dir, os.path.basename("{img_name}"))
- output_path = "{output_path}"
- shutil.move(temp_mask_path, output_path)
- # 4. 清理临时目录
- try:
- os.rmdir(temp_mask_dir)
- except:
- pass
- del session
- torch.cuda.empty_cache()
- torch.cuda.synchronize()
- gc.collect()
- gc.collect()
- '''
- ], capture_output=True, text=True)
- if result.returncode != 0:
- print(f" 错误:{result.stderr}")
- else:
- print(f" 完成 -> {output_name}")
- # 删除原始 RGB 图片(temp_data 根目录下的)
- original_rgb_path = os.path.join(img_folder, img_name)
- if os.path.exists(original_rgb_path):
- os.remove(original_rgb_path)
- print(f"已删除原始图片:{original_rgb_path}")
- print(f"✅ 处理完成!结果已保存到:{out_d}")
|