Files
fcb_photo_review/photo_mask.py
2024-07-10 09:24:36 +08:00

298 lines
11 KiB
Python

import logging.config
import math
import os
import tempfile
import traceback
import urllib.request
from time import sleep
import cv2
import numpy as np
import paddleclas
from paddleocr import PaddleOCR
from sqlalchemy import update
from auto_email.error_email import send_an_error_email
from config.log import LOGGING_CONFIG
from config.mysql import MysqlSession
from config.photo_review import PHHD_BATCH_SIZE, SLEEP_MINUTES
from config.ucloud import BUCKET
from models import ZxPhrec, ZxPhhd
from ucloud import ucloud
OCR = PaddleOCR(use_angle_cls=False, lang="ch", show_log=False)
def open_image(img_path):
if img_path.startswith("http"):
# 发送HTTP请求并获取图像数据
resp = urllib.request.urlopen(img_path)
# 将数据读取为字节流
image_data = resp.read()
# 将字节流转换为NumPy数组
image_np = np.frombuffer(image_data, np.uint8)
# 解码NumPy数组为OpenCV图像格式
image = cv2.imdecode(image_np, cv2.IMREAD_COLOR)
else:
image = cv2.imread(img_path)
return image
def split_image(img, max_ratio=2.82, best_ration=1.41, overlap=0.05):
split_result = []
# 获取图片的宽度和高度
height, width = img.shape[:2]
# 计算宽高比
ratio = max(width, height) / min(width, height)
# 检查是否需要裁剪
if ratio > max_ratio:
# 确定裁剪的尺寸,保持长宽比,以较短边为基准
new_ratio = best_ration - overlap
if width < height:
# 高度是较长边
cropped_width = width * best_ration
for i in range(math.ceil(height / (width * new_ratio))):
offset = round(width * new_ratio * i)
# 参数形式为[y1:y2, x1:x2]
cropped_img = img[offset:round(offset + cropped_width), 0:width]
split_result.append({"img": cropped_img, "x_offset": 0, "y_offset": offset})
# 最后一次裁剪时不足的部分填充黑色
last_img = split_result[-1]["img"]
split_result[-1]["img"] = cv2.copyMakeBorder(last_img, 0, round(cropped_width - last_img.shape[0]), 0, 0,
cv2.BORDER_CONSTANT, value=(0, 0, 0))
else:
# 宽度是较长边
cropped_height = height * best_ration
for i in range(math.ceil(width / (height * new_ratio))):
offset = round(height * new_ratio * i)
cropped_img = img[0:height, offset:round(offset + cropped_height)]
split_result.append({"img": cropped_img, "x_offset": offset, "y_offset": 0})
# 最后一次裁剪时不足的部分填充黑色
last_img = split_result[-1]["img"]
split_result[-1]["img"] = cv2.copyMakeBorder(last_img, 0, 0, 0, round(cropped_height - last_img.shape[1]),
cv2.BORDER_CONSTANT, value=(0, 0, 0))
else:
split_result.append({"img": img, "x_offset": 0, "y_offset": 0})
return split_result
# 获取图片旋转角度
def get_image_rotation_angles(img):
angles = ['0', '90']
model = paddleclas.PaddleClas(model_name="text_image_orientation")
result = model.predict(input_data=img)
try:
result = next(result)[0]
if result["scores"][0] < 0.5:
return angles
angles = result["label_names"]
except Exception as e:
logging.error("获取图片旋转角度失败", exc_info=e)
return angles
def rotate_image(img, angle):
if angle == 0:
return img
height, width, _ = img.shape
if angle == 180:
new_width = width
new_height = height
else:
new_width = height
new_height = width
# 绕图像的中心旋转
# 参数:旋转中心 旋转度数 scale
matrix = cv2.getRotationMatrix2D((width / 2, height / 2), angle, 1)
# 旋转后平移
matrix[0, 2] += (new_width - width) / 2
matrix[1, 2] += (new_height - height) / 2
# 参数:原始图像 旋转参数 元素图像宽高
rotated = cv2.warpAffine(img, matrix, (new_width, new_height))
return rotated
def rotate_rectangle(rectangle, center, angle):
def rotate_point(pt, angle, center):
matrix = cv2.getRotationMatrix2D(center, angle, 1)
# 旋转后平移
matrix[0, 2] += center[1] - center[0]
matrix[1, 2] += center[0] - center[1]
reverse_matrix = cv2.invertAffineTransform(matrix)
pt = np.array([[pt[0]], [pt[1]], [1]])
return np.dot(reverse_matrix, pt)
x1, y1, x2, y2 = rectangle
# 计算矩形的四个顶点
top_left = (x1, y1)
bot_left = (x1, y2)
top_right = (x2, y1)
bot_right = (x2, y2)
# 旋转矩形的四个顶点
rot_top_left = rotate_point(top_left, angle, center).astype(int)
rot_bot_left = rotate_point(bot_left, angle, center).astype(int)
rot_bot_right = rotate_point(bot_right, angle, center).astype(int)
rot_top_right = rotate_point(top_right, angle, center).astype(int)
# 找出旋转后矩形的新左上角和右下角坐标
new_top_left = (min(rot_top_left[0], rot_bot_left[0], rot_bot_right[0], rot_top_right[0]),
min(rot_top_left[1], rot_bot_left[1], rot_bot_right[1], rot_top_right[1]))
new_bot_right = (max(rot_top_left[0], rot_bot_left[0], rot_bot_right[0], rot_top_right[0]),
max(rot_top_left[1], rot_bot_left[1], rot_bot_right[1], rot_top_right[1]))
return [new_top_left[0], new_top_left[1], new_bot_right[0], new_bot_right[1]]
def get_ocr_layout(ocr, img_path):
def _get_box(old_box):
new_box = [
min(old_box[0][0], old_box[3][0]), # x1
min(old_box[0][1], old_box[1][1]), # y1
max(old_box[1][0], old_box[2][0]), # x2
max(old_box[2][1], old_box[3][1]), # y2
]
return new_box
def _normal_box(box_data):
# Ensure the height and width of bbox are greater than zero
if box_data[3] - box_data[1] < 0 or box_data[2] - box_data[0] < 0:
return False
return True
layout = []
ocr_result = ocr.ocr(img_path, cls=False)
ocr_result = ocr_result[0]
if not ocr_result:
return layout
for segment in ocr_result:
box = segment[0]
box = _get_box(box)
if not _normal_box(box):
continue
text = segment[1][0]
layout.append((box, text))
return layout
def find_box_of_content(content, layout):
full_box = layout[0]
box_len = full_box[2] - full_box[0]
text = layout[1]
text_len = len(text)
char_len = box_len / text_len
index = text.index(content)
return (
full_box[0] + index * char_len,
full_box[1],
full_box[0] + (index + len(content)) * char_len,
full_box[3],
)
def mask_image(image, content):
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
cv2.imwrite(temp_file.name, image)
result = []
try:
layouts = get_ocr_layout(OCR, temp_file.name)
if not layouts:
# 无识别结果
return result
else:
# 涂抹
for layout in layouts:
if content in layout[1]:
result.append(find_box_of_content(content, layout))
return result
except Exception as e:
logging.error("涂抹时出错", exc_info=e)
finally:
try:
os.remove(temp_file.name)
except Exception as e:
logging.info(f"删除临时文件 {temp_file.name} 时出错", exc_info=e)
def photo_mask(pk_phhd, content):
session = MysqlSession()
phrecs = session.query(ZxPhrec.pk_phrec, ZxPhrec.pk_phhd, ZxPhrec.cfjaddress) \
.filter(ZxPhrec.pk_phhd == pk_phhd) \
.filter(ZxPhrec.cRectype.in_(["3", "4"])) \
.all()
session.close()
for phrec in phrecs:
is_copy_success = ucloud.copy_file(BUCKET, phrec.cfjaddress, "drg2015", phrec.cfjaddress)
if not is_copy_success:
continue
img_url = ucloud.get_private_url(phrec.cfjaddress)
if not img_url:
continue
# 打开图片
image = open_image(img_url)
split_result = split_image(image)
for img in split_result:
angles = get_image_rotation_angles(img["img"])
rotated_img = rotate_image(img["img"], int(angles[0]))
results = mask_image(rotated_img, content)
for result in results:
height, width = img["img"].shape[:2]
center = (width / 2, height / 2)
result = rotate_rectangle(result, center, int(angles[0]))
result = (
result[0] + img["x_offset"],
result[1] + img["y_offset"],
result[2] + img["x_offset"],
result[3] + img["y_offset"],
)
cv2.rectangle(image, (int(result[0]), int(result[1])), (int(result[2]), int(result[3])),
(255, 255, 255), -1, 0)
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
cv2.imwrite(temp_file.name, image)
cv2.imwrite(f"./mask_test2/{phrec.cfjaddress}.jpg", image)
for i in range(3):
is_upload_success = ucloud.upload_file(phrec.cfjaddress, temp_file.name)
if is_upload_success:
break
if __name__ == '__main__':
logging.config.dictConfig(LOGGING_CONFIG)
try:
while 1:
session = MysqlSession()
phhds = session.query(ZxPhhd.pk_phhd, ZxPhhd.cXm).filter(
ZxPhhd.paint_flag == "1"
).limit(PHHD_BATCH_SIZE).all()
# 将状态改为正在涂抹中
pk_phhd_values = [phhd.pk_phhd for phhd in phhds]
update_flag = (update(ZxPhhd).where(ZxPhhd.pk_phhd.in_(pk_phhd_values)).values(paint_flag="2"))
session.execute(update_flag)
session.commit()
session.close()
if phhds:
for phhd in phhds:
pk_phhd = phhd.pk_phhd
logging.info(f"开始涂抹:{pk_phhd}")
photo_mask(pk_phhd, phhd.cXm)
# 识别完成更新标识
session = MysqlSession()
update_flag = (update(ZxPhhd).where(ZxPhhd.pk_phhd == pk_phhd).values(paint_flag="8"))
session.execute(update_flag)
session.commit()
session.close()
else:
# 没有查询到新案子,等待一段时间后再查
log = logging.getLogger()
log.info(f"暂未查询到需要涂抹的案子,等待{SLEEP_MINUTES}分钟...")
sleep(SLEEP_MINUTES * 60)
except Exception as e:
logging.error(traceback.format_exc())
send_an_error_email(program_name='照片涂抹脚本', error_name=repr(e), error_detail=traceback.format_exc())