extract_initial_mask.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. import os
  2. import cv2
  3. import numpy as np
  4. import onnxruntime
  5. from PIL import Image
  6. import psutil
  7. import subprocess
  8. import sys
  9. import shutil
  10. def get_memory_usage():
  11. process = psutil.Process(os.getpid())
  12. return process.memory_info().rss / 1024 / 1024
  13. def red_edge_generate(img_path, save_path):
  14. img = cv2.imread(img_path)
  15. img_2 = np.zeros_like(img)
  16. mask = (img[:, :, 0] == 0) * (img[:, :, 1] == 0) * (img[:, :, 2] == 0)
  17. img_2[~mask] = (255, 255, 255)
  18. edges = cv2.Canny(img_2, 50, 150)
  19. kernel = np.ones((3, 3), np.uint8)
  20. edges = cv2.dilate(edges, kernel, 1)
  21. # vis
  22. mask = edges[:, :, None] / 255.0
  23. masks = np.concatenate([mask, mask, mask], axis=-1)
  24. img1 = (masks * (0.0, 0.0, 255.0)).clip(0, 255)
  25. alpha = 0.9
  26. img = img1 * alpha + img * (1 - masks * alpha)
  27. cv2.imwrite(save_path, img)
  28. def predict_birefnet_onnx(image_path, onnx_session, mask_dir, input_size=(1024, 1024)):
  29. """
  30. 使用 ONNX 模型进行 BiRefNet 推理并进行后处理
  31. 参数:
  32. - image_path: 输入图片路径
  33. - onnx_session: 已初始化的 onnxruntime.InferenceSession 对象
  34. - mask_dir: 结果保存目录
  35. - input_size: 模型要求的输入尺寸
  36. """
  37. # 1. 图像加载与预处理
  38. orig_img = Image.open(image_path).convert("RGB")
  39. w_orig, h_orig = orig_img.size
  40. # 缩放并转为 float32 Numpy 数组
  41. img_resized = orig_img.resize(input_size, resample=Image.BILINEAR)
  42. img_np = np.array(img_resized).astype(np.float32) / 255.0
  43. # 标准化 (使用 float32 避免类型提升)
  44. mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
  45. std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
  46. img_np = (img_np - mean) / std
  47. # HWC -> CHW 并增加 Batch 维度
  48. img_np = img_np.transpose(2, 0, 1)[np.newaxis, :]
  49. img_np = np.ascontiguousarray(img_np)
  50. # 2. ONNX 推理
  51. input_name = onnx_session.get_inputs()[0].name
  52. outputs = onnx_session.run(None, {input_name: img_np})
  53. raw_preds = outputs[-1] # 取最后一个输出
  54. # 3. 后处理
  55. # Sigmoid 激活
  56. pred_mask = 1 / (1 + np.exp(-raw_preds))
  57. pred_mask = pred_mask.squeeze()
  58. # 尺寸还原至原图大小 (Width, Height)
  59. mask_resized = cv2.resize(pred_mask, (w_orig, h_orig), interpolation=cv2.INTER_LINEAR)
  60. # 转为 8bit 灰度图
  61. mask_8bit = (mask_resized * 255).astype(np.uint8)
  62. # 腐蚀处理 (Erosion)
  63. kernel = np.ones((3, 3), np.uint8)
  64. mask_eroded = cv2.erode(mask_8bit, kernel, iterations=1)
  65. # Canny 边缘检测
  66. edges = cv2.Canny(mask_eroded, 150, 200)
  67. # 4. 保存与返回
  68. if not os.path.exists(mask_dir):
  69. os.makedirs(mask_dir)
  70. output_filename = os.path.basename(image_path)
  71. save_path = os.path.join(mask_dir, output_filename)
  72. cv2.imwrite(save_path, mask_eroded)
  73. # 如果需要,可以将 edges 也保存或返回
  74. # cv2.imwrite(os.path.join(mask_dir, "edge_" + output_filename), edges)
  75. return mask_eroded, edges
  76. # --- 使用示例 ---
  77. # if __name__ == "__main__":
  78. # # 初始化 Session (只需初始化一次)
  79. # weights = "/media/gu/d54b9541-2b55-4c75-b059-3006d51983d53/lqc/Wall-mask/BiRefNet/ckpts/own_dataset_finetune/shenguang_v1.onnx"
  80. # # session = onnxruntime.InferenceSession(weights, providers=['CPUExecutionProvider'])
  81. # session = onnxruntime.InferenceSession(weights, providers=['CUDAExecutionProvider'])
  82. #
  83. #
  84. # img_p = '00004-VqCaAuuoeWk.png'
  85. # out_d = './output_results'
  86. # print(f"推理前内存: {get_memory_usage():.2f} MB")
  87. # mask, edge = predict_birefnet_onnx(img_p, session, out_d)
  88. # print(f"推理后内存: {get_memory_usage():.2f} MB")
  89. # print("处理完成,结果已保存。")
  90. if __name__ == "__main__":
  91. import gc
  92. import sys
  93. if len(sys.argv) < 2:
  94. print("用法:python extract_initial_mask.py <图片名称>")
  95. print("示例:python extract_initial_mask.py SG-n6nV8B2oW95")
  96. sys.exit(1)
  97. img_name = sys.argv[1]
  98. weights = "initial_mask.onnx"
  99. img_folder = "temp_data"
  100. # 检查文件是否存在(优先在 img_folder 下查找)
  101. rgb_path = os.path.join(img_folder, img_name)
  102. if not os.path.exists(rgb_path):
  103. # 尝试带扩展名
  104. for ext in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG']:
  105. test_path = os.path.join(img_folder, img_name + ext)
  106. if os.path.exists(test_path):
  107. rgb_path = test_path
  108. img_name = img_name + ext
  109. break
  110. else:
  111. print(f"错误:找不到图片 {img_name}")
  112. sys.exit(1)
  113. # 创建子文件夹:temp_data/name/
  114. base_name = os.path.splitext(img_name)[0]
  115. out_d = os.path.join(img_folder, base_name)
  116. os.makedirs(out_d, exist_ok=True)
  117. # 复制原始 RGB 图片到子文件夹
  118. rgb_path_in_folder = os.path.join(out_d, img_name)
  119. shutil.copy2(rgb_path, rgb_path_in_folder)
  120. rgb_path = rgb_path_in_folder # 更新路径为子文件夹中的路径
  121. output_name = f"initial_mask_{img_name}" # initial_mask_xxx.png
  122. output_path = os.path.join(out_d, output_name)
  123. print(f"处理:{img_name}")
  124. print(f"输出目录:{out_d}")
  125. # 在子进程中推理,进程退出时显存 100% 释放
  126. result = subprocess.run([
  127. sys.executable, "-c", f'''
  128. import sys
  129. import gc
  130. import torch
  131. import onnxruntime
  132. from extract_initial_mask import predict_birefnet_onnx, red_edge_generate
  133. # 1. 给 RGB 图像添加红色轮廓(覆盖原图)
  134. red_edge_generate("{rgb_path}", "{rgb_path}")
  135. # 2. 使用带红边的图生成 mask
  136. session = onnxruntime.InferenceSession("{weights}", providers=[("CUDAExecutionProvider", {{"device_id": 0}})])
  137. temp_mask_dir = "{out_d}/_temp_masks"
  138. import os
  139. os.makedirs(temp_mask_dir, exist_ok=True)
  140. predict_birefnet_onnx("{rgb_path}", session, temp_mask_dir)
  141. # 3. 将生成的 mask 重命名并移动到目标目录,添加前缀
  142. import shutil
  143. temp_mask_path = os.path.join(temp_mask_dir, os.path.basename("{img_name}"))
  144. output_path = "{output_path}"
  145. shutil.move(temp_mask_path, output_path)
  146. # 4. 清理临时目录
  147. try:
  148. os.rmdir(temp_mask_dir)
  149. except:
  150. pass
  151. del session
  152. torch.cuda.empty_cache()
  153. torch.cuda.synchronize()
  154. gc.collect()
  155. gc.collect()
  156. '''
  157. ], capture_output=True, text=True)
  158. if result.returncode != 0:
  159. print(f" 错误:{result.stderr}")
  160. else:
  161. print(f" 完成 -> {output_name}")
  162. # 删除原始 RGB 图片(temp_data 根目录下的)
  163. original_rgb_path = os.path.join(img_folder, img_name)
  164. if os.path.exists(original_rgb_path):
  165. os.remove(original_rgb_path)
  166. print(f"已删除原始图片:{original_rgb_path}")
  167. print(f"✅ 处理完成!结果已保存到:{out_d}")