126 lines
4.8 KiB
Python
126 lines
4.8 KiB
Python
![]() |
import json
|
|||
|
from src.patient import Patient
|
|||
|
from src.chat import llm_answer
|
|||
|
from src.utils import read_knowledge_json, read_rule_json
|
|||
|
from src.prompts import *
|
|||
|
from typing import List
|
|||
|
|
|||
|
|
|||
|
class Agent:
|
|||
|
def __init__(self, task, knowledge, rule):
|
|||
|
self.task = task
|
|||
|
self.knowledge = knowledge
|
|||
|
self.rule = rule
|
|||
|
|
|||
|
def choose_option(self, history, patient: Patient, logger):
|
|||
|
recorded_info = patient.recorded_info
|
|||
|
must_ask_options, may_ask_options = classify_options(self.knowledge)
|
|||
|
|
|||
|
if must_ask_options:
|
|||
|
return must_ask_options[0]
|
|||
|
if not must_ask_options and not may_ask_options:
|
|||
|
return '无'
|
|||
|
|
|||
|
user_content = info_combination.format(str(history[-30:]), recorded_info, str(may_ask_options))
|
|||
|
logger.info('***CHOOSE OPTION【1】***:\n' + user_content.strip())
|
|||
|
option = llm_answer(system_content=self.task, user_content=user_content)
|
|||
|
|
|||
|
# filter llm answer
|
|||
|
if option not in self.knowledge.keys():
|
|||
|
return '无'
|
|||
|
else:
|
|||
|
return option
|
|||
|
|
|||
|
def update_knowledge(self, option, info):
|
|||
|
if info and option in self.knowledge.keys():
|
|||
|
self.knowledge[option]["必问"] = -1
|
|||
|
|
|||
|
def process_query(self, option, patient, history, logger, overall=False):
|
|||
|
# history --> info
|
|||
|
# question = patient.history[-1]['content']
|
|||
|
print(999,patient.recorded_info, self.rule[option], option, history[-30:])
|
|||
|
user_content = extract_info_content.format(patient.recorded_info, self.rule[option], option, history[-30:])
|
|||
|
logger.info('***PROCESS QUERY【3】***:\n总体评估({})'.format(overall) + user_content.strip())
|
|||
|
if not overall:
|
|||
|
llm_respond = llm_answer(extract_info_task, user_content=user_content)
|
|||
|
else:
|
|||
|
llm_respond = llm_answer(extract_overall_info_task, user_content=user_content)
|
|||
|
print(llm_answer,)
|
|||
|
llm_respond = llm_respond.replace('json', '').replace('```', '').strip()
|
|||
|
try:
|
|||
|
info = json.loads(llm_respond)
|
|||
|
except:
|
|||
|
logger.error('json loads error, llm respond: ' + llm_respond)
|
|||
|
return None
|
|||
|
return info if 'status' in info.keys() else None
|
|||
|
|
|||
|
def all_asked(self):
|
|||
|
for _, option in self.knowledge.items():
|
|||
|
if option['必问'] != -1:
|
|||
|
return False
|
|||
|
return True
|
|||
|
|
|||
|
class CirculatorySystemAgent(Agent):
|
|||
|
def __init__(self, task, knowledge, rule):
|
|||
|
super().__init__(task, knowledge, rule)
|
|||
|
|
|||
|
|
|||
|
class RespiratorySystemAgent(Agent):
|
|||
|
def __init__(self, task, knowledge, rule):
|
|||
|
super().__init__(task, knowledge, rule)
|
|||
|
|
|||
|
|
|||
|
class NervousSystemAgent(Agent):
|
|||
|
def __init__(self, task, knowledge, rule):
|
|||
|
super().__init__(task, knowledge, rule)
|
|||
|
|
|||
|
|
|||
|
def init_multi_agents(agent_names):
|
|||
|
multi_agents = []
|
|||
|
agent_name0 = agent_names[0]
|
|||
|
multi_agents.append(CirculatorySystemAgent(circulatory_system_task, read_knowledge_json(agent_name0),
|
|||
|
read_rule_json(agent_name0)))
|
|||
|
|
|||
|
agent_name1 = agent_names[1]
|
|||
|
multi_agents.append(RespiratorySystemAgent(respiratory_system_task, read_knowledge_json(agent_name1),
|
|||
|
read_rule_json(agent_name1)))
|
|||
|
|
|||
|
agent_name2 = agent_names[2]
|
|||
|
multi_agents.append(NervousSystemAgent(nervous_system_task, read_knowledge_json(agent_name2),
|
|||
|
read_rule_json(agent_name2)))
|
|||
|
return multi_agents
|
|||
|
|
|||
|
|
|||
|
def doctor_state(history, patient: Patient, option, q_templates, missing_info, logger):
|
|||
|
if '预问诊' in option and not missing_info:
|
|||
|
return q_templates[0]
|
|||
|
user_content = consultation_content.format(patient.recorded_info, missing_info, option, q_templates, history[-30:])
|
|||
|
logger.info('***ASK QUESTION【2】***:\n' + user_content.strip())
|
|||
|
statement = llm_answer(consultation_task, user_content=user_content, temperature=0.8, max_tokens=256, stream=True)
|
|||
|
return statement
|
|||
|
|
|||
|
|
|||
|
def choose_agent(option, agent_order, agents_list: List[Agent]):
|
|||
|
# go next agent when 1. first agent 2. unnecessary to ask 3. all options asked
|
|||
|
if (not option or option == '无' or agents_list[agent_order].all_asked()) and agent_order < len(
|
|||
|
agents_list) - 1:
|
|||
|
return agent_order + 1, agents_list[agent_order + 1]
|
|||
|
|
|||
|
# stay asking when option is defined
|
|||
|
if not (not option or option == '无'):
|
|||
|
return agent_order, agents_list[agent_order]
|
|||
|
|
|||
|
return agent_order, None
|
|||
|
|
|||
|
|
|||
|
def classify_options(knowledge):
|
|||
|
must_ask_options = []
|
|||
|
may_ask_options = []
|
|||
|
for option, option_dict in knowledge.items():
|
|||
|
if option_dict['必问'] == 1:
|
|||
|
must_ask_options.append(option)
|
|||
|
elif option_dict['必问'] == 0:
|
|||
|
may_ask_options.append(option)
|
|||
|
|
|||
|
return must_ask_options, may_ask_options
|