demo.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. """
  2. YOLOE 门目标检测脚本
  3. 基于 yoloe-26x-seg.pt 模型检测全景图中的门/门口目标
  4. """
  5. import os
  6. import argparse
  7. from pathlib import Path
  8. from tqdm import tqdm
  9. from ultralytics import YOLOE
  10. class DoorDetector:
  11. """门目标检测器,支持配置和批量处理"""
  12. # 门相关类别(精简版,覆盖主要语义)
  13. DOOR_CLASSES = [
  14. "door", # 门(通用)
  15. "doorway", # 门口
  16. "entrance", # 入口
  17. "door frame", # 门框
  18. "open door", # 打开的门
  19. "closed door", # 关闭的门
  20. ]
  21. def __init__(
  22. self,
  23. model_path: str = "yoloe-26x-seg.pt",
  24. classes: list = None,
  25. conf: float = 0.35,
  26. iou: float = 0.45,
  27. max_det: int = 50,
  28. ):
  29. """
  30. 初始化检测器
  31. Args:
  32. model_path: YOLOE 模型路径
  33. classes: 检测类别列表
  34. conf: 置信度阈值
  35. iou: NMS IoU 阈值
  36. max_det: 最大检测数量
  37. """
  38. self.classes = classes or self.DOOR_CLASSES
  39. self.conf = conf
  40. self.iou = iou
  41. self.max_det = max_det
  42. print(f"加载模型:{model_path}")
  43. self.model = YOLOE(model_path)
  44. self.model.set_classes(self.classes)
  45. print(f"检测类别 ({len(self.classes)}): {self.classes}")
  46. def detect(self, img_path: str, save_path: str = None) -> dict:
  47. """
  48. 检测单张图像
  49. Args:
  50. img_path: 输入图像路径
  51. save_path: 结果保存路径(可选)
  52. Returns:
  53. 检测结果字典 {boxes, masks, scores, class_names}
  54. """
  55. if not os.path.exists(img_path):
  56. raise FileNotFoundError(f"图像不存在:{img_path}")
  57. results = self.model.predict(
  58. img_path,
  59. imgsz=(1024, 2048),
  60. conf=self.conf,
  61. iou=self.iou,
  62. max_det=self.max_det,
  63. augment=True,
  64. retina_masks=True,
  65. half=False,
  66. verbose=False,
  67. )
  68. result = results[0]
  69. # 提取检测结果
  70. detection_info = {
  71. "boxes": [],
  72. "scores": [],
  73. "class_names": [],
  74. "count": 0,
  75. }
  76. if result.boxes is not None:
  77. detection_info["boxes"] = result.boxes.xyxy.cpu().numpy().tolist()
  78. detection_info["scores"] = result.boxes.conf.cpu().numpy().tolist()
  79. detection_info["class_names"] = [
  80. self.model.names[int(cls)]
  81. for cls in result.boxes.cls.cpu().numpy()
  82. ]
  83. detection_info["count"] = len(detection_info["boxes"])
  84. # 保存可视化结果
  85. if save_path:
  86. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  87. result.save(save_path)
  88. detection_info["save_path"] = save_path
  89. return detection_info
  90. def detect_batch(
  91. self,
  92. input_folder: str,
  93. output_folder: str,
  94. extensions: tuple = (".jpg", ".jpeg", ".png", ".bmp")
  95. ) -> list:
  96. """
  97. 批量检测文件夹中的所有图像
  98. Args:
  99. input_folder: 输入文件夹路径
  100. output_folder: 输出文件夹路径
  101. extensions: 支持的图像扩展名
  102. Returns:
  103. 所有检测结果的列表
  104. """
  105. input_path = Path(input_folder)
  106. output_path = Path(output_folder)
  107. output_path.mkdir(parents=True, exist_ok=True)
  108. # 获取所有图像文件
  109. img_files = [
  110. f for f in input_path.iterdir()
  111. if f.is_file() and f.suffix.lower() in extensions
  112. ]
  113. if not img_files:
  114. print(f"警告:在 '{input_folder}' 中未找到支持的图像文件")
  115. return []
  116. print(f"找到 {len(img_files)} 张图像,开始检测...")
  117. all_results = []
  118. for img_file in tqdm(img_files, desc="检测进度"):
  119. try:
  120. save_path = output_path / img_file.name
  121. result = self.detect(str(img_file), str(save_path))
  122. all_results.append({
  123. "image": img_file.name,
  124. **result
  125. })
  126. # 打印简要结果
  127. if result["count"] > 0:
  128. tqdm.write(
  129. f" {img_file.name}: {result['count']} 个目标 "
  130. f"(classes: {', '.join(set(result['class_names']))})"
  131. )
  132. else:
  133. tqdm.write(f" {img_file.name}: 未检测到目标")
  134. except Exception as e:
  135. print(f"❌ 处理 {img_file.name} 时出错:{e}")
  136. all_results.append({
  137. "image": img_file.name,
  138. "error": str(e),
  139. "count": 0
  140. })
  141. # 汇总统计
  142. total_detections = sum(r["count"] for r in all_results)
  143. images_with_detections = sum(1 for r in all_results if r["count"] > 0)
  144. print(f"\n{'='*50}")
  145. print(f"检测完成!")
  146. print(f" 总图像数:{len(img_files)}")
  147. print(f" 检测到目标的图像:{images_with_detections}")
  148. print(f" 总检测数量:{total_detections}")
  149. print(f" 结果保存至:{output_folder}")
  150. print(f"{'='*50}")
  151. return all_results
  152. def main():
  153. parser = argparse.ArgumentParser(
  154. description="YOLOE 门目标检测工具",
  155. formatter_class=argparse.RawDescriptionHelpFormatter,
  156. )
  157. parser.add_argument(
  158. "--input", "-i",
  159. type=str,
  160. default="image",
  161. help="输入图像文件夹 (默认:image)"
  162. )
  163. parser.add_argument(
  164. "--output", "-o",
  165. type=str,
  166. default="result",
  167. help="输出结果文件夹 (默认:result)"
  168. )
  169. parser.add_argument(
  170. "--model", "-m",
  171. type=str,
  172. default="yoloe-26x-seg.pt",
  173. help="YOLOE 模型路径 (默认:yoloe-26x-seg.pt)"
  174. )
  175. parser.add_argument(
  176. "--conf",
  177. type=float,
  178. default=0.35,
  179. help="置信度阈值 (默认:0.35)"
  180. )
  181. parser.add_argument(
  182. "--iou",
  183. type=float,
  184. default=0.45,
  185. help="NMS IoU 阈值 (默认:0.45)"
  186. )
  187. parser.add_argument(
  188. "--max-det",
  189. type=int,
  190. default=50,
  191. help="最大检测数量 (默认:50)"
  192. )
  193. parser.add_argument(
  194. "--classes",
  195. type=str,
  196. nargs="+",
  197. default=None,
  198. help="自定义检测类别列表 (可选)"
  199. )
  200. args = parser.parse_args()
  201. # 创建检测器
  202. detector = DoorDetector(
  203. model_path=args.model,
  204. classes=args.classes,
  205. conf=args.conf,
  206. iou=args.iou,
  207. max_det=args.max_det,
  208. )
  209. # 执行批量检测
  210. results = detector.detect_batch(args.input, args.output)
  211. return results
  212. if __name__ == "__main__":
  213. main()