统一模型接口,新增文本分类接口

This commit is contained in:
2024-09-27 13:50:55 +08:00
parent 117b29a737
commit f1149854ce
13 changed files with 144 additions and 97 deletions

View File

@@ -1,4 +1,4 @@
import logging
import logging.config
from flask import Flask, request
from paddleclas import PaddleClas
@@ -10,9 +10,9 @@ app = Flask(__name__)
CLAS = PaddleClas(model_name='text_image_orientation')
@app.route('/clas/orientation', methods=['POST'])
@app.route(rule='/', methods=['POST'])
@process_request
def orientation():
def main():
"""
判断图片旋转角度逆时针旋转该角度后为正可能值['0', '90', '180', '270']
:return: 最有可能的两个角度

View File

@@ -0,0 +1,28 @@
import logging.config
from flask import Flask, request
from paddlenlp import Taskflow
from log import LOGGING_CONFIG
from utils import process_request
app = Flask(__name__)
schema = ['基本医保结算单', '出院记录', '费用清单']
CLAS = Taskflow('zero_shot_text_classification', model='utc-xbase', schema=schema,
task_path='model/text_classification', precision='fp16')
@app.route('/', methods=['POST'])
@process_request
def main():
text = request.form.get('text')
cls_result = CLAS(text)
cls_result = cls_result[0].get('predictions')[0]
if cls_result['score'] < 0.8:
raise Exception(f'识别结果置信度过低text: {text}')
return cls_result['label']
if __name__ == '__main__':
logging.config.dictConfig(LOGGING_CONFIG)
app.run('0.0.0.0', 5008)

View File

@@ -1,4 +1,4 @@
import logging
import logging.config
import os.path
import cv2
@@ -11,9 +11,9 @@ from utils import process_request, parse_img_path
app = Flask(__name__)
@app.route('/det/books', methods=['POST'])
@app.route('/', methods=['POST'])
@process_request
def books():
def main():
img_path = request.form.get('img_path')
result = detector.get_book_areas(img_path)

View File

@@ -1,4 +1,4 @@
import logging
import logging.config
import os
import cv2
@@ -11,9 +11,9 @@ from utils import process_request, parse_img_path
app = Flask(__name__)
@app.route('/dewarp', methods=['POST'])
@app.route('/', methods=['POST'])
@process_request
def dewarp():
def main():
img_path = request.form.get('img_path')
img = cv2.imread(img_path)
dewarped_img = dewarper.dewarp_image(img)

View File

@@ -1,5 +1,5 @@
import json
import logging
import logging.config
from flask import Flask, request
from paddlenlp import Taskflow
@@ -14,9 +14,9 @@ COST = Taskflow('information_extraction', schema=COST_LIST_SCHEMA, model='uie-x-
task_path='model/cost_list_model', layout_analysis=False, precision='fp16')
@app.route('/nlp/cost', methods=['POST'])
@app.route('/', methods=['POST'])
@process_request
def cost():
def main():
img_path = request.form.get('img_path')
layout = request.form.get('layout')
return COST({'doc': img_path, 'layout': json.loads(layout)})

View File

@@ -1,5 +1,5 @@
import json
import logging
import logging.config
from flask import Flask, request
from paddlenlp import Taskflow
@@ -16,9 +16,9 @@ DISCHARGE = Taskflow('information_extraction', schema=DISCHARGE_RECORD_SCHEMA, m
task_path='model/discharge_record_model', layout_analysis=False, precision='fp16')
@app.route('/nlp/discharge', methods=['POST'])
@app.route('/', methods=['POST'])
@process_request
def discharge():
def main():
img_path = request.form.get('img_path')
layout = request.form.get('layout')
return DISCHARGE({'doc': img_path, 'layout': json.loads(layout)})

View File

@@ -1,5 +1,5 @@
import json
import logging
import logging.config
from flask import Flask, request
from paddlenlp import Taskflow
@@ -20,9 +20,9 @@ SETTLEMENT_IE = Taskflow('information_extraction', schema=SETTLEMENT_LIST_SCHEMA
task_path='model/settlement_list_model', layout_analysis=False, precision='fp16')
@app.route('/nlp/settlement', methods=['POST'])
@app.route('/', methods=['POST'])
@process_request
def settlement():
def main():
img_path = request.form.get('img_path')
layout = request.form.get('layout')
return SETTLEMENT_IE({'doc': img_path, 'layout': json.loads(layout)})

View File

@@ -0,0 +1 @@
文本分类模型存放目录

View File

@@ -1,4 +1,4 @@
import logging
import logging.config
from flask import Flask, request
from paddleocr import PaddleOCR
@@ -10,9 +10,9 @@ app = Flask(__name__)
OCR = PaddleOCR(use_angle_cls=False, show_log=False, gpu_id=0, det_db_box_thresh=0.3)
@app.route('/ocr', methods=['POST'])
@app.route('/', methods=['POST'])
@process_request
def ocr():
def main():
img_path = request.form.get('img_path')
return OCR.ocr(img_path, cls=False)