更新OCR版本,Bata版,还不能上线
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
import logging
|
||||
import math
|
||||
import urllib.request
|
||||
from io import BytesIO
|
||||
|
||||
import cv2
|
||||
import numpy
|
||||
from PIL import Image
|
||||
from PIL.ExifTags import TAGS
|
||||
from paddleclas import PaddleClas
|
||||
from tenacity import retry, stop_after_attempt, wait_random
|
||||
|
||||
@@ -14,20 +17,36 @@ def read(image_path):
|
||||
"""
|
||||
从网络或本地读取图片
|
||||
:param image_path: 网络或本地路径
|
||||
:return: NumPy数组形式的图片
|
||||
:return: NumPy数组形式的图片, EXIF数据
|
||||
"""
|
||||
if image_path.startswith("http"):
|
||||
# 发送HTTP请求并获取图像数据
|
||||
resp = urllib.request.urlopen(image_path, timeout=60)
|
||||
# 将数据读取为字节流
|
||||
image_data = resp.read()
|
||||
# 将字节流转换为NumPy数组
|
||||
image_np = numpy.frombuffer(image_data, numpy.uint8)
|
||||
# 解码NumPy数组为OpenCV图像格式
|
||||
image = cv2.imdecode(image_np, cv2.IMREAD_COLOR)
|
||||
else:
|
||||
image = cv2.imread(image_path)
|
||||
return image
|
||||
with open(image_path, "rb") as f:
|
||||
image_data = f.read()
|
||||
|
||||
# 解析EXIF信息(基于原始字节流)
|
||||
exif_data = {}
|
||||
try:
|
||||
# 用PIL打开原始字节流
|
||||
with Image.open(BytesIO(image_data)) as img:
|
||||
# 获取EXIF字典
|
||||
exif_info = img._getexif()
|
||||
if exif_info:
|
||||
# 将EXIF标签的数字ID转换为可读名称(如36867对应"DateTimeOriginal")
|
||||
for tag_id, value in exif_info.items():
|
||||
tag_name = TAGS.get(tag_id, tag_id)
|
||||
exif_data[tag_name] = value
|
||||
except Exception as e:
|
||||
logging.error("解析EXIF信息失败", exc_info=e)
|
||||
# 将字节流转换为NumPy数组
|
||||
image_np = numpy.frombuffer(image_data, numpy.uint8)
|
||||
# 解码NumPy数组为OpenCV图像格式
|
||||
image = cv2.imdecode(image_np, cv2.IMREAD_COLOR)
|
||||
return image, exif_data
|
||||
|
||||
|
||||
def capture(image, rectangle):
|
||||
@@ -61,7 +80,7 @@ def split(image, ratio=1.414, overlap=0.05, x_compensation=3):
|
||||
"""
|
||||
split_result = []
|
||||
if isinstance(image, str):
|
||||
image = read(image)
|
||||
image, _ = read(image)
|
||||
height, width = image.shape[:2]
|
||||
hw_ratio = height / width
|
||||
wh_ratio = width / height
|
||||
|
||||
19
util/util.py
19
util/util.py
@@ -12,9 +12,10 @@ def get_default_datetime():
|
||||
return datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
|
||||
def get_ocr_layout(ocr, img_path):
|
||||
def get_ocr_layout(ocr, img_path, is_screenshot=False):
|
||||
"""
|
||||
获取ocr识别的结果,转为合适的layout形式
|
||||
:param is_screenshot: 是否是截图
|
||||
:param ocr: ocr模型
|
||||
:param img_path: 图片本地路径
|
||||
:return:
|
||||
@@ -36,18 +37,18 @@ def get_ocr_layout(ocr, img_path):
|
||||
return True
|
||||
|
||||
layout = []
|
||||
ocr_result = ocr.ocr(img_path, cls=False)
|
||||
ocr_result = ocr_result[0]
|
||||
ocr_result = ocr.predict(input=img_path, use_doc_orientation_classify=not is_screenshot, use_doc_unwarping=not is_screenshot)
|
||||
ocr_result = next(ocr_result)
|
||||
if not ocr_result:
|
||||
return layout
|
||||
for segment in ocr_result:
|
||||
box = segment[0]
|
||||
return layout, "0"
|
||||
angle = ocr_result.get("doc_preprocessor_res", {}).get("angle", "0")
|
||||
for i in range(len(ocr_result.get('rec_texts'))):
|
||||
box = ocr_result.get("rec_polys")[i].tolist()
|
||||
box = _get_box(box)
|
||||
if not _normal_box(box):
|
||||
continue
|
||||
text = segment[1][0]
|
||||
layout.append((box, text))
|
||||
return layout
|
||||
layout.append((box, ocr_result.get("rec_texts")[i]))
|
||||
return layout, str(angle)
|
||||
|
||||
|
||||
def delete_temp_file(temp_files):
|
||||
|
||||
Reference in New Issue
Block a user