diff --git a/services/paddle_services/clas_text.py b/services/paddle_services/clas_text.py index 9eb77fd..6c12158 100644 --- a/services/paddle_services/clas_text.py +++ b/services/paddle_services/clas_text.py @@ -9,7 +9,7 @@ from utils import process_request app = Flask(__name__) schema = ['基本医保结算单', '出院记录', '费用清单'] CLAS = Taskflow('zero_shot_text_classification', model='utc-xbase', schema=schema, - task_path='model/text_classification', precision='fp16') + task_path='model/text_classification', precision='fp32') @app.route('/', methods=['POST']) @@ -20,7 +20,7 @@ def main(): cls_result = cls_result[0].get('predictions') if cls_result: cls_result = cls_result[0] - if cls_result['score'] < 0.8: + if cls_result['score'] and float(cls_result['score']) < 0.8: logging.info(f"识别结果置信度{cls_result['score']}过低!text: {text}") return None return cls_result['label']