32 lines
902 B
Python
32 lines
902 B
Python
import logging.config
|
||
|
||
from flask import Flask, request
|
||
from paddlenlp import Taskflow
|
||
|
||
from log import LOGGING_CONFIG
|
||
from utils import process_request
|
||
|
||
app = Flask(__name__)
|
||
schema = ['基本医保结算单', '出院记录', '费用清单']
|
||
CLAS = Taskflow('zero_shot_text_classification', model='utc-xbase', schema=schema,
|
||
task_path='model/text_classification', precision='fp16')
|
||
|
||
|
||
@app.route('/', methods=['POST'])
|
||
@process_request
|
||
def main():
|
||
text = request.form.get('text')
|
||
cls_result = CLAS(text)
|
||
if not cls_result:
|
||
return None
|
||
cls_result = cls_result[0].get('predictions')[0]
|
||
if cls_result['score'] < 0.8:
|
||
logging.info(f"识别结果置信度{cls_result['score']}过低!text: {text}")
|
||
return None
|
||
return cls_result['label']
|
||
|
||
|
||
if __name__ == '__main__':
|
||
logging.config.dictConfig(LOGGING_CONFIG)
|
||
app.run('0.0.0.0', 5008)
|