| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105 |
- import torch
- from diffusers import Flux2KleinPipeline
- from diffusers.utils import load_image
- from PIL import Image
- import os
- import sys
- import subprocess
- import gc
- import time
- if __name__ == '__main__':
- if len(sys.argv) < 2:
- print("用法:python inference_refine_mask.py <文件夹名称>")
- print("示例:python inference_refine_mask.py SG-n6nV8B2oW95")
- sys.exit(1)
- folder_name = sys.argv[1]
- weights = "./black-forest-labs/FLUX.2-klein-4B"
- img_folder = "temp_data"
- folder_path = os.path.join(img_folder, folder_name)
- if not os.path.isdir(folder_path):
- print(f"错误:文件夹不存在 {folder_path}")
- sys.exit(1)
- # 查找文件夹中的 initial_mask 图片
- initial_masks = [f for f in os.listdir(folder_path)
- if f.startswith('initial_mask_') and f.endswith(('.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'))]
- if not initial_masks:
- print(f"错误:未找到 initial_mask 图片")
- sys.exit(1)
- img_name = initial_masks[0]
- img_path = os.path.join(folder_path, img_name)
- output_name = "refine_mask_" + img_name[len("initial_mask_"):] # refine_mask_xxx.png
- output_path = os.path.join(folder_path, output_name)
- print(f"处理文件夹:{folder_name}")
- print(f"输入:{img_path}")
- print(f"输出:{output_path}")
- # 在子进程中推理,进程退出时显存 100% 释放
- result = subprocess.run([
- sys.executable, "-c", f'''
- import sys
- import gc
- import torch
- from diffusers import Flux2KleinPipeline
- from diffusers.utils import load_image
- from PIL import Image
- # 加载模型
- pipe = Flux2KleinPipeline.from_pretrained("{weights}", torch_dtype=torch.bfloat16)
- pipe = pipe.to("cuda")
- generator = torch.Generator(device="cuda").manual_seed(0)
- # 读取 initial_mask
- image = load_image("{img_path}")
- base_width, base_height = image.size
- target_width = (base_width // 8) * 8
- target_height = (base_height // 8) * 8
- prompt = """
- (best quality, 4k), architectural floor plan mask, instance segmentation,
- do not add extra blocks,
- distinct separate white blocks, clear black gaps between rooms,
- separated connected components, clean sharp edges, top-down view,
- binary mask style, white rooms on black background, no touching blocks,
- The image should be positioned exactly as it was in the original image; do not shift it.
- logical room separation
- """
- result_img = pipe(
- image=image,
- prompt=prompt,
- height=target_height,
- width=target_width,
- guidance_scale=4.0,
- num_inference_steps=4,
- generator=generator
- ).images[0]
- # 缩放并保存
- scaled_image = result_img.resize((base_width, base_height))
- scaled_image.save("{output_path}")
- # 清理显存
- del pipe
- del generator
- torch.cuda.empty_cache()
- torch.cuda.synchronize()
- gc.collect()
- gc.collect()
- '''
- ], capture_output=True, text=True, timeout=300)
- if result.returncode != 0:
- print(f" 错误:{result.stderr}")
- else:
- print(f" 完成 -> {output_name}")
- print(f"✅ 处理完成!结果已保存到:{output_path}")
|