diff --git a/app/routes/message_routes.py b/app/routes/message_routes.py index 4e7bc50..050468e 100644 --- a/app/routes/message_routes.py +++ b/app/routes/message_routes.py @@ -1,134 +1,134 @@ -from flask import Blueprint, jsonify, request -from ..models import Session, Message -from .. import db -from . import globals -import json,ast -# from .report_routes import consultation - -message_routes = Blueprint('message', __name__) - - -def filter_data(data): - return { - "content": data.get("content"), - "remark": data.get("remark"), - "role": data.get("role"), - "sessionId": data.get("sessionId") - } - - -def validate_session_id(sessionId): - """验证sessionId是否提供且存在于Session表中""" - if not sessionId: - return jsonify({"error": "sessionId is required"}), 400 - session = Session.query.filter_by(id=sessionId).first() - if not session: - return jsonify({"error": "Session not found"}), 404 - return None - - -@message_routes.route('/get-message/', methods=['GET']) -def get_message(sessionId): - """获取当前用户的所有会话""" - validation_result = validate_session_id(sessionId) - if validation_result: - return validation_result - messages = Message.query.filter_by(sessionId=sessionId).all() - return jsonify(messages=[message.to_dict() for message in messages]), 200 - - -@message_routes.route('/add-message', methods=['POST']) -def add_message(): - """添加会话""" - data = request.get_json() - print("22222222request",request) - result = filter_data(data) - validation_result = validate_session_id(result["sessionId"]) # 修改为字典访问方式 - if validation_result: - return validation_result - if not result["role"]: # 修改为字典访问方式 - return jsonify({"error": "role is required"}), 400 - - new_message = Message(**result) - db.session.add(new_message) - db.session.commit() - return jsonify({'message': 'message added successfully', 'data': new_message.to_dict()}), 201 - - -@message_routes.route('/update-message//', methods=['PUT']) -def update_message(messageId, sessionId): - """更新会话""" - print("1111111111111111111request",request) - data = request.get_json() - remark = data.get("remark") # 修改为字典访问方式 - validation_result = validate_session_id(sessionId) # 修改为字典访问方式 - if validation_result: - return validation_result # 修改为字典访问方式 - message = Message.query.filter_by(id=messageId).first() # 修改为字典访问方式 - if not message: # 修改为字典访问方式 - return jsonify({"error": "message not found"}), 404 # 修改为字典访问方式 - if remark: - message.remark = remark # 修改为字典访问方式 - db.session.commit() # 修改为字典访问方式 - return jsonify({'message': 'message updated successfully'}) - - -@message_routes.route('/to-chat', methods=['POST']) -def to_chat(): - print("请求头:", dict(request.headers)) # 检查Content-Type - print("原始字节数据:", request.get_data()) # 检查数据是否合法 - print("1111111111111111111request",request) - # if not request.is_json: - # return jsonify({"error": "Content-Type must be application/json"}), 401 - - # # 2. 获取外层JSON数据 - # try: - # data = request.json - # print("外层JSON数据:", data) - # except Exception as e: - # return jsonify({"error": f"外层JSON解析失败: {str(e)}"}), 402 - - # # 3. 解析嵌套的JSON字符串(msgList和patientInfo) - # try: - # # 解析msgList(字符串 -> 列表) - # msg_list = json.loads(data["msgList"]) - # print("解析后的msgList:", msg_list) - - # # 解析patientInfo(字符串 -> 字典) - # patient_info = json.loads(data["patientInfo"]) - # print("解析后的patientInfo:", patient_info) - - # except json.JSONDecodeError as e: - # return jsonify({"error": f"嵌套JSON解析失败: {str(e)}"}), 403 - # except KeyError as e: - # return jsonify({"error": f"缺少必要字段: {str(e)}"}), 404 - - - - - - data = request.json - print("1111111111111111111data",data) - patientId = data['patientId'] - msgList = data['msgList'] - patient_info = data['patientInfo'] - print("222222222patient_info",patient_info) - if len(msgList) == 0: - globals.consultation.init_session(patientId,case_data=patient_info) - content = None - else: - content = data['msgList'][-1]['content'] - print("333333333content",type(content)) - value, option, system = None, None, None - - cur_card, answer = globals.consultation.qa_chat(patientId, content, msgList) - - if cur_card and cur_card[0] and cur_card[0]['status'] == 'success': - value, option, system = cur_card[0]['option_value'], cur_card[1], cur_card[2] - - return jsonify({ - 'answer': answer, - 'analysis': { - 'value': value, 'option': option, 'system': system - } - }), 201 +from flask import Blueprint, jsonify, request +from ..models import Session, Message +from .. import db +from . import globals +import json,ast +# from .report_routes import consultation + +message_routes = Blueprint('message', __name__) + + +def filter_data(data): + return { + "content": data.get("content"), + "remark": data.get("remark"), + "role": data.get("role"), + "sessionId": data.get("sessionId") + } + + +def validate_session_id(sessionId): + """验证sessionId是否提供且存在于Session表中""" + if not sessionId: + return jsonify({"error": "sessionId is required"}), 400 + session = Session.query.filter_by(id=sessionId).first() + if not session: + return jsonify({"error": "Session not found"}), 404 + return None + + +@message_routes.route('/get-message/', methods=['GET']) +def get_message(sessionId): + """获取当前用户的所有会话""" + validation_result = validate_session_id(sessionId) + if validation_result: + return validation_result + messages = Message.query.filter_by(sessionId=sessionId).all() + return jsonify(messages=[message.to_dict() for message in messages]), 200 + + +@message_routes.route('/add-message', methods=['POST']) +def add_message(): + """添加会话""" + data = request.get_json() + print("22222222request",request) + result = filter_data(data) + validation_result = validate_session_id(result["sessionId"]) # 修改为字典访问方式 + if validation_result: + return validation_result + if not result["role"]: # 修改为字典访问方式 + return jsonify({"error": "role is required"}), 400 + + new_message = Message(**result) + db.session.add(new_message) + db.session.commit() + return jsonify({'message': 'message added successfully', 'data': new_message.to_dict()}), 201 + + +@message_routes.route('/update-message//', methods=['PUT']) +def update_message(messageId, sessionId): + """更新会话""" + print("1111111111111111111request",request) + data = request.get_json() + remark = data.get("remark") # 修改为字典访问方式 + validation_result = validate_session_id(sessionId) # 修改为字典访问方式 + if validation_result: + return validation_result # 修改为字典访问方式 + message = Message.query.filter_by(id=messageId).first() # 修改为字典访问方式 + if not message: # 修改为字典访问方式 + return jsonify({"error": "message not found"}), 404 # 修改为字典访问方式 + if remark: + message.remark = remark # 修改为字典访问方式 + db.session.commit() # 修改为字典访问方式 + return jsonify({'message': 'message updated successfully'}) + + +@message_routes.route('/to-chat', methods=['POST']) +def to_chat(): + print("请求头:", dict(request.headers)) # 检查Content-Type + print("原始字节数据:", request.get_data()) # 检查数据是否合法 + print("1111111111111111111request",request) + # if not request.is_json: + # return jsonify({"error": "Content-Type must be application/json"}), 401 + + # # 2. 获取外层JSON数据 + # try: + # data = request.json + # print("外层JSON数据:", data) + # except Exception as e: + # return jsonify({"error": f"外层JSON解析失败: {str(e)}"}), 402 + + # # 3. 解析嵌套的JSON字符串(msgList和patientInfo) + # try: + # # 解析msgList(字符串 -> 列表) + # msg_list = json.loads(data["msgList"]) + # print("解析后的msgList:", msg_list) + + # # 解析patientInfo(字符串 -> 字典) + # patient_info = json.loads(data["patientInfo"]) + # print("解析后的patientInfo:", patient_info) + + # except json.JSONDecodeError as e: + # return jsonify({"error": f"嵌套JSON解析失败: {str(e)}"}), 403 + # except KeyError as e: + # return jsonify({"error": f"缺少必要字段: {str(e)}"}), 404 + + + + + + data = request.json + print("1111111111111111111data",data) + patientId = data['patientId'] + msgList = data['msgList'] + patient_info = data['patientInfo'] + print("222222222patient_info",patient_info) + if len(msgList) == 0: + globals.consultation.init_session(patientId,case_data=patient_info) + content = None + else: + content = data['msgList'][-1]['content'] + print("333333333content",type(content)) + value, option, system = None, None, None + + cur_card, answer = globals.consultation.qa_chat(patientId, content, msgList) + + if cur_card and cur_card[0] and cur_card[0]['status'] == 'success': + value, option, system = cur_card[0]['option_value'], cur_card[1], cur_card[2] + + return jsonify({ + 'answer': answer, + 'analysis': { + 'value': value, 'option': option, 'system': system + } + }), 201 diff --git a/src/main.py b/src/main.py index d7e4435..0b2754e 100644 --- a/src/main.py +++ b/src/main.py @@ -1,119 +1,119 @@ -import os -from dotenv import load_dotenv -from src.session import SessionState -from src.case_info import format_to_report -import ast - -class Consultation: - def __init__(self): - # initialize one patient and session state - self.session_map = {} - - def init_session(self, session_id, case_data): - load_dotenv() - self.session_map[session_id] = SessionState(case_data) - - def qa_chat(self, session_id, content, msgList, question=None): - session_state = self.session_map[session_id] - print(777777,session_state.__dict__) - session_state.history = msgList - print("5555555session_state.option",session_state.option) - if_overall = '总体评估' in session_state.option if session_state.option else False - session_state.process_query_task(content, overall=if_overall) - print(8888,session_state.__dict__) - cur_card = [session_state.info, session_state.option, session_state.agent_names[session_state.agent_order]] - print("66666cur_card",cur_card) - if not session_state.info or session_state.info['status'] == 'success': - # update recorded information and update agent's knowledge - session_state.update() - - # choose agent and - session_state.choose_agent_and_option_task() - if session_state.option == '无': - session_state.choose_agent_and_option_task() - - elif session_state.info['status'] == 'need_clarification': - session_state.missing_info = session_state.info['missing_info'] - - # doctor asks the question - stream_state = session_state.doctor_state_task() - if session_state.if_end(): - pass - - self.session_map[session_id] = session_state - return cur_card, stream_state - - def save_result(self, session_id): - return self.session_map[session_id].patient.recorded_info - - def format_report(self, session_id): - return format_to_report(self.session_map[session_id].patient.recorded_info) - - -# if __name__ == '__main__': -# print(os.getenv('CASE')) -# consultation = Consultation() -# consultation.qa_chat() - -"""import os -from dotenv import load_dotenv -from src.session import SessionState -from src.case_info import format_to_report - -# globals.consultation.init_session(case_data=data) - -# import os -# from dotenv import load_dotenv -# from session import SessionState -# from case_info import format_to_report - -class Consultation: - def __init__(self): - # initialize one patient and session state - self.session_map = {} - - def init_session(self, case_data): - load_dotenv() - self.session_map = SessionState(case_data) - - def qa_chat(self, question=None): - session_state = self.session_map - if_overall = '总体评估' in session_state.option if session_state.option else False - session_state.process_query_task(question, overall=if_overall) - cur_card = [session_state.info, session_state.option, session_state.agent_names[session_state.agent_order]] - if not session_state.info or session_state.info['status'] == 'success': - # update recorded information and update agent's knowledge - session_state.update() - - # choose agent and option - session_state.choose_agent_and_option_task() - if session_state.option == '无': - session_state.choose_agent_and_option_task() - - elif session_state.info['status'] == 'need_clarification': - session_state.missing_info = session_state.info['missing_info'] - - # doctor asks the question - stream_state = session_state.doctor_state_task() - if session_state.if_end(): - pass - self.session_map = session_state - return cur_card, stream_state - - def save_result(self): - return self.session_map.patient.recorded_info - - def format_report(self): - return format_to_report(self.session_map.patient.recorded_info) - - # def save_result(self): - # return self.session_state.patient.recorded_info - - # def format_report(self): - # return format_to_report(self.session_state.patient.recorded_info) -""" - -if __name__ == '__main__': - print(os.getenv('CASE')) - consultation = Consultation() +import os +from dotenv import load_dotenv +from src.session import SessionState +from src.case_info import format_to_report +import ast + +class Consultation: + def __init__(self): + # initialize one patient and session state + self.session_map = {} + + def init_session(self, session_id, case_data): + load_dotenv() + self.session_map[session_id] = SessionState(case_data) + + def qa_chat(self, session_id, content, msgList, question=None): + session_state = self.session_map[session_id] + print(777777,session_state.__dict__) + session_state.history = msgList + print("5555555session_state.option",session_state.option) + if_overall = '总体评估' in session_state.option if session_state.option else False + session_state.process_query_task(content, overall=if_overall) + print(8888,session_state.__dict__) + cur_card = [session_state.info, session_state.option, session_state.agent_names[session_state.agent_order]] + print("66666cur_card",cur_card) + if not session_state.info or session_state.info['status'] == 'success': + # update recorded information and update agent's knowledge + session_state.update() + + # choose agent and + session_state.choose_agent_and_option_task() + if session_state.option == '无': + session_state.choose_agent_and_option_task() + + elif session_state.info['status'] == 'need_clarification': + session_state.missing_info = session_state.info['missing_info'] + + # doctor asks the question + stream_state = session_state.doctor_state_task() + if session_state.if_end(): + pass + + self.session_map[session_id] = session_state + return cur_card, stream_state + + def save_result(self, session_id): + return self.session_map[session_id].patient.recorded_info + + def format_report(self, session_id): + return format_to_report(self.session_map[session_id].patient.recorded_info) + + +# if __name__ == '__main__': +# print(os.getenv('CASE')) +# consultation = Consultation() +# consultation.qa_chat() + +"""import os +from dotenv import load_dotenv +from src.session import SessionState +from src.case_info import format_to_report + +# globals.consultation.init_session(case_data=data) + +# import os +# from dotenv import load_dotenv +# from session import SessionState +# from case_info import format_to_report + +class Consultation: + def __init__(self): + # initialize one patient and session state + self.session_map = {} + + def init_session(self, case_data): + load_dotenv() + self.session_map = SessionState(case_data) + + def qa_chat(self, question=None): + session_state = self.session_map + if_overall = '总体评估' in session_state.option if session_state.option else False + session_state.process_query_task(question, overall=if_overall) + cur_card = [session_state.info, session_state.option, session_state.agent_names[session_state.agent_order]] + if not session_state.info or session_state.info['status'] == 'success': + # update recorded information and update agent's knowledge + session_state.update() + + # choose agent and option + session_state.choose_agent_and_option_task() + if session_state.option == '无': + session_state.choose_agent_and_option_task() + + elif session_state.info['status'] == 'need_clarification': + session_state.missing_info = session_state.info['missing_info'] + + # doctor asks the question + stream_state = session_state.doctor_state_task() + if session_state.if_end(): + pass + self.session_map = session_state + return cur_card, stream_state + + def save_result(self): + return self.session_map.patient.recorded_info + + def format_report(self): + return format_to_report(self.session_map.patient.recorded_info) + + # def save_result(self): + # return self.session_state.patient.recorded_info + + # def format_report(self): + # return format_to_report(self.session_state.patient.recorded_info) +""" + +if __name__ == '__main__': + print(os.getenv('CASE')) + consultation = Consultation() consultation.qa_chat() \ No newline at end of file diff --git a/src/session.py b/src/session.py index d1958d0..08ac521 100644 --- a/src/session.py +++ b/src/session.py @@ -1,90 +1,90 @@ -from src.agents import doctor_state, choose_agent, init_multi_agents -from src.utils import get_q_template, init_logger, process_stream -from src.patient import Patient - -# from agents import doctor_state, choose_agent, init_multi_agents -# from utils import get_q_template, init_logger, process_stream -# from patient import Patient - - -class SessionState: - def __init__(self, case): - self.patient = Patient(case) - self.agent_order = -1 - self.agent_names = ['circulatory_system', 'respiratory_system', 'nervous_system'] - self.cur_agent = None - self.agents = init_multi_agents(self.agent_names) - - self.query = None - self.info = None - self.option = None - self.q_templates = None - self.missing_info = None - - self.history = [] - self.logger = init_logger() - - def get_agent_name(self): - return self.agent_names[self.agent_order] - - def if_end(self): - return self.option == '无' and self.agent_order == len(self.agent_names) - 1 - - def add_to_history(self, role, message): - pass - # self.history.append({"role": role, "content": message}) - # print("history:",self.history) - - def process_query_task(self, cur_query, overall=False): - - if cur_query: - print(7777,cur_query) - # self.add_to_history('user', cur_query) - self.logger.info('query: ' + str(cur_query)) - else: - self.logger.info('query: None') - info = self.cur_agent.process_query(self.option, self.patient, self.history, - self.logger, overall) if self.cur_agent else None - self.query = cur_query - self.info = info - self.logger.info('information in the query: ' + str(info)) - - def choose_agent_and_option_task(self): - # choose next system agent - current_agent_order, chosen_agent = choose_agent(self.option, self.agent_order, self.agents) - self.agent_order, self.cur_agent = current_agent_order, chosen_agent - self.logger.info("chosen agent: " + self.agent_names[current_agent_order]) - - # choose option - if self.cur_agent: - print("!!!!!!!!!!!",self.cur_agent) - self.option = self.cur_agent.choose_option(self.history, self.patient, self.logger) - print("!!!!!!!!!!!self.option",self.option) - self.q_templates = get_q_template(self.get_agent_name(), self.option) - self.missing_info = None - self.logger.info('current option and question templates:\n' + self.option + ' ' + str(self.q_templates)) - - - def doctor_state_task(self): - n = 3 - stream_state = doctor_state(self.history, self.patient, self.option, self.q_templates, self.missing_info, self.logger) - state_list = process_stream(stream_state) if type(stream_state) != str else [stream_state[i:i+3] for i in range(0, len(stream_state), 3)] - - if len(self.history) == 0: - s = '我是您的麻醉评估医生。我们接下来需要了解一些基本情况。' - state = [s[i:i+3] for i in range(0, len(s), 3)] + state_list - elif self.if_end(): - s = '谢谢你的配合,我们已经了解了你的整体情况,再见!' - state = [s[i:i+3] for i in range(0, len(s), 3)] - else: - state = state_list - # self.add_to_history('assistant', ''.join(state)) - self.missing_info = None - self.logger.info('doctor statement: ' + ''.join(state)) - self.logger.info('HISTORY:\n' + str(self.history) + '\n\n') - print('DOCTOR:', ''.join(state)) - return state - - def update(self): - if self.info: self.patient.update_info(self.option, self.info) - if self.cur_agent: self.cur_agent.update_knowledge(self.option, self.info) +from src.agents import doctor_state, choose_agent, init_multi_agents +from src.utils import get_q_template, init_logger, process_stream +from src.patient import Patient + +# from agents import doctor_state, choose_agent, init_multi_agents +# from utils import get_q_template, init_logger, process_stream +# from patient import Patient + + +class SessionState: + def __init__(self, case): + self.patient = Patient(case) + self.agent_order = -1 + self.agent_names = ['circulatory_system', 'respiratory_system', 'nervous_system'] + self.cur_agent = None + self.agents = init_multi_agents(self.agent_names) + + self.query = None + self.info = None + self.option = None + self.q_templates = None + self.missing_info = None + + self.history = [] + self.logger = init_logger() + + def get_agent_name(self): + return self.agent_names[self.agent_order] + + def if_end(self): + return self.option == '无' and self.agent_order == len(self.agent_names) - 1 + + def add_to_history(self, role, message): + pass + # self.history.append({"role": role, "content": message}) + # print("history:",self.history) + + def process_query_task(self, cur_query, overall=False): + + if cur_query: + print(7777,cur_query) + # self.add_to_history('user', cur_query) + self.logger.info('query: ' + str(cur_query)) + else: + self.logger.info('query: None') + info = self.cur_agent.process_query(self.option, self.patient, self.history, + self.logger, overall) if self.cur_agent else None + self.query = cur_query + self.info = info + self.logger.info('information in the query: ' + str(info)) + + def choose_agent_and_option_task(self): + # choose next system agent + current_agent_order, chosen_agent = choose_agent(self.option, self.agent_order, self.agents) + self.agent_order, self.cur_agent = current_agent_order, chosen_agent + self.logger.info("chosen agent: " + self.agent_names[current_agent_order]) + + # choose option + if self.cur_agent: + print("!!!!!!!!!!!",self.cur_agent) + self.option = self.cur_agent.choose_option(self.history, self.patient, self.logger) + print("!!!!!!!!!!!self.option",self.option) + self.q_templates = get_q_template(self.get_agent_name(), self.option) + self.missing_info = None + self.logger.info('current option and question templates:\n' + self.option + ' ' + str(self.q_templates)) + + + def doctor_state_task(self): + n = 3 + stream_state = doctor_state(self.history, self.patient, self.option, self.q_templates, self.missing_info, self.logger) + state_list = process_stream(stream_state) if type(stream_state) != str else [stream_state[i:i+3] for i in range(0, len(stream_state), 3)] + + if len(self.history) == 0: + s = '我是您的麻醉评估医生。我们接下来需要了解一些基本情况。' + state = [s[i:i+3] for i in range(0, len(s), 3)] + state_list + elif self.if_end(): + s = '谢谢你的配合,我们已经了解了你的整体情况,再见!' + state = [s[i:i+3] for i in range(0, len(s), 3)] + else: + state = state_list + # self.add_to_history('assistant', ''.join(state)) + self.missing_info = None + self.logger.info('doctor statement: ' + ''.join(state)) + self.logger.info('HISTORY:\n' + str(self.history) + '\n\n') + print('DOCTOR:', ''.join(state)) + return state + + def update(self): + if self.info: self.patient.update_info(self.option, self.info) + if self.cur_agent: self.cur_agent.update_knowledge(self.option, self.info) diff --git a/test.py b/test.py new file mode 100644 index 0000000..eddc4c8 --- /dev/null +++ b/test.py @@ -0,0 +1 @@ +print(112) \ No newline at end of file