shuqianpinggu/src/agents.py

126 lines
4.8 KiB
Python
Raw Normal View History

2025-06-17 17:46:44 +08:00
import json
from src.patient import Patient
from src.chat import llm_answer
from src.utils import read_knowledge_json, read_rule_json
from src.prompts import *
from typing import List
class Agent:
def __init__(self, task, knowledge, rule):
self.task = task
self.knowledge = knowledge
self.rule = rule
def choose_option(self, history, patient: Patient, logger):
recorded_info = patient.recorded_info
must_ask_options, may_ask_options = classify_options(self.knowledge)
if must_ask_options:
return must_ask_options[0]
if not must_ask_options and not may_ask_options:
return ''
user_content = info_combination.format(str(history[-30:]), recorded_info, str(may_ask_options))
logger.info('***CHOOSE OPTION【1】***:\n' + user_content.strip())
option = llm_answer(system_content=self.task, user_content=user_content)
# filter llm answer
if option not in self.knowledge.keys():
return ''
else:
return option
def update_knowledge(self, option, info):
if info and option in self.knowledge.keys():
self.knowledge[option]["必问"] = -1
def process_query(self, option, patient, history, logger, overall=False):
# history --> info
# question = patient.history[-1]['content']
print(999,patient.recorded_info, self.rule[option], option, history[-30:])
user_content = extract_info_content.format(patient.recorded_info, self.rule[option], option, history[-30:])
logger.info('***PROCESS QUERY【3】***:\n总体评估({}'.format(overall) + user_content.strip())
if not overall:
llm_respond = llm_answer(extract_info_task, user_content=user_content)
else:
llm_respond = llm_answer(extract_overall_info_task, user_content=user_content)
print(llm_answer,)
llm_respond = llm_respond.replace('json', '').replace('```', '').strip()
try:
info = json.loads(llm_respond)
except:
logger.error('json loads error, llm respond: ' + llm_respond)
return None
return info if 'status' in info.keys() else None
def all_asked(self):
for _, option in self.knowledge.items():
if option['必问'] != -1:
return False
return True
class CirculatorySystemAgent(Agent):
def __init__(self, task, knowledge, rule):
super().__init__(task, knowledge, rule)
class RespiratorySystemAgent(Agent):
def __init__(self, task, knowledge, rule):
super().__init__(task, knowledge, rule)
class NervousSystemAgent(Agent):
def __init__(self, task, knowledge, rule):
super().__init__(task, knowledge, rule)
def init_multi_agents(agent_names):
multi_agents = []
agent_name0 = agent_names[0]
multi_agents.append(CirculatorySystemAgent(circulatory_system_task, read_knowledge_json(agent_name0),
read_rule_json(agent_name0)))
agent_name1 = agent_names[1]
multi_agents.append(RespiratorySystemAgent(respiratory_system_task, read_knowledge_json(agent_name1),
read_rule_json(agent_name1)))
agent_name2 = agent_names[2]
multi_agents.append(NervousSystemAgent(nervous_system_task, read_knowledge_json(agent_name2),
read_rule_json(agent_name2)))
return multi_agents
def doctor_state(history, patient: Patient, option, q_templates, missing_info, logger):
if '预问诊' in option and not missing_info:
return q_templates[0]
user_content = consultation_content.format(patient.recorded_info, missing_info, option, q_templates, history[-30:])
logger.info('***ASK QUESTION【2】***:\n' + user_content.strip())
statement = llm_answer(consultation_task, user_content=user_content, temperature=0.8, max_tokens=256, stream=True)
return statement
def choose_agent(option, agent_order, agents_list: List[Agent]):
# go next agent when 1. first agent 2. unnecessary to ask 3. all options asked
if (not option or option == '' or agents_list[agent_order].all_asked()) and agent_order < len(
agents_list) - 1:
return agent_order + 1, agents_list[agent_order + 1]
# stay asking when option is defined
if not (not option or option == ''):
return agent_order, agents_list[agent_order]
return agent_order, None
def classify_options(knowledge):
must_ask_options = []
may_ask_options = []
for option, option_dict in knowledge.items():
if option_dict['必问'] == 1:
must_ask_options.append(option)
elif option_dict['必问'] == 0:
may_ask_options.append(option)
return must_ask_options, may_ask_options