server.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 入户门检测 FastAPI 服务
  5. 提供 HTTP API 接口用于入户门位置检测:
  6. - POST /detect - 检测入户门位置
  7. - GET /status/{task_id} - 查询任务状态
  8. - GET /result/{task_id} - 获取检测结果
  9. - GET /health - 健康检查
  10. """
  11. import os
  12. import sys
  13. import json
  14. import shutil
  15. import uuid
  16. import asyncio
  17. from pathlib import Path
  18. from typing import Dict, List, Optional, Any
  19. from datetime import datetime
  20. from concurrent.futures import ThreadPoolExecutor
  21. import tempfile
  22. from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form
  23. from fastapi.responses import JSONResponse, FileResponse
  24. from pydantic import BaseModel, Field
  25. import uvicorn
  26. # 导入入户门检测器
  27. from export_entrance_position import EntranceDoorDetector
  28. # ============================================================================
  29. # 数据模型
  30. # ============================================================================
  31. class DetectRequest(BaseModel):
  32. """检测请求体"""
  33. scene_folder: str = Field(..., description="场景文件夹路径")
  34. model_path: Optional[str] = Field("yoloe-26x-seg.pt", description="YOLOE 模型路径")
  35. conf: Optional[float] = Field(0.35, description="检测置信度阈值", ge=0, le=1)
  36. iou: Optional[float] = Field(0.45, description="NMS IoU 阈值", ge=0, le=1)
  37. voxel_size: Optional[float] = Field(0.03, description="点云体素尺寸")
  38. imgsz: Optional[List[int]] = Field([1024, 2048], description="YOLOE 输入图像尺寸 [高,宽]")
  39. vis_ply: Optional[bool] = Field(False, description="是否生成可视化 PLY")
  40. class DetectResponse(BaseModel):
  41. """检测响应体"""
  42. task_id: str
  43. status: str
  44. message: str
  45. result_path: Optional[str] = None
  46. vis_ply_path: Optional[str] = None
  47. class StatusResponse(BaseModel):
  48. """任务状态响应"""
  49. task_id: str
  50. status: str # pending, processing, completed, failed
  51. progress: Optional[float] = None # 0-1
  52. message: Optional[str] = None
  53. result: Optional[Dict[str, Any]] = None
  54. error: Optional[str] = None
  55. created_at: str
  56. completed_at: Optional[str] = None
  57. class HealthResponse(BaseModel):
  58. """健康检查响应"""
  59. status: str
  60. model_loaded: bool
  61. gpu_available: bool
  62. timestamp: str
  63. # ============================================================================
  64. # 任务管理
  65. # ============================================================================
  66. class TaskManager:
  67. """任务管理器"""
  68. def __init__(self):
  69. self.tasks: Dict[str, Dict[str, Any]] = {}
  70. self.results_dir = Path("server_results")
  71. self.results_dir.mkdir(exist_ok=True)
  72. def create_task(self, task_id: str, scene_folder: str) -> Dict[str, Any]:
  73. """创建新任务"""
  74. task = {
  75. "task_id": task_id,
  76. "status": "pending",
  77. "progress": 0.0,
  78. "message": "任务已创建,等待处理",
  79. "result": None,
  80. "error": None,
  81. "created_at": datetime.now().isoformat(),
  82. "completed_at": None,
  83. "scene_folder": scene_folder,
  84. "result_path": None,
  85. "vis_ply_path": None
  86. }
  87. self.tasks[task_id] = task
  88. return task
  89. def update_task(self, task_id: str, **kwargs):
  90. """更新任务状态"""
  91. if task_id not in self.tasks:
  92. raise ValueError(f"任务 {task_id} 不存在")
  93. for key, value in kwargs.items():
  94. self.tasks[task_id][key] = value
  95. def get_task(self, task_id: str) -> Optional[Dict[str, Any]]:
  96. """获取任务信息"""
  97. return self.tasks.get(task_id)
  98. def get_task_status(self, task_id: str) -> StatusResponse:
  99. """获取任务状态"""
  100. task = self.tasks.get(task_id)
  101. if not task:
  102. raise HTTPException(status_code=404, detail=f"任务 {task_id} 不存在")
  103. return StatusResponse(
  104. task_id=task["task_id"],
  105. status=task["status"],
  106. progress=task["progress"],
  107. message=task["message"],
  108. result=task["result"],
  109. error=task["error"],
  110. created_at=task["created_at"],
  111. completed_at=task["completed_at"]
  112. )
  113. # ============================================================================
  114. # FastAPI 应用
  115. # ============================================================================
  116. app = FastAPI(
  117. title="入户门检测服务",
  118. description="基于 YOLOE 和 RGB-D 点云的入户门位置检测 API",
  119. version="1.0.0"
  120. )
  121. # 任务管理器
  122. task_manager = TaskManager()
  123. # 线程池
  124. executor = ThreadPoolExecutor(max_workers=2)
  125. # 模型加载状态
  126. model_loaded = False
  127. detector_instance = None
  128. @app.on_event("startup")
  129. async def startup_event():
  130. """服务启动时预加载模型"""
  131. global model_loaded, detector_instance
  132. model_path = "yoloe-26x-seg.pt"
  133. if Path(model_path).exists():
  134. try:
  135. # 预加载模型(创建一个临时检测器实例)
  136. print(f"预加载 YOLOE 模型:{model_path}")
  137. # 注意:这里不实际初始化,因为需要场景文件夹
  138. model_loaded = True
  139. print("模型加载成功")
  140. except Exception as e:
  141. print(f"模型加载失败:{e}")
  142. model_loaded = False
  143. else:
  144. print(f"模型文件不存在:{model_path}")
  145. model_loaded = False
  146. @app.get("/health", response_model=HealthResponse)
  147. async def health_check():
  148. """健康检查"""
  149. import torch
  150. return HealthResponse(
  151. status="healthy",
  152. model_loaded=model_loaded,
  153. gpu_available=torch.cuda.is_available(),
  154. timestamp=datetime.now().isoformat()
  155. )
  156. @app.post("/detect", response_model=DetectResponse)
  157. async def detect_entrance_door(
  158. request: DetectRequest,
  159. background_tasks: BackgroundTasks
  160. ):
  161. """
  162. 检测入户门位置
  163. 提交检测任务后,使用 task_id 查询状态和结果
  164. """
  165. # 验证场景文件夹
  166. scene_path = Path(request.scene_folder)
  167. if not scene_path.exists():
  168. raise HTTPException(status_code=400, detail=f"场景文件夹不存在:{request.scene_folder}")
  169. # 创建任务
  170. task_id = str(uuid.uuid4())
  171. task_manager.create_task(task_id, request.scene_folder)
  172. # 异步执行检测
  173. background_tasks.add_task(
  174. run_detection,
  175. task_id=task_id,
  176. scene_folder=request.scene_folder,
  177. model_path=request.model_path,
  178. conf=request.conf,
  179. iou=request.iou,
  180. voxel_size=request.voxel_size,
  181. imgsz=request.imgsz,
  182. vis_ply=request.vis_ply
  183. )
  184. return DetectResponse(
  185. task_id=task_id,
  186. status="pending",
  187. message="任务已提交,正在处理中",
  188. result_path=None,
  189. vis_ply_path=None
  190. )
  191. @app.post("/detect/upload", response_model=DetectResponse)
  192. async def detect_entrance_door_upload(
  193. scene_folder: str = Form(...),
  194. model_path: Optional[str] = Form("yoloe-26x-seg.pt"),
  195. conf: Optional[float] = Form(0.35),
  196. iou: Optional[float] = Form(0.45),
  197. voxel_size: Optional[float] = Form(0.03),
  198. imgsz_height: Optional[int] = Form(1024),
  199. imgsz_width: Optional[int] = Form(2048),
  200. vis_ply: Optional[bool] = Form(False),
  201. files: List[UploadFile] = File(...),
  202. background_tasks: BackgroundTasks = None
  203. ):
  204. """
  205. 上传场景文件并检测入户门位置
  206. 需要上传:
  207. - vision.txt 文件
  208. - pano_img/ 文件夹中的 JPG 图像
  209. - depth_img/ 文件夹中的 PNG 深度图
  210. """
  211. # 创建临时目录
  212. temp_dir = tempfile.mkdtemp(prefix="entrance_detect_")
  213. temp_path = Path(temp_dir)
  214. try:
  215. # 保存上传的文件
  216. pano_dir = temp_path / "pano_img"
  217. depth_dir = temp_path / "depth_img"
  218. pano_dir.mkdir()
  219. depth_dir.mkdir()
  220. for file in files:
  221. if file.filename.endswith(".jpg"):
  222. file_path = pano_dir / file.filename
  223. elif file.filename.endswith(".png"):
  224. file_path = depth_dir / file.filename
  225. elif file.filename == "vision.txt":
  226. file_path = temp_path / "vision.txt"
  227. else:
  228. continue
  229. with open(file_path, "wb") as f:
  230. content = await file.read()
  231. f.write(content)
  232. # 创建任务
  233. task_id = str(uuid.uuid4())
  234. task_manager.create_task(task_id, str(temp_path))
  235. # 异步执行检测
  236. background_tasks.add_task(
  237. run_detection,
  238. task_id=task_id,
  239. scene_folder=str(temp_path),
  240. model_path=model_path,
  241. conf=conf,
  242. iou=iou,
  243. voxel_size=voxel_size,
  244. imgsz=[imgsz_height, imgsz_width],
  245. vis_ply=vis_ply,
  246. cleanup_temp=True
  247. )
  248. return DetectResponse(
  249. task_id=task_id,
  250. status="pending",
  251. message="文件上传成功,任务已提交",
  252. result_path=None,
  253. vis_ply_path=None
  254. )
  255. except Exception as e:
  256. # 清理临时目录
  257. shutil.rmtree(temp_dir, ignore_errors=True)
  258. raise HTTPException(status_code=500, detail=f"文件上传失败:{str(e)}")
  259. @app.get("/status/{task_id}", response_model=StatusResponse)
  260. async def get_task_status(task_id: str):
  261. """查询任务状态"""
  262. return task_manager.get_task_status(task_id)
  263. @app.get("/result/{task_id}")
  264. async def get_task_result(task_id: str):
  265. """获取检测结果"""
  266. task = task_manager.get_task(task_id)
  267. if not task:
  268. raise HTTPException(status_code=404, detail=f"任务 {task_id} 不存在")
  269. if task["status"] != "completed":
  270. raise HTTPException(
  271. status_code=400,
  272. detail=f"任务尚未完成,当前状态:{task['status']}"
  273. )
  274. # 读取结果文件
  275. if task["result_path"] and Path(task["result_path"]).exists():
  276. with open(task["result_path"], "r", encoding="utf-8") as f:
  277. result = json.load(f)
  278. return JSONResponse(content=result)
  279. else:
  280. raise HTTPException(status_code=404, detail="结果文件不存在")
  281. @app.get("/result/{task_id}/ply")
  282. async def get_task_ply(task_id: str):
  283. """获取可视化 PLY 文件"""
  284. task = task_manager.get_task(task_id)
  285. if not task:
  286. raise HTTPException(status_code=404, detail=f"任务 {task_id} 不存在")
  287. if task["status"] != "completed":
  288. raise HTTPException(
  289. status_code=400,
  290. detail=f"任务尚未完成,当前状态:{task['status']}"
  291. )
  292. if task["vis_ply_path"] and Path(task["vis_ply_path"]).exists():
  293. return FileResponse(
  294. task["vis_ply_path"],
  295. media_type="application/octet-stream",
  296. filename="vis.ply"
  297. )
  298. else:
  299. raise HTTPException(status_code=404, detail="PLY 文件不存在")
  300. @app.delete("/task/{task_id}")
  301. async def delete_task(task_id: str):
  302. """删除任务"""
  303. if task_id not in task_manager.tasks:
  304. raise HTTPException(status_code=404, detail=f"任务 {task_id} 不存在")
  305. task = task_manager.tasks[task_id]
  306. # 清理结果文件
  307. if task["result_path"] and Path(task["result_path"]).exists():
  308. os.remove(task["result_path"])
  309. if task["vis_ply_path"] and Path(task["vis_ply_path"]).exists():
  310. os.remove(task["vis_ply_path"])
  311. # 清理场景文件夹(如果是临时创建的)
  312. scene_path = Path(task.get("scene_folder", ""))
  313. if scene_path.exists() and str(scene_path).startswith("/tmp"):
  314. shutil.rmtree(scene_path, ignore_errors=True)
  315. del task_manager.tasks[task_id]
  316. return {"message": f"任务 {task_id} 已删除"}
  317. # ============================================================================
  318. # 后台任务
  319. # ============================================================================
  320. def run_detection(
  321. task_id: str,
  322. scene_folder: str,
  323. model_path: str,
  324. conf: float,
  325. iou: float,
  326. voxel_size: float,
  327. imgsz: List[int],
  328. vis_ply: bool,
  329. cleanup_temp: bool = False
  330. ):
  331. """运行入户门检测(在后台线程中执行)"""
  332. try:
  333. # 更新状态为处理中
  334. task_manager.update_task(
  335. task_id,
  336. status="processing",
  337. progress=0.1,
  338. message="正在初始化检测器..."
  339. )
  340. # 创建检测器
  341. detector = EntranceDoorDetector(
  342. scene_folder=scene_folder,
  343. model_path=model_path,
  344. conf=conf,
  345. iou=iou,
  346. voxel_size=voxel_size,
  347. imgsz=(imgsz[0], imgsz[1])
  348. )
  349. task_manager.update_task(
  350. task_id,
  351. progress=0.3,
  352. message="正在检测门并识别入户门..."
  353. )
  354. # 执行检测
  355. success = detector.detect_and_identify()
  356. task_manager.update_task(
  357. task_id,
  358. progress=0.7,
  359. message="正在导出结果..."
  360. )
  361. if success:
  362. # 导出 JSON 结果
  363. output_folder = Path(scene_folder) / "output"
  364. output_folder.mkdir(exist_ok=True)
  365. result_path = output_folder / "entrance_position.json"
  366. detector.export_json(str(result_path))
  367. vis_ply_path = None
  368. if vis_ply:
  369. vis_ply_path = output_folder / "vis.ply"
  370. detector.export_vis_ply(detector.combined_pc, str(vis_ply_path))
  371. # 读取结果
  372. with open(result_path, "r", encoding="utf-8") as f:
  373. result = json.load(f)
  374. # 更新任务为完成
  375. task_manager.update_task(
  376. task_id,
  377. status="completed",
  378. progress=1.0,
  379. message="检测完成",
  380. result=result,
  381. result_path=str(result_path),
  382. vis_ply_path=str(vis_ply_path) if vis_ply_path else None,
  383. completed_at=datetime.now().isoformat()
  384. )
  385. else:
  386. task_manager.update_task(
  387. task_id,
  388. status="failed",
  389. progress=0.0,
  390. message="检测失败:无法确定入户门位置",
  391. error="无法确定入户门位置",
  392. completed_at=datetime.now().isoformat()
  393. )
  394. # 清理临时目录
  395. if cleanup_temp:
  396. shutil.rmtree(scene_folder, ignore_errors=True)
  397. except Exception as e:
  398. import traceback
  399. error_msg = f"{str(e)}\n{traceback.format_exc()}"
  400. task_manager.update_task(
  401. task_id,
  402. status="failed",
  403. progress=0.0,
  404. message=f"检测失败:{str(e)}",
  405. error=error_msg,
  406. completed_at=datetime.now().isoformat()
  407. )
  408. # ============================================================================
  409. # 主程序
  410. # ============================================================================
  411. if __name__ == "__main__":
  412. import argparse
  413. parser = argparse.ArgumentParser(description="入户门检测 FastAPI 服务")
  414. parser.add_argument("--host", type=str, default="0.0.0.0", help="监听地址")
  415. parser.add_argument("--port", type=int, default=8000, help="监听端口")
  416. parser.add_argument("--reload", action="store_true", help="启用热重载(开发模式)")
  417. parser.add_argument("--workers", type=int, default=1, help="工作进程数")
  418. args = parser.parse_args()
  419. print(f"启动入户门检测服务...")
  420. print(f" 监听地址:http://{args.host}:{args.port}")
  421. print(f" API 文档:http://{args.host}:{args.port}/docs")
  422. print(f" 健康检查:http://{args.host}:{args.port}/health")
  423. uvicorn.run(
  424. "server:app",
  425. host=args.host,
  426. port=args.port,
  427. reload=args.reload,
  428. workers=args.workers
  429. )