diff --git a/.gitignore b/.gitignore index ec5907e..9b52c50 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,5 @@ bundle extension-launcher* internal/customactionplan/testdir/ licenses +main.exe +internal/hostgacommunicator/TestArtifacts/ diff --git a/Makefile b/Makefile index b030809..c45a40a 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,6 @@ BUNDLEDIR=bundle/linux/prod BUNDLEDIR_TEST=bundle/linux/test BINDIR=$(BUNDLEDIR)/bin BINDIR_TEST=$(BUNDLEDIR_TEST)/bin -EXTENSIONVERSION=1.0.18 ALLOWED_EXT1=Microsoft.CPlat.Core.VMApplicationManagerLinux ALLOWED_EXT2=Microsoft.CPlat.Core.EDP.VMApplicationManagerLinux @@ -20,20 +19,20 @@ clean: -rm -Rf $(BUNDLEDIR_TEST) -rm -Rf licenses -extension-launcher: validate-extension-name +extension-launcher: validate-extension-name validate-extension-version GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o extension-launcher -ldflags="-X 'main.ExtensionName=$(EXTENSIONNAME)' -X 'main.ExtensionVersion=$(EXTENSIONVERSION)' -X 'main.ExecutableName=vm-application-manager'" ./launcher -extension-launcher-arm64: validate-extension-name +extension-launcher-arm64: validate-extension-name validate-extension-version GOOS=linux GOARCH=arm64 CGO_ENABLED=0 go build -o extension-launcher-arm64 -ldflags="-X 'main.ExtensionName=$(EXTENSIONNAME)' -X 'main.ExtensionVersion=$(EXTENSIONVERSION)' -X 'main.ExecutableName=vm-application-manager'" ./launcher # For ARM64 machines, install command will rename vm-application-manager-arm64 to vm-application-manager -vm-application-manager: validate-extension-name +vm-application-manager: validate-extension-name validate-extension-version GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o vm-application-manager -ldflags="-X 'main.ExtensionName=$(EXTENSIONNAME)' -X 'main.ExtensionVersion=$(EXTENSIONVERSION)'" ./main -vm-application-manager-arm64: validate-extension-name +vm-application-manager-arm64: validate-extension-name validate-extension-version GOOS=linux GOARCH=arm64 CGO_ENABLED=0 go build -o vm-application-manager-arm64 -ldflags="-X 'main.ExtensionName=$(EXTENSIONNAME)' -X 'main.ExtensionVersion=$(EXTENSIONVERSION)'" ./main -.PHONY: validate-extension-name +.PHONY: validate-extension-name validate-extension-version validate-extension-name: @case "$(EXTENSIONNAME)" in \ "$(ALLOWED_EXT1)"|"$(ALLOWED_EXT2)" ) ;; \ @@ -44,6 +43,18 @@ validate-extension-name: exit 1 ;; \ esac +validate-extension-version: + @if [ -z "$(EXTENSIONVERSION)" ]; then \ + echo "Error: EXTENSIONVERSION parameter is required"; \ + echo "Usage: make EXTENSIONVERSION="; \ + exit 1; \ + fi + @echo "$(EXTENSIONVERSION)" | grep -qE '^[0-9]+\.[0-9]+\.[0-9]+$$' || { \ + echo "Error: EXTENSIONVERSION '$(EXTENSIONVERSION)' does not match required pattern n.n.n (e.g., 1.0.18)"; \ + exit 1; \ + } + @echo "Using EXTENSIONVERSION: $(EXTENSIONVERSION)" + collect-licenses: @echo "Collecting open source licenses..." @if [ ! -f "$$(go env GOPATH)/bin/go-licenses" ]; then \ @@ -51,8 +62,8 @@ collect-licenses: go install github.com/google/go-licenses@latest; \ fi mkdir -p licenses/reports - $$(go env GOPATH)/bin/go-licenses save ./main --save_path=licenses/texts - $$(go env GOPATH)/bin/go-licenses csv ./main > licenses/reports/THIRD_PARTY_LICENSES.csv + -$$(go env GOPATH)/bin/go-licenses save ./main --save_path=licenses/texts --ignore=std --ignore=golang.org/x/sys + -$$(go env GOPATH)/bin/go-licenses csv ./main --ignore=std --ignore=golang.org/x/sys > licenses/reports/THIRD_PARTY_LICENSES.csv @echo "License collection complete!" bundle-prod: extension-launcher extension-launcher-arm64 vm-application-manager vm-application-manager-arm64 @@ -70,7 +81,7 @@ bundle-prod: extension-launcher extension-launcher-arm64 vm-application-manager bundle-test: @echo "Building and packaging TEST bundle into $(BUNDLEDIR_TEST) with EXTENSIONNAME=$(ALLOWED_EXT2)" - $(MAKE) EXTENSIONNAME=$(ALLOWED_EXT2) extension-launcher extension-launcher-arm64 vm-application-manager vm-application-manager-arm64 + $(MAKE) EXTENSIONNAME=$(ALLOWED_EXT2) EXTENSIONVERSION=$(EXTENSIONVERSION) extension-launcher extension-launcher-arm64 vm-application-manager vm-application-manager-arm64 mkdir -p $(BINDIR_TEST) mv extension-launcher "$(BINDIR_TEST)/" mv extension-launcher-arm64 "$(BINDIR_TEST)/" diff --git a/README.md b/README.md index c5c9532..c41e027 100644 --- a/README.md +++ b/README.md @@ -32,9 +32,9 @@ Trademarks This project may contain trademarks or logos for projects, products, - OS-specific tests may be run on an Azure VM or local VM (eg. WSL on Windows) - To build the extension zip packages - Windows: - - execute `nmake -f makefile.win` + - execute `nmake -f makefile.win EXTENSIONVERSION=` - Linux: - - execute `make` + - execute `make EXTENSIONVERSION=` - Please do not check-in vendor files diff --git a/go.mod b/go.mod index cad4b2c..60d34fc 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.24.0 toolchain go1.24.3 require ( - github.com/Azure/azure-extension-platform v0.0.0-20250107200156-aa20f765d49f + github.com/Azure/azure-extension-platform v0.0.0-20260406194436-44ca1f420dd8 github.com/ahmetalpbalkan/go-httpbin v0.0.0-20200921172446-862fbad56b77 github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.7.0 diff --git a/go.sum b/go.sum index 5fa4a51..8df5944 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/Azure/azure-extension-platform v0.0.0-20250107200156-aa20f765d49f h1:ddsUz/suc9txCMz/xWOslqNMvzhbWFMTflUrbcMNoSw= -github.com/Azure/azure-extension-platform v0.0.0-20250107200156-aa20f765d49f/go.mod h1:0458BvQsi5ch6kn+KZtI5m88Z3L9UFXdoY1+6nKdivY= +github.com/Azure/azure-extension-platform v0.0.0-20260406194436-44ca1f420dd8 h1:MwcGMMMhzVioChv9aIe5pbl85WiQuWaR9t+sdZpK3/U= +github.com/Azure/azure-extension-platform v0.0.0-20260406194436-44ca1f420dd8/go.mod h1:0458BvQsi5ch6kn+KZtI5m88Z3L9UFXdoY1+6nKdivY= github.com/ahmetalpbalkan/go-httpbin v0.0.0-20200921172446-862fbad56b77 h1:QLWeOzO9GTjP14jyM0g7IHhYbnWWR3Wi4kipv3iDOJY= github.com/ahmetalpbalkan/go-httpbin v0.0.0-20200921172446-862fbad56b77/go.mod h1:Rg55S63lgqSBCawY/oTm7jdFSySp6jwIqgHMB2IeHK8= github.com/ahmetb/go-httpbin v0.0.0-20200921172446-862fbad56b77 h1:tLnVshegsavDh3VnYwLVgYe7i5/O61LrhKGU+cTR95E= diff --git a/internal/packageregistry/packageregistry.go b/internal/packageregistry/packageregistry.go index c4ecddc..c5b8d10 100644 --- a/internal/packageregistry/packageregistry.go +++ b/internal/packageregistry/packageregistry.go @@ -5,7 +5,6 @@ package packageregistry import ( "encoding/json" - "io/ioutil" "os" "path" "time" @@ -168,7 +167,7 @@ func (self *PackageRegistry) GetExistingPackages() (CurrentPackageRegistry, erro _, err := os.Stat(localApplicationRegistryFilePath) if err == nil { // The file exists - fileBytes, err := ioutil.ReadFile(localApplicationRegistryFilePath) + fileBytes, err := os.ReadFile(localApplicationRegistryFilePath) if err != nil { return currentPackageRegistry, err } @@ -210,7 +209,7 @@ func (self *PackageRegistry) WriteToDisk(packageRegistry CurrentPackageRegistry) return err } - err = ioutil.WriteFile(regFile, bytes, constants.FilePermissions_UserOnly_ReadWrite) + err = os.WriteFile(regFile, bytes, constants.FilePermissions_UserOnly_ReadWrite) self.logger.Info("Wrote package registry to %v", regFile) if doesBackupFileExist { diff --git a/launcher/main.go b/launcher/main.go index 8f79159..1aa23b7 100644 --- a/launcher/main.go +++ b/launcher/main.go @@ -49,6 +49,13 @@ func main() { eh.Exit(exithelper.MiscError) } + // validate ExtensionVersion against the version reported by Guest Agent + if extensionVersionFromEnv, err := vmextension.GetGuestAgentEnvironmetVariable(vmextension.GuestAgentEnvVarExtensionVersion); err == nil { + if extensionVersionFromEnv != ExtensionVersion { + el.Warn("ExtensionVersion mismatch: compile-time ExtensionVersion value '%s' does not match value '%s' in environment variable '%s'", ExtensionVersion, extensionVersionFromEnv, vmextension.GuestAgentEnvVarExtensionVersion) + } + } + arg := args[1] switch arg { diff --git a/main/main.go b/main/main.go index 00d3507..f811d75 100644 --- a/main/main.go +++ b/main/main.go @@ -25,8 +25,8 @@ import ( ) var ( - ExtensionName string // assign at compile time - ExtensionVersion = "1.0.10" // should be assigned at compile time, do not edit in code + ExtensionName string // assign at compile time, it is the ExtensionPublisher.ExtensionType + ExtensionVersion = "1.0.10" // should be assigned at compile time, do not edit in code outside of unit tests reportStatusFunc = utils.ReportStatus getVMExtensionFunc = getVMExtension customEnableFunc = customEnable @@ -34,10 +34,7 @@ var ( ) const ( - vmPackagesSetting = "vmPackages" - operationInstall = "install" - operationUpdate = "update" - operationRemove = "remove" + argVersion = "version" filelockTimeoutDuration = 45 * time.Minute ) @@ -49,12 +46,26 @@ func main() { } func getExtensionAndRun(arguments []string) error { + if len(arguments) == 2 && strings.EqualFold(arguments[1], argVersion) { + fmt.Println("Extension version is", ExtensionVersion) + return nil + } + // require SeqNoChange is set to false because we want the extension to ensure that the packages are in sync with the desired packages ext, err := getVMExtensionFunc() if err != nil { return err } + // validate ExtensionVersion against the version reported by Guest Agent + if extVersionInEnvVariable, err := vmextensionhelper.GetGuestAgentEnvironmetVariable(vmextensionhelper.GuestAgentEnvVarExtensionVersion); err == nil { + if extVersionInEnvVariable != ExtensionVersion { + msg := fmt.Sprintf("ExtensionVersion mismatch: compile-time ExtensionVersion value '%s' does not match value '%s' in environment variable '%s'", ExtensionVersion, extVersionInEnvVariable, vmextensionhelper.GuestAgentEnvVarExtensionVersion) + ext.ExtensionLogger.Warn(msg) + ext.ExtensionEvents.LogWarningEvent("ExtensionVersion", msg) + } + } + if len(arguments) != 2 { ext.ExtensionLogger.Error("ExtensionError", "vm-application-manager requires an argument") ext.ExtensionEvents.LogCriticalEvent("ExtensionError", "vm-application-manager requires an argument") @@ -62,6 +73,16 @@ func getExtensionAndRun(arguments []string) error { } command := arguments[1] + if command == vmextensionhelper.UpdateOperation.ToString() { + if updateToVersion, err := vmextensionhelper.GetGuestAgentEnvironmetVariable(vmextensionhelper.GuestAgentEnvVarUpdateToVersion); err == nil { + if updateToVersion != ExtensionVersion { + msg := fmt.Sprintf("ExtensionVersion mismatch: compile-time ExtensionVersion value '%s' does not match value '%s' in environment variable '%s'", ExtensionVersion, updateToVersion, vmextensionhelper.GuestAgentEnvVarUpdateToVersion) + ext.ExtensionLogger.Warn(msg) + ext.ExtensionEvents.LogWarningEvent("ExtensionVersion", msg) + } + } + } + pid := os.Getpid() ext.ExtensionEvents.LogInformationalEvent("vm-application-manager-process", fmt.Sprintf("VmApplications extension starting, PID: %d, Command: %s", pid, command)) defer ext.ExtensionEvents.LogInformationalEvent("vm-application-manager-process", fmt.Sprintf("VmApplications extension exiting, PID: %d, Command: %s", pid, command)) @@ -124,7 +145,7 @@ func getVMExtension() (*vmextensionhelper.VMExtension, error) { return nil, err } - ii.UninstallCallback = vmAppUninstallCallback + ii.UninstallCallback = nil // no need to do any special handling on uninstall, so we can set the callback to nil ii.UpdateCallback = vmAppUpdateCallback ii.LogFileNamePattern = "VmAppExt_%v.log" @@ -257,42 +278,3 @@ func customEnable(ext *vmextensionhelper.VMExtension, hostgaCommunicator hostgac return nil } - -// Callback indicating the extension is being removed -func vmAppUninstallCallback(ext *vmextensionhelper.VMExtension) error { - ext.ExtensionEvents.LogInformationalEvent("Uninstalling", "VmApplications extension - removing all applications for uninstall") - hostGaCommunicator := hostgacommunicator.HostGaCommunicator{} - err := doVmAppUninstallCallback(ext, &hostGaCommunicator) - if err == nil { - ext.ExtensionEvents.LogInformationalEvent("Completed", "VmApplications extension uninstalled. Result=Success") - } else { - ext.ExtensionEvents.LogInformationalEvent( - "Completed", - fmt.Sprintf("VmApplications extension uninstall finished. Result=Failure;Reason=%v", err.Error())) - } - return err -} - -func doVmAppUninstallCallback(ext *vmextensionhelper.VMExtension, hostGaCommunicator hostgacommunicator.IHostGaCommunicator) error { - packageRegistry, err := packageregistry.New(ext.ExtensionLogger, ext.HandlerEnv, filelockTimeoutDuration) - if err != nil { - return errors.Wrapf(err, "Could not create package registry") - } - defer packageRegistry.Close() - - currentPackageRegistry, err := packageRegistry.GetExistingPackages() - if err != nil { - return errors.Wrapf(err, "Could not read current package registry") - } - - // Create an empty incoming collection so we'll create an action plan to remove all applications - emptyIncomingCollection := make(packageregistry.VMAppPackageIncomingCollection, 0) - - actionPlan := actionplan.New(currentPackageRegistry, emptyIncomingCollection, ext.HandlerEnv, hostGaCommunicator, ext.ExtensionLogger) - commandHandler := commandhandler.CommandHandler{} - - // Removing applications is best effort, so even if there are errors here, we ignore them - _, _ = actionPlan.Execute(packageRegistry, ext.ExtensionEvents, &commandHandler) - - return nil -} diff --git a/main/main_test.go b/main/main_test.go index ddf848f..2310aa4 100644 --- a/main/main_test.go +++ b/main/main_test.go @@ -29,7 +29,6 @@ import ( handlersettings "github.com/Azure/azure-extension-platform/pkg/settings" "github.com/Azure/azure-extension-platform/pkg/status" "github.com/Azure/azure-extension-platform/vmextension" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -466,84 +465,6 @@ func Test_main_nothingToProcess_withStatus(t *testing.T) { require.Equal(t, requestedSequenceNumber, currentSequenceNumber) } -func Test_uninstall_cannotCreatePackageRegistry(t *testing.T) { - vmApplications := []extdeserialization.VmAppSetting{} - ext := createTestVMExtension(t, vmApplications) - hostGaCommunicator := NoopHostGaCommunicator{} - - // Set the config folder to an invalid path so we can't create a package registry - ext.HandlerEnv.ConfigFolder = "/yabaflarg/flarpaglarp" - - err := doVmAppUninstallCallback(ext, &hostGaCommunicator) - require.Error(t, err) - require.EqualError(t, err, cannotCreatePackageRegistryError) -} - -func Test_uninstall_cannotReadPackageRegistry(t *testing.T) { - vmApplications := []extdeserialization.VmAppSetting{} - ext := createTestVMExtension(t, vmApplications) - hostGaCommunicator := NoopHostGaCommunicator{} - - // Write an invalid registry so we can't create a package registry - appRegistryFilePath := path.Join(ext.HandlerEnv.ConfigFolder, packageregistry.LocalApplicationRegistryFileName) - ioutil.WriteFile(appRegistryFilePath, []byte("}"), 0644) - defer os.Remove(appRegistryFilePath) - - err := doVmAppUninstallCallback(ext, &hostGaCommunicator) - require.Error(t, err) - require.EqualError(t, err, "Could not read current package registry: invalid character '}' looking for beginning of value") -} - -func Test_uninstall_noAppsToUninstall(t *testing.T) { - vmApplications := []extdeserialization.VmAppSetting{} - ext := createTestVMExtension(t, vmApplications) - hostGaCommunicator := NoopHostGaCommunicator{} - - package1 := path.Join(ext.HandlerEnv.ConfigFolder, "package1") - package2 := path.Join(ext.HandlerEnv.ConfigFolder, "package2") - package1Quotes := fmt.Sprintf("\"%v\"", package1) - package2Quotes := fmt.Sprintf("\"%v\"", package2) - - // Create a package registry where the remove commands will write their respective files - reg := packageregistry.CurrentPackageRegistry{"package1": &packageregistry.VMAppPackageCurrent{ - ApplicationName: "package1", - DirectDownloadOnly: false, - InstallCommand: "dontcare", - RemoveCommand: "echo moein > " + package1Quotes, - UpdateCommand: "dontcare", - Version: "1.2.3.1", - }, "package2": &packageregistry.VMAppPackageCurrent{ - ApplicationName: "package2", - DirectDownloadOnly: true, - InstallCommand: "dontcare", - RemoveCommand: "echo moein > " + package2Quotes, - UpdateCommand: "dontcare", - Version: "1.2.3.2", - }} - - pkgHndlr, err := packageregistry.New(nopLog(), ext.HandlerEnv, time.Second) - assert.NoError(t, err, "operation should not throw error") - err = pkgHndlr.WriteToDisk(reg) - assert.NoError(t, err, "Should be able to write package registry to disk") - pkgHndlr.Close() - - err = doVmAppUninstallCallback(ext, &hostGaCommunicator) - require.NoError(t, err) - - // Verify we removed both apps, which deleted the files - require.True(t, fileExists(package1), "First application was not removed") - require.True(t, fileExists(package2), "Second application was not removed") -} - -func Test_uninstall_uninstallApps(t *testing.T) { - vmApplications := []extdeserialization.VmAppSetting{} - ext := createTestVMExtension(t, vmApplications) - hostGaCommunicator := NoopHostGaCommunicator{} - - err := doVmAppUninstallCallback(ext, &hostGaCommunicator) - require.NoError(t, err) -} - func fileExists(filePath string) bool { if _, err := os.Stat(filePath); errors.Is(err, os.ErrNotExist) { return false diff --git a/main/update.go b/main/update.go new file mode 100644 index 0000000..b69e145 --- /dev/null +++ b/main/update.go @@ -0,0 +1,101 @@ +package main + +import ( + "fmt" + "os" + "path/filepath" + "sort" + + "github.com/Azure/VMApplication-Extension/internal/packageregistry" + "github.com/pkg/errors" +) + +var ( + errorExtensionVersionDirNotFound = errors.New("could not find the directory that contains all the extension versions") + errorNoOlderPakcageRegistryFileFound = errors.New(fmt.Sprintf("could not find an older '%s' file", packageregistry.LocalApplicationRegistryFileName)) + emptyPackageRegistryContent = []byte("[]") +) + +type FileInfoWithFilePath struct { + fileInfo os.FileInfo + filePath string +} + +type SortableFileInfoImpl struct { + FileInfoArray []FileInfoWithFilePath +} +type SortableFileInfo interface { + Len() int + Less(i, j int) bool + Swap(i, j int) +} + +func (sortableFileInfo SortableFileInfoImpl) Len() int { + return len(sortableFileInfo.FileInfoArray) +} + +func (sortableFileInfo SortableFileInfoImpl) Less(i, j int) bool { + return sortableFileInfo.FileInfoArray[i].fileInfo.ModTime().Before(sortableFileInfo.FileInfoArray[j].fileInfo.ModTime()) +} + +func (sortableFileInfo SortableFileInfoImpl) Swap(i, j int) { + swapVar := sortableFileInfo.FileInfoArray[i] + sortableFileInfo.FileInfoArray[i] = sortableFileInfo.FileInfoArray[j] + sortableFileInfo.FileInfoArray[j] = swapVar +} + +// splitPathAroundVersionedDir splits dirpath into (head, versionedDirName, tail) by walking up to find an ancestor +// directory whose name matches one of the version-checking functions. +func splitPathAroundVersionedDir(dirpath string, + dirnameCheckers []func(currentFolderName string) bool) ( + head, + versionedDirName, + tail string, + errorToReturn error) { + // contains an array of comparison functions that will be run to determine the version dir + // to have robustness, if the first way of comparison fails, use the next one + + for _, checkDirName := range dirnameCheckers { + relativePathToConfigFolder := "" + for currentFolderPath := dirpath; currentFolderPath != filepath.Dir(currentFolderPath); currentFolderPath = filepath.Dir(currentFolderPath) { + currentFolderName := filepath.Base(currentFolderPath) + if checkDirName(currentFolderName) { + head = filepath.Dir(currentFolderPath) + versionedDirName = currentFolderName + tail = relativePathToConfigFolder + errorToReturn = nil + return + } + relativePathToConfigFolder = filepath.Join(currentFolderName, relativePathToConfigFolder) + } + } + head = "" + versionedDirName = "" + tail = "" + errorToReturn = errorExtensionVersionDirNotFound + return +} + +func getMostRecentlyUpdatedPackageRegistryFile(dirContainingAllVersions string, intermediatePath string, expectedDirNamePatternChecker func(string) bool) (string, error) { + fileInfo, err := os.ReadDir(dirContainingAllVersions) //reads directory and returns content in sorted order + if err != nil { + return "", err + } + sortableRegistryFileInfo := SortableFileInfoImpl{ + FileInfoArray: []FileInfoWithFilePath{}, + } + for _, fileInfo := range fileInfo { + if fileInfo.IsDir() && fileInfo.Name() != ExtensionVersion && expectedDirNamePatternChecker(fileInfo.Name()) { + registryFilePath := filepath.Join(dirContainingAllVersions, fileInfo.Name(), intermediatePath, packageregistry.LocalApplicationRegistryFileName) + registryFileInfo, err := os.Stat(registryFilePath) + if err == nil { + sortableRegistryFileInfo.FileInfoArray = append(sortableRegistryFileInfo.FileInfoArray, FileInfoWithFilePath{registryFileInfo, registryFilePath}) + } + } + } + if sortableRegistryFileInfo.Len() < 1 { + return "", errorNoOlderPakcageRegistryFileFound + } + sort.Sort(sortableRegistryFileInfo) + return sortableRegistryFileInfo.FileInfoArray[len(sortableRegistryFileInfo.FileInfoArray)-1].filePath, nil +} diff --git a/main/update_linux.go b/main/update_linux.go index 1c0fea4..3c13e87 100644 --- a/main/update_linux.go +++ b/main/update_linux.go @@ -4,10 +4,112 @@ package main import ( + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/Azure/VMApplication-Extension/internal/packageregistry" + "github.com/Azure/VMApplication-Extension/pkg/utils" vmextensionhelper "github.com/Azure/azure-extension-platform/vmextension" ) +// package registry file is in the config dir, which has the pattern +// /var/lib/waagent/Microsoft.CPlat.Core.VMApplicationManagerLinux-/config +// need to move it from an older version to the current one, if it exists func vmAppUpdateCallback(ext *vmextensionhelper.VMExtension) error { - //no-op function + + packageRegistryFilePathForCurrentVersion := filepath.Join(ext.HandlerEnv.ConfigFolder, packageregistry.LocalApplicationRegistryFileName) + _, err := os.Stat(packageRegistryFilePathForCurrentVersion) + if !os.IsNotExist(err) { + msg := fmt.Sprintf("package registry file '%s' already exists for current version, no need to copy from older version, update operation completed.", packageRegistryFilePathForCurrentVersion) + ext.ExtensionLogger.Info(msg) + ext.ExtensionEvents.LogInformationalEvent("ExtensionUpdate", msg) + return nil + } + + head, versionedDirName, tail, err := splitPathAroundVersionedDirLinux(ext.HandlerEnv.ConfigFolder) + if err != nil { + return err + } + dirnameChecker := getDirNameCheckerWithKnownExtensionVersion(ExtensionVersion) + if !dirnameChecker(versionedDirName) { + msg := fmt.Sprintf("ExtensionVersion '%s' is not part of the ext.HandlerEnv.ConfigFolder path '%s'", ExtensionVersion, ext.HandlerEnv.ConfigFolder) + ext.ExtensionLogger.Warn(msg) + ext.ExtensionEvents.LogWarningEvent("ExtensionUpdate", msg) + } + + previousPackageRegistryFilePath, err := getMostRecentlyUpdatedPackageRegistryFile(head, tail, getDirNameCheckerWithExtensionVersionPattern()) + if err != nil { + return err + } + + previousPackageRegistryContent, err := os.ReadFile(previousPackageRegistryFilePath) + if err != nil { + return err + } + + // Creates and writes previous registry content to package registry file for new extension version + err = os.WriteFile(packageRegistryFilePathForCurrentVersion, previousPackageRegistryContent, 0666) + if err != nil { + return err + } + msg := fmt.Sprintf("successfully copied package registry file from '%s' to '%s'", previousPackageRegistryFilePath, packageRegistryFilePathForCurrentVersion) + ext.ExtensionLogger.Info(msg) + ext.ExtensionEvents.LogInformationalEvent("ExtensionUpdate", msg) + + // Overwrite the package registry for older version to be an empty list of applications + err = os.WriteFile(previousPackageRegistryFilePath, emptyPackageRegistryContent, 0666) + if err == nil { + msg = fmt.Sprintf("successfully cleared package registry file for older version at '%s'", previousPackageRegistryFilePath) + ext.ExtensionLogger.Info(msg) + ext.ExtensionEvents.LogInformationalEvent("ExtensionUpdate", msg) + } + return nil } + +// splitPathAroundVersionedDirLinux splits dirpath into (head, versionedDirName, tail) by walking up to find an ancestor +// directory whose name matches ExtensionName- (e.g. "Microsoft.CPlat.Core.VMApplicationManagerLinux-1.0.10"). +func splitPathAroundVersionedDirLinux(dirpath string) (head, versionedDirName, tail string, errorToReturn error) { + // contains an array of comparison functions that will be run to determine the version dir + // to have robustness, if the first way of comparison fails, use the next one + var dirnameCheckers []func(currentFolderName string) bool + + currentExtensionVersion, err := vmextensionhelper.GetGuestAgentEnvironmetVariable(vmextensionhelper.GuestAgentEnvVarExtensionVersion) + if err == nil { + // checks against 'current extension version' populated by Guest Agent + dirnameCheckers = append(dirnameCheckers, getDirNameCheckerWithKnownExtensionVersion(currentExtensionVersion)) + } + + updateExtensionVersion, err := vmextensionhelper.GetGuestAgentEnvironmetVariable(vmextensionhelper.GuestAgentEnvVarUpdateToVersion) + if err == nil { + // checks against 'extension version to update' populated by Guest Agent + dirnameCheckers = append(dirnameCheckers, getDirNameCheckerWithKnownExtensionVersion(updateExtensionVersion)) + } + + // check against extension version variable + dirnameCheckers = append(dirnameCheckers, getDirNameCheckerWithKnownExtensionVersion(ExtensionVersion)) + + // check against extension version pattern + dirnameCheckers = append(dirnameCheckers, getDirNameCheckerWithExtensionVersionPattern()) + + return splitPathAroundVersionedDir(dirpath, dirnameCheckers) +} + +func getDirNameCheckerWithKnownExtensionVersion(extensionVersion string) func(currentDirName string) bool { + expectedDirName := ExtensionName + "-" + extensionVersion + return func(currentDirName string) bool { + return strings.EqualFold(expectedDirName, currentDirName) + } +} + +func getDirNameCheckerWithExtensionVersionPattern() func(currentDirName string) bool { + return func(currentDirName string) bool { + if strings.HasPrefix(currentDirName, ExtensionName+"-") { + versionPart := strings.TrimPrefix(currentDirName, ExtensionName+"-") + return utils.IsValidVersionString(versionPart) + } + return false + } +} diff --git a/main/update_linux_test.go b/main/update_linux_test.go new file mode 100644 index 0000000..7c7e805 --- /dev/null +++ b/main/update_linux_test.go @@ -0,0 +1,304 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +package main + +import ( + "bytes" + "os" + "path" + "path/filepath" + "testing" + "time" + + "github.com/Azure/VMApplication-Extension/internal/extdeserialization" + "github.com/Azure/VMApplication-Extension/internal/packageregistry" + vmextension "github.com/Azure/azure-extension-platform/vmextension" + "github.com/stretchr/testify/require" +) + +// createTestFilesLinux creates a directory structure simulating multiple extension versions on Linux: +// +// folderPath/ExtensionName-1.0.1/configFolderName/fileName (badcontent) +// folderPath/ExtensionName-1.0.3/configFolderName/fileName (badcontent) +// folderPath/ExtensionName-0.0.1/configFolderName/fileName (Test File Contents) — most recently modified +func createTestFilesLinux(folderPath, configFolderName, fileName string) error { + for _, ver := range []string{"1.0.1", "0.0.1", "1.0.3"} { + dirName := ExtensionName + "-" + ver + err := os.MkdirAll(filepath.Join(folderPath, dirName, configFolderName), 0755) + if err != nil { + return err + } + } + + testContent := []byte("badcontent") + err := os.WriteFile(filepath.Join(folderPath, ExtensionName+"-1.0.1", configFolderName, fileName), testContent, 0777) + if err != nil { + return err + } + err = os.WriteFile(filepath.Join(folderPath, ExtensionName+"-1.0.3", configFolderName, fileName), testContent, 0777) + if err != nil { + return err + } + testContent = []byte("Test File Contents") + time.Sleep(time.Second) + err = os.WriteFile(filepath.Join(folderPath, ExtensionName+"-0.0.1", configFolderName, fileName), testContent, 0777) + if err != nil { + return err + } + + return nil +} + +func Test_noInfiniteLoops(t *testing.T) { + ExtensionName = "TestExtension" + defer func() { ExtensionName = "" }() + + order := 1 + vmApplications := []extdeserialization.VmAppSetting{ + { + ApplicationName: "iggy", + Order: &order, + }, + } + ext := createTestVMExtension(t, vmApplications) + + // this overwrite creates a path that does not contain a version folder, so the update function should return an error instead of infinitely looping + ext.HandlerEnv.ConfigFolder = filepath.Join(ext.HandlerEnv.ConfigFolder, "someRandomFolder", "random2", "random3", "config") + + //call update + err := vmAppUpdateCallback(ext) + require.ErrorIs(t, err, errorExtensionVersionDirNotFound) +} + +func Test_findVersionDir_fallsBackThroughComparisonFunctions(t *testing.T) { + ExtensionName = "TestExtension" + extensionVersionOriginalValue := ExtensionVersion + + defer func() { + // revert it to what the other tests might expect after this test is run + ExtensionVersion = "1.0.10" + ExtensionName = "" + }() + + // Create a directory structure: /ExtensionName-1.0.10/config + root := t.TempDir() + expectedVersionedDirName := ExtensionName + "-" + extensionVersionOriginalValue + versionDir := filepath.Join(root, expectedVersionedDirName, "config") + err := os.MkdirAll(versionDir, 0755) + require.NoError(t, err) + + // Subtest 1: env vars not set — falls back to ExtensionVersion match + t.Run("no_env_vars_uses_ExtensionVersion_match", func(t *testing.T) { + os.Unsetenv(string(vmextension.GuestAgentEnvVarExtensionVersion)) + os.Unsetenv(string(vmextension.GuestAgentEnvVarUpdateToVersion)) + + parent, dirWithVersion, relPath, err := splitPathAroundVersionedDirLinux(versionDir) + require.NoError(t, err) + require.Equal(t, root, parent) + require.Equal(t, "config", relPath) + require.Equal(t, expectedVersionedDirName, dirWithVersion) + }) + + // Subtest 2: AZURE_GUEST_AGENT_EXTENSION_VERSION matches — uses first checker + t.Run("extension_version_env_var_matches", func(t *testing.T) { + t.Setenv(string(vmextension.GuestAgentEnvVarExtensionVersion), extensionVersionOriginalValue) + os.Unsetenv(string(vmextension.GuestAgentEnvVarUpdateToVersion)) + + parent, dirWithVersion, relPath, err := splitPathAroundVersionedDirLinux(versionDir) + require.NoError(t, err) + require.Equal(t, root, parent) + require.Equal(t, "config", relPath) + extensionVersionFromEnv, err := vmextension.GetGuestAgentEnvironmetVariable(vmextension.GuestAgentEnvVarExtensionVersion) + require.NoError(t, err) + require.Equal(t, ExtensionName+"-"+extensionVersionFromEnv, dirWithVersion) + }) + + // Subtest 3: VERSION env var matches — uses second checker + t.Run("update_to_version_env_var_matches", func(t *testing.T) { + os.Unsetenv(string(vmextension.GuestAgentEnvVarExtensionVersion)) + t.Setenv(string(vmextension.GuestAgentEnvVarUpdateToVersion), extensionVersionOriginalValue) + + parent, dirWithVersion, relPath, err := splitPathAroundVersionedDirLinux(versionDir) + require.NoError(t, err) + require.Equal(t, root, parent) + require.Equal(t, "config", relPath) + updateToVersionFromEnv, err := vmextension.GetGuestAgentEnvironmetVariable(vmextension.GuestAgentEnvVarUpdateToVersion) + require.NoError(t, err) + require.Equal(t, ExtensionName+"-"+updateToVersionFromEnv, dirWithVersion) + }) + + // Subtest 4: env vars set to wrong values — falls back to ExtensionVersion match + t.Run("env_vars_wrong_falls_back_to_ExtensionVersion", func(t *testing.T) { + t.Setenv(string(vmextension.GuestAgentEnvVarExtensionVersion), "9.9.9") + t.Setenv(string(vmextension.GuestAgentEnvVarUpdateToVersion), "8.8.8") + + parent, dirWithVersion, relPath, err := splitPathAroundVersionedDirLinux(versionDir) + require.NoError(t, err) + require.Equal(t, root, parent) + require.Equal(t, "config", relPath) + require.Equal(t, expectedVersionedDirName, dirWithVersion) + }) + + // Subtest 5: env vars set to wrong values, ExtensionVersion doesn't match — falls back to pattern match + t.Run("env_vars_wrong_falls_back_to_pattern", func(t *testing.T) { + t.Setenv(string(vmextension.GuestAgentEnvVarExtensionVersion), "9.9.9") + t.Setenv(string(vmextension.GuestAgentEnvVarUpdateToVersion), "8.8.8") + + ExtensionVersion = "1.0.0" // the directory was created with extension version 1.0.10, this should fail to match + + parent, dirWithVersion, relPath, err := splitPathAroundVersionedDirLinux(versionDir) + require.NoError(t, err) + require.Equal(t, root, parent) + require.Equal(t, "config", relPath) + require.Equal(t, expectedVersionedDirName, dirWithVersion) // should still find the version dir based on pattern match even though env vars and ExtensionVersion value don't match + }) + + // Subtest 6: no version dir in path and no env vars — should return error + t.Run("no_version_dir_returns_error", func(t *testing.T) { + os.Unsetenv(string(vmextension.GuestAgentEnvVarExtensionVersion)) + os.Unsetenv(string(vmextension.GuestAgentEnvVarUpdateToVersion)) + + noVersionPath := filepath.Join(t.TempDir(), "noVersion", "data") + err := os.MkdirAll(noVersionPath, 0755) + require.NoError(t, err) + + _, _, _, err = splitPathAroundVersionedDirLinux(noVersionPath) + require.ErrorIs(t, err, errorExtensionVersionDirNotFound) + }) +} + +func Test_cannotFindPackageConfigFile(t *testing.T) { + ExtensionName = "TestExtension" + defer func() { ExtensionName = "" }() + + order := 1 + vmApplications := []extdeserialization.VmAppSetting{ + { + ApplicationName: "iggy", + Order: &order, + }, + } + ext := createTestVMExtension(t, vmApplications) + + //set up test files + configFolderName := "config" + ext.HandlerEnv.ConfigFolder = filepath.Join(ext.HandlerEnv.ConfigFolder, ExtensionName+"-"+ExtensionVersion, configFolderName) + + //call update + err := vmAppUpdateCallback(ext) + require.ErrorIs(t, err, errorNoOlderPakcageRegistryFileFound) +} + +func Test_existingPackageRegistryFileIsNotOverwritten(t *testing.T) { + ExtensionName = "TestExtension" + defer func() { ExtensionName = "" }() + + ext := createTestVMExtension(t, []extdeserialization.VmAppSetting{}) + + configFolderName := "config" + testFolderPath := t.TempDir() + ext.HandlerEnv.ConfigFolder = filepath.Join(testFolderPath, ExtensionName+"-"+ExtensionVersion, configFolderName) + err := os.MkdirAll(ext.HandlerEnv.ConfigFolder, 0755) + require.NoError(t, err) + fileName := packageregistry.LocalApplicationRegistryFileName + err = createTestFilesLinux(testFolderPath, configFolderName, fileName) + require.NoError(t, err) + + fileBytes := []byte("special message") + packageRegistryFilePath := path.Join(ext.HandlerEnv.ConfigFolder, packageregistry.LocalApplicationRegistryFileName) + err = os.WriteFile(packageRegistryFilePath, fileBytes, 0777) + require.NoError(t, err) + err = vmAppUpdateCallback(ext) + require.NoError(t, err) + // verify file was not overwritten + readBytes, err := os.ReadFile(packageRegistryFilePath) + require.NoError(t, err) + require.True(t, bytes.Equal(fileBytes, readBytes)) +} + +func Test_vmAppUpdateCallback_endToEnd(t *testing.T) { + ExtensionName = "TestExtension" + defer func() { ExtensionName = "" }() + + ext := createTestVMExtension(t, []extdeserialization.VmAppSetting{}) + + // --- Set up config folder structure with multiple old versions --- + configRoot := t.TempDir() + configFolderName := "config" + currentConfigDir := filepath.Join(configRoot, ExtensionName+"-"+ExtensionVersion, configFolderName) + err := os.MkdirAll(currentConfigDir, 0755) + require.NoError(t, err) + ext.HandlerEnv.ConfigFolder = currentConfigDir + + // Creates ExtensionName-{1.0.1,0.0.1,1.0.3}/config/ with placeholder files; + // 0.0.1 is written last so it has the most recent modification time + fileName := packageregistry.LocalApplicationRegistryFileName + err = createTestFilesLinux(configRoot, configFolderName, fileName) + require.NoError(t, err) + + // Overwrite 1.0.3's registry file with real package data using the package registry API. + // WriteToDisk updates the file's modification time, making 1.0.3 the most recently updated. + mostRecentOldConfigDir := filepath.Join(configRoot, ExtensionName+"-1.0.3", configFolderName) + ext.HandlerEnv.ConfigFolder = mostRecentOldConfigDir + oldPkr, err := packageregistry.New(ext.ExtensionLogger, ext.HandlerEnv, 1*time.Second) + require.NoError(t, err) + + oldPackages := packageregistry.CurrentPackageRegistry{ + "appA": &packageregistry.VMAppPackageCurrent{ + ApplicationName: "appA", + Version: "1.0", + DownloadDir: "/var/lib/waagent/downloads/appA/1.0", + }, + "appB": &packageregistry.VMAppPackageCurrent{ + ApplicationName: "appB", + Version: "2.0", + DownloadDir: "/var/lib/waagent/downloads/appB/2.0", + }, + } + time.Sleep(time.Second) // sleep before writing to ensure this file has the most recent mod time + err = oldPkr.WriteToDisk(oldPackages) + require.NoError(t, err) + err = oldPkr.Close() + require.NoError(t, err) + + // Read back the old registry content before update + oldRegistryPath := filepath.Join(mostRecentOldConfigDir, packageregistry.LocalApplicationRegistryFileName) + oldFileContents, err := os.ReadFile(oldRegistryPath) + require.NoError(t, err) + + // Restore ConfigFolder to the current version + ext.HandlerEnv.ConfigFolder = currentConfigDir + + // --- Call vmAppUpdateCallback --- + err = vmAppUpdateCallback(ext) + require.NoError(t, err) + + // --- Validation 1: Package registry file was copied from the old version --- + newFileContents, err := os.ReadFile(filepath.Join(currentConfigDir, packageregistry.LocalApplicationRegistryFileName)) + require.NoError(t, err) + require.True(t, bytes.Equal(oldFileContents, newFileContents), + "package registry content should be copied from the old version") + + // --- Validation 2: Old version's registry should be emptied --- + oldFileContentsAfterUpdate, err := os.ReadFile(oldRegistryPath) + require.NoError(t, err) + require.True(t, bytes.Equal([]byte("[]"), oldFileContentsAfterUpdate), + "old version's package registry should be overwritten with empty list") + + // --- Validation 3: Copied registry should be readable and contain the expected packages --- + pkr, err := packageregistry.New(ext.ExtensionLogger, ext.HandlerEnv, 1*time.Second) + require.NoError(t, err) + defer pkr.Close() + + packages, err := pkr.GetExistingPackages() + require.NoError(t, err) + require.Len(t, packages, len(oldPackages)) + for name, expected := range oldPackages { + actual, ok := packages[name] + require.True(t, ok, "expected package %s not found in copied registry", name) + require.Equal(t, expected.Version, actual.Version) + require.Equal(t, expected.DownloadDir, actual.DownloadDir) + require.Equal(t, expected.ApplicationName, actual.ApplicationName) + } +} diff --git a/main/update_windows.go b/main/update_windows.go index dd02699..24b0764 100644 --- a/main/update_windows.go +++ b/main/update_windows.go @@ -5,108 +5,75 @@ package main import ( "fmt" - "io/ioutil" "os" - "path" + "os/exec" "path/filepath" "regexp" - "sort" "strings" "github.com/Azure/VMApplication-Extension/internal/packageregistry" + "github.com/Azure/VMApplication-Extension/pkg/utils" vmextensionhelper "github.com/Azure/azure-extension-platform/vmextension" - "github.com/pkg/errors" ) -var ( - errorExtensionVersionDirNotFound = errors.New("could not find the directory that contains all the extension versions") - errorNoOlderPakcageRegistryFileFound = errors.New(fmt.Sprintf("could not find an older '%s' file", packageregistry.LocalApplicationRegistryFileName)) - versionNumberRegx, _ = regexp.Compile(`[0-9]+\.[0-9]+\.[0-9]+`) - emptyPackageRegistryContent = []byte("[]") -) - -type FileInfoWithFilePath struct { - fileInfo os.FileInfo - filePath string -} +// splitPathAroundVersionedDirWindows splits dirpath into (head, versionedDirName, tail) by walking up to find an ancestor +// directory whose name is a bare version string (e.g. "1.0.10"). +func splitPathAroundVersionedDirWindows(dirpath string) (head, versionedDirName, tail string, errorToReturn error) { + // contains an array of comparison functions that will be run to determine the version dir + // to have robustness, if the first way of comparison fails, use the next one + var dirnameCheckers []func(currentFolderName string) bool + + currentExtensionVersion, err := vmextensionhelper.GetGuestAgentEnvironmetVariable(vmextensionhelper.GuestAgentEnvVarExtensionVersion) + if err == nil { + // checks against 'current extension version' populated by Guest Agent + dirnameCheckers = append(dirnameCheckers, getCaseInsensitiveStringEqualityChecker(currentExtensionVersion)) + } -type SortableFileInfoImpl struct { - FileInfoArray []FileInfoWithFilePath -} -type SortableFileInfo interface { - Len() int - Less(i, j int) bool - Swap(i, j int) -} + updateExtensionVersion, err := vmextensionhelper.GetGuestAgentEnvironmetVariable(vmextensionhelper.GuestAgentEnvVarUpdateToVersion) + if err == nil { + // checks against 'extension version to update' populated by Guest Agent + dirnameCheckers = append(dirnameCheckers, getCaseInsensitiveStringEqualityChecker(updateExtensionVersion)) + } -func (sortableFileInfo SortableFileInfoImpl) Len() int { - return len(sortableFileInfo.FileInfoArray) -} + // check against extension version variable + dirnameCheckers = append(dirnameCheckers, getCaseInsensitiveStringEqualityChecker(ExtensionVersion)) -func (sortableFileInfo SortableFileInfoImpl) Less(i, j int) bool { - return sortableFileInfo.FileInfoArray[i].fileInfo.ModTime().Before(sortableFileInfo.FileInfoArray[j].fileInfo.ModTime()) -} + // check against extension version pattern + dirnameCheckers = append(dirnameCheckers, utils.IsValidVersionString) -func (sortableFileInfo SortableFileInfoImpl) Swap(i, j int) { - swapVar := sortableFileInfo.FileInfoArray[i] - sortableFileInfo.FileInfoArray[i] = sortableFileInfo.FileInfoArray[j] - sortableFileInfo.FileInfoArray[j] = swapVar + return splitPathAroundVersionedDir(dirpath, dirnameCheckers) } -func getMostRecentlyUpdatedPackageRegistryFile(dirContainingAllVersions string, intermediatePath string) (string, error) { - fileInfo, err := ioutil.ReadDir(dirContainingAllVersions) //reads directory and returns content in sorted order - if err != nil { - return "", err +func getCaseInsensitiveStringEqualityChecker(knownString string) func(currentString string) bool { + return func(currentString string) bool { + return strings.EqualFold(knownString, currentString) } - sortableRegistryFileInfo := SortableFileInfoImpl{ - FileInfoArray: []FileInfoWithFilePath{}, - } - for _, fileInfo := range fileInfo { - if fileInfo.IsDir() && fileInfo.Name() != ExtensionVersion && versionNumberRegx.MatchString(fileInfo.Name()) { - registryFilePath := path.Join(dirContainingAllVersions, fileInfo.Name(), intermediatePath, packageregistry.LocalApplicationRegistryFileName) - registryFileInfo, err := os.Stat(registryFilePath) - if err == nil { - sortableRegistryFileInfo.FileInfoArray = append(sortableRegistryFileInfo.FileInfoArray, FileInfoWithFilePath{registryFileInfo, registryFilePath}) - } - } - } - if sortableRegistryFileInfo.Len() < 1 { - return "", errorNoOlderPakcageRegistryFileFound - } - sort.Sort(sortableRegistryFileInfo) - return sortableRegistryFileInfo.FileInfoArray[len(sortableRegistryFileInfo.FileInfoArray)-1].filePath, nil } func vmAppUpdateCallback(ext *vmextensionhelper.VMExtension) error { // for extension update on windows, we retrieve the applicationRegistry.active file from a previous version of the extension - folderPath := ext.HandlerEnv.ConfigFolder - currentFolderName := "" - pathToFile := "" packageRegistryFilePathForCurrentVersion := filepath.Join(ext.HandlerEnv.ConfigFolder, packageregistry.LocalApplicationRegistryFileName) _, err := os.Stat(packageRegistryFilePathForCurrentVersion) if !os.IsNotExist(err) { - // a package registry file already exists for current version, nothing to do + msg := fmt.Sprintf("package registry file '%s' already exists for current version, no need to copy from older version, update operation completed.", packageRegistryFilePathForCurrentVersion) + ext.ExtensionLogger.Info(msg) + ext.ExtensionEvents.LogInformationalEvent("ExtensionUpdate", msg) return nil } - //loop to find directory that contains current version - breakLoopAfter := 5 - for i := 0; ; i++ { - currentFolderName = filepath.Base(folderPath) - if strings.Contains(currentFolderName, ExtensionVersion) { - break - } - pathToFile = filepath.Join(currentFolderName, pathToFile) //keeping track of full path to file - folderPath = filepath.Dir(folderPath) //update folderpath to walk up directory - if i > breakLoopAfter { - return errorExtensionVersionDirNotFound - } + folderPathThatContainsAllTheVersions, versionedDirName, relativePathToConfigFolder, err := splitPathAroundVersionedDirWindows(ext.HandlerEnv.ConfigFolder) + if err != nil { + return err + } + dirnameChecker := getCaseInsensitiveStringEqualityChecker(ExtensionVersion) + if !dirnameChecker(versionedDirName) { + msg := fmt.Sprintf("ExtensionVersion '%s' is not part of the ext.HandlerEnv.ConfigFolder path '%s'", ExtensionVersion, ext.HandlerEnv.ConfigFolder) + ext.ExtensionLogger.Warn(msg) + ext.ExtensionEvents.LogWarningEvent("ExtensionUpdate", msg) } - folderPath = filepath.Dir(folderPath) //folder that contains all the versions - - previousPackageRegistryFilePath, err := getMostRecentlyUpdatedPackageRegistryFile(folderPath, pathToFile) + previousPackageRegistryFilePath, err := getMostRecentlyUpdatedPackageRegistryFile(folderPathThatContainsAllTheVersions, relativePathToConfigFolder, utils.IsValidVersionString) if err != nil { return err } @@ -116,20 +83,157 @@ func vmAppUpdateCallback(ext *vmextensionhelper.VMExtension) error { return err } + // Creates and writes previous registry content to package registry file for new extension version + err = os.WriteFile(packageRegistryFilePathForCurrentVersion, previousPackageRegistryContent, 0666) + if err != nil { + return err + } + msg := fmt.Sprintf("successfully copied package registry file from '%s' to '%s'", previousPackageRegistryFilePath, packageRegistryFilePathForCurrentVersion) + ext.ExtensionLogger.Info(msg) + ext.ExtensionEvents.LogInformationalEvent("ExtensionUpdate", msg) + // Overwrite the package registry for older version to be an empty list of applications - // This prevents the uninstall operation for older extension removing installed VM Apps - // Set file contents for older package registry prior to newer one in order to ensure most recently - // updated package registry corresponds to the newest version err = os.WriteFile(previousPackageRegistryFilePath, emptyPackageRegistryContent, 0666) + if err == nil { + msg = fmt.Sprintf("successfully cleared package registry file for older version at '%s'", previousPackageRegistryFilePath) + ext.ExtensionLogger.Info(msg) + ext.ExtensionEvents.LogInformationalEvent("ExtensionUpdate", msg) + } + + // do the following operations in a best effort manner + err = moveDownloadDirToCurrentVersion(ext) + if err != nil { + ext.ExtensionLogger.Warn("Failed to move download directory to current version with error: %v", err) + ext.ExtensionEvents.LogWarningEvent("vm-application-manager-update", fmt.Sprintf("Failed to move download directory to current version with error: %v", err)) + } else { + msg = "successfully moved download directory to current version" + ext.ExtensionLogger.Info(msg) + ext.ExtensionEvents.LogInformationalEvent("ExtensionUpdate", msg) + + if err = updateDownloadDirInPackageRegistryFile(ext); err != nil { + ext.ExtensionLogger.Warn("Failed to update download directory in package registry file with error: %v", err) + ext.ExtensionEvents.LogWarningEvent("vm-application-manager-update", fmt.Sprintf("Failed to update download directory in package registry file with error: %v", err)) + } else { + msg = "successfully updated download directory paths in package registry file" + ext.ExtensionLogger.Info(msg) + ext.ExtensionEvents.LogInformationalEvent("ExtensionUpdate", msg) + } + } + return nil +} + +func updateDownloadDirInPackageRegistryFile(ext *vmextensionhelper.VMExtension) error { + packageRegistry, err := packageregistry.New(ext.ExtensionLogger, ext.HandlerEnv, filelockTimeoutDuration) + if err != nil { + return err + } + defer packageRegistry.Close() + existingPackages, err := packageRegistry.GetExistingPackages() if err != nil { return err } - // Creates and writes previous registry content to package registry file for new extension version - err = os.WriteFile(packageRegistryFilePathForCurrentVersion, previousPackageRegistryContent, 0666) + if len(existingPackages) == 0 { + return nil + } + + downloadDirBeforeVersion, _, downloadDirAfterVersion, err := splitPathAroundVersionedDirWindows(ext.HandlerEnv.DataFolder) if err != nil { return err } + // Build a regex that matches: // in DownloadDir paths + // Use forward slashes since DownloadDir is stored with filepath.ToSlash + escapedPrefix := regexp.QuoteMeta(filepath.ToSlash(downloadDirBeforeVersion)) + escapedSuffix := regexp.QuoteMeta(filepath.ToSlash(downloadDirAfterVersion)) + downloadDirVersionRegex := regexp.MustCompile(escapedPrefix + `/[^/]+/` + escapedSuffix) + replacement := filepath.ToSlash(downloadDirBeforeVersion) + "/" + ExtensionVersion + "/" + filepath.ToSlash(downloadDirAfterVersion) + + for packageName, packageInfo := range existingPackages { + normalized := filepath.ToSlash(packageInfo.DownloadDir) + updated := downloadDirVersionRegex.ReplaceAllLiteralString(normalized, replacement) + if updated == normalized { + ext.ExtensionLogger.Warn("Could not update downloadDir for package '%s', no version segment matched", packageName) + ext.ExtensionEvents.LogWarningEvent("vm-application-manager-update", fmt.Sprintf("Could not update downloadDir for package '%s', no version segment matched", packageName)) + } else { + packageInfo.DownloadDir = filepath.FromSlash(updated) + } + } + + return packageRegistry.WriteToDisk(existingPackages) +} + +// move the download directory from old version to new version +func moveDownloadDirToCurrentVersion(ext *vmextensionhelper.VMExtension) error { + packageRegistry, err := packageregistry.New(ext.ExtensionLogger, ext.HandlerEnv, filelockTimeoutDuration) + if err != nil { + return err + } + defer packageRegistry.Close() + + rootOfAllVersions, versionedDirName, relativePathAfterVersion, err := splitPathAroundVersionedDirWindows(ext.HandlerEnv.DataFolder) + if err != nil { + return err + } + + if !strings.EqualFold(versionedDirName, ExtensionVersion) { + msg := fmt.Sprintf("ExtensionVersion mismatch: ext.HandlerEnv.DataFolder path '%s' does contain versionedDirName '%s'", ext.HandlerEnv.DataFolder, versionedDirName) + ext.ExtensionLogger.Warn(msg) + ext.ExtensionEvents.LogWarningEvent("ExtensionVersion", msg) + } + + entries, err := os.ReadDir(rootOfAllVersions) + if err != nil { + return err + } + var downloadDirectoryForAllVersions []string + for _, entry := range entries { + if entry.IsDir() && utils.IsValidVersionString(entry.Name()) { + dirName := filepath.Join(rootOfAllVersions, entry.Name(), relativePathAfterVersion) + downloadDir, err := os.Stat(dirName) + if err != nil { + ext.ExtensionLogger.Warn("Skipping directory %s when looking for download directories to move, with error: %v", dirName, err) + ext.ExtensionEvents.LogWarningEvent("vm-application-manager-update", fmt.Sprintf("Skipping directory %s when looking for download directories to move, with error: %v", dirName, err)) + continue + } + if !downloadDir.IsDir() || entry.Name() == ExtensionVersion { + //if the config folder is not a directory, or if the version folder is the same as the current version, then skip it + continue + } + downloadDirectoryForAllVersions = append(downloadDirectoryForAllVersions, dirName) + } + } + + for _, downloadDir := range downloadDirectoryForAllVersions { + directoryContents, err := os.ReadDir(downloadDir) + if err != nil { + ext.ExtensionLogger.Warn("Failed to read directory %s with error: %v", downloadDir, err) + ext.ExtensionEvents.LogWarningEvent("vm-application-manager-update", fmt.Sprintf("Failed to read directory %s with error: %v", downloadDir, err)) + continue + } + for _, entry := range directoryContents { + if entry.IsDir() { + sourceDirFullPath := filepath.Join(downloadDir, entry.Name()) + destDirFullPath := filepath.Join(ext.HandlerEnv.DataFolder, entry.Name()) + err = copySubdirectoryUsingRobocopy(sourceDirFullPath, destDirFullPath) + // copy the directory from current entry to ext.HandlerEnv.DataFolder + if err != nil { + ext.ExtensionLogger.Warn("Failed to copy directory from %s to %s with error: %s", sourceDirFullPath, destDirFullPath, err.Error()) + ext.ExtensionEvents.LogWarningEvent("vm-application-manager-update", fmt.Sprintf("Failed to copy directory from %s to %s with error: %s", sourceDirFullPath, destDirFullPath, err.Error())) + } + } + } + } + return nil } + +func copySubdirectoryUsingRobocopy(src, dst string) error { + cmd := exec.Command("robocopy", src, dst, "/E", "/sl", "/NFL", "/NDL", "/NJH", "/NJS") + err := cmd.Run() + // robocopy exit codes 0-7 are success/informational + if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() < 8 { + return nil + } + return err +} diff --git a/main/update_windows_test.go b/main/update_windows_test.go index 9fdb247..425de28 100644 --- a/main/update_windows_test.go +++ b/main/update_windows_test.go @@ -5,7 +5,6 @@ package main import ( "bytes" - "io/ioutil" "os" "path" "path/filepath" @@ -14,7 +13,8 @@ import ( "github.com/Azure/VMApplication-Extension/internal/extdeserialization" "github.com/Azure/VMApplication-Extension/internal/packageregistry" - "github.com/stretchr/testify/assert" + vmextension "github.com/Azure/azure-extension-platform/vmextension" + "github.com/stretchr/testify/require" ) func Test_didFileMove(t *testing.T) { @@ -33,29 +33,29 @@ func Test_didFileMove(t *testing.T) { testFolderPath := ext.HandlerEnv.ConfigFolder //path to create test version folders ext.HandlerEnv.ConfigFolder = filepath.Join(ext.HandlerEnv.ConfigFolder, ExtensionVersion, runtimeFolderName) //overwrite to match path pattern of config folder in VM err := os.MkdirAll(ext.HandlerEnv.ConfigFolder, os.ModeDir) //creates new folders - assert.NoError(t, err) + require.NoError(t, err) fileName := packageregistry.LocalApplicationRegistryFileName //gets name of application registry file err = createTestFiles(testFolderPath, runtimeFolderName, fileName) - assert.NoError(t, err) + require.NoError(t, err) // cleanup defer os.RemoveAll(testFolderPath) oldFileContents, err := os.ReadFile(filepath.Join(testFolderPath, "0.0.1", runtimeFolderName, fileName)) - assert.NoError(t, err) + require.NoError(t, err) //call update err = vmAppUpdateCallback(ext) - assert.NoError(t, err) + require.NoError(t, err) oldFileContentsAfterUpdate, err := os.ReadFile(filepath.Join(testFolderPath, "0.0.1", runtimeFolderName, fileName)) - assert.NoError(t, err) + require.NoError(t, err) newFileContents, err := os.ReadFile(filepath.Join(ext.HandlerEnv.ConfigFolder, fileName)) - assert.NoError(t, err) + require.NoError(t, err) //checks - assert.True(t, bytes.Equal(oldFileContents, newFileContents)) - assert.True(t, bytes.Equal([]byte("[]"), oldFileContentsAfterUpdate)) + require.True(t, bytes.Equal(oldFileContents, newFileContents)) + require.True(t, bytes.Equal([]byte("[]"), oldFileContentsAfterUpdate)) } func Test_noInfiniteLoops(t *testing.T) { @@ -68,13 +68,105 @@ func Test_noInfiniteLoops(t *testing.T) { } ext := createTestVMExtension(t, vmApplications) - //set up test files - runtimeFolderName := "RuntimeSettings" //path to create test version folders - ext.HandlerEnv.ConfigFolder = filepath.Join(ext.HandlerEnv.ConfigFolder, "6.6.6", runtimeFolderName) //overwrite to match path pattern of config folder in VM + // this overwrite creates a path that does not contain a version folder, so the update function should return an error instead of infinitely looping + ext.HandlerEnv.ConfigFolder = filepath.Join(ext.HandlerEnv.ConfigFolder, "someRadomFolder", "random2", "random3", "RuntimeSettings") //call update err := vmAppUpdateCallback(ext) - assert.ErrorIs(t, err, errorExtensionVersionDirNotFound) + require.ErrorIs(t, err, errorExtensionVersionDirNotFound) +} + +func Test_findVersionDir_fallsBackThroughComparisonFunctions(t *testing.T) { + // Create a directory structure: /1.0.10/RuntimeSettings + root := t.TempDir() + extensionVersionOriginalValue := ExtensionVersion + versionDir := filepath.Join(root, extensionVersionOriginalValue, "RuntimeSettings") + err := os.MkdirAll(versionDir, os.ModeDir) + require.NoError(t, err) + + defer func() { + // revert it to what the other tests might expect after this test is run + ExtensionVersion = "1.0.10" + }() + + // Subtest 1: env vars not set — falls back to ExtensionVersion match + t.Run("no_env_vars_uses_pattern_match", func(t *testing.T) { + os.Unsetenv(string(vmextension.GuestAgentEnvVarExtensionVersion)) + os.Unsetenv(string(vmextension.GuestAgentEnvVarUpdateToVersion)) + + parent, dirWithVersion, relPath, err := splitPathAroundVersionedDirWindows(versionDir) + require.NoError(t, err) + require.Equal(t, root, parent) + require.Equal(t, "RuntimeSettings", relPath) + require.Equal(t, extensionVersionOriginalValue, dirWithVersion) + }) + + // Subtest 2: AZURE_GUEST_AGENT_EXTENSION_VERSION matches — uses first checker + t.Run("extension_version_env_var_matches", func(t *testing.T) { + t.Setenv(string(vmextension.GuestAgentEnvVarExtensionVersion), extensionVersionOriginalValue) + os.Unsetenv(string(vmextension.GuestAgentEnvVarUpdateToVersion)) + + parent, dirWithVersion, relPath, err := splitPathAroundVersionedDirWindows(versionDir) + require.NoError(t, err) + require.Equal(t, root, parent) + require.Equal(t, "RuntimeSettings", relPath) + extensionVersionfromEnv, err := vmextension.GetGuestAgentEnvironmetVariable(vmextension.GuestAgentEnvVarExtensionVersion) + require.NoError(t, err) + require.Equal(t, extensionVersionfromEnv, dirWithVersion) + }) + + // Subtest 3: VERSION env var matches — uses second checker + t.Run("update_to_version_env_var_matches", func(t *testing.T) { + os.Unsetenv(string(vmextension.GuestAgentEnvVarExtensionVersion)) + t.Setenv(string(vmextension.GuestAgentEnvVarUpdateToVersion), extensionVersionOriginalValue) + + parent, dirWithVersion, relPath, err := splitPathAroundVersionedDirWindows(versionDir) + require.NoError(t, err) + require.Equal(t, root, parent) + require.Equal(t, "RuntimeSettings", relPath) + updateToVersionFromEnv, err := vmextension.GetGuestAgentEnvironmetVariable(vmextension.GuestAgentEnvVarUpdateToVersion) + require.NoError(t, err) + require.Equal(t, updateToVersionFromEnv, dirWithVersion) + }) + + // Subtest 4: env vars set to wrong values — falls back to ExtensionVersion match + t.Run("env_vars_wrong_falls_back_to_ExtensionVersion", func(t *testing.T) { + t.Setenv(string(vmextension.GuestAgentEnvVarExtensionVersion), "9.9.9") + t.Setenv(string(vmextension.GuestAgentEnvVarUpdateToVersion), "8.8.8") + + parent, dirWithVersion, relPath, err := splitPathAroundVersionedDirWindows(versionDir) + require.NoError(t, err) + require.Equal(t, root, parent) + require.Equal(t, "RuntimeSettings", relPath) + require.Equal(t, extensionVersionOriginalValue, dirWithVersion) + }) + + // Subtest 5: env vars set to wrong values, ExtensionVersion doesn't match — falls back to pattern match + t.Run("env_vars_wrong_falls_back_to_pattern", func(t *testing.T) { + t.Setenv(string(vmextension.GuestAgentEnvVarExtensionVersion), "9.9.9") + t.Setenv(string(vmextension.GuestAgentEnvVarUpdateToVersion), "8.8.8") + + ExtensionVersion = "1.0.0" // the directory was created with extension version 1.0.10, this should fail to match + + parent, dirWithVersion, relPath, err := splitPathAroundVersionedDirWindows(versionDir) + require.NoError(t, err) + require.Equal(t, root, parent) + require.Equal(t, "RuntimeSettings", relPath) + require.Equal(t, ExtensionVersion, dirWithVersion) // should still find the version dir based on pattern match even though env vars and ExtensionVersion value don't match + }) + + // Subtest 6: no version dir in path and no env vars — should return error + t.Run("no_version_dir_returns_error", func(t *testing.T) { + os.Unsetenv(string(vmextension.GuestAgentEnvVarExtensionVersion)) + os.Unsetenv(string(vmextension.GuestAgentEnvVarUpdateToVersion)) + + noVersionPath := filepath.Join(t.TempDir(), "noVersion", "data") + err := os.MkdirAll(noVersionPath, os.ModeDir) + require.NoError(t, err) + + _, _, _, err = splitPathAroundVersionedDirWindows(noVersionPath) + require.ErrorIs(t, err, errorExtensionVersionDirNotFound) + }) } func Test_cannotFindPackageConfigFile(t *testing.T) { @@ -93,7 +185,7 @@ func Test_cannotFindPackageConfigFile(t *testing.T) { //call update err := vmAppUpdateCallback(ext) - assert.ErrorIs(t, err, errorNoOlderPakcageRegistryFileFound) + require.ErrorIs(t, err, errorNoOlderPakcageRegistryFileFound) } func Test_existingPackageRegistryFileIsNotOverwritten(t *testing.T) { @@ -103,23 +195,23 @@ func Test_existingPackageRegistryFileIsNotOverwritten(t *testing.T) { testFolderPath := ext.HandlerEnv.ConfigFolder //path to create test version folders ext.HandlerEnv.ConfigFolder = filepath.Join(ext.HandlerEnv.ConfigFolder, ExtensionVersion, runtimeFolderName) //overwrite to match path pattern of config folder in VM err := os.MkdirAll(ext.HandlerEnv.ConfigFolder, os.ModeDir) //creates new folders - assert.NoError(t, err) + require.NoError(t, err) fileName := packageregistry.LocalApplicationRegistryFileName //gets name of application registry file err = createTestFiles(testFolderPath, runtimeFolderName, fileName) - assert.NoError(t, err) + require.NoError(t, err) // cleanup defer os.RemoveAll(testFolderPath) fileBytes := []byte("special message") packageRegistryFilePath := path.Join(ext.HandlerEnv.ConfigFolder, packageregistry.LocalApplicationRegistryFileName) - err = ioutil.WriteFile(packageRegistryFilePath, fileBytes, 0777) - assert.NoError(t, err) + err = os.WriteFile(packageRegistryFilePath, fileBytes, 0777) + require.NoError(t, err) err = vmAppUpdateCallback(ext) - assert.NoError(t, err) + require.NoError(t, err) // verify file was not overwritten - readBytes, err := ioutil.ReadFile(packageRegistryFilePath) - assert.NoError(t, err) - assert.True(t, bytes.Equal(fileBytes, readBytes)) + readBytes, err := os.ReadFile(packageRegistryFilePath) + require.NoError(t, err) + require.True(t, bytes.Equal(fileBytes, readBytes)) } func createTestFiles(folderPath, runtimeFolderName, fileName string) error { @@ -141,20 +233,340 @@ func createTestFiles(folderPath, runtimeFolderName, fileName string) error { //creating test file testContent := []byte("badcontent") - err = ioutil.WriteFile(filepath.Join(folderPath, "1.0.1", runtimeFolderName, fileName), testContent, 0777) + err = os.WriteFile(filepath.Join(folderPath, "1.0.1", runtimeFolderName, fileName), testContent, 0777) if err != nil { return err } - err = ioutil.WriteFile(filepath.Join(folderPath, "1.0.3", runtimeFolderName, fileName), testContent, 0777) + err = os.WriteFile(filepath.Join(folderPath, "1.0.3", runtimeFolderName, fileName), testContent, 0777) if err != nil { return err } testContent = []byte("Test File Contents") time.Sleep(time.Second) - err = ioutil.WriteFile(filepath.Join(folderPath, "0.0.1", runtimeFolderName, fileName), testContent, 0777) + err = os.WriteFile(filepath.Join(folderPath, "0.0.1", runtimeFolderName, fileName), testContent, 0777) if err != nil { return err } return nil } + +// setupDataFolderForMoveTest creates a directory structure simulating older version data folders: +// +// rootDir//downloads/appA/file.txt +// rootDir//downloads/appB/file.txt +// +// and sets ext.HandlerEnv.DataFolder to rootDir//downloads +func setupDataFolderForMoveTest(t *testing.T, ext *vmextension.VMExtension, oldVersions []string) string { + t.Helper() + rootDir := t.TempDir() + downloadsSubpath := "downloads" + + // Create data folder for current version (empty) + currentDataFolder := filepath.Join(rootDir, ExtensionVersion, downloadsSubpath) + err := os.MkdirAll(currentDataFolder, os.ModeDir) + require.NoError(t, err) + ext.HandlerEnv.DataFolder = currentDataFolder + + // Create old version data folders with sample subdirectories and files + for _, ver := range oldVersions { + for _, app := range []string{"appA", "appB"} { + appDir := filepath.Join(rootDir, ver, downloadsSubpath, app) + err := os.MkdirAll(appDir, os.ModeDir) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(appDir, "file.txt"), []byte("content-"+ver+"-"+app), 0666) + require.NoError(t, err) + } + } + + return rootDir +} + +func Test_moveDownloadDirToCurrentVersion_copiesFromOlderVersions(t *testing.T) { + ext := createTestVMExtension(t, []extdeserialization.VmAppSetting{}) + rootDir := setupDataFolderForMoveTest(t, ext, []string{"0.0.1", "1.0.3"}) + defer os.RemoveAll(rootDir) + + // Ensure config folder exists for the package registry lock file + err := os.MkdirAll(ext.HandlerEnv.ConfigFolder, os.ModeDir) + require.NoError(t, err) + + err = moveDownloadDirToCurrentVersion(ext) + require.NoError(t, err) + + // Verify subdirectories were copied into the current DataFolder + for _, app := range []string{"appA", "appB"} { + copiedFile := filepath.Join(ext.HandlerEnv.DataFolder, app, "file.txt") + _, err := os.Stat(copiedFile) + require.NoError(t, err, "expected copied file at %s", copiedFile) + } +} + +func Test_moveDownloadDirToCurrentVersion_skipsCurrentVersion(t *testing.T) { + ext := createTestVMExtension(t, []extdeserialization.VmAppSetting{}) + // Only create a data folder for the current version itself (no older versions) + rootDir := t.TempDir() + downloadsSubpath := "downloads" + + currentDataFolder := filepath.Join(rootDir, ExtensionVersion, downloadsSubpath) + err := os.MkdirAll(currentDataFolder, os.ModeDir) + require.NoError(t, err) + ext.HandlerEnv.DataFolder = currentDataFolder + + // Create a subdirectory in the current version's data folder with a marker file + appDir := filepath.Join(rootDir, ExtensionVersion, downloadsSubpath, "appFromCurrent") + err = os.MkdirAll(appDir, os.ModeDir) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(appDir, "marker.txt"), []byte("should-not-be-copied"), 0666) + require.NoError(t, err) + + err = os.MkdirAll(ext.HandlerEnv.ConfigFolder, os.ModeDir) + require.NoError(t, err) + + err = moveDownloadDirToCurrentVersion(ext) + require.NoError(t, err) + + // The current version's own dirs should not be re-copied into DataFolder root + // (DataFolder already IS the current version folder, so we just check no error) +} + +func Test_moveDownloadDirToCurrentVersion_noOlderVersions(t *testing.T) { + ext := createTestVMExtension(t, []extdeserialization.VmAppSetting{}) + rootDir := t.TempDir() + downloadsSubpath := "downloads" + + currentDataFolder := filepath.Join(rootDir, ExtensionVersion, downloadsSubpath) + err := os.MkdirAll(currentDataFolder, os.ModeDir) + require.NoError(t, err) + ext.HandlerEnv.DataFolder = currentDataFolder + + err = os.MkdirAll(ext.HandlerEnv.ConfigFolder, os.ModeDir) + require.NoError(t, err) + + err = moveDownloadDirToCurrentVersion(ext) + require.NoError(t, err) +} + +func Test_moveDownloadDirToCurrentVersion_noVersionDirFound(t *testing.T) { + ext := createTestVMExtension(t, []extdeserialization.VmAppSetting{}) + // Set DataFolder to a path containing no version-pattern directory + ext.HandlerEnv.DataFolder = filepath.Join(t.TempDir(), "noVersionHere", "data") + err := os.MkdirAll(ext.HandlerEnv.DataFolder, os.ModeDir) + require.NoError(t, err) + + err = moveDownloadDirToCurrentVersion(ext) + require.ErrorIs(t, err, errorExtensionVersionDirNotFound) +} + +func Test_moveDownloadDirToCurrentVersion_nonVersionDirsIgnored(t *testing.T) { + ext := createTestVMExtension(t, []extdeserialization.VmAppSetting{}) + rootDir := setupDataFolderForMoveTest(t, ext, []string{"0.0.1"}) + defer os.RemoveAll(rootDir) + + // Create a non-version directory sibling (should be ignored) + nonVersionDir := filepath.Join(rootDir, "notAVersion", "downloads", "appX") + err := os.MkdirAll(nonVersionDir, os.ModeDir) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(nonVersionDir, "file.txt"), []byte("should-not-copy"), 0666) + require.NoError(t, err) + + err = os.MkdirAll(ext.HandlerEnv.ConfigFolder, os.ModeDir) + require.NoError(t, err) + + err = moveDownloadDirToCurrentVersion(ext) + require.NoError(t, err) + + // Verify the non-version directory content was NOT copied + _, err = os.Stat(filepath.Join(ext.HandlerEnv.DataFolder, "appX")) + require.True(t, os.IsNotExist(err), "non-version directory content should not be copied") + + // Verify old version content WAS copied + _, err = os.Stat(filepath.Join(ext.HandlerEnv.DataFolder, "appA", "file.txt")) + require.NoError(t, err, "old version content should be copied") +} + +func Test_moveAndUpdateDownloadDir_updatesRegistryPaths(t *testing.T) { + ext := createTestVMExtension(t, []extdeserialization.VmAppSetting{}) + oldVersion := "0.0.1" + rootDir := setupDataFolderForMoveTest(t, ext, []string{oldVersion}) + defer os.RemoveAll(rootDir) + + err := os.MkdirAll(ext.HandlerEnv.ConfigFolder, os.ModeDir) + require.NoError(t, err) + + // Write a package registry file with DownloadDir pointing to the old version + oldDownloadDir := filepath.Join(rootDir, oldVersion, "downloads") + registryContent := `[{"applicationName":"appA","version":"1.0","downloadDir":"` + filepath.ToSlash(oldDownloadDir) + `"},{"applicationName":"appB","version":"2.0","downloadDir":"` + filepath.ToSlash(oldDownloadDir) + `"}]` + registryFilePath := filepath.Join(ext.HandlerEnv.ConfigFolder, packageregistry.LocalApplicationRegistryFileName) + err = os.WriteFile(registryFilePath, []byte(registryContent), 0666) + require.NoError(t, err) + + // Move download dirs, then update paths in registry + err = moveDownloadDirToCurrentVersion(ext) + require.NoError(t, err) + + err = updateDownloadDirInPackageRegistryFile(ext) + require.NoError(t, err) + + // Read back the registry and verify DownloadDir was updated to the current version + pkr, err := packageregistry.New(ext.ExtensionLogger, ext.HandlerEnv, 1*time.Second) + require.NoError(t, err) + defer pkr.Close() + + packages, err := pkr.GetExistingPackages() + require.NoError(t, err) + + expectedDownloadDir := filepath.Join(rootDir, ExtensionVersion, "downloads") + for _, pkg := range packages { + require.Equal(t, expectedDownloadDir, pkg.DownloadDir, + "DownloadDir for %s should point to current version", pkg.ApplicationName) + } +} + +func Test_updateDownloadDirInPackageRegistryFile_noPackages(t *testing.T) { + ext := createTestVMExtension(t, []extdeserialization.VmAppSetting{}) + + err := os.MkdirAll(ext.HandlerEnv.ConfigFolder, os.ModeDir) + require.NoError(t, err) + + // Write an empty package registry + registryFilePath := filepath.Join(ext.HandlerEnv.ConfigFolder, packageregistry.LocalApplicationRegistryFileName) + err = os.WriteFile(registryFilePath, []byte("[]"), 0666) + require.NoError(t, err) + + err = updateDownloadDirInPackageRegistryFile(ext) + require.NoError(t, err) +} + +func Test_updateDownloadDirInPackageRegistryFile_packageWithNoVersionInPath(t *testing.T) { + ext := createTestVMExtension(t, []extdeserialization.VmAppSetting{}) + + err := os.MkdirAll(ext.HandlerEnv.ConfigFolder, os.ModeDir) + require.NoError(t, err) + + // Set up DataFolder with a version path so findVersionDir succeeds + dataRoot := t.TempDir() + ext.HandlerEnv.DataFolder = filepath.Join(dataRoot, ExtensionVersion, "downloads") + err = os.MkdirAll(ext.HandlerEnv.DataFolder, os.ModeDir) + require.NoError(t, err) + + // Write a registry where DownloadDir has no version segment — regex won't match, so it warns but doesn't fail + registryContent := `[{"applicationName":"appX","version":"1.0","downloadDir":"C:/noVersionHere/downloads"}]` + registryFilePath := filepath.Join(ext.HandlerEnv.ConfigFolder, packageregistry.LocalApplicationRegistryFileName) + err = os.WriteFile(registryFilePath, []byte(registryContent), 0666) + require.NoError(t, err) + + err = updateDownloadDirInPackageRegistryFile(ext) + require.NoError(t, err) + + // Verify the DownloadDir is unchanged since no version dir was found + pkr, err := packageregistry.New(ext.ExtensionLogger, ext.HandlerEnv, 1*time.Second) + require.NoError(t, err) + defer pkr.Close() + + packages, err := pkr.GetExistingPackages() + require.NoError(t, err) + require.Len(t, packages, 1) + require.Equal(t, "C:/noVersionHere/downloads", packages["appX"].DownloadDir, + "DownloadDir should remain unchanged when no version dir is found") +} + +func Test_vmAppUpdateCallback_endToEnd(t *testing.T) { + ext := createTestVMExtension(t, []extdeserialization.VmAppSetting{}) + oldVersion := "0.0.1" + configSubpath := "RuntimeSettings" + + // --- Set up config folder structure: //RuntimeSettings --- + configRoot := t.TempDir() + oldConfigDir := filepath.Join(configRoot, oldVersion, configSubpath) + err := os.MkdirAll(oldConfigDir, os.ModeDir) + require.NoError(t, err) + + currentConfigDir := filepath.Join(configRoot, ExtensionVersion, configSubpath) + err = os.MkdirAll(currentConfigDir, os.ModeDir) + require.NoError(t, err) + ext.HandlerEnv.ConfigFolder = currentConfigDir + + // --- Set up data folder structure: //downloads/// --- + dataRoot := t.TempDir() + oldDataDir := filepath.Join(dataRoot, oldVersion, "downloads") + appVersions := map[string]string{"appA": "1.0", "appB": "2.0"} + for app, ver := range appVersions { + appDir := filepath.Join(oldDataDir, app, ver) + err = os.MkdirAll(appDir, os.ModeDir) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(appDir, "file.txt"), []byte("content-"+app), 0666) + require.NoError(t, err) + } + currentDataDir := filepath.Join(dataRoot, ExtensionVersion, "downloads") + ext.HandlerEnv.DataFolder = currentDataDir + + // --- Write a package registry file for the OLD version using the package registry --- + // Temporarily point ConfigFolder to old config dir to write the registry there + ext.HandlerEnv.ConfigFolder = oldConfigDir + oldPkr, err := packageregistry.New(ext.ExtensionLogger, ext.HandlerEnv, 1*time.Second) + require.NoError(t, err) + + oldPackages := packageregistry.CurrentPackageRegistry{ + "appA": &packageregistry.VMAppPackageCurrent{ + ApplicationName: "appA", + Version: "1.0", + DownloadDir: filepath.ToSlash(filepath.Join(oldDataDir, "appA", "1.0")), + }, + "appB": &packageregistry.VMAppPackageCurrent{ + ApplicationName: "appB", + Version: "2.0", + DownloadDir: filepath.ToSlash(filepath.Join(oldDataDir, "appB", "2.0")), + }, + } + err = oldPkr.WriteToDisk(oldPackages) + require.NoError(t, err) + err = oldPkr.Close() + require.NoError(t, err) + + // Restore ConfigFolder to the current version + ext.HandlerEnv.ConfigFolder = currentConfigDir + + // --- Call vmAppUpdateCallback --- + err = vmAppUpdateCallback(ext) + require.NoError(t, err) + + // --- Validation 1: Package registry file was copied to current version's config folder --- + newRegistryPath := filepath.Join(currentConfigDir, packageregistry.LocalApplicationRegistryFileName) + _, err = os.Stat(newRegistryPath) + require.NoError(t, err, "package registry file should exist in the current version config folder") + + // Old version's registry should be emptied + oldRegistryAfterUpdate, err := os.ReadFile(filepath.Join(oldConfigDir, packageregistry.LocalApplicationRegistryFileName)) + require.NoError(t, err) + require.True(t, bytes.Equal([]byte("[]"), oldRegistryAfterUpdate), + "old version's package registry should be overwritten with empty list") + + // --- Validation 2: Download directories were copied to current version's data folder --- + for app, ver := range appVersions { + copiedFile := filepath.Join(currentDataDir, app, ver, "file.txt") + _, err = os.Stat(copiedFile) + require.NoError(t, err, "download directory for %s should be copied to current version", app) + } + + // --- Validation 3: DownloadDir in the registry was updated to point to current version --- + pkr, err := packageregistry.New(ext.ExtensionLogger, ext.HandlerEnv, 1*time.Second) + require.NoError(t, err) + defer pkr.Close() + + packages, err := pkr.GetExistingPackages() + require.NoError(t, err) + require.Len(t, packages, 2) + + for appName, pkg := range packages { + expectedDownloadDir := filepath.Join(dataRoot, ExtensionVersion, "downloads", appName, appVersions[appName]) + require.Equal(t, expectedDownloadDir, pkg.DownloadDir, + "DownloadDir for %s should point to current version path", appName) + + // Verify the expected files exist at the DownloadDir path + fileContent, err := os.ReadFile(filepath.Join(pkg.DownloadDir, "file.txt")) + require.NoError(t, err, "file.txt should exist in DownloadDir for %s", appName) + require.Equal(t, "content-"+appName, string(fileContent), + "file.txt content for %s should match", appName) + } +} diff --git a/makefile.win b/makefile.win index 4c1aff1..208afa8 100644 --- a/makefile.win +++ b/makefile.win @@ -7,7 +7,6 @@ BUNDLEDIR_PROD=bundle\windows\prod BUNDLEDIR_TEST=bundle\windows\test BINDIR_PROD=$(BUNDLEDIR_PROD)\bin BINDIR_TEST=$(BUNDLEDIR_TEST)\bin -EXTENSIONVERSION=1.0.18 EXTENSIONNAME_PROD=Microsoft.CPlat.Core.VMApplicationManagerWindows EXTENSIONNAME_TEST=Microsoft.CPlat.Core.EDP.VMApplicationManagerWindows @@ -27,19 +26,14 @@ validate-extensionname: !ENDIF @echo Using EXTENSIONNAME: $(EXTENSIONNAME) -validate-extensionname: -!IF "$(EXTENSIONNAME)" == "" - @echo Error: EXTENSIONNAME parameter is required - @echo Usage: nmake -f makefile.win EXTENSIONNAME=^ - @echo Valid values: $(EXTENSIONNAME_PROD) or $(EXTENSIONNAME_TEST) +validate-extensionversion: +!IF "$(EXTENSIONVERSION)" == "" + @echo Error: EXTENSIONVERSION parameter is required + @echo Usage: nmake -f makefile.win EXTENSIONVERSION=^ @exit 1 !ENDIF -!IF "$(EXTENSIONNAME)" != "$(EXTENSIONNAME_PROD)" && "$(EXTENSIONNAME)" != "$(EXTENSIONNAME_TEST)" - @echo Error: Invalid EXTENSIONNAME "$(EXTENSIONNAME)" - @echo Valid values: $(EXTENSIONNAME_PROD) or $(EXTENSIONNAME_TEST) - @exit 1 -!ENDIF - @echo Using EXTENSIONNAME: $(EXTENSIONNAME) + @echo $(EXTENSIONVERSION)| findstr /R "^[0-9][0-9]*\.[0-9][0-9]*\.[0-9][0-9]*$$" >nul 2>&1 || ( echo Error: EXTENSIONVERSION $(EXTENSIONVERSION) does not match required pattern n.n.n & exit /b 1 ) + @echo Using EXTENSIONVERSION: $(EXTENSIONVERSION) clean: -del extension-launcher.exe @@ -50,28 +44,28 @@ clean: -rmdir /S /Q $(BUNDLEDIR_TEST) -rmdir /S /Q licenses -extension-launcher: validate-extensionname +extension-launcher: validate-extensionname validate-extensionversion set GOOS=windows set GOARCH=amd64 set CGO_ENABLED=0 go env GOOS GOARCH go build -o extension-launcher.exe -ldflags="-X 'main.ExtensionName=$(EXTENSIONNAME)' -X 'main.ExtensionVersion=$(EXTENSIONVERSION)' -X 'main.ExecutableName=vm-application-manager.exe'" .\launcher -extension-launcher-arm64: validate-extensionname +extension-launcher-arm64: validate-extensionname validate-extensionversion set GOOS=windows set GOARCH=arm64 set CGO_ENABLED=0 go env GOOS GOARCH go build -o extension-launcher-arm64.exe -ldflags="-X 'main.ExtensionName=$(EXTENSIONNAME)' -X 'main.ExtensionVersion=$(EXTENSIONVERSION)' -X 'main.ExecutableName=vm-application-manager.exe'" .\launcher -vm-application-manager: validate-extensionname +vm-application-manager: validate-extensionname validate-extensionversion set GOOS=windows set GOARCH=amd64 set CGO_ENABLED=0 go env GOOS GOARCH go build -o vm-application-manager.exe -ldflags="-X 'main.ExtensionName=$(EXTENSIONNAME)' -X 'main.ExtensionVersion=$(EXTENSIONVERSION)'" .\main -vm-application-manager-arm64: validate-extensionname +vm-application-manager-arm64: validate-extensionname validate-extensionversion set GOOS=windows set GOARCH=arm64 set CGO_ENABLED=0 @@ -90,7 +84,7 @@ collect-licenses: @echo License collection complete! bundle-prod: - nmake -f makefile.win EXTENSIONNAME="$(EXTENSIONNAME_PROD)" extension-launcher extension-launcher-arm64 vm-application-manager vm-application-manager-arm64 + nmake -f makefile.win EXTENSIONNAME="$(EXTENSIONNAME_PROD)" EXTENSIONVERSION="$(EXTENSIONVERSION)" extension-launcher extension-launcher-arm64 vm-application-manager vm-application-manager-arm64 mkdir $(BINDIR_PROD) move extension-launcher.exe "$(BINDIR_PROD)\" move extension-launcher-arm64.exe "$(BINDIR_PROD)\" @@ -101,7 +95,7 @@ bundle-prod: powershell -c "Compress-Archive -Path $(BUNDLEDIR_PROD)\* -DestinationPath $(BUNDLEDIR_PROD)\vm-application-manager.zip" bundle-test: - nmake -f makefile.win EXTENSIONNAME="$(EXTENSIONNAME_TEST)" extension-launcher extension-launcher-arm64 vm-application-manager vm-application-manager-arm64 + nmake -f makefile.win EXTENSIONNAME="$(EXTENSIONNAME_TEST)" EXTENSIONVERSION="$(EXTENSIONVERSION)" extension-launcher extension-launcher-arm64 vm-application-manager vm-application-manager-arm64 mkdir $(BINDIR_TEST) move extension-launcher.exe "$(BINDIR_TEST)\" move extension-launcher-arm64.exe "$(BINDIR_TEST)\" diff --git a/pkg/utils/versionstringutils.go b/pkg/utils/versionstringutils.go index 14a0025..8ee0003 100644 --- a/pkg/utils/versionstringutils.go +++ b/pkg/utils/versionstringutils.go @@ -4,10 +4,15 @@ package utils import ( + "regexp" "strconv" "strings" ) +const ( + versionStringPattern = `^[0-9]+(\.[0-9]+){2,4}$` +) + func AreVersionsEqual(versionString1 *string, versionString2 *string) bool { cmpResult, err := CompareVersion(versionString1, versionString2) if err == nil { @@ -86,6 +91,14 @@ func CompareVersion(versionString1 *string, versionString2 *string) (int, error) } } +func IsValidVersionString(versionString string) bool { + matched, err := regexp.MatchString(versionStringPattern, versionString) + if err != nil { + return false + } + return matched +} + func findNonZeroVersionNumber(versionStringSlice []string, startIndex int, length int) (int, error) { for i := startIndex; i < length; i++ { num, err := strconv.Atoi(versionStringSlice[i])