diff --git a/internal/actionplan/actionplan_test.go b/internal/actionplan/actionplan_test.go index 103e29a..6d7d656 100644 --- a/internal/actionplan/actionplan_test.go +++ b/internal/actionplan/actionplan_test.go @@ -62,13 +62,13 @@ var mockCommandFailOnDemand CommandExecutor = func(command string, workingDir st // implements IHostGaCommunicator type NoopHostGaCommunicator struct{} -func (downloader *NoopHostGaCommunicator) DownloadPackage(logger *logging.ExtensionLogger, appName string, dst string) error { +func (downloader *NoopHostGaCommunicator) DownloadPackage(logger *logging.ExtensionLogger, appName string, appVersion string, dst string) error { return nil } -func (downloader *NoopHostGaCommunicator) DownloadConfig(logger *logging.ExtensionLogger, appName string, dst string) error { +func (downloader *NoopHostGaCommunicator) DownloadConfig(logger *logging.ExtensionLogger, appName string, appVersion string, dst string) error { return nil } -func (downloader *NoopHostGaCommunicator) GetVMAppInfo(logger *logging.ExtensionLogger, appName string) (*hostgacommunicator.VMAppMetadata, error) { +func (downloader *NoopHostGaCommunicator) GetVMAppInfo(logger *logging.ExtensionLogger, appName string, appVersion string) (*hostgacommunicator.VMAppMetadata, error) { return nil, nil } diff --git a/internal/actionplan/executehelper.go b/internal/actionplan/executehelper.go index b96d742..bfbcfbe 100644 --- a/internal/actionplan/executehelper.go +++ b/internal/actionplan/executehelper.go @@ -96,7 +96,7 @@ func (actionPlan *ActionPlan) executeHelper(registryHandler packageregistry.IPac if err == nil { // download packages downloadPackageFileName := path.Join(downloadPath, vmAppPackageCurrent.PackageFileName) - if err := actionPlan.hostGaCommunicator.DownloadPackage(actionPlan.logger, vmAppPackageCurrent.ApplicationName, downloadPackageFileName); err != nil { + if err := actionPlan.hostGaCommunicator.DownloadPackage(actionPlan.logger, vmAppPackageCurrent.ApplicationName, version, downloadPackageFileName); err != nil { actionPlan.logger.Error("Failed to download package for application %v, version %v. Error: %v", appName, version, err.Error()) errorMessageToReturn = extensionerrors.CombineErrors(errorMessageToReturn, errors.Wrapf(err, "failed to download package file %s", downloadPackageFileName)) } @@ -111,7 +111,7 @@ func (actionPlan *ActionPlan) executeHelper(registryHandler packageregistry.IPac // download configuration if vmAppPackageCurrent.ConfigExists { downloadConfigFileName := path.Join(downloadPath, vmAppPackageCurrent.ConfigFileName) - if err := actionPlan.hostGaCommunicator.DownloadConfig(actionPlan.logger, vmAppPackageCurrent.ApplicationName, downloadConfigFileName); err != nil { + if err := actionPlan.hostGaCommunicator.DownloadConfig(actionPlan.logger, vmAppPackageCurrent.ApplicationName, version, downloadConfigFileName); err != nil { actionPlan.logger.Error("Failed to download config for application %v, version %v. Error: %v", appName, version, err.Error()) errorMessageToReturn = extensionerrors.CombineErrors(errorMessageToReturn, errors.Wrapf(err, "failed to download config file %s", downloadConfigFileName)) } diff --git a/internal/actionplan/executehelper_test.go b/internal/actionplan/executehelper_test.go index 28b0906..04f1a41 100644 --- a/internal/actionplan/executehelper_test.go +++ b/internal/actionplan/executehelper_test.go @@ -27,17 +27,21 @@ type mockHostGaCommunicator struct { DownloadConfigCount int } -func (mockCommunicator *mockHostGaCommunicator) GetVMAppInfo(el *logging.ExtensionLogger, appName string) (*hostgacommunicator.VMAppMetadata, error) { +func (mockCommunicator *mockHostGaCommunicator) GetVMAppInfo(el *logging.ExtensionLogger, appName string, appVersion string) (*hostgacommunicator.VMAppMetadata, error) { return &hostgacommunicator.VMAppMetadata{}, nil } -func (mockCommunicator *mockHostGaCommunicator) DownloadPackage(el *logging.ExtensionLogger, appName string, dst string) error { - mockCommunicator.DownloadPackageCount++ +func (mockCommunicator *mockHostGaCommunicator) DownloadPackage(el *logging.ExtensionLogger, appName string, appVersion string, dst string) error { + if appVersion == vmAppPackageCurrent.Version { + mockCommunicator.DownloadPackageCount++ + } return copyFile(mockCommunicator.pkgFileSourcePath, dst) } -func (mockCommunicator *mockHostGaCommunicator) DownloadConfig(el *logging.ExtensionLogger, appName string, dst string) error { - mockCommunicator.DownloadConfigCount++ +func (mockCommunicator *mockHostGaCommunicator) DownloadConfig(el *logging.ExtensionLogger, appName string, appVersion string, dst string) error { + if appVersion == vmAppPackageCurrent.Version { + mockCommunicator.DownloadConfigCount++ + } return copyFile(mockCommunicator.configFileSourcePath, dst) } @@ -88,7 +92,7 @@ var extensionLogger = logging.New(nil) var extensionEventManager = extensionevents.New(extensionLogger, &handlerEnvironment) var vmAppPackageCurrent = packageregistry.VMAppPackageCurrent{ ApplicationName: "test app", - Version: "1.0.0", + Version: "1.5.0", InstallCommand: "install", RemoveCommand: "remove", UpdateCommand: "update", diff --git a/internal/extdeserialization/extdeserialization.go b/internal/extdeserialization/extdeserialization.go index 22307c3..ae540d1 100644 --- a/internal/extdeserialization/extdeserialization.go +++ b/internal/extdeserialization/extdeserialization.go @@ -25,6 +25,7 @@ type VmAppSetting struct { Order *int `json:"order"` TreatFailureAsDeploymentFailure bool `json:"treatFailureAsDeploymentFailure"` Actions []*ActionSetting `json:"actions"` + Version string `json:"version"` } func GetParameterNames(settings ActionSetting) []string { diff --git a/internal/hostgacommunicator/download.go b/internal/hostgacommunicator/download.go index 2ef6413..04370f3 100644 --- a/internal/hostgacommunicator/download.go +++ b/internal/hostgacommunicator/download.go @@ -2,12 +2,13 @@ package hostgacommunicator import ( "fmt" - "github.com/Azure/azure-extension-platform/pkg/constants" "io" "net/http" "os" "time" + "github.com/Azure/azure-extension-platform/pkg/constants" + "github.com/Azure/VMApplication-Extension/internal/requesthelper" "github.com/Azure/azure-extension-platform/pkg/logging" "github.com/pkg/errors" @@ -34,8 +35,8 @@ type downloadRequestFactory struct { downloadedBytes int64 } -func newPackageDownloadRequestFactory(el *logging.ExtensionLogger, appName string) (*downloadRequestFactory, error) { - downloadURL, err := getOperationURI(el, appName, packageOperation) +func newPackageDownloadRequestFactory(el *logging.ExtensionLogger, appName string, appVersion string) (*downloadRequestFactory, error) { + downloadURL, err := getOperationURI(el, appName, appVersion, packageOperation) if err != nil { return nil, errors.Wrapf(err, "failed to obtain operationURI") } @@ -48,8 +49,8 @@ func newPackageDownloadRequestFactory(el *logging.ExtensionLogger, appName strin return &drf, nil } -func newConfigDownloadRequestFactory(el *logging.ExtensionLogger, appName string) (*downloadRequestFactory, error) { - downloadURL, err := getOperationURI(el, appName, configOperation) +func newConfigDownloadRequestFactory(el *logging.ExtensionLogger, appName string, appVersion string) (*downloadRequestFactory, error) { + downloadURL, err := getOperationURI(el, appName, appVersion, configOperation) if err != nil { return nil, errors.Wrapf(err, "failed to obtain operationURI") } diff --git a/internal/hostgacommunicator/hostgacommunicator.go b/internal/hostgacommunicator/hostgacommunicator.go index cc71bd0..3761bc2 100644 --- a/internal/hostgacommunicator/hostgacommunicator.go +++ b/internal/hostgacommunicator/hostgacommunicator.go @@ -3,21 +3,23 @@ package hostgacommunicator import ( "encoding/json" "fmt" + "net/url" + "os" + "github.com/Azure/VMApplication-Extension/internal/requesthelper" "github.com/Azure/azure-extension-platform/pkg/logging" "github.com/pkg/errors" - "net/url" - "os" ) const hostGaPluginPort = "32526" const WireProtocolAddress = "AZURE_GUEST_AGENT_WIRE_PROTOCOL_ADDRESS" const wireServerFallbackAddress = "http://168.63.129.16:32526" +const versionQueryParameterForHostGaRequests = "version" type IHostGaCommunicator interface { - DownloadPackage(el *logging.ExtensionLogger, appName string, dst string) error - DownloadConfig(el *logging.ExtensionLogger, appName string, dst string) error - GetVMAppInfo(el *logging.ExtensionLogger, appName string) (*VMAppMetadata, error) + DownloadPackage(el *logging.ExtensionLogger, appName string, appVersion string, dst string) error + DownloadConfig(el *logging.ExtensionLogger, appName string, appVersion string, dst string) error + GetVMAppInfo(el *logging.ExtensionLogger, appName string, appVersion string) (*VMAppMetadata, error) } // HostGaCommunicator provides methods for retrieving application metadata and packages @@ -25,8 +27,8 @@ type IHostGaCommunicator interface { type HostGaCommunicator struct{} // GetVMAppInfo returns the metadata for the application -func (*HostGaCommunicator) GetVMAppInfo(el *logging.ExtensionLogger, appName string) (*VMAppMetadata, error) { - requestManager, err := getMetadataRequestManager(el, appName) +func (*HostGaCommunicator) GetVMAppInfo(el *logging.ExtensionLogger, appName string, appVersion string) (*VMAppMetadata, error) { + requestManager, err := getMetadataRequestManager(el, appName, appVersion) if err != nil { return nil, errors.Wrapf(err, "Could not create the request manager") } @@ -51,8 +53,8 @@ func (*HostGaCommunicator) GetVMAppInfo(el *logging.ExtensionLogger, appName str // DownloadPackage downloads the application package through HostGaPlugin to the specified // file. If the download fails, it automatically retrieves at the last received bytes // and rebuilds the file from downloaded parts -func (*HostGaCommunicator) DownloadPackage(el *logging.ExtensionLogger, appName string, dst string) error { - requestFactory, err := newPackageDownloadRequestFactory(el, appName) +func (*HostGaCommunicator) DownloadPackage(el *logging.ExtensionLogger, appName string, appVersion string, dst string) error { + requestFactory, err := newPackageDownloadRequestFactory(el, appName, appVersion) if err != nil { return errors.Wrapf(err, "Could not create the request factory") } @@ -64,8 +66,8 @@ func (*HostGaCommunicator) DownloadPackage(el *logging.ExtensionLogger, appName // DownloadConfig downloads the application config through HostGaPlugin to the specified // file. If the download fails, it automatically retrieves at the last received bytes // and rebuilds the file from downloaded parts -func (*HostGaCommunicator) DownloadConfig(el *logging.ExtensionLogger, appName string, dst string) error { - requestFactory, err := newConfigDownloadRequestFactory(el, appName) +func (*HostGaCommunicator) DownloadConfig(el *logging.ExtensionLogger, appName string, appVersion string, dst string) error { + requestFactory, err := newConfigDownloadRequestFactory(el, appName, appVersion) if err != nil { return errors.Wrapf(err, "Could not create the request factory") } @@ -74,13 +76,11 @@ func (*HostGaCommunicator) DownloadConfig(el *logging.ExtensionLogger, appName s return err } -func getOperationURI(el *logging.ExtensionLogger, appName string, operation string) (string, error) { +func getOperationURI(el *logging.ExtensionLogger, appName string, appVersion string, operation string) (string, error) { baseAddress := os.Getenv(WireProtocolAddress) if baseAddress == "" { el.Warn("environment variable %s not set, using WireProtocol fallback address", WireProtocolAddress) - uri, _ := url.Parse(wireServerFallbackAddress) - uri.Path = fmt.Sprintf("applications/%s/%s", appName, operation) - return uri.String(), nil + baseAddress = wireServerFallbackAddress } uri, err := url.Parse(baseAddress) @@ -111,5 +111,11 @@ func getOperationURI(el *logging.ExtensionLogger, appName string, operation stri uri.Scheme = "http" } + if appVersion != "" { + q := uri.Query() + q.Set(versionQueryParameterForHostGaRequests, appVersion) + uri.RawQuery = q.Encode() + } + return uri.String(), nil } diff --git a/internal/hostgacommunicator/hostgacommunicator_test.go b/internal/hostgacommunicator/hostgacommunicator_test.go index d9bf14e..cfbc2da 100644 --- a/internal/hostgacommunicator/hostgacommunicator_test.go +++ b/internal/hostgacommunicator/hostgacommunicator_test.go @@ -18,7 +18,8 @@ import ( ) const ( - myAppName = "chipmunkdetector" + myAppName = "chipmunkdetector" + myAppVersion = "1.0.0" ) var testDirPath string @@ -42,7 +43,7 @@ func cleanupTestDir() { func TestGetVmAppInfo_InvalidUri(t *testing.T) { os.Setenv(WireProtocolAddress, "h%t!p:notgoingtohappen!") hgc := &HostGaCommunicator{} - _, err := hgc.GetVMAppInfo(nopLog(), myAppName) + _, err := hgc.GetVMAppInfo(nopLog(), myAppName, myAppVersion) require.NotNil(t, err, "did not fail") require.Contains(t, err.Error(), "Could not parse the HostGA URI", "Wrong message for invalid uri") } @@ -56,7 +57,7 @@ func TestGetVmAppInfo_RequestFailed(t *testing.T) { os.Setenv(WireProtocolAddress, srv.URL) hgc := &HostGaCommunicator{} - _, err := hgc.GetVMAppInfo(nopLog(), myAppName) + _, err := hgc.GetVMAppInfo(nopLog(), myAppName, myAppVersion) require.NotNil(t, err, "did not fail") require.Contains(t, err.Error(), "Metadata request failed with retries.", "Wrong message for failed request") } @@ -72,7 +73,7 @@ func TestGetVmAppInfo_CouldNotDecodeResponse(t *testing.T) { os.Setenv(WireProtocolAddress, srv.URL) hgc := &HostGaCommunicator{} - _, err := hgc.GetVMAppInfo(nopLog(), myAppName) + _, err := hgc.GetVMAppInfo(nopLog(), myAppName, myAppVersion) require.NotNil(t, err, "did not fail") require.Contains(t, err.Error(), "failed to decode response body", "Wrong message for invalid response") } @@ -96,7 +97,7 @@ func TestGetVmAppInfo_MissingProperties(t *testing.T) { os.Setenv(WireProtocolAddress, srv.URL) hgc := &HostGaCommunicator{} - actual, err := hgc.GetVMAppInfo(nopLog(), myAppName) + actual, err := hgc.GetVMAppInfo(nopLog(), myAppName, myAppVersion) require.Nil(t, err, "request failed") require.Equal(t, expected.ApplicationName, actual.ApplicationName) require.Equal(t, expected.Version, actual.Version) @@ -125,7 +126,7 @@ func TestGetVmAppInfo_ValidResponse(t *testing.T) { os.Setenv(WireProtocolAddress, srv.URL) hgc := &HostGaCommunicator{} - actual, err := hgc.GetVMAppInfo(nopLog(), myAppName) + actual, err := hgc.GetVMAppInfo(nopLog(), myAppName, myAppVersion) require.Nil(t, err, "request failed") require.Equal(t, expected.ApplicationName, actual.ApplicationName) require.Equal(t, expected.Version, actual.Version) @@ -155,7 +156,7 @@ func TestDownloadPackage_CannotRemoveExistingFile(t *testing.T) { os.Setenv(WireProtocolAddress, srv.URL) hgc := &HostGaCommunicator{} - err = hgc.DownloadPackage(nopLog(), myAppName, filePath) + err = hgc.DownloadPackage(nopLog(), myAppName, myAppVersion, filePath) require.NotNil(t, err, "did not fail") require.Contains(t, err.Error(), "Could not remove the existing file", "Wrong message for failing to remove locked file") } @@ -171,7 +172,7 @@ func TestDownloadPackage_InvalidPath(t *testing.T) { os.Setenv(WireProtocolAddress, srv.URL) hgc := &HostGaCommunicator{} - err := hgc.DownloadPackage(nopLog(), myAppName, filePath) + err := hgc.DownloadPackage(nopLog(), myAppName, myAppVersion, filePath) require.NotNil(t, err, "did not fail") require.Contains(t, err.Error(), "Cannot retrieve file information", "Wrong message for invalid file path") } @@ -192,7 +193,7 @@ func TestDownloadPackage_SingeCallDownload(t *testing.T) { os.Setenv(WireProtocolAddress, srv.URL) hgc := &HostGaCommunicator{} - err := hgc.DownloadPackage(nopLog(), myAppName, filePath) + err := hgc.DownloadPackage(nopLog(), myAppName, myAppVersion, filePath) require.Nil(t, err, "Download failed") verifyFileContents(t, filePath, expected) } @@ -223,7 +224,7 @@ func TestDownloadPackage_TooManyTries(t *testing.T) { os.Setenv(WireProtocolAddress, srv.URL) hgc := &HostGaCommunicator{} - err := hgc.DownloadPackage(nopLog(), myAppName, filePath) + err := hgc.DownloadPackage(nopLog(), myAppName, myAppVersion, filePath) require.NotNil(t, err, "did not fail") require.Contains(t, err.Error(), "Failed to completely download the file", "Wrong message for incomplete file") } @@ -254,7 +255,7 @@ func TestDownloadPackage_IntermediateCallFails(t *testing.T) { os.Setenv(WireProtocolAddress, srv.URL) hgc := &HostGaCommunicator{} - err := hgc.DownloadPackage(nopLog(), myAppName, filePath) + err := hgc.DownloadPackage(nopLog(), myAppName, myAppVersion, filePath) require.NotNil(t, err, "did not fail") require.Contains(t, err.Error(), "Unrecoverable error while downloading the file", "Wrong message for failure mid-retries") } @@ -290,7 +291,7 @@ func TestDownloadPackage_MultipleCallDownload(t *testing.T) { os.Setenv(WireProtocolAddress, srv.URL) hgc := &HostGaCommunicator{} - err := hgc.DownloadPackage(nopLog(), myAppName, filePath) + err := hgc.DownloadPackage(nopLog(), myAppName, myAppVersion, filePath) require.Nil(t, err, "Download failed") require.Equal(t, expectedCallCount, callCount) verifyFileContents(t, filePath, expected) @@ -312,46 +313,52 @@ func TestDownloadConfig_SingeCallDownload(t *testing.T) { os.Setenv(WireProtocolAddress, srv.URL) hgc := &HostGaCommunicator{} - err := hgc.DownloadConfig(nopLog(), myAppName, filePath) + err := hgc.DownloadConfig(nopLog(), myAppName, myAppVersion, filePath) require.Nil(t, err, "Download failed") verifyFileContents(t, filePath, expected) } func TestGetOperationUri(t *testing.T) { appName := "myApp" + appVersion := "1.0.0" operation := "metadata" el := logging.New(nil) os.Setenv(WireProtocolAddress, "10.0.0.1") - uri, err := getOperationURI(el, appName, operation) + uri, err := getOperationURI(el, appName, appVersion, operation) assert.NoError(t, err) - assert.Equal(t, fmt.Sprintf("http://10.0.0.1:%s/applications/%s/%s", hostGaPluginPort, appName, operation), uri) + assert.Equal(t, fmt.Sprintf("http://10.0.0.1:%s/applications/%s/%s?version=%s", hostGaPluginPort, appName, operation, appVersion), uri) os.Setenv(WireProtocolAddress, "10.0.0.1:1234") - uri, err = getOperationURI(el, appName, operation) + uri, err = getOperationURI(el, appName, appVersion, operation) assert.NoError(t, err) - assert.Equal(t, fmt.Sprintf("http://10.0.0.1:1234/applications/%s/%s", appName, operation), uri) + assert.Equal(t, fmt.Sprintf("http://10.0.0.1:1234/applications/%s/%s?version=%s", appName, operation, appVersion), uri) os.Setenv(WireProtocolAddress, "foo.bar.com") - uri, err = getOperationURI(el, appName, operation) + uri, err = getOperationURI(el, appName, appVersion, operation) assert.NoError(t, err) - assert.Equal(t, fmt.Sprintf("http://foo.bar.com:%s/applications/%s/%s", hostGaPluginPort, appName, operation), uri) + assert.Equal(t, fmt.Sprintf("http://foo.bar.com:%s/applications/%s/%s?version=%s", hostGaPluginPort, appName, operation, appVersion), uri) os.Setenv(WireProtocolAddress, "foo.bar.com:1568") - uri, err = getOperationURI(el, appName, operation) + uri, err = getOperationURI(el, appName, appVersion, operation) assert.NoError(t, err) - assert.Equal(t, fmt.Sprintf("http://foo.bar.com:1568/applications/%s/%s", appName, operation), uri) + assert.Equal(t, fmt.Sprintf("http://foo.bar.com:1568/applications/%s/%s?version=%s", appName, operation, appVersion), uri) os.Setenv(WireProtocolAddress, "https://foo.bar.com:1568") - uri, err = getOperationURI(el, appName, operation) + uri, err = getOperationURI(el, appName, appVersion, operation) + assert.NoError(t, err) + assert.Equal(t, fmt.Sprintf("https://foo.bar.com:1568/applications/%s/%s?version=%s", appName, operation, appVersion), uri) + + // test with no app version + uri, err = getOperationURI(el, appName, "", operation) assert.NoError(t, err) assert.Equal(t, fmt.Sprintf("https://foo.bar.com:1568/applications/%s/%s", appName, operation), uri) // test fallback address for Wire Server os.Setenv(WireProtocolAddress, "") - uri, err = getOperationURI(el, appName, operation) + uri, err = getOperationURI(el, appName, appVersion, operation) assert.NoError(t, err) - assert.Equal(t, fmt.Sprintf("%s/applications/%s/%s", wireServerFallbackAddress, appName, operation), uri) + assert.Equal(t, fmt.Sprintf("%s/applications/%s/%s?version=%s", wireServerFallbackAddress, appName, operation, appVersion), uri) } func TestGetGetVmAppInfo(t *testing.T) { @@ -384,7 +391,7 @@ func TestGetGetVmAppInfo(t *testing.T) { os.Setenv(WireProtocolAddress, srv.URL) hgc := &HostGaCommunicator{} - vmAppMetadata, err := hgc.GetVMAppInfo(nopLog(), "advancedsettingsapp") + vmAppMetadata, err := hgc.GetVMAppInfo(nopLog(), "advancedsettingsapp", myAppVersion) assert.NoError(t, err) assert.Equal(t, "flarg.exe", vmAppMetadata.PackageFileName) assert.Equal(t, "flarg.cfg", vmAppMetadata.ConfigFileName) diff --git a/internal/hostgacommunicator/metadata.go b/internal/hostgacommunicator/metadata.go index 38ba6f6..cd12d3d 100644 --- a/internal/hostgacommunicator/metadata.go +++ b/internal/hostgacommunicator/metadata.go @@ -74,8 +74,8 @@ type metadataRequestFactory struct { url string } -func newMetadataRequestFactory(el *logging.ExtensionLogger, appName string) (*metadataRequestFactory, error) { - url, err := getOperationURI(el, appName, metadataOperation) +func newMetadataRequestFactory(el *logging.ExtensionLogger, appName string, appVersion string) (*metadataRequestFactory, error) { + url, err := getOperationURI(el, appName, appVersion, metadataOperation) if err != nil { return nil, errors.Wrapf(err, "failed to obtain operationURI") } @@ -88,8 +88,8 @@ func (u metadataRequestFactory) GetRequest() (*http.Request, error) { return http.NewRequest("GET", u.url, nil) } -func getMetadataRequestManager(el *logging.ExtensionLogger, appName string) (*requesthelper.RequestManager, error) { - factory, err := newMetadataRequestFactory(el, appName) +func getMetadataRequestManager(el *logging.ExtensionLogger, appName string, appVersion string) (*requesthelper.RequestManager, error) { + factory, err := newMetadataRequestFactory(el, appName, appVersion) if err != nil { return nil, errors.Wrapf(err, "failed to create request factory") } diff --git a/internal/requesthelper/retry.go b/internal/requesthelper/retry.go index 484e976..78e8d46 100644 --- a/internal/requesthelper/retry.go +++ b/internal/requesthelper/retry.go @@ -90,6 +90,11 @@ func WithRetries(el *logging.ExtensionLogger, rm *RequestManager, sf SleepFunc) func isTransientHTTPStatusCode(statusCode int) bool { switch statusCode { case + // A 404 error from HGAP will be considered a transient error. This either means: + // - Gallery application data for the container was not found + // - Metadata for the specific applicaton version requested was not found + // This is not an expected error and is likely an indication of a race condition, so the request will be retried. + http.StatusNotFound, // 404 http.StatusRequestTimeout, // 408 http.StatusTooManyRequests, // 429 http.StatusInternalServerError, // 500 diff --git a/internal/requesthelper/retry_test.go b/internal/requesthelper/retry_test.go index 544e290..3a7d414 100644 --- a/internal/requesthelper/retry_test.go +++ b/internal/requesthelper/retry_test.go @@ -91,6 +91,19 @@ func TestWithRetries_failing_validateNumberOfCalls(t *testing.T) { require.EqualValues(t, 7, d.calls, "calls exactly expRetryN times") } +func TestRequestRetriedOnHttpNotFound(t *testing.T) { + srv := httptest.NewServer(httpbin.GetMux()) + defer srv.Close() + + d := NewTestURLRequest(srv.URL + "/status/404") + rm := requesthelper.GetRequestManager(d, testRequestTimeout) + + sr := new(sleepRecorder) + _, err := requesthelper.WithRetries(nopLog(), rm, sr.Sleep) + require.EqualError(t, err, "unexpected status code: actual=404 expected=200") + require.EqualValues(t, 7, d.calls, "Request should have been retried expRetryN times") +} + func TestWithRetries_failedCreateRequest(t *testing.T) { bd := &badDownloader{} rm := requesthelper.GetRequestManager(bd, testRequestTimeout) diff --git a/main/main_test.go b/main/main_test.go index 9d9c2cc..29f7640 100644 --- a/main/main_test.go +++ b/main/main_test.go @@ -37,15 +37,15 @@ type NoopHostGaCommunicator struct { ConfigFileNameUsed string } -func (communicator *NoopHostGaCommunicator) DownloadPackage(el *logging.ExtensionLogger, appName string, dst string) error { +func (communicator *NoopHostGaCommunicator) DownloadPackage(el *logging.ExtensionLogger, appName string, appVersion string, dst string) error { communicator.PackageFileNameUsed = dst return nil } -func (communicator *NoopHostGaCommunicator) DownloadConfig(el *logging.ExtensionLogger, appName string, dst string) error { +func (communicator *NoopHostGaCommunicator) DownloadConfig(el *logging.ExtensionLogger, appName string, appVersion string, dst string) error { communicator.ConfigFileNameUsed = dst return nil } -func (communicator *NoopHostGaCommunicator) GetVMAppInfo(el *logging.ExtensionLogger, appName string) (*hostgacommunicator.VMAppMetadata, error) { +func (communicator *NoopHostGaCommunicator) GetVMAppInfo(el *logging.ExtensionLogger, appName string, appVersion string) (*hostgacommunicator.VMAppMetadata, error) { return communicator.MetadataToReturn, nil } @@ -162,11 +162,12 @@ func Test_getVMAppProtectedSettings_valid(t *testing.T) { ApplicationName: "iggy", Order: &order, Actions: []*extdeserialization.ActionSetting{&actions}, + Version: "2.0.1", } vmAppProtectedSettings := extdeserialization.VmAppProtectedSettings{&appSettings} testSettings := handlersettings.HandlerSettings{ PublicSettings: "{}", - ProtectedSettings: "[{\"applicationName\": \"iggy\", \"order\": 1, \"actions\": [{\"name\": \"logging\",\"script\": \"echo %CustomAction_blobURL%\",\"timestamp\": \"20210604T155300Z\",\"parameters\": [{\"name\": \"blobURL\",\"value\": \"myaccount.blob.core.windows.net\"}],\"tickCount\": 10193113}]}]", + ProtectedSettings: "[{\"applicationName\": \"iggy\", \"order\": 1, \"version\": \"2.0.1\", \"actions\": [{\"name\": \"logging\",\"script\": \"echo %CustomAction_blobURL%\",\"timestamp\": \"20210604T155300Z\",\"parameters\": [{\"name\": \"blobURL\",\"value\": \"myaccount.blob.core.windows.net\"}],\"tickCount\": 10193113}]}]", } out, err := extdeserialization.GetVMAppProtectedSettings(&testSettings) @@ -175,6 +176,21 @@ func Test_getVMAppProtectedSettings_valid(t *testing.T) { require.EqualValues(t, vmAppProtectedSettings[0].ApplicationName, out[0].ApplicationName) require.EqualValues(t, *vmAppProtectedSettings[0].Order, *out[0].Order) require.EqualValues(t, *vmAppProtectedSettings[0].Actions[0], *out[0].Actions[0]) + require.EqualValues(t, vmAppProtectedSettings[0].Version, *&out[0].Version) +} + +func TestGetVMAppProtectedSettingsWithoutVersionDefaultsToEmptyString(t *testing.T) { + testSettings := handlersettings.HandlerSettings{ + PublicSettings: "{}", + ProtectedSettings: "[{\"applicationName\": \"iggy\", \"order\": 1 }]", + } + + out, err := extdeserialization.GetVMAppProtectedSettings(&testSettings) + require.NoError(t, err) + + require.EqualValues(t, "iggy", out[0].ApplicationName) + require.EqualValues(t, 1, *out[0].Order) + require.Empty(t, *&out[0].Version) } func Test_getVMAppProtectedSettings_valid_no_custom_actions(t *testing.T) { diff --git a/main/resolvevmapp.go b/main/resolvevmapp.go index 3b23d85..19fb15d 100644 --- a/main/resolvevmapp.go +++ b/main/resolvevmapp.go @@ -16,7 +16,7 @@ func getVMAppIncomingCollection(settings extdeserialization.VmAppProtectedSettin if app.ApplicationName == "" { return nil, errors.New("missing application name") } - vmAppInfo, err := communicator.GetVMAppInfo(el, app.ApplicationName) + vmAppInfo, err := communicator.GetVMAppInfo(el, app.ApplicationName, app.Version) if err != nil { // TODO: ignore errors? return incomingCollection, err