| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- 入户门检测 FastAPI 服务
- 提供 HTTP API 接口用于入户门位置检测:
- - POST /detect - 检测入户门位置
- - GET /status/{task_id} - 查询任务状态
- - GET /result/{task_id} - 获取检测结果
- - GET /health - 健康检查
- """
- import os
- import sys
- import json
- import shutil
- import uuid
- import asyncio
- from pathlib import Path
- from typing import Dict, List, Optional, Any
- from datetime import datetime
- from concurrent.futures import ThreadPoolExecutor
- import tempfile
- from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form
- from fastapi.responses import JSONResponse, FileResponse
- from pydantic import BaseModel, Field
- import uvicorn
- # 导入入户门检测器
- from export_entrance_position import EntranceDoorDetector
- # ============================================================================
- # 数据模型
- # ============================================================================
- class DetectRequest(BaseModel):
- """检测请求体"""
- scene_folder: str = Field(..., description="场景文件夹路径")
- model_path: Optional[str] = Field("yoloe-26x-seg.pt", description="YOLOE 模型路径")
- conf: Optional[float] = Field(0.35, description="检测置信度阈值", ge=0, le=1)
- iou: Optional[float] = Field(0.45, description="NMS IoU 阈值", ge=0, le=1)
- voxel_size: Optional[float] = Field(0.03, description="点云体素尺寸")
- imgsz: Optional[List[int]] = Field([1024, 2048], description="YOLOE 输入图像尺寸 [高,宽]")
- vis_ply: Optional[bool] = Field(False, description="是否生成可视化 PLY")
- class DetectResponse(BaseModel):
- """检测响应体"""
- task_id: str
- status: str
- message: str
- result_path: Optional[str] = None
- vis_ply_path: Optional[str] = None
- class StatusResponse(BaseModel):
- """任务状态响应"""
- task_id: str
- status: str # pending, processing, completed, failed
- progress: Optional[float] = None # 0-1
- message: Optional[str] = None
- result: Optional[Dict[str, Any]] = None
- error: Optional[str] = None
- created_at: str
- completed_at: Optional[str] = None
- class HealthResponse(BaseModel):
- """健康检查响应"""
- status: str
- model_loaded: bool
- gpu_available: bool
- timestamp: str
- # ============================================================================
- # 任务管理
- # ============================================================================
- class TaskManager:
- """任务管理器"""
- def __init__(self):
- self.tasks: Dict[str, Dict[str, Any]] = {}
- self.results_dir = Path("server_results")
- self.results_dir.mkdir(exist_ok=True)
- def create_task(self, task_id: str, scene_folder: str) -> Dict[str, Any]:
- """创建新任务"""
- task = {
- "task_id": task_id,
- "status": "pending",
- "progress": 0.0,
- "message": "任务已创建,等待处理",
- "result": None,
- "error": None,
- "created_at": datetime.now().isoformat(),
- "completed_at": None,
- "scene_folder": scene_folder,
- "result_path": None,
- "vis_ply_path": None
- }
- self.tasks[task_id] = task
- return task
- def update_task(self, task_id: str, **kwargs):
- """更新任务状态"""
- if task_id not in self.tasks:
- raise ValueError(f"任务 {task_id} 不存在")
- for key, value in kwargs.items():
- self.tasks[task_id][key] = value
- def get_task(self, task_id: str) -> Optional[Dict[str, Any]]:
- """获取任务信息"""
- return self.tasks.get(task_id)
- def get_task_status(self, task_id: str) -> StatusResponse:
- """获取任务状态"""
- task = self.tasks.get(task_id)
- if not task:
- raise HTTPException(status_code=404, detail=f"任务 {task_id} 不存在")
- return StatusResponse(
- task_id=task["task_id"],
- status=task["status"],
- progress=task["progress"],
- message=task["message"],
- result=task["result"],
- error=task["error"],
- created_at=task["created_at"],
- completed_at=task["completed_at"]
- )
- # ============================================================================
- # FastAPI 应用
- # ============================================================================
- app = FastAPI(
- title="入户门检测服务",
- description="基于 YOLOE 和 RGB-D 点云的入户门位置检测 API",
- version="1.0.0"
- )
- # 任务管理器
- task_manager = TaskManager()
- # 线程池
- executor = ThreadPoolExecutor(max_workers=2)
- # 模型加载状态
- model_loaded = False
- detector_instance = None
- @app.on_event("startup")
- async def startup_event():
- """服务启动时预加载模型"""
- global model_loaded, detector_instance
- model_path = "yoloe-26x-seg.pt"
- if Path(model_path).exists():
- try:
- # 预加载模型(创建一个临时检测器实例)
- print(f"预加载 YOLOE 模型:{model_path}")
- # 注意:这里不实际初始化,因为需要场景文件夹
- model_loaded = True
- print("模型加载成功")
- except Exception as e:
- print(f"模型加载失败:{e}")
- model_loaded = False
- else:
- print(f"模型文件不存在:{model_path}")
- model_loaded = False
- @app.get("/health", response_model=HealthResponse)
- async def health_check():
- """健康检查"""
- import torch
- return HealthResponse(
- status="healthy",
- model_loaded=model_loaded,
- gpu_available=torch.cuda.is_available(),
- timestamp=datetime.now().isoformat()
- )
- @app.post("/detect", response_model=DetectResponse)
- async def detect_entrance_door(
- request: DetectRequest,
- background_tasks: BackgroundTasks
- ):
- """
- 检测入户门位置
- 提交检测任务后,使用 task_id 查询状态和结果
- """
- # 验证场景文件夹
- scene_path = Path(request.scene_folder)
- if not scene_path.exists():
- raise HTTPException(status_code=400, detail=f"场景文件夹不存在:{request.scene_folder}")
- # 创建任务
- task_id = str(uuid.uuid4())
- task_manager.create_task(task_id, request.scene_folder)
- # 异步执行检测
- background_tasks.add_task(
- run_detection,
- task_id=task_id,
- scene_folder=request.scene_folder,
- model_path=request.model_path,
- conf=request.conf,
- iou=request.iou,
- voxel_size=request.voxel_size,
- imgsz=request.imgsz,
- vis_ply=request.vis_ply
- )
- return DetectResponse(
- task_id=task_id,
- status="pending",
- message="任务已提交,正在处理中",
- result_path=None,
- vis_ply_path=None
- )
- @app.post("/detect/upload", response_model=DetectResponse)
- async def detect_entrance_door_upload(
- scene_folder: str = Form(...),
- model_path: Optional[str] = Form("yoloe-26x-seg.pt"),
- conf: Optional[float] = Form(0.35),
- iou: Optional[float] = Form(0.45),
- voxel_size: Optional[float] = Form(0.03),
- imgsz_height: Optional[int] = Form(1024),
- imgsz_width: Optional[int] = Form(2048),
- vis_ply: Optional[bool] = Form(False),
- files: List[UploadFile] = File(...),
- background_tasks: BackgroundTasks = None
- ):
- """
- 上传场景文件并检测入户门位置
- 需要上传:
- - vision.txt 文件
- - pano_img/ 文件夹中的 JPG 图像
- - depth_img/ 文件夹中的 PNG 深度图
- """
- # 创建临时目录
- temp_dir = tempfile.mkdtemp(prefix="entrance_detect_")
- temp_path = Path(temp_dir)
- try:
- # 保存上传的文件
- pano_dir = temp_path / "pano_img"
- depth_dir = temp_path / "depth_img"
- pano_dir.mkdir()
- depth_dir.mkdir()
- for file in files:
- if file.filename.endswith(".jpg"):
- file_path = pano_dir / file.filename
- elif file.filename.endswith(".png"):
- file_path = depth_dir / file.filename
- elif file.filename == "vision.txt":
- file_path = temp_path / "vision.txt"
- else:
- continue
- with open(file_path, "wb") as f:
- content = await file.read()
- f.write(content)
- # 创建任务
- task_id = str(uuid.uuid4())
- task_manager.create_task(task_id, str(temp_path))
- # 异步执行检测
- background_tasks.add_task(
- run_detection,
- task_id=task_id,
- scene_folder=str(temp_path),
- model_path=model_path,
- conf=conf,
- iou=iou,
- voxel_size=voxel_size,
- imgsz=[imgsz_height, imgsz_width],
- vis_ply=vis_ply,
- cleanup_temp=True
- )
- return DetectResponse(
- task_id=task_id,
- status="pending",
- message="文件上传成功,任务已提交",
- result_path=None,
- vis_ply_path=None
- )
- except Exception as e:
- # 清理临时目录
- shutil.rmtree(temp_dir, ignore_errors=True)
- raise HTTPException(status_code=500, detail=f"文件上传失败:{str(e)}")
- @app.get("/status/{task_id}", response_model=StatusResponse)
- async def get_task_status(task_id: str):
- """查询任务状态"""
- return task_manager.get_task_status(task_id)
- @app.get("/result/{task_id}")
- async def get_task_result(task_id: str):
- """获取检测结果"""
- task = task_manager.get_task(task_id)
- if not task:
- raise HTTPException(status_code=404, detail=f"任务 {task_id} 不存在")
- if task["status"] != "completed":
- raise HTTPException(
- status_code=400,
- detail=f"任务尚未完成,当前状态:{task['status']}"
- )
- # 读取结果文件
- if task["result_path"] and Path(task["result_path"]).exists():
- with open(task["result_path"], "r", encoding="utf-8") as f:
- result = json.load(f)
- return JSONResponse(content=result)
- else:
- raise HTTPException(status_code=404, detail="结果文件不存在")
- @app.get("/result/{task_id}/ply")
- async def get_task_ply(task_id: str):
- """获取可视化 PLY 文件"""
- task = task_manager.get_task(task_id)
- if not task:
- raise HTTPException(status_code=404, detail=f"任务 {task_id} 不存在")
- if task["status"] != "completed":
- raise HTTPException(
- status_code=400,
- detail=f"任务尚未完成,当前状态:{task['status']}"
- )
- if task["vis_ply_path"] and Path(task["vis_ply_path"]).exists():
- return FileResponse(
- task["vis_ply_path"],
- media_type="application/octet-stream",
- filename="vis.ply"
- )
- else:
- raise HTTPException(status_code=404, detail="PLY 文件不存在")
- @app.delete("/task/{task_id}")
- async def delete_task(task_id: str):
- """删除任务"""
- if task_id not in task_manager.tasks:
- raise HTTPException(status_code=404, detail=f"任务 {task_id} 不存在")
- task = task_manager.tasks[task_id]
- # 清理结果文件
- if task["result_path"] and Path(task["result_path"]).exists():
- os.remove(task["result_path"])
- if task["vis_ply_path"] and Path(task["vis_ply_path"]).exists():
- os.remove(task["vis_ply_path"])
- # 清理场景文件夹(如果是临时创建的)
- scene_path = Path(task.get("scene_folder", ""))
- if scene_path.exists() and str(scene_path).startswith("/tmp"):
- shutil.rmtree(scene_path, ignore_errors=True)
- del task_manager.tasks[task_id]
- return {"message": f"任务 {task_id} 已删除"}
- # ============================================================================
- # 后台任务
- # ============================================================================
- def run_detection(
- task_id: str,
- scene_folder: str,
- model_path: str,
- conf: float,
- iou: float,
- voxel_size: float,
- imgsz: List[int],
- vis_ply: bool,
- cleanup_temp: bool = False
- ):
- """运行入户门检测(在后台线程中执行)"""
- try:
- # 更新状态为处理中
- task_manager.update_task(
- task_id,
- status="processing",
- progress=0.1,
- message="正在初始化检测器..."
- )
- # 创建检测器
- detector = EntranceDoorDetector(
- scene_folder=scene_folder,
- model_path=model_path,
- conf=conf,
- iou=iou,
- voxel_size=voxel_size,
- imgsz=(imgsz[0], imgsz[1])
- )
- task_manager.update_task(
- task_id,
- progress=0.3,
- message="正在检测门并识别入户门..."
- )
- # 执行检测
- success = detector.detect_and_identify()
- task_manager.update_task(
- task_id,
- progress=0.7,
- message="正在导出结果..."
- )
- if success:
- # 导出 JSON 结果
- output_folder = Path(scene_folder) / "output"
- output_folder.mkdir(exist_ok=True)
- result_path = output_folder / "entrance_position.json"
- detector.export_json(str(result_path))
- vis_ply_path = None
- if vis_ply:
- vis_ply_path = output_folder / "vis.ply"
- detector.export_vis_ply(detector.combined_pc, str(vis_ply_path))
- # 读取结果
- with open(result_path, "r", encoding="utf-8") as f:
- result = json.load(f)
- # 更新任务为完成
- task_manager.update_task(
- task_id,
- status="completed",
- progress=1.0,
- message="检测完成",
- result=result,
- result_path=str(result_path),
- vis_ply_path=str(vis_ply_path) if vis_ply_path else None,
- completed_at=datetime.now().isoformat()
- )
- else:
- task_manager.update_task(
- task_id,
- status="failed",
- progress=0.0,
- message="检测失败:无法确定入户门位置",
- error="无法确定入户门位置",
- completed_at=datetime.now().isoformat()
- )
- # 清理临时目录
- if cleanup_temp:
- shutil.rmtree(scene_folder, ignore_errors=True)
- except Exception as e:
- import traceback
- error_msg = f"{str(e)}\n{traceback.format_exc()}"
- task_manager.update_task(
- task_id,
- status="failed",
- progress=0.0,
- message=f"检测失败:{str(e)}",
- error=error_msg,
- completed_at=datetime.now().isoformat()
- )
- # ============================================================================
- # 主程序
- # ============================================================================
- if __name__ == "__main__":
- import argparse
- parser = argparse.ArgumentParser(description="入户门检测 FastAPI 服务")
- parser.add_argument("--host", type=str, default="0.0.0.0", help="监听地址")
- parser.add_argument("--port", type=int, default=8000, help="监听端口")
- parser.add_argument("--reload", action="store_true", help="启用热重载(开发模式)")
- parser.add_argument("--workers", type=int, default=1, help="工作进程数")
- args = parser.parse_args()
- print(f"启动入户门检测服务...")
- print(f" 监听地址:http://{args.host}:{args.port}")
- print(f" API 文档:http://{args.host}:{args.port}/docs")
- print(f" 健康检查:http://{args.host}:{args.port}/health")
- uvicorn.run(
- "server:app",
- host=args.host,
- port=args.port,
- reload=args.reload,
- workers=args.workers
- )
|