shuqianpinggu/src/main.py

59 lines
2.1 KiB
Python
Raw Normal View History

2025-06-17 17:46:44 +08:00
import os
from dotenv import load_dotenv
from src.session import SessionState
from src.case_info import format_to_report
import ast
2025-06-17 17:49:40 +08:00
# import os
# from dotenv import load_dotenv
# from session import SessionState
# from case_info import format_to_report
2025-06-17 17:46:44 +08:00
class Consultation:
def __init__(self):
# initialize one patient and session state
self.session_map = {}
def init_session(self, session_id, case_data):
load_dotenv()
self.session_map[session_id] = SessionState(case_data)
2025-06-17 17:49:40 +08:00
def qa_chat(self, session_id, question=None):
2025-06-17 17:46:44 +08:00
session_state = self.session_map[session_id]
2025-06-17 17:49:40 +08:00
print("22222self.session_map",self.session_map)
2025-06-17 17:46:44 +08:00
if_overall = '总体评估' in session_state.option if session_state.option else False
2025-06-17 17:49:40 +08:00
session_state.process_query_task(question, overall=if_overall)
print("1111111session_state",vars(session_state))
2025-06-17 17:46:44 +08:00
cur_card = [session_state.info, session_state.option, session_state.agent_names[session_state.agent_order]]
2025-06-17 17:49:40 +08:00
print("3333cur_card",cur_card)
2025-06-17 17:46:44 +08:00
if not session_state.info or session_state.info['status'] == 'success':
# update recorded information and update agent's knowledge
session_state.update()
2025-06-17 17:49:40 +08:00
# choose agent and option
2025-06-17 17:46:44 +08:00
session_state.choose_agent_and_option_task()
if session_state.option == '':
session_state.choose_agent_and_option_task()
elif session_state.info['status'] == 'need_clarification':
session_state.missing_info = session_state.info['missing_info']
# doctor asks the question
stream_state = session_state.doctor_state_task()
if session_state.if_end():
pass
self.session_map[session_id] = session_state
return cur_card, stream_state
def save_result(self, session_id):
return self.session_map[session_id].patient.recorded_info
def format_report(self, session_id):
return format_to_report(self.session_map[session_id].patient.recorded_info)
if __name__ == '__main__':
print(os.getenv('CASE'))
consultation = Consultation()
2025-06-17 17:49:40 +08:00
consultation.qa_chat()