-
Notifications
You must be signed in to change notification settings - Fork 19
Expand file tree
/
Copy pathswarm_overwatch.py
More file actions
739 lines (646 loc) · 30.7 KB
/
swarm_overwatch.py
File metadata and controls
739 lines (646 loc) · 30.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
"""
Swarm Overwatch - Real-time monitoring interface for rollout pool information
Similar to htop/nvidia-smi for AgentJet swarm servers
"""
import time
from datetime import datetime
from typing import Optional
import httpx
import numpy as np
from rich.console import Console
from rich.live import Live
from rich.table import Table
from rich.panel import Panel
from rich.layout import Layout
from rich.text import Text
from loguru import logger
from ajet.tuner_lib.experimental.swarm_overwatch_utils import (
CurrentBatchRolloutPoolInformation,
RewardHistoryResponse,
)
class SwarmOverwatch:
"""Real-time monitoring interface for swarm rollout pool"""
def __init__(self, server_url: str, refresh_interval: float = 2.0):
"""
Initialize the overwatch monitor
Args:
server_url: Base URL of the swarm server (e.g., http://localhost:10086)
refresh_interval: Refresh interval in seconds (default: 2.0)
"""
self.server_url = server_url.rstrip("/")
self.refresh_interval = refresh_interval
self.console = Console()
self.last_update_time = None
self.error_count = 0
self.total_requests = 0
self._httpx_client = httpx.Client(timeout=5.0)
def fetch_pool_info(self) -> Optional[CurrentBatchRolloutPoolInformation]:
"""Fetch current batch rollout pool information from server"""
try:
response = self._httpx_client.get(
f"{self.server_url}/get_current_batch_rollout_pool_information",
timeout=5.0,
)
response.raise_for_status()
self.total_requests += 1
self.last_update_time = datetime.now()
data = CurrentBatchRolloutPoolInformation.model_validate(response.json())
return data
except Exception as e:
self.error_count += 1
# logger.error(f"Failed to fetch pool info: {e}")
return None
def fetch_reward_history(self) -> Optional[RewardHistoryResponse]:
"""Fetch reward history from server for visualization"""
try:
response = self._httpx_client.get(
f"{self.server_url}/get_reward_history",
timeout=5.0,
)
response.raise_for_status()
data = RewardHistoryResponse.model_validate(response.json())
return data
except Exception as e:
logger.error(f"Failed to fetch reward history: {e}")
return None
def create_header(
self, info: Optional[CurrentBatchRolloutPoolInformation] = None
) -> Panel:
"""Create header panel with server info"""
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
last_update = (
self.last_update_time.strftime("%H:%M:%S")
if self.last_update_time
else "Never"
)
header_text = Text()
header_text.append("AgentJet Swarm Overwatch", style="bold cyan")
header_text.append(f"\nServer: {self.server_url}", style="dim")
header_text.append(f"\nCurrent Time: {now}", style="green")
header_text.append(f" | Last Update: {last_update}", style="yellow")
header_text.append(f" | Refresh: {self.refresh_interval}s", style="blue")
header_text.append(f"\nRequests: {self.total_requests}", style="magenta")
# header_text.append(
# f" | Errors: {self.error_count}",
# style="red" if self.error_count > 0 else "green",
# )
# Add engine status and global step if available
if info:
if info.engine_status:
header_text.append(
f"\nEngine Status: {info.engine_status}", style="bold yellow"
)
# Show booting time if engine is booting
if info.engine_status == "ENGINE.BOOTING" and info.booting_start_time is not None:
booting_duration = int(time.time() - info.booting_start_time)
header_text.append(
f" | Booting Time: {booting_duration}s", style="bold cyan"
)
if info.global_step is not None:
header_text.append(
f" | Global Step (Model's Weight Version): {info.global_step:,}", style="bold blue"
)
return Panel(header_text, border_style="bright_blue", padding=(0, 1))
def create_progress_bar(self, current: int, target: int, title: str) -> tuple:
"""Create a progress bar representation"""
if target == 0:
percentage = 0.0
else:
percentage = (current / target) * 100
return current, target, percentage
def create_summary_table(self, info: CurrentBatchRolloutPoolInformation) -> Table:
"""Create summary statistics table"""
# Determine border style based on engine status
is_active = info.engine_status in ["ENGINE.ROLLING", "ENGINE.ROLLING_POST"]
border_style = "blue" if is_active else "dim"
title_prefix = "" if is_active else "[WAITING ENGINE.ROLLING] "
table = Table(
title=f"{title_prefix}Completed Episode Pool Summary (Progress to Hit Next Weight Update)",
show_header=True,
header_style="bold magenta",
border_style=border_style,
expand=True,
)
table.add_column("Metric", style="cyan", width=30)
table.add_column("Current", justify="right", style="green", width=15)
table.add_column("Target", justify="right", style="yellow", width=15)
table.add_column("Progress", justify="right", style="blue", width=15)
table.add_column("Bar", width=30)
# Determine which row to highlight based on sample_collection_method
highlight_episodes = (
info.sample_collection_method == "rollout_until_finish_enough_episodes"
)
highlight_tasks = (
info.sample_collection_method == "rollout_until_finish_enough_tasks"
)
highlight_non_dummy = (
info.sample_collection_method
== "rollout_until_finish_enough_non_dummy_tasks"
)
# Episodes
ep_cur, ep_tgt, ep_pct = self.create_progress_bar(
info.completed_episodes, info.completed_episode_target, "Completed Episodes"
)
ep_bar = self._create_text_bar(ep_pct)
ep_metric = (
"-> *Completed Episodes (chosen)*"
if highlight_episodes
else "Completed Episodes"
)
ep_style = "bold green" if highlight_episodes else None
table.add_row(
f"[{ep_style}]{ep_metric}[/{ep_style}]"
if highlight_episodes
else ep_metric,
f"{ep_cur:,}",
f"{ep_tgt:,}",
f"{ep_pct:.1f}%",
ep_bar,
style=ep_style if highlight_episodes else None,
)
# Tasks
task_cur, task_tgt, task_pct = self.create_progress_bar(
info.completed_tasks, info.completed_task_target, "Completed Tasks"
)
task_bar = self._create_text_bar(task_pct)
task_metric = (
"-> *Completed Tasks (chosen)*" if highlight_tasks else "Completed Tasks"
)
task_style = "bold green" if highlight_tasks else None
table.add_row(
f"[{task_style}]{task_metric}[/{task_style}]"
if highlight_tasks
else task_metric,
f"{task_cur:,}",
f"{task_tgt:,}",
f"{task_pct:.1f}%",
task_bar,
style=task_style if highlight_tasks else None,
)
# Non-dummy tasks
nd_cur, nd_tgt, nd_pct = self.create_progress_bar(
info.completed_non_dummy_tasks,
info.completed_non_dummy_task_target,
"Completed Non-Dummy Tasks",
)
nd_bar = self._create_text_bar(nd_pct)
nd_metric = (
"-> *Completed Non-Dummy Tasks (chosen)*"
if highlight_non_dummy
else "Completed Non-Dummy Tasks"
)
nd_style = "bold green" if highlight_non_dummy else None
table.add_row(
f"[{nd_style}]{nd_metric}[/{nd_style}]"
if highlight_non_dummy
else nd_metric,
f"{nd_cur:,}",
f"{nd_tgt:,}",
f"{nd_pct:.1f}%",
nd_bar,
style=nd_style if highlight_non_dummy else None,
)
# Average episodes per task
if info.completed_tasks_details:
episodes_per_task = [len(episode_list) for episode_list in info.completed_tasks_details.values()]
avg_episodes_per_task = sum(episodes_per_task) / len(episodes_per_task) if episodes_per_task else 0.0
else:
avg_episodes_per_task = 0.0
table.add_row(
"Average Episode Per Task",
f"{avg_episodes_per_task:.2f}",
f"{info.task_expected_num_repeat:,}",
"-",
"-"
)
return table
def _create_text_bar(self, percentage: float, width: int = 20) -> str:
"""Create a text-based progress bar"""
filled = int((percentage / 100) * width)
bar = "█" * filled + "░" * (width - filled)
return f"[{'green' if percentage >= 100 else 'yellow'}]{bar}[/]"
def create_running_episodes_table(
self, info: CurrentBatchRolloutPoolInformation
) -> Table:
"""Create running episodes table"""
# Determine border style based on engine status
is_active = info.engine_status in ["ENGINE.ROLLING", "ENGINE.ROLLING_POST"]
border_style = "blue" if is_active else "dim"
title_prefix = "" if is_active else "[WAITING ENGINE.ROLLING] "
# Add total counts to title if running episodes exist
title = f"{title_prefix}Running Episodes"
if info.running_episode_details:
num_episodes = len(info.running_episode_details)
# Count unique tasks from running episodes
unique_tasks = set()
for details in info.running_episode_details.values():
task_id = details.get("task_id")
if task_id:
unique_tasks.add(task_id)
num_tasks = len(unique_tasks)
title = f"{title_prefix}Running Episodes (Episodes: {num_episodes})"
table = Table(
title=title,
show_header=True,
header_style="bold magenta",
border_style=border_style,
expand=True,
)
table.add_column("Episode UUID", style="cyan", no_wrap=True, width=20, overflow="ellipsis")
table.add_column("Status", style="green")
table.add_column("LLM Calls", style="magenta", justify="right")
table.add_column("Last Req / Patience", style="yellow")
if not info.running_episode_details:
table.add_row("[dim]No running episodes[/dim]", "", "", "")
return table
# Sort by time since last activity (descending)
sorted_episodes = sorted(
info.running_episode_details.items(),
key=lambda x: float(x[1]["time_since_last_activity"].rstrip("s")),
reverse=True,
)
for episode_uuid, details in sorted_episodes[:30]:
last_req = details["time_since_last_activity"]
patience = details.get("discard_episode_timeout", "N/A")
llm_calls = details.get("llm_call_count", "0")
table.add_row(
episode_uuid[:40] if len(episode_uuid) > 40 else episode_uuid,
details["episode_status"],
llm_calls,
f"{last_req} / {patience}",
)
if len(sorted_episodes) > 30:
table.add_row(
f"[dim]... and {len(sorted_episodes) - 30} more episodes[/dim]", "", "", ""
)
return table
def create_task_details_table(
self, info: CurrentBatchRolloutPoolInformation
) -> Table:
"""Create detailed task completion table"""
# Determine border style based on engine status
is_active = info.engine_status in ["ENGINE.ROLLING", "ENGINE.ROLLING_POST"]
border_style = "blue" if is_active else "dim"
title_prefix = "" if is_active else "[WAITING ENGINE.ROLLING] "
table = Table(
title=f"{title_prefix}Task Completion Details",
show_header=True,
header_style="bold magenta",
border_style=border_style,
expand=True,
)
table.add_column("Task ID", style="cyan", no_wrap=True, overflow="ellipsis")
table.add_column("Episodes", justify="right", style="green")
table.add_column("Reward", justify="right", style="yellow")
table.add_column("Episode UUIDs (first 3)", style="dim", overflow="fold")
if not info.completed_tasks_details:
table.add_row("[dim]No task details available[/dim]", "", "", "")
return table
# Sort tasks by number of completed episodes (descending)
sorted_tasks = sorted(
info.completed_tasks_details.items(), key=lambda x: len(x[1]), reverse=True
)
# Limit to top 30 tasks to fit in terminal
for task_id, episode_uuids in sorted_tasks[:30]:
# Show first 3 episode UUIDs
preview_uuids = episode_uuids[:3]
uuid_str = ", ".join([f"{uuid[:8]}..." for uuid in preview_uuids])
if len(episode_uuids) > 3:
uuid_str += f" (+{len(episode_uuids) - 3} more)"
# Calculate reward statistics
reward_str = "-"
if info.completed_tasks_rewards and task_id in info.completed_tasks_rewards:
rewards = info.completed_tasks_rewards[task_id]
if rewards:
mean_reward = np.mean(rewards)
std_reward = np.std(rewards)
reward_str = f"{mean_reward:.3f} ± {std_reward:.3f}"
table.add_row(
task_id[:40] if len(task_id) > 40 else task_id,
f"{len(episode_uuids):,}",
reward_str,
uuid_str,
)
if len(sorted_tasks) > 30:
table.add_row(
f"[dim]... and {len(sorted_tasks) - 30} more tasks[/dim]", "", "", ""
)
return table
def create_logo_panel(self, info: CurrentBatchRolloutPoolInformation) -> Text:
"""Create logo display for OFFLINE and BOOTING states"""
logo = """
█████╗ ██████╗ ███████╗███╗ ██╗████████╗ ██╗███████╗████████╗
██╔══██╗██╔════╝ ██╔════╝████╗ ██║╚══██╔══╝ ██║██╔════╝╚══██╔══╝
███████║██║ ███╗█████╗ ██╔██╗ ██║ ██║ ██║█████╗ ██║
██╔══██║██║ ██║██╔══╝ ██║╚██╗██║ ██║ ██ ██║██╔══╝ ██║
██║ ██║╚██████╔╝███████╗██║ ╚████║ ██║ ╚█████╔╝███████╗ ██║
╚═╝ ╚═╝ ╚═════╝ ╚══════╝╚═╝ ╚═══╝ ╚═╝ ╚════╝ ╚══════╝ ╚═╝
"""
content = Text()
content.append(logo, style="bold cyan")
content.append("\n\n")
if info.engine_status == "ENGINE.OFFLINE":
content.append("Engine Status: ", style="bold white")
content.append("OFFLINE", style="bold yellow")
content.append("\n\n")
content.append("Ready to accept commands:\n", style="bold green")
content.append(f" • swarm_client = SwarmClient(\"{self.server_url}\")\n", style="cyan")
content.append(" • swarm_client.sync_train_config()\n", style="cyan")
content.append(" • swarm_client.start_engine()\n", style="cyan")
content.append("\n")
content.append("Please sync your training configuration and start the engine to begin rollouts.", style="dim")
elif info.engine_status == "ENGINE.BOOTING":
content.append("Engine Status: ", style="bold white")
content.append("BOOTING", style="bold yellow")
if info.booting_start_time is not None:
booting_duration = int(time.time() - info.booting_start_time)
content.append(f" ({booting_duration}s)", style="bold cyan")
content.append("\n\n")
content.append("Engine is booting...\n", style="bold green")
content.append("Please wait, we are loading model weights and ray workers, this will take a few minutes.\n", style="dim")
return content
def create_dashboard(
self, info: Optional[CurrentBatchRolloutPoolInformation], init=False
) -> Layout:
"""Create the main dashboard layout"""
layout = Layout()
# Create header
header = self.create_header(info)
if (info is None) and (not init):
# Show error state
error_panel = Panel(
"[bold red]Failed to fetch data from server, please check your connection or simply wait a moment...[/bold red]\n"
f"[dim]Attempted to connect to: {self.server_url}[/dim]\n",
border_style="red",
padding=(1, 2),
)
layout.split_column(Layout(header, size=8), Layout(error_panel))
elif (info is None) and (init):
# Initial state before first successful data fetch
welcome_panel = Panel(
"[bold green]Welcome to AgentJet Swarm Overwatch![/bold green]\n\n"
"Attempting to connect to server and fetch data...\n"
f"[dim]Target server: {self.server_url}[/dim]\n",
border_style="green",
padding=(1, 2),
)
layout.split_column(Layout(header, size=8), Layout(welcome_panel))
else:
# Check engine status and show logo for OFFLINE or BOOTING states
assert info is not None # for type checker
if info.engine_status in ["ENGINE.OFFLINE", "ENGINE.BOOTING"]:
# Hide tables and show logo
logo_display = self.create_logo_panel(info)
layout.split_column(
Layout(header, size=8),
Layout(logo_display),
)
else:
# Show data tables for other engine states
summary = self.create_summary_table(info)
running_episodes = self.create_running_episodes_table(info)
details = self.create_task_details_table(info)
# Create a horizontal layout for running episodes and task details
bottom_row = Layout()
bottom_row.split_row(
Layout(running_episodes, name="running"),
Layout(details, name="details"),
)
layout.split_column(
Layout(header, size=8),
Layout(summary, size=12),
Layout(bottom_row),
)
return layout
def display_reward_curve(self):
"""Display ASCII reward curve in terminal"""
self.console.clear()
# Fetch reward history
history = self.fetch_reward_history()
if history is None or not history.history:
self.console.print("[bold yellow]No reward history available yet.[/bold yellow]")
self.console.print("[dim]Reward history is recorded when training completes batches with rewards.[/dim]")
self.console.print("\n[dim]Press Enter to return to menu...[/dim]")
input()
return
# Get terminal size
terminal_width = self.console.width or 80
terminal_height = self.console.height or 24
# Reserve space for header, labels, and footer
chart_width = min(terminal_width - 15, 120) # Reserve space for y-axis labels
chart_height = min(terminal_height - 10, 30) # Reserve space for header and x-axis
# Extract data
global_steps = [entry.global_step for entry in history.history]
mean_rewards = [entry.mean_reward for entry in history.history]
# Calculate y-axis range with padding
y_min = min(mean_rewards)
y_max = max(mean_rewards)
y_range = y_max - y_min
if y_range == 0:
y_range = 1.0 # Avoid division by zero
y_min -= 0.5
y_max += 0.5
else:
# Add 10% padding
y_min -= y_range * 0.1
y_max += y_range * 0.1
y_range = y_max - y_min
# Calculate x-axis range
x_min = min(global_steps)
x_max = max(global_steps)
x_range = x_max - x_min
if x_range == 0:
x_range = 1
# Create the chart grid
chart = [[' ' for _ in range(chart_width)] for _ in range(chart_height)]
# Plot the data points
for i, (step, reward) in enumerate(zip(global_steps, mean_rewards)):
# Map to chart coordinates
x = int((step - x_min) / x_range * (chart_width - 1)) if x_range > 0 else 0
y = int((reward - y_min) / y_range * (chart_height - 1)) if y_range > 0 else 0
# Invert y because terminal coordinates go top-down
y = chart_height - 1 - y
# Clamp to valid range
x = max(0, min(chart_width - 1, x))
y = max(0, min(chart_height - 1, y))
# Draw point
chart[y][x] = '*'
# Connect points with lines if there are multiple points
if len(global_steps) > 1:
for i in range(len(global_steps) - 1):
step1, reward1 = global_steps[i], mean_rewards[i]
step2, reward2 = global_steps[i + 1], mean_rewards[i + 1]
x1 = int((step1 - x_min) / x_range * (chart_width - 1)) if x_range > 0 else 0
y1 = int((reward1 - y_min) / y_range * (chart_height - 1)) if y_range > 0 else 0
x2 = int((step2 - x_min) / x_range * (chart_width - 1)) if x_range > 0 else 0
y2 = int((reward2 - y_min) / y_range * (chart_height - 1)) if y_range > 0 else 0
y1 = chart_height - 1 - y1
y2 = chart_height - 1 - y2
# Simple line drawing between points
steps_between = max(abs(x2 - x1), abs(y2 - y1))
if steps_between > 0:
for s in range(1, steps_between):
t = s / steps_between
x = int(x1 + t * (x2 - x1))
y = int(y1 + t * (y2 - y1))
x = max(0, min(chart_width - 1, x))
y = max(0, min(chart_height - 1, y))
if chart[y][x] == ' ':
chart[y][x] = '.'
# Build the output
output = Text()
output.append("\n Reward Curve (Mean Reward vs Global Step)\n", style="bold cyan")
output.append(f" Server: {self.server_url}\n", style="dim")
output.append(f" Data points: {len(global_steps)}\n\n", style="dim")
# Draw y-axis labels and chart
y_labels = []
for i in range(chart_height):
y_val = y_max - (i / (chart_height - 1)) * y_range if chart_height > 1 else y_max
y_labels.append(y_val)
for i, row in enumerate(chart):
# Y-axis label (only show a few)
if i == 0 or i == chart_height - 1 or i == chart_height // 2:
label = f"{y_labels[i]:8.3f} |"
else:
label = " |"
output.append(label, style="dim")
output.append(''.join(row), style="green")
output.append("\n")
# X-axis
output.append(" +" + "-" * chart_width + "\n", style="dim")
# X-axis labels
x_label_line = " "
x_label_line += f"{x_min:<{chart_width // 3}}"
mid_step = x_min + x_range // 2
x_label_line += f"{mid_step:^{chart_width // 3}}"
x_label_line += f"{x_max:>{chart_width // 3}}"
output.append(x_label_line[:chart_width + 10] + "\n", style="dim")
output.append(" " + " " * (chart_width // 2 - 5) + "Global Step\n", style="dim cyan")
# Statistics
output.append("\n Statistics:\n", style="bold yellow")
output.append(f" Latest Global Step: {global_steps[-1]}\n", style="green")
output.append(f" Latest Mean Reward: {mean_rewards[-1]:.4f}\n", style="green")
output.append(f" Min Mean Reward: {min(mean_rewards):.4f} (step {global_steps[mean_rewards.index(min(mean_rewards))]})\n", style="cyan")
output.append(f" Max Mean Reward: {max(mean_rewards):.4f} (step {global_steps[mean_rewards.index(max(mean_rewards))]})\n", style="cyan")
self.console.print(output)
self.console.print("\n[dim]Press Enter to return to menu...[/dim]")
input()
def display_latest_llm_call(self):
while True:
response = httpx.post(f"{self.server_url}/replay_latest_llm_call", timeout=30.0)
structured_response = response.json()
self.console.clear()
if "input" not in structured_response or "output" not in structured_response:
self.console.print(f"[bold red]{structured_response}[/bold red]")
time.sleep(5)
continue
else:
input = structured_response["input"]
output = structured_response["output"]
self.console.print(f"\n[bold green]Input:[/bold green]\n{input}")
self.console.print(f"\n[bold green]Output:[/bold green]\n{output}")
hide_when_more_than_n_line_break = 4
try:
input_items = ""
output_items = ""
for item in input['messages']:
role = item['role']
content = item['content']
if isinstance(content, list):
content = content[0].get('text', '')
if content.count('\n') >= hide_when_more_than_n_line_break:
content = content.replace('\n',' ')[:200] + " ....."
else:
content = content.replace('\n',' ')
input_items += f"[bold blue]@{role}:[/bold blue] {content}\n"
for item in output['choices']:
role = item['message']['role']
content = item['message']['content']
if content.count('\n') >= hide_when_more_than_n_line_break:
content = content.replace('\n',' ')[:200] + " ....."
else:
content = content.replace('\n',' ')
output_items += f"[bold red]@{role}:[/bold red] {content}\n"
self.console.print(f"\n-------------------------------------------------------------")
self.console.print(f"\n[bold green]Input Simlified:[/bold green]\n{input_items}")
self.console.print(f"\n[bold green]Output Simlified:[/bold green]\n{output_items}")
except:
pass
time.sleep(5)
def choose_run(self) -> str:
mode = "overwatch"
# mode = "replay_latest_llm_call"
while True:
self.console.clear()
try:
if mode == "overwatch":
self.run()
elif mode == "replay_latest_llm_call":
self.display_latest_llm_call()
except KeyboardInterrupt:
self.console.clear()
self.console.print("\n[bold yellow]Overwatch stopped by user[/bold yellow]")
self.console.print(
f"[dim]Total requests: {self.total_requests}, Errors: {self.error_count}[/dim]\n"
)
self.console.print("\n[bold]Choose action:[/bold]")
self.console.print(" [bold cyan]o[/bold cyan] - Return to overwatch")
self.console.print(" [bold cyan]t[/bold cyan] - Show replay_latest_llm_call")
self.console.print(" [bold cyan]c[/bold cyan] - Show reward curve")
self.console.print(" [bold cyan]ctrl+c[/bold cyan] - Exit")
choice = input("\n> ").strip().lower()
if choice == "o":
mode = "overwatch"
self.console.clear()
continue
elif choice == "t":
mode = "replay_latest_llm_call"
self.console.clear()
continue
elif choice == "c":
self.display_reward_curve()
self.console.clear()
continue
else:
self.console.print("[yellow]Invalid choice. Please enter 'o', 't', or 'c'.[/yellow]")
def run(self):
"""Start the monitoring interface"""
with Live(
self.create_dashboard(None, init=True),
console=self.console,
refresh_per_second=1,
screen=True,
) as live:
self.console.print(
"[bold green]Starting Swarm Overwatch...[/bold green]"
)
self.console.print(f"[dim]Press Ctrl+C to exit[/dim]\n")
time.sleep(1)
while True:
try:
# Fetch latest data
info = self.fetch_pool_info()
# Update display
live.update(self.create_dashboard(info))
# Wait for next refresh
time.sleep(self.refresh_interval)
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(f"Error in monitoring loop: {e}")
time.sleep(self.refresh_interval)
def start_overwatch(server_url: str, refresh_interval: float = 2.0):
"""
Start the swarm overwatch monitoring interface
Args:
server_url: Base URL of the swarm server
refresh_interval: Refresh interval in seconds (default: 2.0)
"""
overwatch = SwarmOverwatch(server_url, refresh_interval)
try:
overwatch.choose_run()
except KeyboardInterrupt:
logger.info("Swarm Overwatch stopped by user")
if __name__ == "__main__":
# Test with default localhost
start_overwatch("http://localhost:10086")