test
This commit is contained in:
parent
4911f3024f
commit
e105003815
|
@ -1,134 +1,134 @@
|
|||
from flask import Blueprint, jsonify, request
|
||||
from ..models import Session, Message
|
||||
from .. import db
|
||||
from . import globals
|
||||
import json,ast
|
||||
# from .report_routes import consultation
|
||||
|
||||
message_routes = Blueprint('message', __name__)
|
||||
|
||||
|
||||
def filter_data(data):
|
||||
return {
|
||||
"content": data.get("content"),
|
||||
"remark": data.get("remark"),
|
||||
"role": data.get("role"),
|
||||
"sessionId": data.get("sessionId")
|
||||
}
|
||||
|
||||
|
||||
def validate_session_id(sessionId):
|
||||
"""验证sessionId是否提供且存在于Session表中"""
|
||||
if not sessionId:
|
||||
return jsonify({"error": "sessionId is required"}), 400
|
||||
session = Session.query.filter_by(id=sessionId).first()
|
||||
if not session:
|
||||
return jsonify({"error": "Session not found"}), 404
|
||||
return None
|
||||
|
||||
|
||||
@message_routes.route('/get-message/<int:sessionId>', methods=['GET'])
|
||||
def get_message(sessionId):
|
||||
"""获取当前用户的所有会话"""
|
||||
validation_result = validate_session_id(sessionId)
|
||||
if validation_result:
|
||||
return validation_result
|
||||
messages = Message.query.filter_by(sessionId=sessionId).all()
|
||||
return jsonify(messages=[message.to_dict() for message in messages]), 200
|
||||
|
||||
|
||||
@message_routes.route('/add-message', methods=['POST'])
|
||||
def add_message():
|
||||
"""添加会话"""
|
||||
data = request.get_json()
|
||||
print("22222222request",request)
|
||||
result = filter_data(data)
|
||||
validation_result = validate_session_id(result["sessionId"]) # 修改为字典访问方式
|
||||
if validation_result:
|
||||
return validation_result
|
||||
if not result["role"]: # 修改为字典访问方式
|
||||
return jsonify({"error": "role is required"}), 400
|
||||
|
||||
new_message = Message(**result)
|
||||
db.session.add(new_message)
|
||||
db.session.commit()
|
||||
return jsonify({'message': 'message added successfully', 'data': new_message.to_dict()}), 201
|
||||
|
||||
|
||||
@message_routes.route('/update-message/<int:messageId>/<int:sessionId>', methods=['PUT'])
|
||||
def update_message(messageId, sessionId):
|
||||
"""更新会话"""
|
||||
print("1111111111111111111request",request)
|
||||
data = request.get_json()
|
||||
remark = data.get("remark") # 修改为字典访问方式
|
||||
validation_result = validate_session_id(sessionId) # 修改为字典访问方式
|
||||
if validation_result:
|
||||
return validation_result # 修改为字典访问方式
|
||||
message = Message.query.filter_by(id=messageId).first() # 修改为字典访问方式
|
||||
if not message: # 修改为字典访问方式
|
||||
return jsonify({"error": "message not found"}), 404 # 修改为字典访问方式
|
||||
if remark:
|
||||
message.remark = remark # 修改为字典访问方式
|
||||
db.session.commit() # 修改为字典访问方式
|
||||
return jsonify({'message': 'message updated successfully'})
|
||||
|
||||
|
||||
@message_routes.route('/to-chat', methods=['POST'])
|
||||
def to_chat():
|
||||
print("请求头:", dict(request.headers)) # 检查Content-Type
|
||||
print("原始字节数据:", request.get_data()) # 检查数据是否合法
|
||||
print("1111111111111111111request",request)
|
||||
# if not request.is_json:
|
||||
# return jsonify({"error": "Content-Type must be application/json"}), 401
|
||||
|
||||
# # 2. 获取外层JSON数据
|
||||
# try:
|
||||
# data = request.json
|
||||
# print("外层JSON数据:", data)
|
||||
# except Exception as e:
|
||||
# return jsonify({"error": f"外层JSON解析失败: {str(e)}"}), 402
|
||||
|
||||
# # 3. 解析嵌套的JSON字符串(msgList和patientInfo)
|
||||
# try:
|
||||
# # 解析msgList(字符串 -> 列表)
|
||||
# msg_list = json.loads(data["msgList"])
|
||||
# print("解析后的msgList:", msg_list)
|
||||
|
||||
# # 解析patientInfo(字符串 -> 字典)
|
||||
# patient_info = json.loads(data["patientInfo"])
|
||||
# print("解析后的patientInfo:", patient_info)
|
||||
|
||||
# except json.JSONDecodeError as e:
|
||||
# return jsonify({"error": f"嵌套JSON解析失败: {str(e)}"}), 403
|
||||
# except KeyError as e:
|
||||
# return jsonify({"error": f"缺少必要字段: {str(e)}"}), 404
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
data = request.json
|
||||
print("1111111111111111111data",data)
|
||||
patientId = data['patientId']
|
||||
msgList = data['msgList']
|
||||
patient_info = data['patientInfo']
|
||||
print("222222222patient_info",patient_info)
|
||||
if len(msgList) == 0:
|
||||
globals.consultation.init_session(patientId,case_data=patient_info)
|
||||
content = None
|
||||
else:
|
||||
content = data['msgList'][-1]['content']
|
||||
print("333333333content",type(content))
|
||||
value, option, system = None, None, None
|
||||
|
||||
cur_card, answer = globals.consultation.qa_chat(patientId, content, msgList)
|
||||
|
||||
if cur_card and cur_card[0] and cur_card[0]['status'] == 'success':
|
||||
value, option, system = cur_card[0]['option_value'], cur_card[1], cur_card[2]
|
||||
|
||||
return jsonify({
|
||||
'answer': answer,
|
||||
'analysis': {
|
||||
'value': value, 'option': option, 'system': system
|
||||
}
|
||||
}), 201
|
||||
from flask import Blueprint, jsonify, request
|
||||
from ..models import Session, Message
|
||||
from .. import db
|
||||
from . import globals
|
||||
import json,ast
|
||||
# from .report_routes import consultation
|
||||
|
||||
message_routes = Blueprint('message', __name__)
|
||||
|
||||
|
||||
def filter_data(data):
|
||||
return {
|
||||
"content": data.get("content"),
|
||||
"remark": data.get("remark"),
|
||||
"role": data.get("role"),
|
||||
"sessionId": data.get("sessionId")
|
||||
}
|
||||
|
||||
|
||||
def validate_session_id(sessionId):
|
||||
"""验证sessionId是否提供且存在于Session表中"""
|
||||
if not sessionId:
|
||||
return jsonify({"error": "sessionId is required"}), 400
|
||||
session = Session.query.filter_by(id=sessionId).first()
|
||||
if not session:
|
||||
return jsonify({"error": "Session not found"}), 404
|
||||
return None
|
||||
|
||||
|
||||
@message_routes.route('/get-message/<int:sessionId>', methods=['GET'])
|
||||
def get_message(sessionId):
|
||||
"""获取当前用户的所有会话"""
|
||||
validation_result = validate_session_id(sessionId)
|
||||
if validation_result:
|
||||
return validation_result
|
||||
messages = Message.query.filter_by(sessionId=sessionId).all()
|
||||
return jsonify(messages=[message.to_dict() for message in messages]), 200
|
||||
|
||||
|
||||
@message_routes.route('/add-message', methods=['POST'])
|
||||
def add_message():
|
||||
"""添加会话"""
|
||||
data = request.get_json()
|
||||
print("22222222request",request)
|
||||
result = filter_data(data)
|
||||
validation_result = validate_session_id(result["sessionId"]) # 修改为字典访问方式
|
||||
if validation_result:
|
||||
return validation_result
|
||||
if not result["role"]: # 修改为字典访问方式
|
||||
return jsonify({"error": "role is required"}), 400
|
||||
|
||||
new_message = Message(**result)
|
||||
db.session.add(new_message)
|
||||
db.session.commit()
|
||||
return jsonify({'message': 'message added successfully', 'data': new_message.to_dict()}), 201
|
||||
|
||||
|
||||
@message_routes.route('/update-message/<int:messageId>/<int:sessionId>', methods=['PUT'])
|
||||
def update_message(messageId, sessionId):
|
||||
"""更新会话"""
|
||||
print("1111111111111111111request",request)
|
||||
data = request.get_json()
|
||||
remark = data.get("remark") # 修改为字典访问方式
|
||||
validation_result = validate_session_id(sessionId) # 修改为字典访问方式
|
||||
if validation_result:
|
||||
return validation_result # 修改为字典访问方式
|
||||
message = Message.query.filter_by(id=messageId).first() # 修改为字典访问方式
|
||||
if not message: # 修改为字典访问方式
|
||||
return jsonify({"error": "message not found"}), 404 # 修改为字典访问方式
|
||||
if remark:
|
||||
message.remark = remark # 修改为字典访问方式
|
||||
db.session.commit() # 修改为字典访问方式
|
||||
return jsonify({'message': 'message updated successfully'})
|
||||
|
||||
|
||||
@message_routes.route('/to-chat', methods=['POST'])
|
||||
def to_chat():
|
||||
print("请求头:", dict(request.headers)) # 检查Content-Type
|
||||
print("原始字节数据:", request.get_data()) # 检查数据是否合法
|
||||
print("1111111111111111111request",request)
|
||||
# if not request.is_json:
|
||||
# return jsonify({"error": "Content-Type must be application/json"}), 401
|
||||
|
||||
# # 2. 获取外层JSON数据
|
||||
# try:
|
||||
# data = request.json
|
||||
# print("外层JSON数据:", data)
|
||||
# except Exception as e:
|
||||
# return jsonify({"error": f"外层JSON解析失败: {str(e)}"}), 402
|
||||
|
||||
# # 3. 解析嵌套的JSON字符串(msgList和patientInfo)
|
||||
# try:
|
||||
# # 解析msgList(字符串 -> 列表)
|
||||
# msg_list = json.loads(data["msgList"])
|
||||
# print("解析后的msgList:", msg_list)
|
||||
|
||||
# # 解析patientInfo(字符串 -> 字典)
|
||||
# patient_info = json.loads(data["patientInfo"])
|
||||
# print("解析后的patientInfo:", patient_info)
|
||||
|
||||
# except json.JSONDecodeError as e:
|
||||
# return jsonify({"error": f"嵌套JSON解析失败: {str(e)}"}), 403
|
||||
# except KeyError as e:
|
||||
# return jsonify({"error": f"缺少必要字段: {str(e)}"}), 404
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
data = request.json
|
||||
print("1111111111111111111data",data)
|
||||
patientId = data['patientId']
|
||||
msgList = data['msgList']
|
||||
patient_info = data['patientInfo']
|
||||
print("222222222patient_info",patient_info)
|
||||
if len(msgList) == 0:
|
||||
globals.consultation.init_session(patientId,case_data=patient_info)
|
||||
content = None
|
||||
else:
|
||||
content = data['msgList'][-1]['content']
|
||||
print("333333333content",type(content))
|
||||
value, option, system = None, None, None
|
||||
|
||||
cur_card, answer = globals.consultation.qa_chat(patientId, content, msgList)
|
||||
|
||||
if cur_card and cur_card[0] and cur_card[0]['status'] == 'success':
|
||||
value, option, system = cur_card[0]['option_value'], cur_card[1], cur_card[2]
|
||||
|
||||
return jsonify({
|
||||
'answer': answer,
|
||||
'analysis': {
|
||||
'value': value, 'option': option, 'system': system
|
||||
}
|
||||
}), 201
|
||||
|
|
236
src/main.py
236
src/main.py
|
@ -1,119 +1,119 @@
|
|||
import os
|
||||
from dotenv import load_dotenv
|
||||
from src.session import SessionState
|
||||
from src.case_info import format_to_report
|
||||
import ast
|
||||
|
||||
class Consultation:
|
||||
def __init__(self):
|
||||
# initialize one patient and session state
|
||||
self.session_map = {}
|
||||
|
||||
def init_session(self, session_id, case_data):
|
||||
load_dotenv()
|
||||
self.session_map[session_id] = SessionState(case_data)
|
||||
|
||||
def qa_chat(self, session_id, content, msgList, question=None):
|
||||
session_state = self.session_map[session_id]
|
||||
print(777777,session_state.__dict__)
|
||||
session_state.history = msgList
|
||||
print("5555555session_state.option",session_state.option)
|
||||
if_overall = '总体评估' in session_state.option if session_state.option else False
|
||||
session_state.process_query_task(content, overall=if_overall)
|
||||
print(8888,session_state.__dict__)
|
||||
cur_card = [session_state.info, session_state.option, session_state.agent_names[session_state.agent_order]]
|
||||
print("66666cur_card",cur_card)
|
||||
if not session_state.info or session_state.info['status'] == 'success':
|
||||
# update recorded information and update agent's knowledge
|
||||
session_state.update()
|
||||
|
||||
# choose agent and
|
||||
session_state.choose_agent_and_option_task()
|
||||
if session_state.option == '无':
|
||||
session_state.choose_agent_and_option_task()
|
||||
|
||||
elif session_state.info['status'] == 'need_clarification':
|
||||
session_state.missing_info = session_state.info['missing_info']
|
||||
|
||||
# doctor asks the question
|
||||
stream_state = session_state.doctor_state_task()
|
||||
if session_state.if_end():
|
||||
pass
|
||||
|
||||
self.session_map[session_id] = session_state
|
||||
return cur_card, stream_state
|
||||
|
||||
def save_result(self, session_id):
|
||||
return self.session_map[session_id].patient.recorded_info
|
||||
|
||||
def format_report(self, session_id):
|
||||
return format_to_report(self.session_map[session_id].patient.recorded_info)
|
||||
|
||||
|
||||
# if __name__ == '__main__':
|
||||
# print(os.getenv('CASE'))
|
||||
# consultation = Consultation()
|
||||
# consultation.qa_chat()
|
||||
|
||||
"""import os
|
||||
from dotenv import load_dotenv
|
||||
from src.session import SessionState
|
||||
from src.case_info import format_to_report
|
||||
|
||||
# globals.consultation.init_session(case_data=data)
|
||||
|
||||
# import os
|
||||
# from dotenv import load_dotenv
|
||||
# from session import SessionState
|
||||
# from case_info import format_to_report
|
||||
|
||||
class Consultation:
|
||||
def __init__(self):
|
||||
# initialize one patient and session state
|
||||
self.session_map = {}
|
||||
|
||||
def init_session(self, case_data):
|
||||
load_dotenv()
|
||||
self.session_map = SessionState(case_data)
|
||||
|
||||
def qa_chat(self, question=None):
|
||||
session_state = self.session_map
|
||||
if_overall = '总体评估' in session_state.option if session_state.option else False
|
||||
session_state.process_query_task(question, overall=if_overall)
|
||||
cur_card = [session_state.info, session_state.option, session_state.agent_names[session_state.agent_order]]
|
||||
if not session_state.info or session_state.info['status'] == 'success':
|
||||
# update recorded information and update agent's knowledge
|
||||
session_state.update()
|
||||
|
||||
# choose agent and option
|
||||
session_state.choose_agent_and_option_task()
|
||||
if session_state.option == '无':
|
||||
session_state.choose_agent_and_option_task()
|
||||
|
||||
elif session_state.info['status'] == 'need_clarification':
|
||||
session_state.missing_info = session_state.info['missing_info']
|
||||
|
||||
# doctor asks the question
|
||||
stream_state = session_state.doctor_state_task()
|
||||
if session_state.if_end():
|
||||
pass
|
||||
self.session_map = session_state
|
||||
return cur_card, stream_state
|
||||
|
||||
def save_result(self):
|
||||
return self.session_map.patient.recorded_info
|
||||
|
||||
def format_report(self):
|
||||
return format_to_report(self.session_map.patient.recorded_info)
|
||||
|
||||
# def save_result(self):
|
||||
# return self.session_state.patient.recorded_info
|
||||
|
||||
# def format_report(self):
|
||||
# return format_to_report(self.session_state.patient.recorded_info)
|
||||
"""
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(os.getenv('CASE'))
|
||||
consultation = Consultation()
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from src.session import SessionState
|
||||
from src.case_info import format_to_report
|
||||
import ast
|
||||
|
||||
class Consultation:
|
||||
def __init__(self):
|
||||
# initialize one patient and session state
|
||||
self.session_map = {}
|
||||
|
||||
def init_session(self, session_id, case_data):
|
||||
load_dotenv()
|
||||
self.session_map[session_id] = SessionState(case_data)
|
||||
|
||||
def qa_chat(self, session_id, content, msgList, question=None):
|
||||
session_state = self.session_map[session_id]
|
||||
print(777777,session_state.__dict__)
|
||||
session_state.history = msgList
|
||||
print("5555555session_state.option",session_state.option)
|
||||
if_overall = '总体评估' in session_state.option if session_state.option else False
|
||||
session_state.process_query_task(content, overall=if_overall)
|
||||
print(8888,session_state.__dict__)
|
||||
cur_card = [session_state.info, session_state.option, session_state.agent_names[session_state.agent_order]]
|
||||
print("66666cur_card",cur_card)
|
||||
if not session_state.info or session_state.info['status'] == 'success':
|
||||
# update recorded information and update agent's knowledge
|
||||
session_state.update()
|
||||
|
||||
# choose agent and
|
||||
session_state.choose_agent_and_option_task()
|
||||
if session_state.option == '无':
|
||||
session_state.choose_agent_and_option_task()
|
||||
|
||||
elif session_state.info['status'] == 'need_clarification':
|
||||
session_state.missing_info = session_state.info['missing_info']
|
||||
|
||||
# doctor asks the question
|
||||
stream_state = session_state.doctor_state_task()
|
||||
if session_state.if_end():
|
||||
pass
|
||||
|
||||
self.session_map[session_id] = session_state
|
||||
return cur_card, stream_state
|
||||
|
||||
def save_result(self, session_id):
|
||||
return self.session_map[session_id].patient.recorded_info
|
||||
|
||||
def format_report(self, session_id):
|
||||
return format_to_report(self.session_map[session_id].patient.recorded_info)
|
||||
|
||||
|
||||
# if __name__ == '__main__':
|
||||
# print(os.getenv('CASE'))
|
||||
# consultation = Consultation()
|
||||
# consultation.qa_chat()
|
||||
|
||||
"""import os
|
||||
from dotenv import load_dotenv
|
||||
from src.session import SessionState
|
||||
from src.case_info import format_to_report
|
||||
|
||||
# globals.consultation.init_session(case_data=data)
|
||||
|
||||
# import os
|
||||
# from dotenv import load_dotenv
|
||||
# from session import SessionState
|
||||
# from case_info import format_to_report
|
||||
|
||||
class Consultation:
|
||||
def __init__(self):
|
||||
# initialize one patient and session state
|
||||
self.session_map = {}
|
||||
|
||||
def init_session(self, case_data):
|
||||
load_dotenv()
|
||||
self.session_map = SessionState(case_data)
|
||||
|
||||
def qa_chat(self, question=None):
|
||||
session_state = self.session_map
|
||||
if_overall = '总体评估' in session_state.option if session_state.option else False
|
||||
session_state.process_query_task(question, overall=if_overall)
|
||||
cur_card = [session_state.info, session_state.option, session_state.agent_names[session_state.agent_order]]
|
||||
if not session_state.info or session_state.info['status'] == 'success':
|
||||
# update recorded information and update agent's knowledge
|
||||
session_state.update()
|
||||
|
||||
# choose agent and option
|
||||
session_state.choose_agent_and_option_task()
|
||||
if session_state.option == '无':
|
||||
session_state.choose_agent_and_option_task()
|
||||
|
||||
elif session_state.info['status'] == 'need_clarification':
|
||||
session_state.missing_info = session_state.info['missing_info']
|
||||
|
||||
# doctor asks the question
|
||||
stream_state = session_state.doctor_state_task()
|
||||
if session_state.if_end():
|
||||
pass
|
||||
self.session_map = session_state
|
||||
return cur_card, stream_state
|
||||
|
||||
def save_result(self):
|
||||
return self.session_map.patient.recorded_info
|
||||
|
||||
def format_report(self):
|
||||
return format_to_report(self.session_map.patient.recorded_info)
|
||||
|
||||
# def save_result(self):
|
||||
# return self.session_state.patient.recorded_info
|
||||
|
||||
# def format_report(self):
|
||||
# return format_to_report(self.session_state.patient.recorded_info)
|
||||
"""
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(os.getenv('CASE'))
|
||||
consultation = Consultation()
|
||||
consultation.qa_chat()
|
180
src/session.py
180
src/session.py
|
@ -1,90 +1,90 @@
|
|||
from src.agents import doctor_state, choose_agent, init_multi_agents
|
||||
from src.utils import get_q_template, init_logger, process_stream
|
||||
from src.patient import Patient
|
||||
|
||||
# from agents import doctor_state, choose_agent, init_multi_agents
|
||||
# from utils import get_q_template, init_logger, process_stream
|
||||
# from patient import Patient
|
||||
|
||||
|
||||
class SessionState:
|
||||
def __init__(self, case):
|
||||
self.patient = Patient(case)
|
||||
self.agent_order = -1
|
||||
self.agent_names = ['circulatory_system', 'respiratory_system', 'nervous_system']
|
||||
self.cur_agent = None
|
||||
self.agents = init_multi_agents(self.agent_names)
|
||||
|
||||
self.query = None
|
||||
self.info = None
|
||||
self.option = None
|
||||
self.q_templates = None
|
||||
self.missing_info = None
|
||||
|
||||
self.history = []
|
||||
self.logger = init_logger()
|
||||
|
||||
def get_agent_name(self):
|
||||
return self.agent_names[self.agent_order]
|
||||
|
||||
def if_end(self):
|
||||
return self.option == '无' and self.agent_order == len(self.agent_names) - 1
|
||||
|
||||
def add_to_history(self, role, message):
|
||||
pass
|
||||
# self.history.append({"role": role, "content": message})
|
||||
# print("history:",self.history)
|
||||
|
||||
def process_query_task(self, cur_query, overall=False):
|
||||
|
||||
if cur_query:
|
||||
print(7777,cur_query)
|
||||
# self.add_to_history('user', cur_query)
|
||||
self.logger.info('query: ' + str(cur_query))
|
||||
else:
|
||||
self.logger.info('query: None')
|
||||
info = self.cur_agent.process_query(self.option, self.patient, self.history,
|
||||
self.logger, overall) if self.cur_agent else None
|
||||
self.query = cur_query
|
||||
self.info = info
|
||||
self.logger.info('information in the query: ' + str(info))
|
||||
|
||||
def choose_agent_and_option_task(self):
|
||||
# choose next system agent
|
||||
current_agent_order, chosen_agent = choose_agent(self.option, self.agent_order, self.agents)
|
||||
self.agent_order, self.cur_agent = current_agent_order, chosen_agent
|
||||
self.logger.info("chosen agent: " + self.agent_names[current_agent_order])
|
||||
|
||||
# choose option
|
||||
if self.cur_agent:
|
||||
print("!!!!!!!!!!!",self.cur_agent)
|
||||
self.option = self.cur_agent.choose_option(self.history, self.patient, self.logger)
|
||||
print("!!!!!!!!!!!self.option",self.option)
|
||||
self.q_templates = get_q_template(self.get_agent_name(), self.option)
|
||||
self.missing_info = None
|
||||
self.logger.info('current option and question templates:\n' + self.option + ' ' + str(self.q_templates))
|
||||
|
||||
|
||||
def doctor_state_task(self):
|
||||
n = 3
|
||||
stream_state = doctor_state(self.history, self.patient, self.option, self.q_templates, self.missing_info, self.logger)
|
||||
state_list = process_stream(stream_state) if type(stream_state) != str else [stream_state[i:i+3] for i in range(0, len(stream_state), 3)]
|
||||
|
||||
if len(self.history) == 0:
|
||||
s = '我是您的麻醉评估医生。我们接下来需要了解一些基本情况。'
|
||||
state = [s[i:i+3] for i in range(0, len(s), 3)] + state_list
|
||||
elif self.if_end():
|
||||
s = '谢谢你的配合,我们已经了解了你的整体情况,再见!'
|
||||
state = [s[i:i+3] for i in range(0, len(s), 3)]
|
||||
else:
|
||||
state = state_list
|
||||
# self.add_to_history('assistant', ''.join(state))
|
||||
self.missing_info = None
|
||||
self.logger.info('doctor statement: ' + ''.join(state))
|
||||
self.logger.info('HISTORY:\n' + str(self.history) + '\n\n')
|
||||
print('DOCTOR:', ''.join(state))
|
||||
return state
|
||||
|
||||
def update(self):
|
||||
if self.info: self.patient.update_info(self.option, self.info)
|
||||
if self.cur_agent: self.cur_agent.update_knowledge(self.option, self.info)
|
||||
from src.agents import doctor_state, choose_agent, init_multi_agents
|
||||
from src.utils import get_q_template, init_logger, process_stream
|
||||
from src.patient import Patient
|
||||
|
||||
# from agents import doctor_state, choose_agent, init_multi_agents
|
||||
# from utils import get_q_template, init_logger, process_stream
|
||||
# from patient import Patient
|
||||
|
||||
|
||||
class SessionState:
|
||||
def __init__(self, case):
|
||||
self.patient = Patient(case)
|
||||
self.agent_order = -1
|
||||
self.agent_names = ['circulatory_system', 'respiratory_system', 'nervous_system']
|
||||
self.cur_agent = None
|
||||
self.agents = init_multi_agents(self.agent_names)
|
||||
|
||||
self.query = None
|
||||
self.info = None
|
||||
self.option = None
|
||||
self.q_templates = None
|
||||
self.missing_info = None
|
||||
|
||||
self.history = []
|
||||
self.logger = init_logger()
|
||||
|
||||
def get_agent_name(self):
|
||||
return self.agent_names[self.agent_order]
|
||||
|
||||
def if_end(self):
|
||||
return self.option == '无' and self.agent_order == len(self.agent_names) - 1
|
||||
|
||||
def add_to_history(self, role, message):
|
||||
pass
|
||||
# self.history.append({"role": role, "content": message})
|
||||
# print("history:",self.history)
|
||||
|
||||
def process_query_task(self, cur_query, overall=False):
|
||||
|
||||
if cur_query:
|
||||
print(7777,cur_query)
|
||||
# self.add_to_history('user', cur_query)
|
||||
self.logger.info('query: ' + str(cur_query))
|
||||
else:
|
||||
self.logger.info('query: None')
|
||||
info = self.cur_agent.process_query(self.option, self.patient, self.history,
|
||||
self.logger, overall) if self.cur_agent else None
|
||||
self.query = cur_query
|
||||
self.info = info
|
||||
self.logger.info('information in the query: ' + str(info))
|
||||
|
||||
def choose_agent_and_option_task(self):
|
||||
# choose next system agent
|
||||
current_agent_order, chosen_agent = choose_agent(self.option, self.agent_order, self.agents)
|
||||
self.agent_order, self.cur_agent = current_agent_order, chosen_agent
|
||||
self.logger.info("chosen agent: " + self.agent_names[current_agent_order])
|
||||
|
||||
# choose option
|
||||
if self.cur_agent:
|
||||
print("!!!!!!!!!!!",self.cur_agent)
|
||||
self.option = self.cur_agent.choose_option(self.history, self.patient, self.logger)
|
||||
print("!!!!!!!!!!!self.option",self.option)
|
||||
self.q_templates = get_q_template(self.get_agent_name(), self.option)
|
||||
self.missing_info = None
|
||||
self.logger.info('current option and question templates:\n' + self.option + ' ' + str(self.q_templates))
|
||||
|
||||
|
||||
def doctor_state_task(self):
|
||||
n = 3
|
||||
stream_state = doctor_state(self.history, self.patient, self.option, self.q_templates, self.missing_info, self.logger)
|
||||
state_list = process_stream(stream_state) if type(stream_state) != str else [stream_state[i:i+3] for i in range(0, len(stream_state), 3)]
|
||||
|
||||
if len(self.history) == 0:
|
||||
s = '我是您的麻醉评估医生。我们接下来需要了解一些基本情况。'
|
||||
state = [s[i:i+3] for i in range(0, len(s), 3)] + state_list
|
||||
elif self.if_end():
|
||||
s = '谢谢你的配合,我们已经了解了你的整体情况,再见!'
|
||||
state = [s[i:i+3] for i in range(0, len(s), 3)]
|
||||
else:
|
||||
state = state_list
|
||||
# self.add_to_history('assistant', ''.join(state))
|
||||
self.missing_info = None
|
||||
self.logger.info('doctor statement: ' + ''.join(state))
|
||||
self.logger.info('HISTORY:\n' + str(self.history) + '\n\n')
|
||||
print('DOCTOR:', ''.join(state))
|
||||
return state
|
||||
|
||||
def update(self):
|
||||
if self.info: self.patient.update_info(self.option, self.info)
|
||||
if self.cur_agent: self.cur_agent.update_knowledge(self.option, self.info)
|
||||
|
|
Loading…
Reference in New Issue
Block a user