提高文本分类精度

This commit is contained in:
2024-10-18 11:02:58 +08:00
parent 3f93bd476a
commit 6529dc3d98

View File

@@ -9,7 +9,7 @@ from utils import process_request
app = Flask(__name__) app = Flask(__name__)
schema = ['基本医保结算单', '出院记录', '费用清单'] schema = ['基本医保结算单', '出院记录', '费用清单']
CLAS = Taskflow('zero_shot_text_classification', model='utc-xbase', schema=schema, CLAS = Taskflow('zero_shot_text_classification', model='utc-xbase', schema=schema,
task_path='model/text_classification', precision='fp16') task_path='model/text_classification', precision='fp32')
@app.route('/', methods=['POST']) @app.route('/', methods=['POST'])
@@ -20,7 +20,7 @@ def main():
cls_result = cls_result[0].get('predictions') cls_result = cls_result[0].get('predictions')
if cls_result: if cls_result:
cls_result = cls_result[0] cls_result = cls_result[0]
if cls_result['score'] < 0.8: if cls_result['score'] and float(cls_result['score']) < 0.8:
logging.info(f"识别结果置信度{cls_result['score']}过低text: {text}") logging.info(f"识别结果置信度{cls_result['score']}过低text: {text}")
return None return None
return cls_result['label'] return cls_result['label']