shuqianpinggu/src/agents.py
WIN-T9V3LPE8NK2\Administrator 4911f3024f 主分支
2025-06-17 17:46:44 +08:00

126 lines
4.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
from src.patient import Patient
from src.chat import llm_answer
from src.utils import read_knowledge_json, read_rule_json
from src.prompts import *
from typing import List
class Agent:
def __init__(self, task, knowledge, rule):
self.task = task
self.knowledge = knowledge
self.rule = rule
def choose_option(self, history, patient: Patient, logger):
recorded_info = patient.recorded_info
must_ask_options, may_ask_options = classify_options(self.knowledge)
if must_ask_options:
return must_ask_options[0]
if not must_ask_options and not may_ask_options:
return ''
user_content = info_combination.format(str(history[-30:]), recorded_info, str(may_ask_options))
logger.info('***CHOOSE OPTION【1】***:\n' + user_content.strip())
option = llm_answer(system_content=self.task, user_content=user_content)
# filter llm answer
if option not in self.knowledge.keys():
return ''
else:
return option
def update_knowledge(self, option, info):
if info and option in self.knowledge.keys():
self.knowledge[option]["必问"] = -1
def process_query(self, option, patient, history, logger, overall=False):
# history --> info
# question = patient.history[-1]['content']
print(999,patient.recorded_info, self.rule[option], option, history[-30:])
user_content = extract_info_content.format(patient.recorded_info, self.rule[option], option, history[-30:])
logger.info('***PROCESS QUERY【3】***:\n总体评估({}'.format(overall) + user_content.strip())
if not overall:
llm_respond = llm_answer(extract_info_task, user_content=user_content)
else:
llm_respond = llm_answer(extract_overall_info_task, user_content=user_content)
print(llm_answer,)
llm_respond = llm_respond.replace('json', '').replace('```', '').strip()
try:
info = json.loads(llm_respond)
except:
logger.error('json loads error, llm respond: ' + llm_respond)
return None
return info if 'status' in info.keys() else None
def all_asked(self):
for _, option in self.knowledge.items():
if option['必问'] != -1:
return False
return True
class CirculatorySystemAgent(Agent):
def __init__(self, task, knowledge, rule):
super().__init__(task, knowledge, rule)
class RespiratorySystemAgent(Agent):
def __init__(self, task, knowledge, rule):
super().__init__(task, knowledge, rule)
class NervousSystemAgent(Agent):
def __init__(self, task, knowledge, rule):
super().__init__(task, knowledge, rule)
def init_multi_agents(agent_names):
multi_agents = []
agent_name0 = agent_names[0]
multi_agents.append(CirculatorySystemAgent(circulatory_system_task, read_knowledge_json(agent_name0),
read_rule_json(agent_name0)))
agent_name1 = agent_names[1]
multi_agents.append(RespiratorySystemAgent(respiratory_system_task, read_knowledge_json(agent_name1),
read_rule_json(agent_name1)))
agent_name2 = agent_names[2]
multi_agents.append(NervousSystemAgent(nervous_system_task, read_knowledge_json(agent_name2),
read_rule_json(agent_name2)))
return multi_agents
def doctor_state(history, patient: Patient, option, q_templates, missing_info, logger):
if '预问诊' in option and not missing_info:
return q_templates[0]
user_content = consultation_content.format(patient.recorded_info, missing_info, option, q_templates, history[-30:])
logger.info('***ASK QUESTION【2】***:\n' + user_content.strip())
statement = llm_answer(consultation_task, user_content=user_content, temperature=0.8, max_tokens=256, stream=True)
return statement
def choose_agent(option, agent_order, agents_list: List[Agent]):
# go next agent when 1. first agent 2. unnecessary to ask 3. all options asked
if (not option or option == '' or agents_list[agent_order].all_asked()) and agent_order < len(
agents_list) - 1:
return agent_order + 1, agents_list[agent_order + 1]
# stay asking when option is defined
if not (not option or option == ''):
return agent_order, agents_list[agent_order]
return agent_order, None
def classify_options(knowledge):
must_ask_options = []
may_ask_options = []
for option, option_dict in knowledge.items():
if option_dict['必问'] == 1:
must_ask_options.append(option)
elif option_dict['必问'] == 0:
may_ask_options.append(option)
return must_ask_options, may_ask_options