多代理系統 – 主管模式範例代碼

此範例遵循監督器代理程式樣式,而且是使用 SQLite 的資料代理程式。SQL 函數可以由 SQL 工具取代,並運用 Oracle 資料庫。例如,這使用 SQLite 來進行示範,因為它立即可用。

嘗試:
  • EMEA 的總收入是多少?
  • EMEA 的收入與 2024 年 APAC 的比較如何?
此代理程式流程的設計如下:
  • 路由器
    • → SQL 代理程式
      • → 比較代理
        • → 洞察分析代理程式
        • → 最終
      • → 最終
    • → 其他代理 → 最終
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langgraph.graph import StateGraph, MessagesState, START, END
import sqlite3
from typing import TypedDict, Literal
from aidputils.agents.toolkit.agent_helper import init_oci_llm, pre_invoke_setup
from aidputils.agents.toolkit.configs import AIDPToolConf, OCIAIConf, ModelArgs
from typing import TypedDict, Literal
from typing import List
from langchain_core.messages import BaseMessage

##  Replace compartment id and endpoint
##
compartment_id = '<your-compartment-ocid>'
endpoint = 'https://inference.generativeai.<oci-region>.oci.oraclecloud.com'
####

checkpointer = globals().get("checkpointer", None)
conn = sqlite3.connect("file::memory:?cache=shared", uri=True)

def setup_db():
  cur = conn.cursor()

  cur.execute("""
    CREATE TABLE IF NOT EXISTS orders (
    order_id INTEGER,
    region TEXT,
    revenue REAL,
    order_date TEXT
  )
  """)

  cur.executemany(
    "INSERT INTO orders VALUES (?, ?, ?, ?)",
    [
        (1, "EMEA", 1200, "2024-10-01"),
        (2, "EMEA", 900, "2024-10-02"),
        (3, "AMER", 1500, "2024-10-01"),
        (4, "EMEA", 400, "2024-10-03"),
        (5, "APAC", 800, "2024-10-02"),
        (6, "EMEA", 1400, "2025-10-01"),
        (7, "AMER", 1200, "2025-10-01"),
        (8, "EMEA", 900, "2025-10-03"),
        (9, "APAC", 300, "2025-10-02"),    ],
  )

  conn.commit()
  #conn.close()

class State(TypedDict):
    question: str
    route: Literal["sql", "other"]
    sql: str
    rows: list
    content: str
    comparison: str
    insight: str
    messages: List[BaseMessage]

model_args = {}
guardrails_config = {
    "name" : "Default Guardrails",
    "description" : "Default empty guardrails configuration",
    "policies" : [ ]
  }

llm_conf = OCIAIConf(model_provider='generic',
                     compartment_id='<your-compartment-ocid>',
                     model_args=model_args,
                     endpoint=endpoint,
                     model_id='xai.grok-4',
                     guardrails_config=guardrails_config)
llm = init_oci_llm(llm_conf)


def supervisor(state: State) -> State:
    messages = [
        HumanMessage(
            content=(
                "Decide whether the following question requires SQL analysis.\n\n"
                "Respond with ONLY one word:\n"
                "- sql\n"
                "- other\n\n"
                f"Question:\n{state['question']}"
            )
        )
    ]

    response: AIMessage = llm.invoke(messages)
    route = response.content.strip().lower()

    return {"route": route}


def other_agent(state: State) -> State:
    response: AIMessage = llm.invoke(
        [HumanMessage(content=state["question"])]
    )
    return {"content": response.content.strip()}

def final(state: State) -> State:
    messages = state.get("messages", [])

    combined_answer = state["content"]
    print(state.get("insight"))
    if state.get("insight") and state["insight"] != "No additional insight.":
        combined_answer += f"\n\nInsight: {state['insight']}"

    messages.append(HumanMessage(content=state["question"]))
    messages.append(AIMessage(content=combined_answer))

    return {
        **state,
        "messages": messages,
    }



def execute_sql(query: str):
    conn = sqlite3.connect("file::memory:?cache=shared", uri=True)
    cur = conn.cursor()
    cur.execute(query)

    columns = [desc[0] for desc in cur.description]
    rows = cur.fetchall()

    conn.close()

    return columns, rows



def sql_agent(state: State) -> State:
    # 1. Generate SQL
    sql_messages = [
        HumanMessage(
            content=(
                "You are a senior data analyst.\n\n"
                "Database schema:\n"
                "orders(order_id, region, revenue, order_date)\n\n"
                "Write a SQLite-compatible SQL query that contents the question below.\n"
                "Return ONLY the SQL starting with the SELECT statement.\n\n"
                f"Question:\n{state['question']}"
            )
        )
    ]

    sql_response: AIMessage = llm.invoke(sql_messages)
    sql = sql_response.content.strip()

    # 2. Execute SQL
    columns, rows = execute_sql(sql)
    results = [dict(zip(columns, row)) for row in rows]

    # 3. Format content (LLM owns presentation)
    format_messages = [
        HumanMessage(
            content=(
                "You are a professional analytics assistant.\n\n"
                "The following data is the FINAL, correct result.\n\n"
                f"Data:\n{results}\n\n"
                f"User Question:\n{state['question']}\n\n"
                "Formatting Rules:\n"
                "- Format entire response using Markdown syntax. Include headers, bold text, and a bulleted list\n"
                "- Start with a short headline (max 12 words)\n"
                "- On the next line, give a complete sentence contenting the question\n"
                "- Clearly state the numeric value\n"
                "- Use commas in numbers\n"
                "- Do NOT mention SQL, databases, tables, or queries\n"
                "- Do NOT explain how the data was obtained\n"
                "- Do NOT add disclaimers\n\n"
                "Output Format (EXACT):\n"
                "<Headline>\n"
                "<Sentence>"

            )
        )
    ]

    formatted_response: AIMessage = llm.invoke(format_messages)
    content = formatted_response.content.strip()

    return { 
        "sql": sql,
        "rows": results,
        "content": content,
   }


def comparison_agent(state: State) -> State:
    rows = state.get("rows", [])

    messages = [
        HumanMessage(
            content=(
                "You are a business analyst.\n\n"
                "Analyze the result data and produce a comparative insight.\n\n"
                f"Result Data:\n{rows}\n\n"
                "Rules:\n"
                "- Format entire response using Markdown syntax. Include headers, bold text, and a bulleted list\n"
                "- If multiple rows exist, identify the highest, lowest, or notable difference\n"
                "- If only one value exists, explain what it represents and how it could be compared\n"
                "- Do NOT mention SQL or databases\n"
                "- One concise sentence\n"
                "- Never say 'no additional insight'\n"
            )
        )
    ]

    response: AIMessage = llm.invoke(messages)

    return {
        "comparison": response.content.strip()
    }

def insight_agent(state: State) -> State:
    messages = [
        HumanMessage(
            content=(
                "You are a senior analytics advisor.\n\n"
                f"User Question:\n{state['question']}\n\n"
                f"Comparison Insight:\n{state.get('comparison', '')}\n\n"
                "Turn this into a clear business insight.\n"
                "- One sentence\n"
                "- No speculation\n"
                "- No technical language\n"
            )
        )
    ]

    response: AIMessage = llm.invoke(messages)

    return {
        "insight": response.content.strip()
    }



class AgentBasic:
  def __init__(self) -> None:
    self.graph = None

  def setup(self) -> None:
      setup_db()
      builder = StateGraph(State)

      builder.add_node("supervisor", supervisor)
      builder.add_node("sql_agent", sql_agent)
      builder.add_node("comparison_agent", comparison_agent)
      builder.add_node("insight_agent", insight_agent)
      builder.add_node("other_agent", other_agent)
      builder.add_node("final", final)
      
      builder.set_entry_point("supervisor")

      builder.add_conditional_edges(
        "supervisor",
        lambda s: s["route"],
        {
          "sql": "sql_agent",
          "other": "other_agent",
        },
      )

      builder.add_edge("sql_agent", "comparison_agent")
      builder.add_edge("comparison_agent", "insight_agent")
      builder.add_edge("insight_agent", "final")
      builder.add_edge("other_agent", "final")
      builder.add_edge("final", END)

      if checkpointer:
        self.graph = builder.compile(checkpointer= checkpointer)
      else:
        self.graph = builder.compile()



  async def invoke(self, user_query: str, **kwargs):
    config = pre_invoke_setup(**kwargs)
    initial_state = {
        "question": user_query
    }
    try:
      return await self.graph.ainvoke(initial_state, config=config)
    except Exception as e:
      import traceback
      #logger.error(f"Exception while calling invoke {e}", exc_info=True)
      print("Stack trace:\n", traceback.format_exc())

import asyncio

async def main():
    test_agent = AgentBasic()
    test_agent.setup()
    result = await test_agent.invoke("What was the total revenue in EMEA?")
    print("Agent response:", result)
if __name__ == "__main__":
    asyncio.run(main())