统一模型接口,新增文本分类接口
This commit is contained in:
28
services/paddle_services/clas_text.py
Normal file
28
services/paddle_services/clas_text.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import logging.config
|
||||
|
||||
from flask import Flask, request
|
||||
from paddlenlp import Taskflow
|
||||
|
||||
from log import LOGGING_CONFIG
|
||||
from utils import process_request
|
||||
|
||||
app = Flask(__name__)
|
||||
schema = ['基本医保结算单', '出院记录', '费用清单']
|
||||
CLAS = Taskflow('zero_shot_text_classification', model='utc-xbase', schema=schema,
|
||||
task_path='model/text_classification', precision='fp16')
|
||||
|
||||
|
||||
@app.route('/', methods=['POST'])
|
||||
@process_request
|
||||
def main():
|
||||
text = request.form.get('text')
|
||||
cls_result = CLAS(text)
|
||||
cls_result = cls_result[0].get('predictions')[0]
|
||||
if cls_result['score'] < 0.8:
|
||||
raise Exception(f'识别结果置信度过低!text: {text}')
|
||||
return cls_result['label']
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.config.dictConfig(LOGGING_CONFIG)
|
||||
app.run('0.0.0.0', 5008)
|
||||
Reference in New Issue
Block a user