126 lines
4.8 KiB
Python
126 lines
4.8 KiB
Python
import json
|
||
from src.patient import Patient
|
||
from src.chat import llm_answer
|
||
from src.utils import read_knowledge_json, read_rule_json
|
||
from src.prompts import *
|
||
from typing import List
|
||
|
||
|
||
class Agent:
|
||
def __init__(self, task, knowledge, rule):
|
||
self.task = task
|
||
self.knowledge = knowledge
|
||
self.rule = rule
|
||
|
||
def choose_option(self, history, patient: Patient, logger):
|
||
recorded_info = patient.recorded_info
|
||
must_ask_options, may_ask_options = classify_options(self.knowledge)
|
||
|
||
if must_ask_options:
|
||
return must_ask_options[0]
|
||
if not must_ask_options and not may_ask_options:
|
||
return '无'
|
||
|
||
user_content = info_combination.format(str(history[-30:]), recorded_info, str(may_ask_options))
|
||
logger.info('***CHOOSE OPTION【1】***:\n' + user_content.strip())
|
||
option = llm_answer(system_content=self.task, user_content=user_content)
|
||
|
||
# filter llm answer
|
||
if option not in self.knowledge.keys():
|
||
return '无'
|
||
else:
|
||
return option
|
||
|
||
def update_knowledge(self, option, info):
|
||
if info and option in self.knowledge.keys():
|
||
self.knowledge[option]["必问"] = -1
|
||
|
||
def process_query(self, option, patient, history, logger, overall=False):
|
||
# history --> info
|
||
# question = patient.history[-1]['content']
|
||
print(999,patient.recorded_info, self.rule[option], option, history[-30:])
|
||
user_content = extract_info_content.format(patient.recorded_info, self.rule[option], option, history[-30:])
|
||
logger.info('***PROCESS QUERY【3】***:\n总体评估({})'.format(overall) + user_content.strip())
|
||
if not overall:
|
||
llm_respond = llm_answer(extract_info_task, user_content=user_content)
|
||
else:
|
||
llm_respond = llm_answer(extract_overall_info_task, user_content=user_content)
|
||
print(llm_answer,)
|
||
llm_respond = llm_respond.replace('json', '').replace('```', '').strip()
|
||
try:
|
||
info = json.loads(llm_respond)
|
||
except:
|
||
logger.error('json loads error, llm respond: ' + llm_respond)
|
||
return None
|
||
return info if 'status' in info.keys() else None
|
||
|
||
def all_asked(self):
|
||
for _, option in self.knowledge.items():
|
||
if option['必问'] != -1:
|
||
return False
|
||
return True
|
||
|
||
class CirculatorySystemAgent(Agent):
|
||
def __init__(self, task, knowledge, rule):
|
||
super().__init__(task, knowledge, rule)
|
||
|
||
|
||
class RespiratorySystemAgent(Agent):
|
||
def __init__(self, task, knowledge, rule):
|
||
super().__init__(task, knowledge, rule)
|
||
|
||
|
||
class NervousSystemAgent(Agent):
|
||
def __init__(self, task, knowledge, rule):
|
||
super().__init__(task, knowledge, rule)
|
||
|
||
|
||
def init_multi_agents(agent_names):
|
||
multi_agents = []
|
||
agent_name0 = agent_names[0]
|
||
multi_agents.append(CirculatorySystemAgent(circulatory_system_task, read_knowledge_json(agent_name0),
|
||
read_rule_json(agent_name0)))
|
||
|
||
agent_name1 = agent_names[1]
|
||
multi_agents.append(RespiratorySystemAgent(respiratory_system_task, read_knowledge_json(agent_name1),
|
||
read_rule_json(agent_name1)))
|
||
|
||
agent_name2 = agent_names[2]
|
||
multi_agents.append(NervousSystemAgent(nervous_system_task, read_knowledge_json(agent_name2),
|
||
read_rule_json(agent_name2)))
|
||
return multi_agents
|
||
|
||
|
||
def doctor_state(history, patient: Patient, option, q_templates, missing_info, logger):
|
||
if '预问诊' in option and not missing_info:
|
||
return q_templates[0]
|
||
user_content = consultation_content.format(patient.recorded_info, missing_info, option, q_templates, history[-30:])
|
||
logger.info('***ASK QUESTION【2】***:\n' + user_content.strip())
|
||
statement = llm_answer(consultation_task, user_content=user_content, temperature=0.8, max_tokens=256, stream=True)
|
||
return statement
|
||
|
||
|
||
def choose_agent(option, agent_order, agents_list: List[Agent]):
|
||
# go next agent when 1. first agent 2. unnecessary to ask 3. all options asked
|
||
if (not option or option == '无' or agents_list[agent_order].all_asked()) and agent_order < len(
|
||
agents_list) - 1:
|
||
return agent_order + 1, agents_list[agent_order + 1]
|
||
|
||
# stay asking when option is defined
|
||
if not (not option or option == '无'):
|
||
return agent_order, agents_list[agent_order]
|
||
|
||
return agent_order, None
|
||
|
||
|
||
def classify_options(knowledge):
|
||
must_ask_options = []
|
||
may_ask_options = []
|
||
for option, option_dict in knowledge.items():
|
||
if option_dict['必问'] == 1:
|
||
must_ask_options.append(option)
|
||
elif option_dict['必问'] == 0:
|
||
may_ask_options.append(option)
|
||
|
||
return must_ask_options, may_ask_options
|