119 lines
4.2 KiB
Python
119 lines
4.2 KiB
Python
![]() |
import os
|
||
|
from dotenv import load_dotenv
|
||
|
from src.session import SessionState
|
||
|
from src.case_info import format_to_report
|
||
|
import ast
|
||
|
|
||
|
class Consultation:
|
||
|
def __init__(self):
|
||
|
# initialize one patient and session state
|
||
|
self.session_map = {}
|
||
|
|
||
|
def init_session(self, session_id, case_data):
|
||
|
load_dotenv()
|
||
|
self.session_map[session_id] = SessionState(case_data)
|
||
|
|
||
|
def qa_chat(self, session_id, content, msgList, question=None):
|
||
|
session_state = self.session_map[session_id]
|
||
|
print(777777,session_state.__dict__)
|
||
|
session_state.history = msgList
|
||
|
print("5555555session_state.option",session_state.option)
|
||
|
if_overall = '总体评估' in session_state.option if session_state.option else False
|
||
|
session_state.process_query_task(content, overall=if_overall)
|
||
|
print(8888,session_state.__dict__)
|
||
|
cur_card = [session_state.info, session_state.option, session_state.agent_names[session_state.agent_order]]
|
||
|
print("66666cur_card",cur_card)
|
||
|
if not session_state.info or session_state.info['status'] == 'success':
|
||
|
# update recorded information and update agent's knowledge
|
||
|
session_state.update()
|
||
|
|
||
|
# choose agent and
|
||
|
session_state.choose_agent_and_option_task()
|
||
|
if session_state.option == '无':
|
||
|
session_state.choose_agent_and_option_task()
|
||
|
|
||
|
elif session_state.info['status'] == 'need_clarification':
|
||
|
session_state.missing_info = session_state.info['missing_info']
|
||
|
|
||
|
# doctor asks the question
|
||
|
stream_state = session_state.doctor_state_task()
|
||
|
if session_state.if_end():
|
||
|
pass
|
||
|
|
||
|
self.session_map[session_id] = session_state
|
||
|
return cur_card, stream_state
|
||
|
|
||
|
def save_result(self, session_id):
|
||
|
return self.session_map[session_id].patient.recorded_info
|
||
|
|
||
|
def format_report(self, session_id):
|
||
|
return format_to_report(self.session_map[session_id].patient.recorded_info)
|
||
|
|
||
|
|
||
|
# if __name__ == '__main__':
|
||
|
# print(os.getenv('CASE'))
|
||
|
# consultation = Consultation()
|
||
|
# consultation.qa_chat()
|
||
|
|
||
|
"""import os
|
||
|
from dotenv import load_dotenv
|
||
|
from src.session import SessionState
|
||
|
from src.case_info import format_to_report
|
||
|
|
||
|
# globals.consultation.init_session(case_data=data)
|
||
|
|
||
|
# import os
|
||
|
# from dotenv import load_dotenv
|
||
|
# from session import SessionState
|
||
|
# from case_info import format_to_report
|
||
|
|
||
|
class Consultation:
|
||
|
def __init__(self):
|
||
|
# initialize one patient and session state
|
||
|
self.session_map = {}
|
||
|
|
||
|
def init_session(self, case_data):
|
||
|
load_dotenv()
|
||
|
self.session_map = SessionState(case_data)
|
||
|
|
||
|
def qa_chat(self, question=None):
|
||
|
session_state = self.session_map
|
||
|
if_overall = '总体评估' in session_state.option if session_state.option else False
|
||
|
session_state.process_query_task(question, overall=if_overall)
|
||
|
cur_card = [session_state.info, session_state.option, session_state.agent_names[session_state.agent_order]]
|
||
|
if not session_state.info or session_state.info['status'] == 'success':
|
||
|
# update recorded information and update agent's knowledge
|
||
|
session_state.update()
|
||
|
|
||
|
# choose agent and option
|
||
|
session_state.choose_agent_and_option_task()
|
||
|
if session_state.option == '无':
|
||
|
session_state.choose_agent_and_option_task()
|
||
|
|
||
|
elif session_state.info['status'] == 'need_clarification':
|
||
|
session_state.missing_info = session_state.info['missing_info']
|
||
|
|
||
|
# doctor asks the question
|
||
|
stream_state = session_state.doctor_state_task()
|
||
|
if session_state.if_end():
|
||
|
pass
|
||
|
self.session_map = session_state
|
||
|
return cur_card, stream_state
|
||
|
|
||
|
def save_result(self):
|
||
|
return self.session_map.patient.recorded_info
|
||
|
|
||
|
def format_report(self):
|
||
|
return format_to_report(self.session_map.patient.recorded_info)
|
||
|
|
||
|
# def save_result(self):
|
||
|
# return self.session_state.patient.recorded_info
|
||
|
|
||
|
# def format_report(self):
|
||
|
# return format_to_report(self.session_state.patient.recorded_info)
|
||
|
"""
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
print(os.getenv('CASE'))
|
||
|
consultation = Consultation()
|
||
|
consultation.qa_chat()
|