diff --git a/cli/azd/pkg/account/manager.go b/cli/azd/pkg/account/manager.go index dbee60ea52a..28ec3d7143d 100644 --- a/cli/azd/pkg/account/manager.go +++ b/cli/azd/pkg/account/manager.go @@ -37,6 +37,7 @@ type Manager interface { GetSubscriptions(ctx context.Context) ([]Subscription, error) GetSubscriptionsWithDefaultSet(ctx context.Context) ([]Subscription, error) GetLocations(ctx context.Context, subscriptionId string) ([]Location, error) + GetTenantDisplayNames(ctx context.Context) (map[string]string, error) SetDefaultSubscription(ctx context.Context, subscriptionId string) (*Subscription, error) SetDefaultLocation(ctx context.Context, subscriptionId string, location string) (*Location, error) } @@ -140,6 +141,11 @@ func (m *manager) GetSubscriptions(ctx context.Context) ([]Subscription, error) return m.subManager.GetSubscriptions(ctx) } +// GetTenantDisplayNames returns a map of tenant ID to display name. +func (m *manager) GetTenantDisplayNames(ctx context.Context) (map[string]string, error) { + return m.subManager.GetTenantDisplayNames(ctx) +} + // Gets the available Azure locations for the specified Azure subscription. func (m *manager) GetLocations(ctx context.Context, subscriptionId string) ([]Location, error) { locations, err := m.subManager.ListLocations(ctx, subscriptionId) diff --git a/cli/azd/pkg/account/subscriptions_manager.go b/cli/azd/pkg/account/subscriptions_manager.go index de5685d0a9f..1a0cda5848d 100644 --- a/cli/azd/pkg/account/subscriptions_manager.go +++ b/cli/azd/pkg/account/subscriptions_manager.go @@ -408,6 +408,28 @@ func (m *SubscriptionsManager) getSubscription(ctx context.Context, subscription return &sub, nil } +// GetTenantDisplayNames returns a map of tenant ID to display name for all tenants +// accessible by the current account. +func (m *SubscriptionsManager) GetTenantDisplayNames(ctx context.Context) (map[string]string, error) { + tenants, err := m.service.ListTenants(ctx) + if err != nil { + return nil, fmt.Errorf("listing tenants: %w", err) + } + + result := make(map[string]string, len(tenants)) + for _, t := range tenants { + if t.TenantID != nil { + name := *t.TenantID + if t.DisplayName != nil && *t.DisplayName != "" { + name = *t.DisplayName + } + result[*t.TenantID] = name + } + } + + return result, nil +} + func toSubscriptions(azSubs []*armsubscriptions.Subscription, userAccessTenantId string) []Subscription { if azSubs == nil { return nil diff --git a/cli/azd/pkg/prompt/prompt_service.go b/cli/azd/pkg/prompt/prompt_service.go index a1d25b0c18c..ad0b3d7117d 100644 --- a/cli/azd/pkg/prompt/prompt_service.go +++ b/cli/azd/pkg/prompt/prompt_service.go @@ -101,6 +101,9 @@ type SelectOptions struct { HelpMessage string // LoadingMessage is the loading message to display to the user. LoadingMessage string + // SkipLoadingSpinner skips the loading spinner in PromptCustomResource. + // Use this when data is pre-loaded and LoadData returns immediately. + SkipLoadingSpinner bool // DisplayNumbers specifies whether to display numbers next to the choices. DisplayNumbers *bool // DisplayCount is the number of choices to display at a time. @@ -157,6 +160,7 @@ type ResourceService interface { type SubscriptionManager interface { GetSubscriptions(ctx context.Context) ([]account.Subscription, error) GetLocations(ctx context.Context, subscriptionId string) ([]account.Location, error) + GetTenantDisplayNames(ctx context.Context) (map[string]string, error) } // PromptServiceInterface defines the methods that the PromptService must implement. @@ -211,6 +215,8 @@ func NewPromptService( } // PromptSubscription prompts the user to select an Azure subscription. +// If the user has access to multiple tenants, a tenant selection prompt is shown first +// to scope down the subscription list. func (ps *promptService) PromptSubscription( ctx context.Context, selectorOptions *SelectOptions, @@ -235,6 +241,31 @@ func (ps *promptService) PromptSubscription( return nil, err } + // Load subscriptions under a spinner first + var subscriptionList []account.Subscription + loadingSpinner := ux.NewSpinner(&ux.SpinnerOptions{ + Text: mergedOptions.LoadingMessage, + }) + + err := loadingSpinner.Run(ctx, func(ctx context.Context) error { + var loadErr error + subscriptionList, loadErr = ps.subscriptionManager.GetSubscriptions(ctx) + return loadErr + }) + if err != nil { + return nil, fmt.Errorf("listing subscriptions: %w", err) + } + + // Apply tenant filtering (after spinner is done so the prompt doesn't overlap) + subscriptionList = filterByTenantEnvVar(subscriptionList) + if !ps.console.IsNoPromptMode() { + subscriptionList, err = promptAndFilterByTenant( + ctx, ps.console, subscriptionList, ps.subscriptionManager.GetTenantDisplayNames) + if err != nil { + return nil, err + } + } + // Get default subscription from user config var defaultSubscriptionId = "" userConfig, err := ps.userConfigManager.Load() @@ -247,19 +278,19 @@ func (ps *promptService) PromptSubscription( hideId := isDemoModeEnabled() - return PromptCustomResource(ctx, CustomResourceOptions[account.Subscription]{ - SelectorOptions: mergedOptions, - LoadData: func(ctx context.Context) ([]*account.Subscription, error) { - subscriptionList, err := ps.subscriptionManager.GetSubscriptions(ctx) - if err != nil { - return nil, err - } + // Use PromptCustomResource with pre-loaded data + subscriptions := make([]*account.Subscription, len(subscriptionList)) + for i := range subscriptionList { + subscriptions[i] = &subscriptionList[i] + } - subscriptions := make([]*account.Subscription, len(subscriptionList)) - for i, subscription := range subscriptionList { - subscriptions[i] = &subscription - } + // Create selector options with spinner disabled since data is already loaded + resourceSelectorOptions := *mergedOptions + resourceSelectorOptions.SkipLoadingSpinner = true + return PromptCustomResource(ctx, CustomResourceOptions[account.Subscription]{ + SelectorOptions: &resourceSelectorOptions, + LoadData: func(ctx context.Context) ([]*account.Subscription, error) { return subscriptions, nil }, DisplayResource: func(subscription *account.Subscription) (string, error) { @@ -768,21 +799,29 @@ func PromptCustomResource[T any](ctx context.Context, options CustomResourceOpti allowNewResource = true selectedIndex = new(0) } else { - loadingSpinner := ux.NewSpinner(&ux.SpinnerOptions{ - Text: options.SelectorOptions.LoadingMessage, - }) - - err := loadingSpinner.Run(ctx, func(ctx context.Context) error { + loadData := func(ctx context.Context) error { resourceList, err := options.LoadData(ctx) if err != nil { return err } - resources = resourceList return nil - }) - if err != nil { - return nil, err + } + + // Skip the spinner when data is pre-loaded + if mergedSelectorOptions.SkipLoadingSpinner { + if err := loadData(ctx); err != nil { + return nil, err + } + } else { + loadingSpinner := ux.NewSpinner(&ux.SpinnerOptions{ + Text: mergedSelectorOptions.LoadingMessage, + }) + if err := loadingSpinner.Run(ctx, func(ctx context.Context) error { + return loadData(ctx) + }); err != nil { + return nil, err + } } if !allowNewResource && len(resources) == 0 { diff --git a/cli/azd/pkg/prompt/prompt_service_extra_test.go b/cli/azd/pkg/prompt/prompt_service_extra_test.go index ce6c3e25972..d250297a693 100644 --- a/cli/azd/pkg/prompt/prompt_service_extra_test.go +++ b/cli/azd/pkg/prompt/prompt_service_extra_test.go @@ -252,6 +252,27 @@ func TestPromptCustomResource_NilSelectorOptions_UsesDefaultsAndForce(t *testing require.Equal(t, 42, *result) } +func TestPromptCustomResource_SkipLoadingSpinner(t *testing.T) { + t.Parallel() + + loaded := false + _, err := PromptCustomResource(t.Context(), CustomResourceOptions[string]{ + SelectorOptions: &SelectOptions{ + SkipLoadingSpinner: true, + AllowNewResource: new(false), + }, + LoadData: func(ctx context.Context) ([]*string, error) { + loaded = true + return nil, nil + }, + }) + + // LoadData should have been called directly (without spinner) + require.True(t, loaded) + // No resources and AllowNewResource=false returns the sentinel error + require.ErrorIs(t, err, ErrNoResourcesFound) +} + // helpers func emptySubs() []account.Subscription { return []account.Subscription{} } diff --git a/cli/azd/pkg/prompt/prompter.go b/cli/azd/pkg/prompt/prompter.go index 6fb8c241486..09f2ada4dac 100644 --- a/cli/azd/pkg/prompt/prompter.go +++ b/cli/azd/pkg/prompt/prompter.go @@ -10,7 +10,6 @@ import ( "log" "os" "slices" - "strconv" "github.com/MakeNowJust/heredoc/v2" "github.com/azure/azure-dev/cli/azd/pkg/account" @@ -71,12 +70,12 @@ func NewDefaultPrompter( } func (p *DefaultPrompter) PromptSubscription(ctx context.Context, msg string) (subscriptionId string, err error) { - subscriptionOptions, subscriptions, defaultSubscription, err := p.getSubscriptionOptions(ctx) + subscriptionInfos, err := p.accountManager.GetSubscriptions(ctx) if err != nil { - return "", err + return "", fmt.Errorf("listing subscriptions: %w", err) } - if len(subscriptionOptions) == 0 { + if len(subscriptionInfos) == 0 { // NOTE: Error text must contain "no subscriptions found" to match the // pattern in error_suggestions.yaml. Update both if rewording. return "", errors.New(heredoc.Docf( @@ -87,6 +86,32 @@ func (p *DefaultPrompter) PromptSubscription(ctx context.Context, msg string) (s )) } + // Filter by AZURE_TENANT_ID if set (works in both prompt and no-prompt modes) + subscriptionInfos = filterByTenantEnvVar(subscriptionInfos) + + // Tenant selection: if multiple tenants, prompt user to pick one + if !p.console.IsNoPromptMode() { + subscriptionInfos, err = promptAndFilterByTenant( + ctx, p.console, subscriptionInfos, p.accountManager.GetTenantDisplayNames) + if err != nil { + return "", err + } + } + + slices.SortFunc(subscriptionInfos, func(a, b account.Subscription) int { + return stringutil.CompareLower(a.Name, b.Name) + }) + + // The default value is based on AZURE_SUBSCRIPTION_ID, falling back to whatever default subscription in + // set in azd's config. + defaultSubscriptionId := os.Getenv(environment.SubscriptionIdEnvVarName) + if defaultSubscriptionId == "" { + defaultSubscriptionId = p.accountManager.GetDefaultSubscriptionID(ctx) + } + + subscriptionOptions, subscriptions, defaultSubscription := + formatSubscriptionOptions(subscriptionInfos, defaultSubscriptionId) + for subscriptionId == "" { subscriptionSelectionIndex, err := p.console.Select(ctx, input.ConsoleOptions{ Message: msg, @@ -110,6 +135,34 @@ func (p *DefaultPrompter) PromptSubscription(ctx context.Context, msg string) (s return subscriptionId, nil } +// formatSubscriptionOptions formats subscription infos into display options. +func formatSubscriptionOptions( + subscriptionInfos []account.Subscription, + defaultSubscriptionId string, +) (options []string, ids []string, defaultOption any) { + options = make([]string, len(subscriptionInfos)) + ids = make([]string, len(subscriptionInfos)) + + hideId := isDemoModeEnabled() + + for index, info := range subscriptionInfos { + if hideId { + options[index] = fmt.Sprintf("%2d. %s", index+1, info.Name) + } else { + options[index] = fmt.Sprintf( + "%2d. %s (%s)", index+1, info.Name, info.Id) + } + + ids[index] = info.Id + + if info.Id == defaultSubscriptionId { + defaultOption = options[index] + } + } + + return options, ids, defaultOption +} + func (p *DefaultPrompter) PromptLocation( ctx context.Context, subId string, @@ -246,44 +299,6 @@ func (p *DefaultPrompter) PromptResourceGroupFrom( return name, nil } -func (p *DefaultPrompter) getSubscriptionOptions(ctx context.Context) ([]string, []string, any, error) { - subscriptionInfos, err := p.accountManager.GetSubscriptions(ctx) - if err != nil { - return nil, nil, nil, fmt.Errorf("listing accounts: %w", err) - } - - slices.SortFunc(subscriptionInfos, func(a, b account.Subscription) int { - return stringutil.CompareLower(a.Name, b.Name) - }) - - // The default value is based on AZURE_SUBSCRIPTION_ID, falling back to whatever default subscription in - // set in azd's config. - defaultSubscriptionId := os.Getenv(environment.SubscriptionIdEnvVarName) - if defaultSubscriptionId == "" { - defaultSubscriptionId = p.accountManager.GetDefaultSubscriptionID(ctx) - } - - var subscriptionOptions = make([]string, len(subscriptionInfos)) - var subscriptions = make([]string, len(subscriptionInfos)) - var defaultSubscription any - - for index, info := range subscriptionInfos { - if v, err := strconv.ParseBool(os.Getenv("AZD_DEMO_MODE")); err == nil && v { - subscriptionOptions[index] = fmt.Sprintf("%2d. %s", index+1, info.Name) - } else { - subscriptionOptions[index] = fmt.Sprintf("%2d. %s (%s)", index+1, info.Name, info.Id) - } - - subscriptions[index] = info.Id - - if info.Id == defaultSubscriptionId { - defaultSubscription = subscriptionOptions[index] - } - } - - return subscriptionOptions, subscriptions, defaultSubscription, nil -} - func (p *DefaultPrompter) IsNoPromptMode() bool { return p.console.IsNoPromptMode() } diff --git a/cli/azd/pkg/prompt/prompter_extra_test.go b/cli/azd/pkg/prompt/prompter_extra_test.go index 2d3d0de0a2e..a65cb300934 100644 --- a/cli/azd/pkg/prompt/prompter_extra_test.go +++ b/cli/azd/pkg/prompt/prompter_extra_test.go @@ -49,8 +49,8 @@ func TestDefaultPrompter_PromptSubscription_HappyPath(t *testing.T) { mockAccount := &mockaccount.MockAccountManager{ Subscriptions: []account.Subscription{ - {Id: "sub-alpha", Name: "Alpha", TenantId: "tenant-1"}, - {Id: "sub-bravo", Name: "Bravo", TenantId: "tenant-2"}, + {Id: "sub-alpha", Name: "Alpha", TenantId: "tenant-1", UserAccessTenantId: "tenant-1"}, + {Id: "sub-bravo", Name: "Bravo", TenantId: "tenant-1", UserAccessTenantId: "tenant-1"}, }, } p, mockCtx := newTestPrompter(t, mockAccount) @@ -224,18 +224,14 @@ func TestDefaultPrompter_PromptLocation_WithDefaultSelectedLocation(t *testing.T require.Contains(t, defaultValue.(string), "West US") } -func TestDefaultPrompter_GetSubscriptionOptions_DemoMode(t *testing.T) { +func TestDefaultPrompter_FormatSubscriptionOptions_DemoMode(t *testing.T) { t.Setenv("AZD_DEMO_MODE", "true") - mockAccount := &mockaccount.MockAccountManager{ - Subscriptions: []account.Subscription{ - {Id: "sub-secret", Name: "Display Only"}, - }, + subscriptions := []account.Subscription{ + {Id: "sub-secret", Name: "Display Only"}, } - p, _ := newTestPrompter(t, mockAccount) - opts, subs, _, err := p.getSubscriptionOptions(t.Context()) - require.NoError(t, err) + opts, subs, _ := formatSubscriptionOptions(subscriptions, "") require.Len(t, opts, 1) require.Len(t, subs, 1) // In demo mode, id must not be exposed. @@ -243,20 +239,14 @@ func TestDefaultPrompter_GetSubscriptionOptions_DemoMode(t *testing.T) { require.Contains(t, opts[0], "Display Only") } -func TestDefaultPrompter_GetSubscriptionOptions_EnvVarDefault(t *testing.T) { - t.Setenv(environment.SubscriptionIdEnvVarName, "sub-env") - - mockAccount := &mockaccount.MockAccountManager{ - DefaultSubscription: "sub-config", // env var takes precedence - Subscriptions: []account.Subscription{ - {Id: "sub-env", Name: "From Env"}, - {Id: "sub-config", Name: "From Config"}, - }, +func TestDefaultPrompter_FormatSubscriptionOptions_EnvVarDefault(t *testing.T) { + subscriptions := []account.Subscription{ + {Id: "sub-env", Name: "From Env"}, + {Id: "sub-config", Name: "From Config"}, } - p, _ := newTestPrompter(t, mockAccount) - _, _, def, err := p.getSubscriptionOptions(t.Context()) - require.NoError(t, err) + // env var default takes precedence + _, _, def := formatSubscriptionOptions(subscriptions, "sub-env") require.NotNil(t, def) require.Contains(t, def.(string), "From Env") } diff --git a/cli/azd/pkg/prompt/prompter_test.go b/cli/azd/pkg/prompt/prompter_test.go index 5e59f5e2a27..cb4c118ba44 100644 --- a/cli/azd/pkg/prompt/prompter_test.go +++ b/cli/azd/pkg/prompt/prompter_test.go @@ -7,84 +7,49 @@ import ( "testing" "github.com/azure/azure-dev/cli/azd/pkg/account" - "github.com/azure/azure-dev/cli/azd/pkg/azapi" - "github.com/azure/azure-dev/cli/azd/pkg/cloud" - "github.com/azure/azure-dev/cli/azd/pkg/environment" - "github.com/azure/azure-dev/cli/azd/test/mocks" - "github.com/azure/azure-dev/cli/azd/test/mocks/mockaccount" "github.com/stretchr/testify/require" ) -func Test_getSubscriptionOptions(t *testing.T) { +func Test_formatSubscriptionOptions(t *testing.T) { t.Run("no default config set", func(t *testing.T) { - mockContext := mocks.NewMockContext(t.Context()) - env := environment.New("test") - resourceService := azapi.NewResourceService(mockContext.SubscriptionCredentialProvider, mockContext.ArmClientOptions) - mockAccount := &mockaccount.MockAccountManager{ - Subscriptions: []account.Subscription{ - { - Id: "1", - Name: "sub1", - TenantId: "", - UserAccessTenantId: "", - IsDefault: false, - }, + subscriptions := []account.Subscription{ + { + Id: "1", + Name: "sub1", + TenantId: "", + UserAccessTenantId: "", + IsDefault: false, }, } - prompter := NewDefaultPrompter( - env, - mockContext.Console, - mockAccount, - resourceService, - cloud.AzurePublic(), - ).(*DefaultPrompter) - subList, subs, result, err := prompter.getSubscriptionOptions(*mockContext.Context) + subList, subs, result := formatSubscriptionOptions(subscriptions, "") - require.Nil(t, err) require.EqualValues(t, 1, len(subList)) require.EqualValues(t, 1, len(subs)) require.EqualValues(t, nil, result) }) t.Run("default value set", func(t *testing.T) { - // mocked config defaultSubId := "SUBSCRIPTION_DEFAULT" - mockContext := mocks.NewMockContext(t.Context()) - env := environment.New("test") - resourceService := azapi.NewResourceService(mockContext.SubscriptionCredentialProvider, mockContext.ArmClientOptions) - mockAccount := &mockaccount.MockAccountManager{ - DefaultLocation: "location", - DefaultSubscription: defaultSubId, - Subscriptions: []account.Subscription{ - { - Id: defaultSubId, - Name: "DISPLAY DEFAULT", - TenantId: "TENANT", - UserAccessTenantId: "USER_TENANT", - IsDefault: true, - }, - { - Id: "SUBSCRIPTION_OTHER", - Name: "DISPLAY OTHER", - TenantId: "TENANT", - UserAccessTenantId: "USER_TENANT", - IsDefault: false, - }, + subscriptions := []account.Subscription{ + { + Id: defaultSubId, + Name: "DISPLAY DEFAULT", + TenantId: "TENANT", + UserAccessTenantId: "USER_TENANT", + IsDefault: true, + }, + { + Id: "SUBSCRIPTION_OTHER", + Name: "DISPLAY OTHER", + TenantId: "TENANT", + UserAccessTenantId: "USER_TENANT", + IsDefault: false, }, - Locations: []account.Location{}, } - prompter := NewDefaultPrompter( - env, - mockContext.Console, - mockAccount, - resourceService, - cloud.AzurePublic(), - ).(*DefaultPrompter) - subList, subs, result, err := prompter.getSubscriptionOptions(*mockContext.Context) + subList, subs, result := formatSubscriptionOptions(subscriptions, defaultSubId) - require.Nil(t, err) require.EqualValues(t, 2, len(subList)) require.EqualValues(t, 2, len(subs)) require.NotNil(t, result) diff --git a/cli/azd/pkg/prompt/tenant_helper.go b/cli/azd/pkg/prompt/tenant_helper.go new file mode 100644 index 00000000000..9b1c665ae40 --- /dev/null +++ b/cli/azd/pkg/prompt/tenant_helper.go @@ -0,0 +1,222 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package prompt + +import ( + "cmp" + "context" + "fmt" + "log" + "os" + "slices" + "strings" + + "github.com/azure/azure-dev/cli/azd/pkg/account" + "github.com/azure/azure-dev/cli/azd/pkg/environment" + "github.com/azure/azure-dev/cli/azd/pkg/input" + "github.com/azure/azure-dev/cli/azd/pkg/output" +) + +// tenantInfo holds display metadata for a tenant extracted from the subscription list. +type tenantInfo struct { + // Id is the tenant ID (GUID). + Id string + // DisplayName is the friendly name of the tenant, or the ID if no name is available. + DisplayName string + // SubscriptionCount is the number of subscriptions accessible via this tenant. + SubscriptionCount int +} + +// extractUniqueTenants extracts unique tenants from a list of subscriptions, +// grouped by UserAccessTenantId (falling back to TenantId when UserAccessTenantId is empty). +// The returned list is sorted by DisplayName. +// Tenant display names are resolved from the provided tenantDisplayNames map; +// if a tenant ID is not in the map, the ID itself is used as the display name. +func extractUniqueTenants( + subscriptions []account.Subscription, + tenantDisplayNames map[string]string, +) []tenantInfo { + tenantMap := make(map[string]*tenantInfo) + + for _, sub := range subscriptions { + tid := sub.UserAccessTenantId + if tid == "" { + tid = sub.TenantId + } + if tid == "" { + continue + } + + if info, exists := tenantMap[tid]; exists { + info.SubscriptionCount++ + } else { + displayName := tid + if name, ok := tenantDisplayNames[tid]; ok && name != "" { + displayName = name + } + tenantMap[tid] = &tenantInfo{ + Id: tid, + DisplayName: displayName, + SubscriptionCount: 1, + } + } + } + + tenants := make([]tenantInfo, 0, len(tenantMap)) + for _, info := range tenantMap { + tenants = append(tenants, *info) + } + + slices.SortFunc(tenants, func(a, b tenantInfo) int { + if c := cmp.Compare( + strings.ToLower(a.DisplayName), + strings.ToLower(b.DisplayName), + ); c != 0 { + return c + } + return cmp.Compare(a.Id, b.Id) + }) + + return tenants +} + +// filterSubscriptionsByTenant filters subscriptions to only those accessible +// through the specified tenant ID. If tenantId is empty, all subscriptions are returned. +func filterSubscriptionsByTenant( + subscriptions []account.Subscription, + tenantId string, +) []account.Subscription { + if tenantId == "" { + return subscriptions + } + + filtered := make([]account.Subscription, 0, len(subscriptions)) + for _, sub := range subscriptions { + accessTenant := sub.UserAccessTenantId + if accessTenant == "" { + accessTenant = sub.TenantId + } + if accessTenant == tenantId { + filtered = append(filtered, sub) + } + } + return filtered +} + +// filterByTenantEnvVar filters subscriptions by AZURE_TENANT_ID if set. +// This is applied in both prompt and no-prompt modes. +// If the env var is set but no subscriptions match (e.g. stale tenant ID), +// the filter is a no-op and returns all subscriptions to avoid blocking the user. +func filterByTenantEnvVar(subscriptions []account.Subscription) []account.Subscription { + tenantId := os.Getenv(environment.TenantIdEnvVarName) + if tenantId == "" { + return subscriptions + } + + filtered := filterSubscriptionsByTenant(subscriptions, tenantId) + // If filtering produces no results, fall back to showing all subscriptions + // rather than erroring out — the tenant ID may be stale + if len(filtered) == 0 { + log.Println("AZURE_TENANT_ID did not match any subscription tenants, showing all subscriptions") + return subscriptions + } + + return filtered +} + +// promptTenantSelection prompts the user to select a tenant when multiple tenants are available. +// Returns the selected tenant ID, or empty string if the user chose "All tenants". +// If there is only one tenant, it is returned automatically without prompting. +func promptTenantSelection( + ctx context.Context, + console input.Console, + tenants []tenantInfo, +) (string, error) { + if len(tenants) <= 1 { + if len(tenants) == 1 { + return tenants[0].Id, nil + } + return "", nil + } + + allTenantsLabel := fmt.Sprintf( + "%2d. All tenants", + len(tenants)+1, + ) + + options := make([]string, len(tenants)+1) + for i, t := range tenants { + options[i] = formatTenantOption(i+1, t) + } + options[len(tenants)] = allTenantsLabel + + selectedIndex, err := console.Select(ctx, input.ConsoleOptions{ + Message: "Select a tenant", + Options: options, + }) + if err != nil { + return "", fmt.Errorf("selecting tenant: %w", err) + } + + // Last option = "All tenants" + if selectedIndex == len(tenants) { + return "", nil + } + + return tenants[selectedIndex].Id, nil +} + +// TenantDisplayNameProvider is a function that fetches tenant display names. +type TenantDisplayNameProvider func(ctx context.Context) (map[string]string, error) + +// promptAndFilterByTenant prompts the user to select a tenant when subscriptions span multiple tenants. +// It extracts unique tenants, fetches display names only when needed, and returns filtered subscriptions. +func promptAndFilterByTenant( + ctx context.Context, + console input.Console, + subscriptions []account.Subscription, + getTenantNames TenantDisplayNameProvider, +) ([]account.Subscription, error) { + // Quick check without display names to avoid unnecessary API call + tenants := extractUniqueTenants(subscriptions, nil) + if len(tenants) <= 1 { + return subscriptions, nil + } + + // Only fetch tenant display names when we actually need to prompt + var tenantNames map[string]string + if getTenantNames != nil { + var err error + tenantNames, err = getTenantNames(ctx) + if err != nil { + log.Printf("failed to fetch tenant display names: %v", err) + tenantNames = nil + } + } + + tenants = extractUniqueTenants(subscriptions, tenantNames) + + selectedTenantId, err := promptTenantSelection(ctx, console, tenants) + if err != nil { + return nil, err + } + + return filterSubscriptionsByTenant(subscriptions, selectedTenantId), nil +} + +func formatTenantOption(index int, t tenantInfo) string { + subCountLabel := fmt.Sprintf( + "%d subscription", t.SubscriptionCount, + ) + if t.SubscriptionCount != 1 { + subCountLabel += "s" + } + + return fmt.Sprintf( + "%2d. %s %s", + index, + t.DisplayName, + output.WithGrayFormat("(%s)", subCountLabel), + ) +} diff --git a/cli/azd/pkg/prompt/tenant_helper_test.go b/cli/azd/pkg/prompt/tenant_helper_test.go new file mode 100644 index 00000000000..f52366c06d7 --- /dev/null +++ b/cli/azd/pkg/prompt/tenant_helper_test.go @@ -0,0 +1,302 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package prompt + +import ( + "strings" + "testing" + + "github.com/azure/azure-dev/cli/azd/pkg/account" + "github.com/azure/azure-dev/cli/azd/pkg/azapi" + "github.com/azure/azure-dev/cli/azd/pkg/cloud" + "github.com/azure/azure-dev/cli/azd/pkg/environment" + "github.com/azure/azure-dev/cli/azd/pkg/input" + "github.com/azure/azure-dev/cli/azd/test/mocks" + "github.com/azure/azure-dev/cli/azd/test/mocks/mockaccount" + "github.com/stretchr/testify/require" +) + +func TestExtractUniqueTenants_Empty(t *testing.T) { + tenants := extractUniqueTenants(nil, nil) + require.Empty(t, tenants) +} + +func TestExtractUniqueTenants_SingleTenant(t *testing.T) { + subs := []account.Subscription{ + {Id: "sub-1", UserAccessTenantId: "tid-1"}, + {Id: "sub-2", UserAccessTenantId: "tid-1"}, + } + + tenants := extractUniqueTenants(subs, map[string]string{"tid-1": "Contoso"}) + require.Len(t, tenants, 1) + require.Equal(t, "tid-1", tenants[0].Id) + require.Equal(t, "Contoso", tenants[0].DisplayName) + require.Equal(t, 2, tenants[0].SubscriptionCount) +} + +func TestExtractUniqueTenants_MultipleTenants(t *testing.T) { + subs := []account.Subscription{ + {Id: "sub-1", UserAccessTenantId: "tid-1"}, + {Id: "sub-2", UserAccessTenantId: "tid-2"}, + {Id: "sub-3", UserAccessTenantId: "tid-1"}, + } + + names := map[string]string{ + "tid-1": "Contoso", + "tid-2": "Fabrikam", + } + + tenants := extractUniqueTenants(subs, names) + require.Len(t, tenants, 2) + // Sorted alphabetically by display name + require.Equal(t, "Contoso", tenants[0].DisplayName) + require.Equal(t, 2, tenants[0].SubscriptionCount) + require.Equal(t, "Fabrikam", tenants[1].DisplayName) + require.Equal(t, 1, tenants[1].SubscriptionCount) +} + +func TestExtractUniqueTenants_FallbackToTenantId(t *testing.T) { + subs := []account.Subscription{ + {Id: "sub-1", TenantId: "tid-1", UserAccessTenantId: ""}, + } + + tenants := extractUniqueTenants(subs, nil) + require.Len(t, tenants, 1) + require.Equal(t, "tid-1", tenants[0].Id) + // Display name falls back to the ID when no names provided + require.Equal(t, "tid-1", tenants[0].DisplayName) +} + +func TestExtractUniqueTenants_NoDisplayNames(t *testing.T) { + subs := []account.Subscription{ + {Id: "sub-1", UserAccessTenantId: "tid-1"}, + {Id: "sub-2", UserAccessTenantId: "tid-2"}, + } + + tenants := extractUniqueTenants(subs, nil) + require.Len(t, tenants, 2) + require.Equal(t, "tid-1", tenants[0].DisplayName) + require.Equal(t, "tid-2", tenants[1].DisplayName) +} + +func TestFilterSubscriptionsByTenant_EmptyTenantId(t *testing.T) { + subs := []account.Subscription{ + {Id: "sub-1", UserAccessTenantId: "tid-1"}, + {Id: "sub-2", UserAccessTenantId: "tid-2"}, + } + + result := filterSubscriptionsByTenant(subs, "") + require.Len(t, result, 2) +} + +func TestFilterSubscriptionsByTenant_Filtered(t *testing.T) { + subs := []account.Subscription{ + {Id: "sub-1", UserAccessTenantId: "tid-1"}, + {Id: "sub-2", UserAccessTenantId: "tid-2"}, + {Id: "sub-3", UserAccessTenantId: "tid-1"}, + } + + result := filterSubscriptionsByTenant(subs, "tid-1") + require.Len(t, result, 2) + require.Equal(t, "sub-1", result[0].Id) + require.Equal(t, "sub-3", result[1].Id) +} + +func TestFilterSubscriptionsByTenant_NoMatch(t *testing.T) { + subs := []account.Subscription{ + {Id: "sub-1", UserAccessTenantId: "tid-1"}, + } + + result := filterSubscriptionsByTenant(subs, "tid-unknown") + require.Empty(t, result) +} + +func TestFilterByTenantEnvVar_NotSet(t *testing.T) { + subs := []account.Subscription{ + {Id: "sub-1", UserAccessTenantId: "tid-1"}, + {Id: "sub-2", UserAccessTenantId: "tid-2"}, + } + + result := filterByTenantEnvVar(subs) + require.Len(t, result, 2) +} + +func TestFilterByTenantEnvVar_Set(t *testing.T) { + t.Setenv("AZURE_TENANT_ID", "tid-1") + + subs := []account.Subscription{ + {Id: "sub-1", UserAccessTenantId: "tid-1"}, + {Id: "sub-2", UserAccessTenantId: "tid-2"}, + } + + result := filterByTenantEnvVar(subs) + require.Len(t, result, 1) + require.Equal(t, "sub-1", result[0].Id) +} + +func TestFilterByTenantEnvVar_NoMatchFallsBack(t *testing.T) { + t.Setenv("AZURE_TENANT_ID", "tid-unknown") + + subs := []account.Subscription{ + {Id: "sub-1", UserAccessTenantId: "tid-1"}, + } + + // Falls back to showing all when the env var doesn't match + result := filterByTenantEnvVar(subs) + require.Len(t, result, 1) +} + +func TestPromptTenantSelection_SingleTenant(t *testing.T) { + mockContext := mocks.NewMockContext(t.Context()) + + tenants := []tenantInfo{ + {Id: "tid-1", DisplayName: "Contoso", SubscriptionCount: 3}, + } + + selected, err := promptTenantSelection(t.Context(), mockContext.Console, tenants) + require.NoError(t, err) + require.Equal(t, "tid-1", selected) +} + +func TestPromptTenantSelection_MultipleTenants_SelectFirst(t *testing.T) { + mockContext := mocks.NewMockContext(t.Context()) + + mockContext.Console.WhenSelect(func(opts input.ConsoleOptions) bool { + return strings.Contains(opts.Message, "Select a tenant") + }).Respond(0) // pick first tenant + + tenants := []tenantInfo{ + {Id: "tid-1", DisplayName: "Contoso", SubscriptionCount: 3}, + {Id: "tid-2", DisplayName: "Fabrikam", SubscriptionCount: 1}, + } + + selected, err := promptTenantSelection(t.Context(), mockContext.Console, tenants) + require.NoError(t, err) + require.Equal(t, "tid-1", selected) +} + +func TestPromptTenantSelection_MultipleTenants_SelectAllTenants(t *testing.T) { + mockContext := mocks.NewMockContext(t.Context()) + + mockContext.Console.WhenSelect(func(opts input.ConsoleOptions) bool { + return strings.Contains(opts.Message, "Select a tenant") + }).Respond(2) // pick "All tenants" (third option with 2 tenants) + + tenants := []tenantInfo{ + {Id: "tid-1", DisplayName: "Contoso", SubscriptionCount: 3}, + {Id: "tid-2", DisplayName: "Fabrikam", SubscriptionCount: 1}, + } + + selected, err := promptTenantSelection(t.Context(), mockContext.Console, tenants) + require.NoError(t, err) + require.Empty(t, selected) // empty string = all tenants +} + +func TestPromptTenantSelection_NoTenants(t *testing.T) { + mockContext := mocks.NewMockContext(t.Context()) + + selected, err := promptTenantSelection(t.Context(), mockContext.Console, nil) + require.NoError(t, err) + require.Empty(t, selected) +} + +func TestPromptSubscription_MultiTenant_TenantPickerShown(t *testing.T) { + mockContext := mocks.NewMockContext(t.Context()) + mockAccount := &mockaccount.MockAccountManager{ + Subscriptions: []account.Subscription{ + {Id: "sub-1", Name: "Alpha", UserAccessTenantId: "tid-1"}, + {Id: "sub-2", Name: "Bravo", UserAccessTenantId: "tid-2"}, + {Id: "sub-3", Name: "Charlie", UserAccessTenantId: "tid-1"}, + }, + } + + p, _ := newTestPrompterWithCtx(t, mockAccount, mockContext) + + // First prompt: select tenant (pick tid-1) + mockContext.Console.WhenSelect(func(opts input.ConsoleOptions) bool { + return strings.Contains(opts.Message, "Select a tenant") + }).Respond(0) // first tenant + + // Second prompt: select subscription from filtered list + mockContext.Console.WhenSelect(func(opts input.ConsoleOptions) bool { + return strings.Contains(opts.Message, "Select a subscription") + }).Respond(1) // pick second in filtered list (Charlie after Alpha alphabetically) + + subId, err := p.PromptSubscription(t.Context(), "Select a subscription") + require.NoError(t, err) + // After filtering to tid-1: Alpha (sub-1) and Charlie (sub-3), sorted + require.Equal(t, "sub-3", subId) // Charlie is second alphabetically +} + +func TestPromptSubscription_MultiTenant_AllTenantsOption(t *testing.T) { + mockContext := mocks.NewMockContext(t.Context()) + mockAccount := &mockaccount.MockAccountManager{ + Subscriptions: []account.Subscription{ + {Id: "sub-1", Name: "Alpha", UserAccessTenantId: "tid-1"}, + {Id: "sub-2", Name: "Bravo", UserAccessTenantId: "tid-2"}, + }, + } + + p, _ := newTestPrompterWithCtx(t, mockAccount, mockContext) + + // First prompt: "All tenants" (last option, index 2 with 2 tenants) + mockContext.Console.WhenSelect(func(opts input.ConsoleOptions) bool { + return strings.Contains(opts.Message, "Select a tenant") + }).Respond(2) + + // Second prompt: select subscription from full list + mockContext.Console.WhenSelect(func(opts input.ConsoleOptions) bool { + return strings.Contains(opts.Message, "Select a subscription") + }).Respond(0) // pick first subscription + + subId, err := p.PromptSubscription(t.Context(), "Select a subscription") + require.NoError(t, err) + require.Equal(t, "sub-1", subId) // Alpha is first alphabetically +} + +func TestPromptSubscription_NoPromptMode_SkipsTenantPicker(t *testing.T) { + t.Setenv("AZURE_TENANT_ID", "tid-1") + + mockContext := mocks.NewMockContext(t.Context()) + mockContext.Console.SetNoPromptMode(true) + + mockAccount := &mockaccount.MockAccountManager{ + Subscriptions: []account.Subscription{ + {Id: "sub-1", Name: "Alpha", UserAccessTenantId: "tid-1"}, + {Id: "sub-2", Name: "Bravo", UserAccessTenantId: "tid-2"}, + {Id: "sub-3", Name: "Charlie", UserAccessTenantId: "tid-1"}, + }, + } + + p, _ := newTestPrompterWithCtx(t, mockAccount, mockContext) + + // In no-prompt mode the tenant picker is skipped, but AZURE_TENANT_ID + // filtering still applies. Subscription selection still goes through + // Console.Select (not bypassed by no-prompt in this legacy prompter path). + mockContext.Console.WhenSelect(func(opts input.ConsoleOptions) bool { + return strings.Contains(opts.Message, "Select a subscription") + }).Respond(0) + + subId, err := p.PromptSubscription(t.Context(), "Select a subscription") + require.NoError(t, err) + // Should be filtered to tid-1 only: Alpha and Charlie + require.Equal(t, "sub-1", subId) +} + +func newTestPrompterWithCtx( + t *testing.T, + mockAccount *mockaccount.MockAccountManager, + mockCtx *mocks.MockContext, +) (*DefaultPrompter, *mocks.MockContext) { + t.Helper() + env := environment.New("test") + resourceService := azapi.NewResourceService( + mockCtx.SubscriptionCredentialProvider, mockCtx.ArmClientOptions) + + p := NewDefaultPrompter( + env, mockCtx.Console, mockAccount, resourceService, cloud.AzurePublic(), + ).(*DefaultPrompter) + + return p, mockCtx +} diff --git a/cli/azd/test/mocks/mockaccount/mock_manager.go b/cli/azd/test/mocks/mockaccount/mock_manager.go index 1b1f21a9c6a..44fafd093d9 100644 --- a/cli/azd/test/mocks/mockaccount/mock_manager.go +++ b/cli/azd/test/mocks/mockaccount/mock_manager.go @@ -100,6 +100,23 @@ func (a *MockAccountManager) SetDefaultLocation( return nil, nil } +func (a *MockAccountManager) GetTenantDisplayNames(ctx context.Context) (map[string]string, error) { + result := make(map[string]string) + for _, sub := range a.Subscriptions { + tid := sub.UserAccessTenantId + if tid == "" { + tid = sub.TenantId + } + if tid == "" { + continue + } + if _, exists := result[tid]; !exists { + result[tid] = tid + } + } + return result, nil +} + // SubscriptionCredentialProviderFunc implements [account.SubscriptionCredentialProvider] using the provided function. type SubscriptionCredentialProviderFunc func(ctx context.Context, subscriptionId string) (azcore.TokenCredential, error) diff --git a/cli/azd/test/mocks/mockaccount/mock_subscriptions.go b/cli/azd/test/mocks/mockaccount/mock_subscriptions.go index 60d51a722ae..0034ed0c5da 100644 --- a/cli/azd/test/mocks/mockaccount/mock_subscriptions.go +++ b/cli/azd/test/mocks/mockaccount/mock_subscriptions.go @@ -24,6 +24,14 @@ func (m *MockSubscriptionManager) ListLocations(ctx context.Context, subscriptio return args.Get(0).([]account.Location), args.Error(1) } +func (m *MockSubscriptionManager) GetTenantDisplayNames(ctx context.Context) (map[string]string, error) { + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(map[string]string), args.Error(1) +} + func (m *MockSubscriptionManager) GetLocations(ctx context.Context, subscriptionId string) ([]account.Location, error) { args := m.Called(ctx, subscriptionId) return args.Get(0).([]account.Location), args.Error(1)