From 9823a019d109081fa4d17ac7f49888fd9e8f0693 Mon Sep 17 00:00:00 2001 From: Lubosz Kosnik <6192897+luboszk@users.noreply.github.com> Date: Fri, 22 May 2026 07:58:16 -0400 Subject: [PATCH] feat: support `case` expression function Adds the `case(pred1, val1, ..., default)` function recently introduced to GitHub Actions expressions, returning the value paired with the first truthy predicate (or the trailing default). Co-Authored-By: Claude Opus 4.7 --- pkg/exprparser/functions.go | 16 ++++++++++++ pkg/exprparser/functions_test.go | 36 +++++++++++++++++++++++++ pkg/exprparser/interpreter.go | 2 ++ pkg/schema/schema.go | 1 + pkg/schema/schema_test.go | 45 ++++++++++++++++++++++++++++++++ 5 files changed, 100 insertions(+) diff --git a/pkg/exprparser/functions.go b/pkg/exprparser/functions.go index 83b2a0802f3..99f853300c9 100644 --- a/pkg/exprparser/functions.go +++ b/pkg/exprparser/functions.go @@ -252,6 +252,22 @@ func (impl *interperterImpl) getNeedsTransitive(job *model.Job) []string { return needs } +func (impl *interperterImpl) caseFunc(args []reflect.Value) (interface{}, error) { + if len(args) < 3 { + return nil, fmt.Errorf("case() requires at least 3 arguments, got %d", len(args)) + } + if len(args)%2 == 0 { + return nil, fmt.Errorf("case() requires an odd number of arguments (pairs of predicate/value plus a default), got %d", len(args)) + } + + for i := 0; i < len(args)-1; i += 2 { + if IsTruthy(impl.getSafeValue(args[i])) { + return impl.getSafeValue(args[i+1]), nil + } + } + return impl.getSafeValue(args[len(args)-1]), nil +} + func (impl *interperterImpl) always() (bool, error) { return true, nil } diff --git a/pkg/exprparser/functions_test.go b/pkg/exprparser/functions_test.go index 7241ddf82e6..0b9cc811dcd 100644 --- a/pkg/exprparser/functions_test.go +++ b/pkg/exprparser/functions_test.go @@ -257,6 +257,42 @@ func TestFunctionFormat(t *testing.T) { } } +func TestFunctionCase(t *testing.T) { + table := []struct { + input string + expected interface{} + error interface{} + name string + }{ + {"case(true, 'yes', 'no') }}", "yes", nil, "case-single-pred-true"}, + {"case(false, 'yes', 'no') }}", "no", nil, "case-single-pred-false-default"}, + {"case(false, 'a', true, 'b', 'default') }}", "b", nil, "case-second-pred-true"}, + {"case(false, 'a', false, 'b', 'default') }}", "default", nil, "case-no-match-default"}, + {"case(true, 'first', true, 'second', 'default') }}", "first", nil, "case-first-true-wins"}, + {"case(1 == 1, 'eq', 'neq') }}", "eq", nil, "case-equality-predicate"}, + {"case('' , 'empty-truthy', 'not-empty') }}", "not-empty", nil, "case-empty-string-is-falsy"}, + {"case('x', 'truthy', 'falsy') }}", "truthy", nil, "case-non-empty-string-is-truthy"}, + {"case(false, 'a', null) }}", nil, nil, "case-default-null"}, + {"case(false, 1, false, 2, 3) }}", 3, nil, "case-numeric-default"}, + {"case(false, 'a') }}", nil, "case() requires at least 3 arguments, got 2", "case-too-few-args"}, + {"case(false, 'a', true, 'b') }}", nil, "case() requires an odd number of arguments (pairs of predicate/value plus a default), got 4", "case-even-args"}, + } + + env := &EvaluationEnvironment{} + + for _, tt := range table { + t.Run(tt.name, func(t *testing.T) { + output, err := NewInterpeter(env, Config{}).Evaluate(tt.input, DefaultStatusCheckNone) + if tt.error != nil { + assert.Equal(t, tt.error, err.Error()) + } else { + assert.Nil(t, err) + assert.Equal(t, tt.expected, output) + } + }) + } +} + func TestMapContains(t *testing.T) { env := &EvaluationEnvironment{ Needs: map[string]Needs{ diff --git a/pkg/exprparser/interpreter.go b/pkg/exprparser/interpreter.go index 8d50913d9d3..644f4cfa0fb 100644 --- a/pkg/exprparser/interpreter.go +++ b/pkg/exprparser/interpreter.go @@ -638,6 +638,8 @@ func (impl *interperterImpl) evaluateFuncCall(funcCallNode *actionlint.FuncCallN return impl.toJSON(args[0]) case "fromjson": return impl.fromJSON(args[0]) + case "case": + return impl.caseFunc(args) case "hashfiles": if impl.env.HashFiles != nil { return impl.env.HashFiles(args) diff --git a/pkg/schema/schema.go b/pkg/schema/schema.go index a523e721f6d..d2a94d980b4 100644 --- a/pkg/schema/schema.go +++ b/pkg/schema/schema.go @@ -173,6 +173,7 @@ func (s *Node) GetFunctions() *[]FunctionInfo { AddFunction(funcs, "startsWith", 2, 2) AddFunction(funcs, "toJson", 1, 1) AddFunction(funcs, "fromJson", 1, 1) + AddFunction(funcs, "case", 3, math.MaxInt32) for _, v := range s.Context { i := strings.Index(v, "(") if i == -1 { diff --git a/pkg/schema/schema_test.go b/pkg/schema/schema_test.go index ce571c96533..496454fbe60 100644 --- a/pkg/schema/schema_test.go +++ b/pkg/schema/schema_test.go @@ -90,3 +90,48 @@ jobs: }).UnmarshalYAML(&node) assert.NoError(t, err) } + +func TestCaseFunctionSchema(t *testing.T) { + var node yaml.Node + err := yaml.Unmarshal([]byte(` +on: push +jobs: + main: + runs-on: ubuntu-latest + env: + TEST: ${{ case(github.event_name == 'workflow_dispatch', 'dispatch', 'other') }} + MULTI: ${{ case(github.ref == 'refs/heads/main', 'production', github.ref == 'refs/heads/staging', 'staging', 'development') }} + steps: + - run: echo $TEST +`), &node) + if !assert.NoError(t, err) { + return + } + err = (&Node{ + Definition: "workflow-root-strict", + Schema: GetWorkflowSchema(), + }).UnmarshalYAML(&node) + assert.NoError(t, err) +} + +func TestCaseFunctionSchemaTooFewArgs(t *testing.T) { + var node yaml.Node + err := yaml.Unmarshal([]byte(` +on: push +jobs: + main: + runs-on: ubuntu-latest + env: + TEST: ${{ case(true, 'only-two') }} + steps: + - run: echo $TEST +`), &node) + if !assert.NoError(t, err) { + return + } + err = (&Node{ + Definition: "workflow-root-strict", + Schema: GetWorkflowSchema(), + }).UnmarshalYAML(&node) + assert.Error(t, err) +}