Skip to content

Commit 941cb0e

Browse files
refactor
1 parent b49a275 commit 941cb0e

File tree

9 files changed

+1029
-25
lines changed

9 files changed

+1029
-25
lines changed

lagent/actions/mcp_client.py

Lines changed: 409 additions & 0 deletions
Large diffs are not rendered by default.

lagent/actions/web_visitor.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
import asyncio
2+
import json
3+
import re
4+
import traceback
5+
import warnings
6+
from typing import Any, List
7+
8+
from transformers import AutoTokenizer
9+
10+
from lagent.actions import AsyncActionMixin, BaseAction
11+
from lagent.schema import ActionStatusCode, ActionValidCode, AgentMessage
12+
from lagent.utils import create_object
13+
14+
15+
def extract_last_json(text: str) -> dict | None:
16+
"""
17+
Extracts the last valid JSON object from a string.
18+
Handles Markdown code blocks (```json ... ```) and raw JSON strings.
19+
"""
20+
try:
21+
# 1. Try to find JSON within Markdown code blocks first
22+
# Look for ```json ... ``` or just ``` ... ```
23+
code_block_pattern = re.compile(r'```(?:json)?\s*(\{.*?\})\s*```', re.DOTALL)
24+
matches = code_block_pattern.findall(text)
25+
if matches:
26+
return json.loads(matches[-1])
27+
28+
# 2. If no code blocks, try to find the last outermost pair of braces
29+
# This regex looks for { ... } lazily but we want the last one.
30+
# A simple approach for nested JSON is tricky with regex,
31+
# so we scan from right to left for the last '}' and find its matching '{'.
32+
33+
stack, end_idx = 0, -1
34+
# Reverse search to find the last valid JSON structure
35+
for i in range(len(text) - 1, -1, -1):
36+
char = text[i]
37+
if char == '}':
38+
if stack == 0:
39+
end_idx = i
40+
stack += 1
41+
elif char == '{':
42+
if stack > 0:
43+
stack -= 1
44+
if stack == 0 and end_idx != -1:
45+
# Found a potential outermost JSON object
46+
candidate = text[i : end_idx + 1]
47+
try:
48+
return json.loads(candidate)
49+
except json.JSONDecodeError:
50+
# If this chunk isn't valid, reset and keep searching backwards
51+
# (or you might decide to stop here depending on strictness)
52+
stack, end_idx = 0, -1
53+
return None
54+
except Exception:
55+
return None
56+
57+
58+
class WebVisitor(AsyncActionMixin, BaseAction):
59+
60+
EXTRACTION_PROMPT = """Please process the following webpage content and user goal to extract relevant information:
61+
62+
## **Webpage Content**
63+
{webpage_content}
64+
65+
## **User Goal**
66+
{goal}
67+
68+
## **Task Guidelines**
69+
1. **Content Scanning for Rationale**: Locate the **specific sections/data** directly related to the user's goal within the webpage content
70+
2. **Key Extraction for Evidence**: Identify and extract the **most relevant information** from the content, you never miss any important information, output the **full original context** of the content as far as possible, it can be more than three paragraphs.
71+
3. **Summary Output for Summary**: Organize into a concise paragraph with logical flow, prioritizing clarity and judge the contribution of the information to the goal.
72+
73+
**Final Output Format using JSON format has "rational", "evidence", "summary" feilds**
74+
"""
75+
76+
def __init__(
77+
self,
78+
browse_tool: BaseAction | dict,
79+
llm: Any,
80+
max_browse_attempts: int = 3,
81+
max_extract_attempts: int = 3,
82+
sleep_interval: int = 3,
83+
truncate_browse_response_length: int | None = None,
84+
tokenizer_path: str | None = None,
85+
name: str = 'visit',
86+
):
87+
super().__init__(
88+
description={
89+
'name': name,
90+
'description': 'Visit webpage(s) and return the summary of the content.',
91+
'parameters': [
92+
{
93+
'name': 'url',
94+
'type': ['STRING', 'ARRAY'],
95+
"items": {"type": "string"},
96+
"minItems": 1,
97+
'description': 'The URL(s) of the webpage(s) to visit. Can be a single URL or an array of URLs.',
98+
},
99+
{'name': 'goal', 'type': 'STRING', 'description': 'The goal of the visit for webpage(s).'},
100+
],
101+
'required': ['url', 'goal'],
102+
}
103+
)
104+
browse_tool = create_object(browse_tool)
105+
assert not browse_tool.is_toolkit and browse_tool.description['required'] == [
106+
'url'
107+
], "browse_tool must be a single-tool action with only 'url' as required argument."
108+
self.browse_tool = browse_tool
109+
self.llm = create_object(llm)
110+
self.max_browse_attempts = max_browse_attempts
111+
self.max_extract_attempts = max_extract_attempts
112+
self.sleep_interval = sleep_interval
113+
self.truncate_browse_response_length = truncate_browse_response_length
114+
self.tokenizer = (
115+
AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) if tokenizer_path else None
116+
)
117+
if self.truncate_browse_response_length is not None and self.tokenizer is None:
118+
warnings.warn(
119+
'truncate_browse_response_length is set but tokenizer_path is not provided. '
120+
'The raw webpage content will be truncated by characters instead of tokens.'
121+
)
122+
123+
async def run(self, url: str | List[str], goal: str) -> str:
124+
if isinstance(url, str):
125+
url = [url]
126+
127+
async def _inner_call(single_url: str) -> str:
128+
try:
129+
return await self._read_webpage(single_url, goal)
130+
except Exception as e:
131+
return f"Error fetching {single_url}: {str(e)}"
132+
133+
response = await asyncio.gather(*[_inner_call(single_url) for single_url in url])
134+
return "\n=======\n".join(response).strip()
135+
136+
async def _read_webpage(self, url: str, goal: str) -> str:
137+
tool_response = compressed = None
138+
return_template = (
139+
f"The useful information in {url} for user goal {goal} as follows: \n\n"
140+
f"Evidence in page: \n{{evidence}}\n\nSummary: \n{{summary}}\n\n"
141+
)
142+
for _ in range(self.max_browse_attempts):
143+
resp = await self.browse_tool({'url': url})
144+
if resp.valid == ActionValidCode.OPEN and resp.state == ActionStatusCode.SUCCESS:
145+
tool_response = resp.format_result()
146+
break
147+
await asyncio.sleep(self.sleep_interval)
148+
else:
149+
return return_template.format(
150+
evidence="The provided webpage content could not be accessed. Please check the URL or file format.",
151+
summary="The webpage content could not be processed, and therefore, no information is available.",
152+
)
153+
154+
if self.truncate_browse_response_length is not None:
155+
tool_response = (
156+
self.tokenizer.decode(
157+
self.tokenizer.encode(
158+
tool_response,
159+
max_length=self.truncate_browse_response_length,
160+
truncation=True,
161+
add_special_tokens=False,
162+
)
163+
)
164+
if self.tokenizer is not None
165+
else tool_response[: self.truncate_browse_response_length]
166+
)
167+
168+
for _ in range(self.max_extract_attempts):
169+
try:
170+
prompt = self.EXTRACTION_PROMPT.format(webpage_content=tool_response, goal=goal)
171+
llm_response = await self.llm.chat([{'role': 'user', 'content': prompt}])
172+
if llm_response and not isinstance(llm_response, str):
173+
llm_response = (
174+
llm_response.content
175+
if isinstance(llm_response, AgentMessage)
176+
else llm_response.choices[0].message.content
177+
)
178+
if not llm_response or len(llm_response) < 10:
179+
tool_response = tool_response[: int(len(tool_response) * 0.7)]
180+
continue
181+
compressed = extract_last_json(llm_response)
182+
if isinstance(compressed, dict) and all(
183+
key in compressed for key in ['rational', 'evidence', 'summary']
184+
):
185+
break
186+
except Exception:
187+
print(f"Error in extracting information: {traceback.format_exc()}")
188+
await asyncio.sleep(self.sleep_interval)
189+
else:
190+
return return_template.format(
191+
evidence="Failed to extract relevant information from the webpage content.",
192+
summary="The webpage content could not be processed, and therefore, no information is available.",
193+
)
194+
return return_template.format(evidence=compressed['evidence'], summary=compressed['summary'])

lagent/agents/agent.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> AgentMessa
7373
self.update_memory(message, session_id=session_id)
7474
response_message = self.forward(*message, session_id=session_id, **kwargs)
7575
if not isinstance(response_message, AgentMessage):
76-
response_message = AgentMessage(sender=self.name, content=response_message)
76+
response_message = AgentMessage.from_model_response(response_message, self.name)
7777
self.update_memory(response_message, session_id=session_id)
7878
response_message = copy.deepcopy(response_message)
7979
for hook in self._hooks.values():
@@ -158,6 +158,28 @@ def reset(self, session_id=0, keypath: Optional[str] = None, recursive: bool = F
158158
for agent in getattr(self, '_agents', {}).values():
159159
agent.reset(session_id, recursive=True)
160160

161+
def get_messages(self, session_id=0, keypath: Optional[str] = None) -> List[dict]:
162+
"""Get OpenAI format messages from memory.
163+
164+
Args:
165+
session_id (int): The session id of the memory.
166+
keypath (Optional[str]): The keypath of the sub-agent to get messages from. Default is None.
167+
168+
Returns:
169+
List[dict]: The messages from the memory including the sub-agent's system prompt.
170+
"""
171+
if keypath:
172+
keys, agent = keypath.split('.'), self
173+
for key in keys:
174+
agents = getattr(agent, '_agents', {})
175+
if key not in agents:
176+
raise KeyError(f'No sub-agent named {key} in {agent}')
177+
agent = agents[key]
178+
return agent.get_messages(session_id=session_id)
179+
if self.aggregator:
180+
return self.aggregator.aggregate(self.memory.get(session_id), self.name, self.output_format, self.template)
181+
raise ValueError(f'{self.name} has no aggregator to get messages')
182+
161183
def __repr__(self):
162184

163185
def _rcsv_repr(agent, n_indent=1):
@@ -186,7 +208,7 @@ async def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> Agen
186208
self.update_memory(message, session_id=session_id)
187209
response_message = await self.forward(*message, session_id=session_id, **kwargs)
188210
if not isinstance(response_message, AgentMessage):
189-
response_message = AgentMessage(sender=self.name, content=response_message)
211+
response_message = AgentMessage.from_model_response(response_message, self.name)
190212
self.update_memory(response_message, session_id=session_id)
191213
response_message = copy.deepcopy(response_message)
192214
for hook in self._hooks.values():
Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,42 @@
1-
from typing import Dict, List
1+
from typing import List
22

33
from lagent.memory import Memory
44
from lagent.prompts import StrParser
5+
from lagent.schema import ActionReturn
56

67

78
class DefaultAggregator:
89

9-
def aggregate(self,
10-
messages: Memory,
11-
name: str,
12-
parser: StrParser = None,
13-
system_instruction: str = None) -> List[Dict[str, str]]:
10+
def aggregate(self, messages: Memory, name: str, parser: StrParser = None, system_instruction=None) -> List[dict]:
1411
_message = []
1512
messages = messages.get_memory()
1613
if system_instruction:
17-
_message.extend(
18-
self.aggregate_system_intruction(system_instruction))
14+
_message.extend(self.aggregate_system_intruction(system_instruction))
1915
for message in messages:
2016
if message.sender == name:
21-
_message.append(
22-
dict(role='assistant', content=str(message.content)))
17+
_message.append(message.to_model_request())
2318
else:
24-
user_message = message.content
25-
if len(_message) > 0 and _message[-1]['role'] == 'user':
26-
_message[-1]['content'] += user_message
19+
user_message, extra_info = message.content, message.extra_info
20+
if isinstance(user_message, list):
21+
for m in user_message:
22+
if isinstance(m, dict):
23+
m = ActionReturn(**m)
24+
assert isinstance(m, ActionReturn), f"Expected m to be ActionReturn, but got {type(m)}"
25+
_message.append(
26+
dict(
27+
role='tool',
28+
tool_call_id=m.tool_call_id,
29+
content=m.format_result(),
30+
name=m.type,
31+
extra_info=extra_info,
32+
)
33+
)
2734
else:
28-
_message.append(dict(role='user', content=user_message))
35+
if len(_message) > 0 and _message[-1]['role'] == 'user':
36+
_message[-1]['content'] += user_message
37+
_message[-1]['extra_info'] = extra_info
38+
else:
39+
_message.append(dict(role='user', content=user_message, extra_info=extra_info))
2940
return _message
3041

3142
@staticmethod
@@ -39,6 +50,5 @@ def aggregate_system_intruction(system_intruction) -> List[dict]:
3950
if not isinstance(msg, dict):
4051
raise TypeError(f'Unsupported message type: {type(msg)}')
4152
if not ('role' in msg and 'content' in msg):
42-
raise KeyError(
43-
f"Missing required key 'role' or 'content': {msg}")
53+
raise KeyError(f"Missing required key 'role' or 'content': {msg}")
4454
return system_intruction

0 commit comments

Comments
 (0)