inference_refine_mask.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import torch
  2. from diffusers import Flux2KleinPipeline
  3. from diffusers.utils import load_image
  4. from PIL import Image
  5. import os
  6. import sys
  7. import subprocess
  8. import gc
  9. import time
  10. if __name__ == '__main__':
  11. if len(sys.argv) < 2:
  12. print("用法:python inference_refine_mask.py <文件夹名称>")
  13. print("示例:python inference_refine_mask.py SG-n6nV8B2oW95")
  14. sys.exit(1)
  15. folder_name = sys.argv[1]
  16. weights = "./black-forest-labs/FLUX.2-klein-4B"
  17. img_folder = "temp_data"
  18. folder_path = os.path.join(img_folder, folder_name)
  19. if not os.path.isdir(folder_path):
  20. print(f"错误:文件夹不存在 {folder_path}")
  21. sys.exit(1)
  22. # 查找文件夹中的 initial_mask 图片
  23. initial_masks = [f for f in os.listdir(folder_path)
  24. if f.startswith('initial_mask_') and f.endswith(('.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'))]
  25. if not initial_masks:
  26. print(f"错误:未找到 initial_mask 图片")
  27. sys.exit(1)
  28. img_name = initial_masks[0]
  29. img_path = os.path.join(folder_path, img_name)
  30. output_name = "refine_mask_" + img_name[len("initial_mask_"):] # refine_mask_xxx.png
  31. output_path = os.path.join(folder_path, output_name)
  32. print(f"处理文件夹:{folder_name}")
  33. print(f"输入:{img_path}")
  34. print(f"输出:{output_path}")
  35. # 在子进程中推理,进程退出时显存 100% 释放
  36. result = subprocess.run([
  37. sys.executable, "-c", f'''
  38. import sys
  39. import gc
  40. import torch
  41. from diffusers import Flux2KleinPipeline
  42. from diffusers.utils import load_image
  43. from PIL import Image
  44. # 加载模型
  45. pipe = Flux2KleinPipeline.from_pretrained("{weights}", torch_dtype=torch.bfloat16)
  46. pipe = pipe.to("cuda")
  47. generator = torch.Generator(device="cuda").manual_seed(0)
  48. # 读取 initial_mask
  49. image = load_image("{img_path}")
  50. base_width, base_height = image.size
  51. target_width = (base_width // 8) * 8
  52. target_height = (base_height // 8) * 8
  53. prompt = """
  54. (best quality, 4k), architectural floor plan mask, instance segmentation,
  55. do not add extra blocks,
  56. distinct separate white blocks, clear black gaps between rooms,
  57. separated connected components, clean sharp edges, top-down view,
  58. binary mask style, white rooms on black background, no touching blocks,
  59. The image should be positioned exactly as it was in the original image; do not shift it.
  60. logical room separation
  61. """
  62. result_img = pipe(
  63. image=image,
  64. prompt=prompt,
  65. height=target_height,
  66. width=target_width,
  67. guidance_scale=4.0,
  68. num_inference_steps=4,
  69. generator=generator
  70. ).images[0]
  71. # 缩放并保存
  72. scaled_image = result_img.resize((base_width, base_height))
  73. scaled_image.save("{output_path}")
  74. # 清理显存
  75. del pipe
  76. del generator
  77. torch.cuda.empty_cache()
  78. torch.cuda.synchronize()
  79. gc.collect()
  80. gc.collect()
  81. '''
  82. ], capture_output=True, text=True, timeout=300)
  83. if result.returncode != 0:
  84. print(f" 错误:{result.stderr}")
  85. else:
  86. print(f" 完成 -> {output_name}")
  87. print(f"✅ 处理完成!结果已保存到:{output_path}")