优化案子处理逻辑
This commit is contained in:
@@ -12,6 +12,44 @@ def get_default_datetime():
|
||||
return datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
|
||||
def ocr_result_to_layout(ocr_result):
|
||||
def _get_box(old_box):
|
||||
new_box = [
|
||||
min(old_box[0][0], old_box[3][0]), # x1
|
||||
min(old_box[0][1], old_box[1][1]), # y1
|
||||
max(old_box[1][0], old_box[2][0]), # x2
|
||||
max(old_box[2][1], old_box[3][1]), # y2
|
||||
]
|
||||
return new_box
|
||||
|
||||
def _normal_box(box_data):
|
||||
# Ensure the height and width of bbox are greater than zero
|
||||
if box_data[3] - box_data[1] < 0 or box_data[2] - box_data[0] < 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
layout = []
|
||||
if not ocr_result:
|
||||
return layout
|
||||
for segment in ocr_result:
|
||||
box = segment[0]
|
||||
box = _get_box(box)
|
||||
if not _normal_box(box):
|
||||
continue
|
||||
text = segment[1][0]
|
||||
layout.append((box, text))
|
||||
return layout
|
||||
|
||||
|
||||
def ocr_result_to_text(ocr_results):
|
||||
text = ''
|
||||
for ocr_result in ocr_results:
|
||||
text += ocr_result[1][0]
|
||||
if len(text) >= 2048:
|
||||
break
|
||||
return text[:2048]
|
||||
|
||||
|
||||
def get_ocr_layout(ocr, img_path):
|
||||
"""
|
||||
获取ocr识别的结果,转为合适的layout形式
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import urllib.request
|
||||
|
||||
import cv2
|
||||
import numpy
|
||||
@@ -12,80 +11,59 @@ from tenacity import retry, stop_after_attempt, wait_random
|
||||
from log import PROJECT_ROOT
|
||||
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_random(1, 3), reraise=True,
|
||||
after=lambda x: logging.warning('获取图片失败!'))
|
||||
def read(image_path):
|
||||
"""
|
||||
从网络或本地读取图片
|
||||
:param image_path: 网络或本地路径
|
||||
:return: NumPy数组形式的图片
|
||||
"""
|
||||
if image_path.startswith('http'):
|
||||
# 发送HTTP请求并获取图像数据
|
||||
resp = urllib.request.urlopen(image_path, timeout=60)
|
||||
# 将数据读取为字节流
|
||||
image_data = resp.read()
|
||||
# 将字节流转换为NumPy数组
|
||||
image_np = numpy.frombuffer(image_data, numpy.uint8)
|
||||
# 解码NumPy数组为OpenCV图像格式
|
||||
image = cv2.imdecode(image_np, cv2.IMREAD_COLOR)
|
||||
else:
|
||||
image = cv2.imread(image_path)
|
||||
return image
|
||||
|
||||
|
||||
def capture(image, rectangle):
|
||||
"""
|
||||
截取图片
|
||||
:param image: 图片NumPy数组
|
||||
:param image: ndarray
|
||||
:param rectangle: 要截取的矩形
|
||||
:return: 截取之后的图片NumPy
|
||||
:return: 截取之后的ndarray图片
|
||||
"""
|
||||
x1, y1, x2, y2 = rectangle
|
||||
height, width = image.shape[:2]
|
||||
if x1 < 0:
|
||||
x1 = 0
|
||||
if y1 < 0:
|
||||
y1 = 0
|
||||
if x2 > width:
|
||||
x2 = width
|
||||
if y2 > height:
|
||||
y2 = height
|
||||
# 确保坐标值在图片范围内
|
||||
x1 = max(0, x1)
|
||||
y1 = max(0, y1)
|
||||
x2 = min(width, x2)
|
||||
y2 = min(height, y2)
|
||||
return image[int(y1):int(y2), int(x1):int(x2)]
|
||||
|
||||
|
||||
def split(image, ratio=1.414, overlap=0.05, x_compensation=3):
|
||||
def split(img_path, ratio=1.414, overlap=0.05, x_compensation=3):
|
||||
"""
|
||||
分割图片
|
||||
:param image:图片,可以是NumPy数组或文件路径
|
||||
:param img_path:图片路径
|
||||
:param ratio: 分割后的比例
|
||||
:param overlap: 图片之间的覆盖比例
|
||||
:param x_compensation: 横向补偿倍率
|
||||
:return: 分割后的图片组(NumPy数组形式)
|
||||
"""
|
||||
split_result = []
|
||||
if isinstance(image, str):
|
||||
image = read(image)
|
||||
image = cv2.imread(img_path)
|
||||
height, width = image.shape[:2]
|
||||
hw_ratio = height / width
|
||||
wh_ratio = width / height
|
||||
|
||||
img_name, img_ext = parse_save_path(img_path)
|
||||
if hw_ratio > ratio: # 纵向过长
|
||||
new_img_height = width * ratio
|
||||
step = width * (ratio - overlap) # 偏移步长
|
||||
for i in range(math.ceil(height / step)):
|
||||
offset = round(step * i)
|
||||
cropped_img = capture(image, [0, offset, width, offset + new_img_height])
|
||||
split_result.append({'img': cropped_img, 'x_offset': 0, 'y_offset': offset})
|
||||
split_path = get_save_path(f'{img_name}.split_{i}.{img_ext}')
|
||||
cv2.imwrite(split_path, cropped_img)
|
||||
split_result.append({'img': split_path, 'x_offset': 0, 'y_offset': offset})
|
||||
elif wh_ratio > ratio: # 横向过长
|
||||
new_img_width = height * ratio
|
||||
step = height * (ratio - overlap * x_compensation) # 一般文字是横向的,所以横向截取时增大重叠部分
|
||||
for i in range(math.ceil(width / step)):
|
||||
offset = round(step * i)
|
||||
cropped_img = capture(image, [offset, 0, offset + new_img_width, width])
|
||||
split_result.append({'img': cropped_img, 'x_offset': offset, 'y_offset': 0})
|
||||
split_path = get_save_path(f'{img_name}.split_{i}.{img_ext}')
|
||||
cv2.imwrite(split_path, cropped_img)
|
||||
split_result.append({'img': split_path, 'x_offset': offset, 'y_offset': 0})
|
||||
else:
|
||||
split_result.append({'img': image, 'x_offset': 0, 'y_offset': 0})
|
||||
split_result.append({'img': img_path, 'x_offset': 0, 'y_offset': 0})
|
||||
return split_result
|
||||
|
||||
|
||||
@@ -108,15 +86,16 @@ def parse_rotation_angles(image):
|
||||
return angles
|
||||
|
||||
|
||||
def rotate(image, angle):
|
||||
def rotate(img_path, angle):
|
||||
"""
|
||||
旋转图片
|
||||
:param image: 图片NumPy数组
|
||||
:param img_path: 图片NumPy数组
|
||||
:param angle: 逆时针旋转角度
|
||||
:return: 旋转后的图片NumPy数组
|
||||
"""
|
||||
if angle == 0:
|
||||
return image
|
||||
return img_path
|
||||
image = cv2.imread(img_path)
|
||||
height, width = image.shape[:2]
|
||||
if angle == 180:
|
||||
new_width = width
|
||||
@@ -132,7 +111,11 @@ def rotate(image, angle):
|
||||
matrix[1, 2] += (new_height - height) / 2
|
||||
# 参数:原始图像 旋转参数 元素图像宽高
|
||||
rotated = cv2.warpAffine(image, matrix, (new_width, new_height))
|
||||
return rotated
|
||||
|
||||
img_name, img_ext = parse_save_path(img_path)
|
||||
rotated_path = get_save_path(f'{img_name}.rotate_{angle}.{img_ext}')
|
||||
cv2.imwrite(rotated_path, rotated)
|
||||
return rotated_path
|
||||
|
||||
|
||||
def invert_rotate_point(point, center, angle):
|
||||
@@ -260,26 +243,38 @@ def parse_img_url(url):
|
||||
:return: 图片名称和图片后缀
|
||||
"""
|
||||
url = url.split('?')[0]
|
||||
return os.path.basename(url).rsplit('.', 1)
|
||||
return os.path.basename(url)
|
||||
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_random(1, 3), reraise=True,
|
||||
after=lambda x: logging.warning('保存图片失败!'))
|
||||
def save_to_local(img_url, save_path=None):
|
||||
def save_to_local(img_url):
|
||||
"""
|
||||
保存图片到本地
|
||||
:param img_url: 图片url
|
||||
:param save_path: 本地保存地址,精确到文件名
|
||||
:return: 本地保存地址
|
||||
"""
|
||||
response = requests.get(img_url)
|
||||
response.raise_for_status() # 检查响应状态码是否正常
|
||||
|
||||
if save_path is None:
|
||||
img_name, img_ext = parse_img_url(img_url)
|
||||
save_path = os.path.join(PROJECT_ROOT, 'tmp_img', img_name + '.' + img_ext)
|
||||
|
||||
save_path = get_save_path(parse_img_url(img_url))
|
||||
with open(save_path, 'wb') as file:
|
||||
file.write(response.content)
|
||||
|
||||
return save_path
|
||||
|
||||
|
||||
def get_img_path(img_full_name):
|
||||
save_path = get_save_path(img_full_name)
|
||||
if os.path.exists(save_path):
|
||||
return save_path
|
||||
return None
|
||||
|
||||
|
||||
def get_save_path(img_full_name):
|
||||
return os.path.join(PROJECT_ROOT, 'tmp_img', img_full_name)
|
||||
|
||||
|
||||
def parse_save_path(img_path):
|
||||
img_full_name = os.path.basename(img_path)
|
||||
img_name, img_ext = img_full_name.rsplit('.', 1)
|
||||
return img_name, img_ext
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import os.path
|
||||
|
||||
import requests
|
||||
from tenacity import retry, stop_after_attempt, wait_random
|
||||
@@ -16,9 +17,10 @@ def ocr(img_path):
|
||||
url = 'http://ocr:5001'
|
||||
response = requests.post(url, {'img_path': img_path})
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
return None
|
||||
ocr_result = response.json()
|
||||
if ocr_result:
|
||||
return ocr_result[0]
|
||||
return None
|
||||
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_random(1, 3), reraise=True,
|
||||
@@ -40,7 +42,7 @@ def ie_settlement(img_path, layout):
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_random(1, 3), reraise=True,
|
||||
after=lambda x: logging.warning('从文本抽取基本医保结算单失败!'))
|
||||
def ie_settlement(text):
|
||||
def ie_settlement_text(text):
|
||||
"""
|
||||
请求基本医保结算单信息抽取接口
|
||||
:param text: 待抽取文本
|
||||
@@ -73,7 +75,7 @@ def ie_discharge(img_path, layout):
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_random(1, 3), reraise=True,
|
||||
after=lambda x: logging.warning('从文本抽取出院记录失败!'))
|
||||
def ie_discharge(text):
|
||||
def ie_discharge_text(text):
|
||||
"""
|
||||
请求出院记录信息抽取接口
|
||||
:param text: 待抽取文本
|
||||
@@ -106,7 +108,7 @@ def ie_cost(img_path, layout):
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_random(1, 3), reraise=True,
|
||||
after=lambda x: logging.warning('从文本抽取费用清单失败!'))
|
||||
def ie_cost(text):
|
||||
def ie_cost_text(text):
|
||||
"""
|
||||
请求费用清单信息抽取接口
|
||||
:param text: 待抽取文本
|
||||
@@ -147,9 +149,22 @@ def det_book(img_path):
|
||||
url = 'http://det_book:5006'
|
||||
response = requests.post(url, {'img_path': img_path})
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
book_path_list = response.json()
|
||||
if len(book_path_list) == 0:
|
||||
return img_path
|
||||
elif len(book_path_list) == 1:
|
||||
return book_path_list[0]
|
||||
else:
|
||||
max_book = img_path
|
||||
max_size = 0
|
||||
for book_path in book_path_list:
|
||||
book_size = os.path.getsize(book_path)
|
||||
if book_size > max_size:
|
||||
max_book = book_path
|
||||
max_size = book_size
|
||||
return max_book
|
||||
else:
|
||||
return [img_path]
|
||||
return img_path
|
||||
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_random(1, 3), reraise=True,
|
||||
|
||||
Reference in New Issue
Block a user