Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cli/azd/pkg/account/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type Manager interface {
GetSubscriptions(ctx context.Context) ([]Subscription, error)
GetSubscriptionsWithDefaultSet(ctx context.Context) ([]Subscription, error)
GetLocations(ctx context.Context, subscriptionId string) ([]Location, error)
GetTenantDisplayNames(ctx context.Context) (map[string]string, error)
SetDefaultSubscription(ctx context.Context, subscriptionId string) (*Subscription, error)
SetDefaultLocation(ctx context.Context, subscriptionId string, location string) (*Location, error)
}
Expand Down Expand Up @@ -140,6 +141,11 @@ func (m *manager) GetSubscriptions(ctx context.Context) ([]Subscription, error)
return m.subManager.GetSubscriptions(ctx)
}

// GetTenantDisplayNames returns a map of tenant ID to display name.
func (m *manager) GetTenantDisplayNames(ctx context.Context) (map[string]string, error) {
return m.subManager.GetTenantDisplayNames(ctx)
}

// Gets the available Azure locations for the specified Azure subscription.
func (m *manager) GetLocations(ctx context.Context, subscriptionId string) ([]Location, error) {
locations, err := m.subManager.ListLocations(ctx, subscriptionId)
Expand Down
22 changes: 22 additions & 0 deletions cli/azd/pkg/account/subscriptions_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,28 @@ func (m *SubscriptionsManager) getSubscription(ctx context.Context, subscription
return &sub, nil
}

// GetTenantDisplayNames returns a map of tenant ID to display name for all tenants
// accessible by the current account.
func (m *SubscriptionsManager) GetTenantDisplayNames(ctx context.Context) (map[string]string, error) {
tenants, err := m.service.ListTenants(ctx)
if err != nil {
return nil, fmt.Errorf("listing tenants: %w", err)
}

result := make(map[string]string, len(tenants))
for _, t := range tenants {
if t.TenantID != nil {
name := *t.TenantID
if t.DisplayName != nil && *t.DisplayName != "" {
name = *t.DisplayName
}
result[*t.TenantID] = name
}
}

return result, nil
}

func toSubscriptions(azSubs []*armsubscriptions.Subscription, userAccessTenantId string) []Subscription {
if azSubs == nil {
return nil
Expand Down
79 changes: 59 additions & 20 deletions cli/azd/pkg/prompt/prompt_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ type SelectOptions struct {
HelpMessage string
// LoadingMessage is the loading message to display to the user.
LoadingMessage string
// SkipLoadingSpinner skips the loading spinner in PromptCustomResource.
// Use this when data is pre-loaded and LoadData returns immediately.
SkipLoadingSpinner bool
Comment thread
vhvb1989 marked this conversation as resolved.
// DisplayNumbers specifies whether to display numbers next to the choices.
DisplayNumbers *bool
// DisplayCount is the number of choices to display at a time.
Expand Down Expand Up @@ -157,6 +160,7 @@ type ResourceService interface {
type SubscriptionManager interface {
GetSubscriptions(ctx context.Context) ([]account.Subscription, error)
GetLocations(ctx context.Context, subscriptionId string) ([]account.Location, error)
GetTenantDisplayNames(ctx context.Context) (map[string]string, error)
}

// PromptServiceInterface defines the methods that the PromptService must implement.
Expand Down Expand Up @@ -211,6 +215,8 @@ func NewPromptService(
}

// PromptSubscription prompts the user to select an Azure subscription.
// If the user has access to multiple tenants, a tenant selection prompt is shown first
// to scope down the subscription list.
func (ps *promptService) PromptSubscription(
ctx context.Context,
selectorOptions *SelectOptions,
Expand All @@ -235,6 +241,31 @@ func (ps *promptService) PromptSubscription(
return nil, err
}

// Load subscriptions under a spinner first
var subscriptionList []account.Subscription
loadingSpinner := ux.NewSpinner(&ux.SpinnerOptions{
Text: mergedOptions.LoadingMessage,
})
Comment thread
vhvb1989 marked this conversation as resolved.

err := loadingSpinner.Run(ctx, func(ctx context.Context) error {
var loadErr error
subscriptionList, loadErr = ps.subscriptionManager.GetSubscriptions(ctx)
return loadErr
})
if err != nil {
return nil, fmt.Errorf("listing subscriptions: %w", err)
}
Comment thread
vhvb1989 marked this conversation as resolved.

// Apply tenant filtering (after spinner is done so the prompt doesn't overlap)
subscriptionList = filterByTenantEnvVar(subscriptionList)
if !ps.console.IsNoPromptMode() {
subscriptionList, err = promptAndFilterByTenant(
ctx, ps.console, subscriptionList, ps.subscriptionManager.GetTenantDisplayNames)
if err != nil {
return nil, err
}
Comment thread
vhvb1989 marked this conversation as resolved.
}

// Get default subscription from user config
var defaultSubscriptionId = ""
userConfig, err := ps.userConfigManager.Load()
Expand All @@ -247,19 +278,19 @@ func (ps *promptService) PromptSubscription(

hideId := isDemoModeEnabled()

return PromptCustomResource(ctx, CustomResourceOptions[account.Subscription]{
SelectorOptions: mergedOptions,
LoadData: func(ctx context.Context) ([]*account.Subscription, error) {
subscriptionList, err := ps.subscriptionManager.GetSubscriptions(ctx)
if err != nil {
return nil, err
}
// Use PromptCustomResource with pre-loaded data
subscriptions := make([]*account.Subscription, len(subscriptionList))
for i := range subscriptionList {
subscriptions[i] = &subscriptionList[i]
}

subscriptions := make([]*account.Subscription, len(subscriptionList))
for i, subscription := range subscriptionList {
subscriptions[i] = &subscription
}
// Create selector options with spinner disabled since data is already loaded
resourceSelectorOptions := *mergedOptions
resourceSelectorOptions.SkipLoadingSpinner = true

return PromptCustomResource(ctx, CustomResourceOptions[account.Subscription]{
SelectorOptions: &resourceSelectorOptions,
LoadData: func(ctx context.Context) ([]*account.Subscription, error) {
return subscriptions, nil
},
Comment thread
vhvb1989 marked this conversation as resolved.
DisplayResource: func(subscription *account.Subscription) (string, error) {
Expand Down Expand Up @@ -768,21 +799,29 @@ func PromptCustomResource[T any](ctx context.Context, options CustomResourceOpti
allowNewResource = true
selectedIndex = new(0)
} else {
loadingSpinner := ux.NewSpinner(&ux.SpinnerOptions{
Text: options.SelectorOptions.LoadingMessage,
})

err := loadingSpinner.Run(ctx, func(ctx context.Context) error {
loadData := func(ctx context.Context) error {
resourceList, err := options.LoadData(ctx)
if err != nil {
return err
}

resources = resourceList
return nil
})
if err != nil {
return nil, err
}

// Skip the spinner when data is pre-loaded
if mergedSelectorOptions.SkipLoadingSpinner {
if err := loadData(ctx); err != nil {
return nil, err
}
Comment thread
vhvb1989 marked this conversation as resolved.
} else {
loadingSpinner := ux.NewSpinner(&ux.SpinnerOptions{
Text: mergedSelectorOptions.LoadingMessage,
})
if err := loadingSpinner.Run(ctx, func(ctx context.Context) error {
return loadData(ctx)
}); err != nil {
return nil, err
}
}

if !allowNewResource && len(resources) == 0 {
Expand Down
21 changes: 21 additions & 0 deletions cli/azd/pkg/prompt/prompt_service_extra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,27 @@ func TestPromptCustomResource_NilSelectorOptions_UsesDefaultsAndForce(t *testing
require.Equal(t, 42, *result)
}

func TestPromptCustomResource_SkipLoadingSpinner(t *testing.T) {
t.Parallel()

loaded := false
_, err := PromptCustomResource(t.Context(), CustomResourceOptions[string]{
SelectorOptions: &SelectOptions{
SkipLoadingSpinner: true,
AllowNewResource: new(false),
},
LoadData: func(ctx context.Context) ([]*string, error) {
loaded = true
return nil, nil
},
})

// LoadData should have been called directly (without spinner)
require.True(t, loaded)
// No resources and AllowNewResource=false returns the sentinel error
require.ErrorIs(t, err, ErrNoResourcesFound)
}

// helpers

func emptySubs() []account.Subscription { return []account.Subscription{} }
99 changes: 57 additions & 42 deletions cli/azd/pkg/prompt/prompter.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"log"
"os"
"slices"
"strconv"

"github.com/MakeNowJust/heredoc/v2"
"github.com/azure/azure-dev/cli/azd/pkg/account"
Expand Down Expand Up @@ -71,12 +70,12 @@ func NewDefaultPrompter(
}

func (p *DefaultPrompter) PromptSubscription(ctx context.Context, msg string) (subscriptionId string, err error) {
subscriptionOptions, subscriptions, defaultSubscription, err := p.getSubscriptionOptions(ctx)
subscriptionInfos, err := p.accountManager.GetSubscriptions(ctx)
if err != nil {
return "", err
return "", fmt.Errorf("listing subscriptions: %w", err)
}

if len(subscriptionOptions) == 0 {
if len(subscriptionInfos) == 0 {
// NOTE: Error text must contain "no subscriptions found" to match the
// pattern in error_suggestions.yaml. Update both if rewording.
return "", errors.New(heredoc.Docf(
Expand All @@ -87,6 +86,32 @@ func (p *DefaultPrompter) PromptSubscription(ctx context.Context, msg string) (s
))
}

// Filter by AZURE_TENANT_ID if set (works in both prompt and no-prompt modes)
subscriptionInfos = filterByTenantEnvVar(subscriptionInfos)

// Tenant selection: if multiple tenants, prompt user to pick one
if !p.console.IsNoPromptMode() {
subscriptionInfos, err = promptAndFilterByTenant(
ctx, p.console, subscriptionInfos, p.accountManager.GetTenantDisplayNames)
if err != nil {
return "", err
}
}
Comment thread
vhvb1989 marked this conversation as resolved.

slices.SortFunc(subscriptionInfos, func(a, b account.Subscription) int {
return stringutil.CompareLower(a.Name, b.Name)
})

// The default value is based on AZURE_SUBSCRIPTION_ID, falling back to whatever default subscription in
// set in azd's config.
defaultSubscriptionId := os.Getenv(environment.SubscriptionIdEnvVarName)
if defaultSubscriptionId == "" {
defaultSubscriptionId = p.accountManager.GetDefaultSubscriptionID(ctx)
}
Comment thread
vhvb1989 marked this conversation as resolved.

subscriptionOptions, subscriptions, defaultSubscription :=
formatSubscriptionOptions(subscriptionInfos, defaultSubscriptionId)

for subscriptionId == "" {
subscriptionSelectionIndex, err := p.console.Select(ctx, input.ConsoleOptions{
Message: msg,
Expand All @@ -110,6 +135,34 @@ func (p *DefaultPrompter) PromptSubscription(ctx context.Context, msg string) (s
return subscriptionId, nil
}

// formatSubscriptionOptions formats subscription infos into display options.
func formatSubscriptionOptions(
subscriptionInfos []account.Subscription,
defaultSubscriptionId string,
) (options []string, ids []string, defaultOption any) {
options = make([]string, len(subscriptionInfos))
ids = make([]string, len(subscriptionInfos))

hideId := isDemoModeEnabled()

for index, info := range subscriptionInfos {
if hideId {
options[index] = fmt.Sprintf("%2d. %s", index+1, info.Name)
} else {
options[index] = fmt.Sprintf(
"%2d. %s (%s)", index+1, info.Name, info.Id)
}
Comment thread
vhvb1989 marked this conversation as resolved.

ids[index] = info.Id

if info.Id == defaultSubscriptionId {
defaultOption = options[index]
}
}

return options, ids, defaultOption
}

func (p *DefaultPrompter) PromptLocation(
ctx context.Context,
subId string,
Expand Down Expand Up @@ -246,44 +299,6 @@ func (p *DefaultPrompter) PromptResourceGroupFrom(
return name, nil
}

func (p *DefaultPrompter) getSubscriptionOptions(ctx context.Context) ([]string, []string, any, error) {
subscriptionInfos, err := p.accountManager.GetSubscriptions(ctx)
if err != nil {
return nil, nil, nil, fmt.Errorf("listing accounts: %w", err)
}

slices.SortFunc(subscriptionInfos, func(a, b account.Subscription) int {
return stringutil.CompareLower(a.Name, b.Name)
})

// The default value is based on AZURE_SUBSCRIPTION_ID, falling back to whatever default subscription in
// set in azd's config.
defaultSubscriptionId := os.Getenv(environment.SubscriptionIdEnvVarName)
if defaultSubscriptionId == "" {
defaultSubscriptionId = p.accountManager.GetDefaultSubscriptionID(ctx)
}

var subscriptionOptions = make([]string, len(subscriptionInfos))
var subscriptions = make([]string, len(subscriptionInfos))
var defaultSubscription any

for index, info := range subscriptionInfos {
if v, err := strconv.ParseBool(os.Getenv("AZD_DEMO_MODE")); err == nil && v {
subscriptionOptions[index] = fmt.Sprintf("%2d. %s", index+1, info.Name)
} else {
subscriptionOptions[index] = fmt.Sprintf("%2d. %s (%s)", index+1, info.Name, info.Id)
}

subscriptions[index] = info.Id

if info.Id == defaultSubscriptionId {
defaultSubscription = subscriptionOptions[index]
}
}

return subscriptionOptions, subscriptions, defaultSubscription, nil
}

func (p *DefaultPrompter) IsNoPromptMode() bool {
return p.console.IsNoPromptMode()
}
Loading
Loading