Files
fcb_photo_review/paddle_detection/detector.py
2024-08-29 10:21:08 +08:00

77 lines
2.6 KiB
Python

import base64
import logging
import tempfile
from collections import defaultdict
import cv2
import numpy as np
import requests
from tenacity import retry, stop_after_attempt, wait_random
from paddle_detection import PADDLE_DET
from paddle_detection.deploy.third_engine.onnx.infer import PredictConfig
from paddle_detection.deploy.third_engine.onnx.preprocess import Compose
from util import image_util, util
def predict_image(infer_config, predictor, img_path):
# load preprocess transforms
transforms = Compose(infer_config.preprocess_infos)
# predict image
inputs = transforms(img_path)
inputs["image"] = np.array(inputs["image"]).astype('float32')
inputs_name = [var.name for var in predictor.get_inputs()]
inputs = {k: inputs[k][None,] for k in inputs_name}
outputs = predictor.run(output_names=None, input_feed=inputs)
bboxes = np.array(outputs[0])
result = defaultdict(list)
for bbox in bboxes:
if bbox[0] > -1 and bbox[1] > infer_config.draw_threshold:
result[bbox[0]].append({"score": bbox[1], "box": bbox[2:]})
return result
def detect_image(img_path):
infer_cfg = "model/object_det_model/infer_cfg.yml"
# load infer config
infer_config = PredictConfig(infer_cfg)
return predict_image(infer_config, PADDLE_DET, img_path)
def get_book_areas(image):
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
cv2.imwrite(temp_file.name, image)
detect_result = detect_image(temp_file.name)
util.delete_temp_file(temp_file.name)
book_areas = detect_result[73]
result = []
for book_area in book_areas:
result.append(image_util.capture(image, book_area["box"]))
return result
@retry(stop=stop_after_attempt(3), wait=wait_random(1, 3), reraise=True,
after=lambda x: logging.warning("获取文档区域失败!"))
def request_book_areas(image):
url = "http://det_api:5000/det/detect_books"
_, encoded_image = cv2.imencode('.jpg', image)
byte_stream = encoded_image.tobytes()
files = {"image": ("image.jpg", byte_stream)}
response = requests.post(url, files=files)
if response.status_code == 200:
img_str_list = response.json()
result = []
for img_str in img_str_list:
img_data = base64.b64decode(img_str)
np_array = np.frombuffer(img_data, np.uint8)
img = cv2.imdecode(np_array, cv2.IMREAD_COLOR)
height, width = img.shape[:2]
if max(height, width) / min(height, width) <= 6.5:
result.append(img) # 过滤异常结果
return result
else:
return []