Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions internal/actionplan/actionplan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ var mockCommandFailOnDemand CommandExecutor = func(command string, workingDir st
// implements IHostGaCommunicator
type NoopHostGaCommunicator struct{}

func (downloader *NoopHostGaCommunicator) DownloadPackage(logger *logging.ExtensionLogger, appName string, dst string) error {
func (downloader *NoopHostGaCommunicator) DownloadPackage(logger *logging.ExtensionLogger, appName string, appVersion string, dst string) error {
return nil
}
func (downloader *NoopHostGaCommunicator) DownloadConfig(logger *logging.ExtensionLogger, appName string, dst string) error {
func (downloader *NoopHostGaCommunicator) DownloadConfig(logger *logging.ExtensionLogger, appName string, appVersion string, dst string) error {
return nil
}
func (downloader *NoopHostGaCommunicator) GetVMAppInfo(logger *logging.ExtensionLogger, appName string) (*hostgacommunicator.VMAppMetadata, error) {
func (downloader *NoopHostGaCommunicator) GetVMAppInfo(logger *logging.ExtensionLogger, appName string, appVersion string) (*hostgacommunicator.VMAppMetadata, error) {
return nil, nil
}

Expand Down
4 changes: 2 additions & 2 deletions internal/actionplan/executehelper.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (actionPlan *ActionPlan) executeHelper(registryHandler packageregistry.IPac
if err == nil {
// download packages
downloadPackageFileName := path.Join(downloadPath, vmAppPackageCurrent.PackageFileName)
if err := actionPlan.hostGaCommunicator.DownloadPackage(actionPlan.logger, vmAppPackageCurrent.ApplicationName, downloadPackageFileName); err != nil {
if err := actionPlan.hostGaCommunicator.DownloadPackage(actionPlan.logger, vmAppPackageCurrent.ApplicationName, version, downloadPackageFileName); err != nil {
actionPlan.logger.Error("Failed to download package for application %v, version %v. Error: %v", appName, version, err.Error())
errorMessageToReturn = extensionerrors.CombineErrors(errorMessageToReturn, errors.Wrapf(err, "failed to download package file %s", downloadPackageFileName))
}
Expand All @@ -111,7 +111,7 @@ func (actionPlan *ActionPlan) executeHelper(registryHandler packageregistry.IPac
// download configuration
if vmAppPackageCurrent.ConfigExists {
downloadConfigFileName := path.Join(downloadPath, vmAppPackageCurrent.ConfigFileName)
if err := actionPlan.hostGaCommunicator.DownloadConfig(actionPlan.logger, vmAppPackageCurrent.ApplicationName, downloadConfigFileName); err != nil {
if err := actionPlan.hostGaCommunicator.DownloadConfig(actionPlan.logger, vmAppPackageCurrent.ApplicationName, version, downloadConfigFileName); err != nil {
actionPlan.logger.Error("Failed to download config for application %v, version %v. Error: %v", appName, version, err.Error())
errorMessageToReturn = extensionerrors.CombineErrors(errorMessageToReturn, errors.Wrapf(err, "failed to download config file %s", downloadConfigFileName))
}
Expand Down
16 changes: 10 additions & 6 deletions internal/actionplan/executehelper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,21 @@ type mockHostGaCommunicator struct {
DownloadConfigCount int
}

func (mockCommunicator *mockHostGaCommunicator) GetVMAppInfo(el *logging.ExtensionLogger, appName string) (*hostgacommunicator.VMAppMetadata, error) {
func (mockCommunicator *mockHostGaCommunicator) GetVMAppInfo(el *logging.ExtensionLogger, appName string, appVersion string) (*hostgacommunicator.VMAppMetadata, error) {
return &hostgacommunicator.VMAppMetadata{}, nil
}

func (mockCommunicator *mockHostGaCommunicator) DownloadPackage(el *logging.ExtensionLogger, appName string, dst string) error {
mockCommunicator.DownloadPackageCount++
func (mockCommunicator *mockHostGaCommunicator) DownloadPackage(el *logging.ExtensionLogger, appName string, appVersion string, dst string) error {
if appVersion == vmAppPackageCurrent.Version {
mockCommunicator.DownloadPackageCount++
}
return copyFile(mockCommunicator.pkgFileSourcePath, dst)
}

func (mockCommunicator *mockHostGaCommunicator) DownloadConfig(el *logging.ExtensionLogger, appName string, dst string) error {
mockCommunicator.DownloadConfigCount++
func (mockCommunicator *mockHostGaCommunicator) DownloadConfig(el *logging.ExtensionLogger, appName string, appVersion string, dst string) error {
if appVersion == vmAppPackageCurrent.Version {
mockCommunicator.DownloadConfigCount++
}
return copyFile(mockCommunicator.configFileSourcePath, dst)
}

Expand Down Expand Up @@ -88,7 +92,7 @@ var extensionLogger = logging.New(nil)
var extensionEventManager = extensionevents.New(extensionLogger, &handlerEnvironment)
var vmAppPackageCurrent = packageregistry.VMAppPackageCurrent{
ApplicationName: "test app",
Version: "1.0.0",
Version: "1.5.0",
InstallCommand: "install",
RemoveCommand: "remove",
UpdateCommand: "update",
Expand Down
1 change: 1 addition & 0 deletions internal/extdeserialization/extdeserialization.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type VmAppSetting struct {
Order *int `json:"order"`
TreatFailureAsDeploymentFailure bool `json:"treatFailureAsDeploymentFailure"`
Actions []*ActionSetting `json:"actions"`
Version string `json:"version"`
}

func GetParameterNames(settings ActionSetting) []string {
Expand Down
11 changes: 6 additions & 5 deletions internal/hostgacommunicator/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ package hostgacommunicator

import (
"fmt"
"github.com/Azure/azure-extension-platform/pkg/constants"
"io"
"net/http"
"os"
"time"

"github.com/Azure/azure-extension-platform/pkg/constants"

"github.com/Azure/VMApplication-Extension/internal/requesthelper"
"github.com/Azure/azure-extension-platform/pkg/logging"
"github.com/pkg/errors"
Expand All @@ -34,8 +35,8 @@ type downloadRequestFactory struct {
downloadedBytes int64
}

func newPackageDownloadRequestFactory(el *logging.ExtensionLogger, appName string) (*downloadRequestFactory, error) {
downloadURL, err := getOperationURI(el, appName, packageOperation)
func newPackageDownloadRequestFactory(el *logging.ExtensionLogger, appName string, appVersion string) (*downloadRequestFactory, error) {
downloadURL, err := getOperationURI(el, appName, appVersion, packageOperation)
if err != nil {
return nil, errors.Wrapf(err, "failed to obtain operationURI")
}
Expand All @@ -48,8 +49,8 @@ func newPackageDownloadRequestFactory(el *logging.ExtensionLogger, appName strin
return &drf, nil
}

func newConfigDownloadRequestFactory(el *logging.ExtensionLogger, appName string) (*downloadRequestFactory, error) {
downloadURL, err := getOperationURI(el, appName, configOperation)
func newConfigDownloadRequestFactory(el *logging.ExtensionLogger, appName string, appVersion string) (*downloadRequestFactory, error) {
downloadURL, err := getOperationURI(el, appName, appVersion, configOperation)
if err != nil {
return nil, errors.Wrapf(err, "failed to obtain operationURI")
}
Expand Down
36 changes: 21 additions & 15 deletions internal/hostgacommunicator/hostgacommunicator.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,32 @@ package hostgacommunicator
import (
"encoding/json"
"fmt"
"net/url"
"os"

"github.com/Azure/VMApplication-Extension/internal/requesthelper"
"github.com/Azure/azure-extension-platform/pkg/logging"
"github.com/pkg/errors"
"net/url"
"os"
)

const hostGaPluginPort = "32526"
const WireProtocolAddress = "AZURE_GUEST_AGENT_WIRE_PROTOCOL_ADDRESS"
const wireServerFallbackAddress = "http://168.63.129.16:32526"
const versionQueryParameterForHostGaRequests = "version"

type IHostGaCommunicator interface {
DownloadPackage(el *logging.ExtensionLogger, appName string, dst string) error
DownloadConfig(el *logging.ExtensionLogger, appName string, dst string) error
GetVMAppInfo(el *logging.ExtensionLogger, appName string) (*VMAppMetadata, error)
DownloadPackage(el *logging.ExtensionLogger, appName string, appVersion string, dst string) error
DownloadConfig(el *logging.ExtensionLogger, appName string, appVersion string, dst string) error
GetVMAppInfo(el *logging.ExtensionLogger, appName string, appVersion string) (*VMAppMetadata, error)
}

// HostGaCommunicator provides methods for retrieving application metadata and packages
// from the HostGaPlugin
type HostGaCommunicator struct{}

// GetVMAppInfo returns the metadata for the application
func (*HostGaCommunicator) GetVMAppInfo(el *logging.ExtensionLogger, appName string) (*VMAppMetadata, error) {
requestManager, err := getMetadataRequestManager(el, appName)
func (*HostGaCommunicator) GetVMAppInfo(el *logging.ExtensionLogger, appName string, appVersion string) (*VMAppMetadata, error) {
requestManager, err := getMetadataRequestManager(el, appName, appVersion)
if err != nil {
return nil, errors.Wrapf(err, "Could not create the request manager")
}
Expand All @@ -51,8 +53,8 @@ func (*HostGaCommunicator) GetVMAppInfo(el *logging.ExtensionLogger, appName str
// DownloadPackage downloads the application package through HostGaPlugin to the specified
// file. If the download fails, it automatically retrieves at the last received bytes
// and rebuilds the file from downloaded parts
func (*HostGaCommunicator) DownloadPackage(el *logging.ExtensionLogger, appName string, dst string) error {
requestFactory, err := newPackageDownloadRequestFactory(el, appName)
func (*HostGaCommunicator) DownloadPackage(el *logging.ExtensionLogger, appName string, appVersion string, dst string) error {
requestFactory, err := newPackageDownloadRequestFactory(el, appName, appVersion)
if err != nil {
return errors.Wrapf(err, "Could not create the request factory")
}
Expand All @@ -64,8 +66,8 @@ func (*HostGaCommunicator) DownloadPackage(el *logging.ExtensionLogger, appName
// DownloadConfig downloads the application config through HostGaPlugin to the specified
// file. If the download fails, it automatically retrieves at the last received bytes
// and rebuilds the file from downloaded parts
func (*HostGaCommunicator) DownloadConfig(el *logging.ExtensionLogger, appName string, dst string) error {
requestFactory, err := newConfigDownloadRequestFactory(el, appName)
func (*HostGaCommunicator) DownloadConfig(el *logging.ExtensionLogger, appName string, appVersion string, dst string) error {
requestFactory, err := newConfigDownloadRequestFactory(el, appName, appVersion)
if err != nil {
return errors.Wrapf(err, "Could not create the request factory")
}
Expand All @@ -74,13 +76,11 @@ func (*HostGaCommunicator) DownloadConfig(el *logging.ExtensionLogger, appName s
return err
}

func getOperationURI(el *logging.ExtensionLogger, appName string, operation string) (string, error) {
func getOperationURI(el *logging.ExtensionLogger, appName string, appVersion string, operation string) (string, error) {
baseAddress := os.Getenv(WireProtocolAddress)
if baseAddress == "" {
el.Warn("environment variable %s not set, using WireProtocol fallback address", WireProtocolAddress)
uri, _ := url.Parse(wireServerFallbackAddress)
uri.Path = fmt.Sprintf("applications/%s/%s", appName, operation)
return uri.String(), nil
baseAddress = wireServerFallbackAddress
}

uri, err := url.Parse(baseAddress)
Expand Down Expand Up @@ -111,5 +111,11 @@ func getOperationURI(el *logging.ExtensionLogger, appName string, operation stri
uri.Scheme = "http"
}

if appVersion != "" {
q := uri.Query()
q.Set(versionQueryParameterForHostGaRequests, appVersion)
uri.RawQuery = q.Encode()
}

return uri.String(), nil
}
57 changes: 32 additions & 25 deletions internal/hostgacommunicator/hostgacommunicator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ import (
)

const (
myAppName = "chipmunkdetector"
myAppName = "chipmunkdetector"
myAppVersion = "1.0.0"
)

var testDirPath string
Expand All @@ -42,7 +43,7 @@ func cleanupTestDir() {
func TestGetVmAppInfo_InvalidUri(t *testing.T) {
os.Setenv(WireProtocolAddress, "h%t!p:notgoingtohappen!")
hgc := &HostGaCommunicator{}
_, err := hgc.GetVMAppInfo(nopLog(), myAppName)
_, err := hgc.GetVMAppInfo(nopLog(), myAppName, myAppVersion)
require.NotNil(t, err, "did not fail")
require.Contains(t, err.Error(), "Could not parse the HostGA URI", "Wrong message for invalid uri")
}
Expand All @@ -56,7 +57,7 @@ func TestGetVmAppInfo_RequestFailed(t *testing.T) {

os.Setenv(WireProtocolAddress, srv.URL)
hgc := &HostGaCommunicator{}
_, err := hgc.GetVMAppInfo(nopLog(), myAppName)
_, err := hgc.GetVMAppInfo(nopLog(), myAppName, myAppVersion)
require.NotNil(t, err, "did not fail")
require.Contains(t, err.Error(), "Metadata request failed with retries.", "Wrong message for failed request")
}
Expand All @@ -72,7 +73,7 @@ func TestGetVmAppInfo_CouldNotDecodeResponse(t *testing.T) {

os.Setenv(WireProtocolAddress, srv.URL)
hgc := &HostGaCommunicator{}
_, err := hgc.GetVMAppInfo(nopLog(), myAppName)
_, err := hgc.GetVMAppInfo(nopLog(), myAppName, myAppVersion)
require.NotNil(t, err, "did not fail")
require.Contains(t, err.Error(), "failed to decode response body", "Wrong message for invalid response")
}
Expand All @@ -96,7 +97,7 @@ func TestGetVmAppInfo_MissingProperties(t *testing.T) {

os.Setenv(WireProtocolAddress, srv.URL)
hgc := &HostGaCommunicator{}
actual, err := hgc.GetVMAppInfo(nopLog(), myAppName)
actual, err := hgc.GetVMAppInfo(nopLog(), myAppName, myAppVersion)
require.Nil(t, err, "request failed")
require.Equal(t, expected.ApplicationName, actual.ApplicationName)
require.Equal(t, expected.Version, actual.Version)
Expand Down Expand Up @@ -125,7 +126,7 @@ func TestGetVmAppInfo_ValidResponse(t *testing.T) {

os.Setenv(WireProtocolAddress, srv.URL)
hgc := &HostGaCommunicator{}
actual, err := hgc.GetVMAppInfo(nopLog(), myAppName)
actual, err := hgc.GetVMAppInfo(nopLog(), myAppName, myAppVersion)
require.Nil(t, err, "request failed")
require.Equal(t, expected.ApplicationName, actual.ApplicationName)
require.Equal(t, expected.Version, actual.Version)
Expand Down Expand Up @@ -155,7 +156,7 @@ func TestDownloadPackage_CannotRemoveExistingFile(t *testing.T) {

os.Setenv(WireProtocolAddress, srv.URL)
hgc := &HostGaCommunicator{}
err = hgc.DownloadPackage(nopLog(), myAppName, filePath)
err = hgc.DownloadPackage(nopLog(), myAppName, myAppVersion, filePath)
require.NotNil(t, err, "did not fail")
require.Contains(t, err.Error(), "Could not remove the existing file", "Wrong message for failing to remove locked file")
}
Expand All @@ -171,7 +172,7 @@ func TestDownloadPackage_InvalidPath(t *testing.T) {

os.Setenv(WireProtocolAddress, srv.URL)
hgc := &HostGaCommunicator{}
err := hgc.DownloadPackage(nopLog(), myAppName, filePath)
err := hgc.DownloadPackage(nopLog(), myAppName, myAppVersion, filePath)
require.NotNil(t, err, "did not fail")
require.Contains(t, err.Error(), "Cannot retrieve file information", "Wrong message for invalid file path")
}
Expand All @@ -192,7 +193,7 @@ func TestDownloadPackage_SingeCallDownload(t *testing.T) {

os.Setenv(WireProtocolAddress, srv.URL)
hgc := &HostGaCommunicator{}
err := hgc.DownloadPackage(nopLog(), myAppName, filePath)
err := hgc.DownloadPackage(nopLog(), myAppName, myAppVersion, filePath)
require.Nil(t, err, "Download failed")
verifyFileContents(t, filePath, expected)
}
Expand Down Expand Up @@ -223,7 +224,7 @@ func TestDownloadPackage_TooManyTries(t *testing.T) {

os.Setenv(WireProtocolAddress, srv.URL)
hgc := &HostGaCommunicator{}
err := hgc.DownloadPackage(nopLog(), myAppName, filePath)
err := hgc.DownloadPackage(nopLog(), myAppName, myAppVersion, filePath)
require.NotNil(t, err, "did not fail")
require.Contains(t, err.Error(), "Failed to completely download the file", "Wrong message for incomplete file")
}
Expand Down Expand Up @@ -254,7 +255,7 @@ func TestDownloadPackage_IntermediateCallFails(t *testing.T) {

os.Setenv(WireProtocolAddress, srv.URL)
hgc := &HostGaCommunicator{}
err := hgc.DownloadPackage(nopLog(), myAppName, filePath)
err := hgc.DownloadPackage(nopLog(), myAppName, myAppVersion, filePath)
require.NotNil(t, err, "did not fail")
require.Contains(t, err.Error(), "Unrecoverable error while downloading the file", "Wrong message for failure mid-retries")
}
Expand Down Expand Up @@ -290,7 +291,7 @@ func TestDownloadPackage_MultipleCallDownload(t *testing.T) {

os.Setenv(WireProtocolAddress, srv.URL)
hgc := &HostGaCommunicator{}
err := hgc.DownloadPackage(nopLog(), myAppName, filePath)
err := hgc.DownloadPackage(nopLog(), myAppName, myAppVersion, filePath)
require.Nil(t, err, "Download failed")
require.Equal(t, expectedCallCount, callCount)
verifyFileContents(t, filePath, expected)
Expand All @@ -312,46 +313,52 @@ func TestDownloadConfig_SingeCallDownload(t *testing.T) {

os.Setenv(WireProtocolAddress, srv.URL)
hgc := &HostGaCommunicator{}
err := hgc.DownloadConfig(nopLog(), myAppName, filePath)
err := hgc.DownloadConfig(nopLog(), myAppName, myAppVersion, filePath)
require.Nil(t, err, "Download failed")
verifyFileContents(t, filePath, expected)
}

func TestGetOperationUri(t *testing.T) {
appName := "myApp"
appVersion := "1.0.0"
operation := "metadata"

el := logging.New(nil)
os.Setenv(WireProtocolAddress, "10.0.0.1")
uri, err := getOperationURI(el, appName, operation)
uri, err := getOperationURI(el, appName, appVersion, operation)
assert.NoError(t, err)
assert.Equal(t, fmt.Sprintf("http://10.0.0.1:%s/applications/%s/%s", hostGaPluginPort, appName, operation), uri)
assert.Equal(t, fmt.Sprintf("http://10.0.0.1:%s/applications/%s/%s?version=%s", hostGaPluginPort, appName, operation, appVersion), uri)

os.Setenv(WireProtocolAddress, "10.0.0.1:1234")
uri, err = getOperationURI(el, appName, operation)
uri, err = getOperationURI(el, appName, appVersion, operation)
assert.NoError(t, err)
assert.Equal(t, fmt.Sprintf("http://10.0.0.1:1234/applications/%s/%s", appName, operation), uri)
assert.Equal(t, fmt.Sprintf("http://10.0.0.1:1234/applications/%s/%s?version=%s", appName, operation, appVersion), uri)

os.Setenv(WireProtocolAddress, "foo.bar.com")
uri, err = getOperationURI(el, appName, operation)
uri, err = getOperationURI(el, appName, appVersion, operation)
assert.NoError(t, err)
assert.Equal(t, fmt.Sprintf("http://foo.bar.com:%s/applications/%s/%s", hostGaPluginPort, appName, operation), uri)
assert.Equal(t, fmt.Sprintf("http://foo.bar.com:%s/applications/%s/%s?version=%s", hostGaPluginPort, appName, operation, appVersion), uri)

os.Setenv(WireProtocolAddress, "foo.bar.com:1568")
uri, err = getOperationURI(el, appName, operation)
uri, err = getOperationURI(el, appName, appVersion, operation)
assert.NoError(t, err)
assert.Equal(t, fmt.Sprintf("http://foo.bar.com:1568/applications/%s/%s", appName, operation), uri)
assert.Equal(t, fmt.Sprintf("http://foo.bar.com:1568/applications/%s/%s?version=%s", appName, operation, appVersion), uri)

os.Setenv(WireProtocolAddress, "https://foo.bar.com:1568")
uri, err = getOperationURI(el, appName, operation)
uri, err = getOperationURI(el, appName, appVersion, operation)
assert.NoError(t, err)
assert.Equal(t, fmt.Sprintf("https://foo.bar.com:1568/applications/%s/%s?version=%s", appName, operation, appVersion), uri)

// test with no app version
uri, err = getOperationURI(el, appName, "", operation)
assert.NoError(t, err)
assert.Equal(t, fmt.Sprintf("https://foo.bar.com:1568/applications/%s/%s", appName, operation), uri)

// test fallback address for Wire Server
os.Setenv(WireProtocolAddress, "")
uri, err = getOperationURI(el, appName, operation)
uri, err = getOperationURI(el, appName, appVersion, operation)
assert.NoError(t, err)
assert.Equal(t, fmt.Sprintf("%s/applications/%s/%s", wireServerFallbackAddress, appName, operation), uri)
assert.Equal(t, fmt.Sprintf("%s/applications/%s/%s?version=%s", wireServerFallbackAddress, appName, operation, appVersion), uri)
}

func TestGetGetVmAppInfo(t *testing.T) {
Expand Down Expand Up @@ -384,7 +391,7 @@ func TestGetGetVmAppInfo(t *testing.T) {

os.Setenv(WireProtocolAddress, srv.URL)
hgc := &HostGaCommunicator{}
vmAppMetadata, err := hgc.GetVMAppInfo(nopLog(), "advancedsettingsapp")
vmAppMetadata, err := hgc.GetVMAppInfo(nopLog(), "advancedsettingsapp", myAppVersion)
assert.NoError(t, err)
assert.Equal(t, "flarg.exe", vmAppMetadata.PackageFileName)
assert.Equal(t, "flarg.cfg", vmAppMetadata.ConfigFileName)
Expand Down
Loading