Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cli/azd/pkg/account/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type Manager interface {
GetSubscriptions(ctx context.Context) ([]Subscription, error)
GetSubscriptionsWithDefaultSet(ctx context.Context) ([]Subscription, error)
GetLocations(ctx context.Context, subscriptionId string) ([]Location, error)
GetTenantDisplayNames(ctx context.Context) (map[string]string, error)
SetDefaultSubscription(ctx context.Context, subscriptionId string) (*Subscription, error)
SetDefaultLocation(ctx context.Context, subscriptionId string, location string) (*Location, error)
}
Expand Down Expand Up @@ -140,6 +141,11 @@ func (m *manager) GetSubscriptions(ctx context.Context) ([]Subscription, error)
return m.subManager.GetSubscriptions(ctx)
}

// GetTenantDisplayNames returns a map of tenant ID to display name.
func (m *manager) GetTenantDisplayNames(ctx context.Context) (map[string]string, error) {
return m.subManager.GetTenantDisplayNames(ctx)
}

// Gets the available Azure locations for the specified Azure subscription.
func (m *manager) GetLocations(ctx context.Context, subscriptionId string) ([]Location, error) {
locations, err := m.subManager.ListLocations(ctx, subscriptionId)
Expand Down
22 changes: 22 additions & 0 deletions cli/azd/pkg/account/subscriptions_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,28 @@ func (m *SubscriptionsManager) getSubscription(ctx context.Context, subscription
return &sub, nil
}

// GetTenantDisplayNames returns a map of tenant ID to display name for all tenants
// accessible by the current account.
func (m *SubscriptionsManager) GetTenantDisplayNames(ctx context.Context) (map[string]string, error) {
tenants, err := m.service.ListTenants(ctx)
if err != nil {
return nil, fmt.Errorf("listing tenants: %w", err)
}

result := make(map[string]string, len(tenants))
for _, t := range tenants {
if t.TenantID != nil {
name := *t.TenantID
if t.DisplayName != nil && *t.DisplayName != "" {
name = *t.DisplayName
}
result[*t.TenantID] = name
}
}

return result, nil
}

func toSubscriptions(azSubs []*armsubscriptions.Subscription, userAccessTenantId string) []Subscription {
if azSubs == nil {
return nil
Expand Down
100 changes: 81 additions & 19 deletions cli/azd/pkg/prompt/prompt_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"fmt"
"io"
"log"
"os"
"slices"
"strconv"
Expand Down Expand Up @@ -157,6 +158,7 @@ type ResourceService interface {
type SubscriptionManager interface {
GetSubscriptions(ctx context.Context) ([]account.Subscription, error)
GetLocations(ctx context.Context, subscriptionId string) ([]account.Location, error)
GetTenantDisplayNames(ctx context.Context) (map[string]string, error)
}

// PromptServiceInterface defines the methods that the PromptService must implement.
Expand Down Expand Up @@ -211,6 +213,8 @@ func NewPromptService(
}

// PromptSubscription prompts the user to select an Azure subscription.
// If the user has access to multiple tenants, a tenant selection prompt is shown first
// to scope down the subscription list.
func (ps *promptService) PromptSubscription(
ctx context.Context,
selectorOptions *SelectOptions,
Expand All @@ -235,6 +239,30 @@ func (ps *promptService) PromptSubscription(
return nil, err
}

// Load subscriptions under a spinner first
var subscriptionList []account.Subscription
loadingSpinner := ux.NewSpinner(&ux.SpinnerOptions{
Text: mergedOptions.LoadingMessage,
})
Comment thread
vhvb1989 marked this conversation as resolved.

err := loadingSpinner.Run(ctx, func(ctx context.Context) error {
var loadErr error
subscriptionList, loadErr = ps.subscriptionManager.GetSubscriptions(ctx)
return loadErr
})
if err != nil {
return nil, err
}
Comment thread
vhvb1989 marked this conversation as resolved.

// Apply tenant filtering (after spinner is done so the prompt doesn't overlap)
subscriptionList = filterByTenantEnvVar(subscriptionList)
if !ps.console.IsNoPromptMode() {
subscriptionList, err = ps.promptAndFilterByTenant(ctx, subscriptionList)
if err != nil {
return nil, err
}
Comment thread
vhvb1989 marked this conversation as resolved.
}

// Get default subscription from user config
var defaultSubscriptionId = ""
userConfig, err := ps.userConfigManager.Load()
Expand All @@ -247,19 +275,18 @@ func (ps *promptService) PromptSubscription(

hideId := isDemoModeEnabled()

// Use PromptCustomResource with pre-loaded data
subscriptions := make([]*account.Subscription, len(subscriptionList))
for i := range subscriptionList {
subscriptions[i] = &subscriptionList[i]
}

// Clear loading message since data is already loaded (avoids a redundant spinner)
mergedOptions.LoadingMessage = ""

return PromptCustomResource(ctx, CustomResourceOptions[account.Subscription]{
SelectorOptions: mergedOptions,
LoadData: func(ctx context.Context) ([]*account.Subscription, error) {
subscriptionList, err := ps.subscriptionManager.GetSubscriptions(ctx)
if err != nil {
return nil, err
}

subscriptions := make([]*account.Subscription, len(subscriptionList))
for i, subscription := range subscriptionList {
subscriptions[i] = &subscription
}

return subscriptions, nil
},
Comment thread
vhvb1989 marked this conversation as resolved.
DisplayResource: func(subscription *account.Subscription) (string, error) {
Expand All @@ -271,6 +298,33 @@ func (ps *promptService) PromptSubscription(
})
}

// promptAndFilterByTenant prompts the user to select a tenant when subscriptions span multiple tenants.
func (ps *promptService) promptAndFilterByTenant(
ctx context.Context,
subscriptions []account.Subscription,
) ([]account.Subscription, error) {
tenants := extractUniqueTenants(subscriptions, nil)
if len(tenants) <= 1 {
return subscriptions, nil
}

// Only fetch tenant display names when we actually need to prompt
tenantNames, err := ps.subscriptionManager.GetTenantDisplayNames(ctx)
if err != nil {
log.Printf("failed to fetch tenant display names, using tenant IDs: %v", err)
tenantNames = map[string]string{}
}

tenants = extractUniqueTenants(subscriptions, tenantNames)

selectedTenantId, err := promptTenantSelection(ctx, ps.console, tenants)
if err != nil {
return nil, err
}

return filterSubscriptionsByTenant(subscriptions, selectedTenantId), nil
}

// PromptLocation prompts the user to select an Azure location.
func (ps *promptService) PromptLocation(
ctx context.Context,
Expand Down Expand Up @@ -768,21 +822,29 @@ func PromptCustomResource[T any](ctx context.Context, options CustomResourceOpti
allowNewResource = true
selectedIndex = new(0)
} else {
loadingSpinner := ux.NewSpinner(&ux.SpinnerOptions{
Text: options.SelectorOptions.LoadingMessage,
})

err := loadingSpinner.Run(ctx, func(ctx context.Context) error {
loadData := func(ctx context.Context) error {
resourceList, err := options.LoadData(ctx)
if err != nil {
return err
}

resources = resourceList
return nil
})
if err != nil {
return nil, err
}

// Skip the spinner when loading message is empty (data is pre-loaded)
if mergedSelectorOptions.LoadingMessage != "" {
loadingSpinner := ux.NewSpinner(&ux.SpinnerOptions{
Text: mergedSelectorOptions.LoadingMessage,
})
if err := loadingSpinner.Run(ctx, func(ctx context.Context) error {
return loadData(ctx)
}); err != nil {
return nil, err
Comment thread
vhvb1989 marked this conversation as resolved.
Outdated
}
} else {
if err := loadData(ctx); err != nil {
return nil, err
}
}

if !allowNewResource && len(resources) == 0 {
Expand Down
122 changes: 81 additions & 41 deletions cli/azd/pkg/prompt/prompter.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ func NewDefaultPrompter(
}

func (p *DefaultPrompter) PromptSubscription(ctx context.Context, msg string) (subscriptionId string, err error) {
subscriptionOptions, subscriptions, defaultSubscription, err := p.getSubscriptionOptions(ctx)
subscriptionInfos, err := p.accountManager.GetSubscriptions(ctx)
if err != nil {
return "", err
return "", fmt.Errorf("listing accounts: %w", err)
Comment thread
vhvb1989 marked this conversation as resolved.
Outdated
}

if len(subscriptionOptions) == 0 {
if len(subscriptionInfos) == 0 {
// NOTE: Error text must contain "no subscriptions found" to match the
// pattern in error_suggestions.yaml. Update both if rewording.
return "", errors.New(heredoc.Docf(
Expand All @@ -87,6 +87,31 @@ func (p *DefaultPrompter) PromptSubscription(ctx context.Context, msg string) (s
))
}

// Filter by AZURE_TENANT_ID if set (works in both prompt and no-prompt modes)
subscriptionInfos = filterByTenantEnvVar(subscriptionInfos)

// Tenant selection: if multiple tenants, prompt user to pick one
if !p.console.IsNoPromptMode() {
subscriptionInfos, err = p.promptAndFilterByTenant(ctx, subscriptionInfos)
if err != nil {
return "", err
}
}
Comment thread
vhvb1989 marked this conversation as resolved.

slices.SortFunc(subscriptionInfos, func(a, b account.Subscription) int {
return stringutil.CompareLower(a.Name, b.Name)
})

// The default value is based on AZURE_SUBSCRIPTION_ID, falling back to whatever default subscription in
// set in azd's config.
defaultSubscriptionId := os.Getenv(environment.SubscriptionIdEnvVarName)
if defaultSubscriptionId == "" {
defaultSubscriptionId = p.accountManager.GetDefaultSubscriptionID(ctx)
}
Comment thread
vhvb1989 marked this conversation as resolved.

subscriptionOptions, subscriptions, defaultSubscription :=
formatSubscriptionOptions(subscriptionInfos, defaultSubscriptionId)

for subscriptionId == "" {
subscriptionSelectionIndex, err := p.console.Select(ctx, input.ConsoleOptions{
Message: msg,
Expand All @@ -110,6 +135,59 @@ func (p *DefaultPrompter) PromptSubscription(ctx context.Context, msg string) (s
return subscriptionId, nil
}

// promptAndFilterByTenant prompts the user to select a tenant when subscriptions span multiple tenants.
func (p *DefaultPrompter) promptAndFilterByTenant(
ctx context.Context,
subscriptions []account.Subscription,
) ([]account.Subscription, error) {
// Quick check without display names to avoid unnecessary API call
tenants := extractUniqueTenants(subscriptions, nil)
if len(tenants) <= 1 {
return subscriptions, nil
}

// Only fetch tenant display names when we actually need to prompt
tenantNames, err := p.accountManager.GetTenantDisplayNames(ctx)
if err != nil {
log.Printf("failed to fetch tenant display names, using tenant IDs: %v", err)
tenantNames = map[string]string{}
}

tenants = extractUniqueTenants(subscriptions, tenantNames)

selectedTenantId, err := promptTenantSelection(ctx, p.console, tenants)
if err != nil {
return nil, err
}

return filterSubscriptionsByTenant(subscriptions, selectedTenantId), nil
}
Comment thread
vhvb1989 marked this conversation as resolved.
Outdated

// formatSubscriptionOptions formats subscription infos into display options.
func formatSubscriptionOptions(
subscriptionInfos []account.Subscription,
defaultSubscriptionId string,
) (options []string, ids []string, defaultOption any) {
options = make([]string, len(subscriptionInfos))
ids = make([]string, len(subscriptionInfos))

for index, info := range subscriptionInfos {
if v, err := strconv.ParseBool(os.Getenv("AZD_DEMO_MODE")); err == nil && v {
options[index] = fmt.Sprintf("%2d. %s", index+1, info.Name)
} else {
options[index] = fmt.Sprintf("%2d. %s (%s)", index+1, info.Name, info.Id)
}
Comment thread
vhvb1989 marked this conversation as resolved.

ids[index] = info.Id

if info.Id == defaultSubscriptionId {
defaultOption = options[index]
}
}

return options, ids, defaultOption
}

func (p *DefaultPrompter) PromptLocation(
ctx context.Context,
subId string,
Expand Down Expand Up @@ -246,44 +324,6 @@ func (p *DefaultPrompter) PromptResourceGroupFrom(
return name, nil
}

func (p *DefaultPrompter) getSubscriptionOptions(ctx context.Context) ([]string, []string, any, error) {
subscriptionInfos, err := p.accountManager.GetSubscriptions(ctx)
if err != nil {
return nil, nil, nil, fmt.Errorf("listing accounts: %w", err)
}

slices.SortFunc(subscriptionInfos, func(a, b account.Subscription) int {
return stringutil.CompareLower(a.Name, b.Name)
})

// The default value is based on AZURE_SUBSCRIPTION_ID, falling back to whatever default subscription in
// set in azd's config.
defaultSubscriptionId := os.Getenv(environment.SubscriptionIdEnvVarName)
if defaultSubscriptionId == "" {
defaultSubscriptionId = p.accountManager.GetDefaultSubscriptionID(ctx)
}

var subscriptionOptions = make([]string, len(subscriptionInfos))
var subscriptions = make([]string, len(subscriptionInfos))
var defaultSubscription any

for index, info := range subscriptionInfos {
if v, err := strconv.ParseBool(os.Getenv("AZD_DEMO_MODE")); err == nil && v {
subscriptionOptions[index] = fmt.Sprintf("%2d. %s", index+1, info.Name)
} else {
subscriptionOptions[index] = fmt.Sprintf("%2d. %s (%s)", index+1, info.Name, info.Id)
}

subscriptions[index] = info.Id

if info.Id == defaultSubscriptionId {
defaultSubscription = subscriptionOptions[index]
}
}

return subscriptionOptions, subscriptions, defaultSubscription, nil
}

func (p *DefaultPrompter) IsNoPromptMode() bool {
return p.console.IsNoPromptMode()
}
Loading
Loading