Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/mars/agent_step.rb
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ def agent(klass = nil)
end
end

def run(input)
self.class.agent.new.ask(input).content
def run(context)
self.class.agent.new.ask(context.current_input).content
end
end
end
7 changes: 4 additions & 3 deletions lib/mars/execution_context.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

module MARS
class ExecutionContext
attr_reader :current_input, :outputs, :global_state
attr_reader :outputs, :global_state
attr_accessor :current_input

def initialize(input: nil, global_state: {})
@current_input = input
Expand All @@ -19,8 +20,8 @@ def record(step_name, output)
@current_input = output
end

def fork(input: current_input)
self.class.new(input: input, global_state: global_state)
def fork(input: current_input, state: {})
self.class.new(input: input, global_state: global_state.merge(state))
end

def merge(child_contexts)
Expand Down
5 changes: 3 additions & 2 deletions lib/mars/workflows/parallel.rb
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def aggregate_results(results)
def execute_steps(context, errors, child_contexts)
Async do |workflow|
tasks = steps.map do |step|
child_ctx = context.fork
child_ctx = context.fork(state: step.state)
child_contexts << child_ctx

workflow.async do
Expand All @@ -54,7 +54,8 @@ def workflow_step(step, child_ctx)
step.run_before_hooks(child_ctx)

step_input = step.formatter.format_input(child_ctx)
result = step.run(step_input)
child_ctx.current_input = step_input
result = step.run(child_ctx)

if result.is_a?(Halt)
step.run_after_hooks(child_ctx, result)
Expand Down
3 changes: 2 additions & 1 deletion lib/mars/workflows/sequential.rb
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def run(input)
step.run_before_hooks(context)

step_input = step.formatter.format_input(context)
result = step.run(step_input)
context.current_input = step_input
result = step.run(context)

if result.is_a?(Halt)
if result.global?
Expand Down
2 changes: 1 addition & 1 deletion spec/mars/agent_step_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

it "creates a new agent instance and calls ask" do
step = step_class.new
result = step.run("hello")
result = step.run(MARS::ExecutionContext.new(input: "hello"))

expect(result).to eq("agent response")
expect(mock_agent_class).to have_received(:new)
Expand Down
24 changes: 12 additions & 12 deletions spec/mars/workflows/sequential_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def initialize(value, **kwargs)
end

def run(input)
input + @value
input.current_input + @value
end
end
end
Expand All @@ -22,7 +22,7 @@ def initialize(multiplier, **kwargs)
end

def run(input)
input * @multiplier
input.current_input * @multiplier
end
end
end
Expand Down Expand Up @@ -67,11 +67,11 @@ def run(_input)

it "records outputs in context accessible by step name" do
step1 = Class.new(MARS::Runnable) do
def run(input) = "from_step1:#{input}"
def run(input) = "from_step1:#{input.current_input}"
end.new(name: "step1")

step2 = Class.new(MARS::Runnable) do
def run(input) = "from_step2:#{input}"
def run(input) = "from_step2:#{input.current_input}"
end.new(name: "step2")

context = MARS::ExecutionContext.new(input: "hello")
Expand All @@ -84,7 +84,7 @@ def run(input) = "from_step2:#{input}"

it "wraps raw input in ExecutionContext automatically" do
step = Class.new(MARS::Runnable) do
def run(input) = "processed:#{input}"
def run(input) = "processed:#{input.current_input}"
end.new(name: "step")

workflow = described_class.new("auto_wrap", steps: [step])
Expand All @@ -100,7 +100,7 @@ def format_output(output)
end

step = Class.new(MARS::Runnable) do
def run(input) = "result:#{input}"
def run(input) = "result:#{input.current_input}"
end.new(name: "step", formatter: uppercase_formatter.new)

workflow = described_class.new("fmt_workflow", steps: [step])
Expand All @@ -115,7 +115,7 @@ def run(input) = "result:#{input}"
before_run { |_ctx, step| hook_log << "before:#{step.name}" }
after_run { |_ctx, _result, step| hook_log << "after:#{step.name}" }

def run(input) = input
def run(input) = input.current_input
end

step = step_class.new(name: "hooked")
Expand All @@ -133,7 +133,7 @@ def run(input) = input
fallbacks: {
branch: Class.new(MARS::Runnable) do
def run(input)
"branched:#{input}"
"branched:#{input.current_input}"
end
end.new(name: "branch_step")
}
Expand All @@ -156,7 +156,7 @@ def run(input)
fallbacks: {
branch: Class.new(MARS::Runnable) do
def run(input)
"branched:#{input}"
"branched:#{input.current_input}"
end
end.new(name: "branch_step")
},
Expand All @@ -179,7 +179,7 @@ def run(input)
fallbacks: {
stop: Class.new(MARS::Runnable) do
def run(input)
"stopped:#{input}"
"stopped:#{input.current_input}"
end
end.new(name: "stop_step")
},
Expand All @@ -202,7 +202,7 @@ def run(input)
fallbacks: {
stop: Class.new(MARS::Runnable) do
def run(input)
"stopped:#{input}"
"stopped:#{input.current_input}"
end
end.new(name: "stop_step")
}
Expand All @@ -212,7 +212,7 @@ def run(input)

string_step = Class.new(MARS::Runnable) do
def run(input)
"after:#{input}"
"after:#{input.current_input}"
end
end.new(name: "after_step")

Expand Down
Loading