diff --git a/internal/hostgacommunicator/hostgacommunicator.go b/internal/hostgacommunicator/hostgacommunicator.go index e9f6887..8940392 100644 --- a/internal/hostgacommunicator/hostgacommunicator.go +++ b/internal/hostgacommunicator/hostgacommunicator.go @@ -18,6 +18,67 @@ import ( const hostGaPluginPort = "32526" const WireProtocolAddress = "AZURE_GUEST_AGENT_WIRE_PROTOCOL_ADDRESS" const wireServerFallbackAddress = "http://168.63.129.16:32526" +const HostGaMetadataErrorPrefix = "HostGaCommunicator GetVMAppInfo error" + +type HostGaCommunicatorError int + +const ( + InitializationError HostGaCommunicatorError = iota + MetadataRequestFailedWithRetries + MetadataRequestFailedInvalidResponseBody + DownloadPackageRequestFactoryError + DownloadPackageFileError + DownloadConfigRequestFactoryError + DownloadConfigFileError +) + +func (hostGaCommunicatorError HostGaCommunicatorError) ToString() string { + switch hostGaCommunicatorError { + case InitializationError: + return "InitializationError" + case MetadataRequestFailedWithRetries: + return "MetadataRequestFailedWithRetries" + case MetadataRequestFailedInvalidResponseBody: + return "MetadataRequestFailedInvalidResponseBody" + case DownloadPackageRequestFactoryError: + return "DownloadPackageRequestFactoryError" + case DownloadPackageFileError: + return "DownloadPackageFileError" + case DownloadConfigRequestFactoryError: + return "DownloadConfigRequestFactoryError" + case DownloadConfigFileError: + return "DownloadConfigFileError" + default: + return "UnknownError" + } +} + +type HostGaCommunicatorGetVMAppInfoError struct { + errorMessage string + errorType HostGaCommunicatorError +} + +func (e *HostGaCommunicatorGetVMAppInfoError) Error() string { + return fmt.Sprintf("%s: %s, error type: %s", HostGaMetadataErrorPrefix, e.errorMessage, e.errorType.ToString()) +} + +type DownloadPackageError struct { + errorMessage string + errorType HostGaCommunicatorError +} + +func (e *DownloadPackageError) Error() string { + return fmt.Sprintf("DownloadPackage error: %s, error type: %s", e.errorMessage, e.errorType.ToString()) +} + +type DownloadConfigError struct { + errorMessage string + errorType HostGaCommunicatorError +} + +func (e *DownloadConfigError) Error() string { + return fmt.Sprintf("DownloadConfig error: %s, error type: %s", e.errorMessage, e.errorType.ToString()) +} type IHostGaCommunicator interface { DownloadPackage(el *logging.ExtensionLogger, appName string, dst string) error @@ -33,7 +94,10 @@ type HostGaCommunicator struct{} func (*HostGaCommunicator) GetVMAppInfo(el *logging.ExtensionLogger, appName string) (*VMAppMetadata, error) { requestManager, isArc, err := getMetadataRequestManager(el, appName) if err != nil { - return nil, errors.Wrapf(err, "Could not create the request manager") + return nil, &HostGaCommunicatorGetVMAppInfoError{ + errorMessage: fmt.Sprintf("Could not create the request manager: %v", err), + errorType: InitializationError, + } } var resp *http.Response @@ -47,7 +111,10 @@ func (*HostGaCommunicator) GetVMAppInfo(el *logging.ExtensionLogger, appName str } if err != nil { - return nil, errors.Wrapf(err, "Metadata request failed with retries.") + return nil, &HostGaCommunicatorGetVMAppInfoError{ + errorMessage: fmt.Sprintf("Metadata request failed after retries: %v", err), + errorType: MetadataRequestFailedWithRetries, + } } body := resp.Body @@ -56,7 +123,10 @@ func (*HostGaCommunicator) GetVMAppInfo(el *logging.ExtensionLogger, appName str var target VMAppMetadataReceiver err = json.NewDecoder(body).Decode(&target) if err != nil { - return nil, errors.Wrapf(err, "failed to decode response body") + return nil, &HostGaCommunicatorGetVMAppInfoError{ + errorMessage: fmt.Sprintf("Failed to decode response body: %v", err), + errorType: MetadataRequestFailedInvalidResponseBody, + } } return target.MapToVMAppMetadata(), nil @@ -68,11 +138,20 @@ func (*HostGaCommunicator) GetVMAppInfo(el *logging.ExtensionLogger, appName str func (*HostGaCommunicator) DownloadPackage(el *logging.ExtensionLogger, appName string, dst string) error { requestFactory, err := newPackageDownloadRequestFactory(el, appName) if err != nil { - return errors.Wrapf(err, "Could not create the request factory") + return &DownloadPackageError{ + errorMessage: fmt.Sprintf("Could not create the request factory: %v", err), + errorType: DownloadPackageRequestFactoryError, + } } err = requestFactory.downloadFile(el, dst) - return err + if err != nil { + return &DownloadPackageError{ + errorMessage: fmt.Sprintf("Failed to download file: %v", err), + errorType: DownloadPackageFileError, + } + } + return nil } // DownloadConfig downloads the application config through HostGaPlugin to the specified @@ -81,11 +160,20 @@ func (*HostGaCommunicator) DownloadPackage(el *logging.ExtensionLogger, appName func (*HostGaCommunicator) DownloadConfig(el *logging.ExtensionLogger, appName string, dst string) error { requestFactory, err := newConfigDownloadRequestFactory(el, appName) if err != nil { - return errors.Wrapf(err, "Could not create the request factory") + return &DownloadConfigError{ + errorMessage: fmt.Sprintf("Could not create the request factory: %v", err), + errorType: DownloadConfigRequestFactoryError, + } } err = requestFactory.downloadFile(el, dst) - return err + if err != nil { + return &DownloadConfigError{ + errorMessage: fmt.Sprintf("Failed to download file: %v", err), + errorType: DownloadConfigFileError, + } + } + return nil } func getOperationURI(el *logging.ExtensionLogger, appName string, operation string) (string, error) { diff --git a/internal/hostgacommunicator/hostgacommunicator_test.go b/internal/hostgacommunicator/hostgacommunicator_test.go index bed0ee8..900bce6 100644 --- a/internal/hostgacommunicator/hostgacommunicator_test.go +++ b/internal/hostgacommunicator/hostgacommunicator_test.go @@ -47,6 +47,9 @@ func TestGetVmAppInfo_InvalidUri(t *testing.T) { hgc := &HostGaCommunicator{} _, err := hgc.GetVMAppInfo(nopLog(), myAppName) require.NotNil(t, err, "did not fail") + _, ok := err.(*HostGaCommunicatorGetVMAppInfoError) + require.True(t, ok, "expected error to be of type *HostGaCommunicatorGetVMAppInfoError") + require.Contains(t, err.Error(), InitializationError.ToString(), "Wrong error code") require.Contains(t, err.Error(), "Could not parse the HostGA URI", "Wrong message for invalid uri") } @@ -60,7 +63,10 @@ func TestGetVmAppInfo_RequestFailed(t *testing.T) { hgc := &HostGaCommunicator{} _, err := hgc.GetVMAppInfo(nopLog(), myAppName) require.NotNil(t, err, "did not fail") - require.Contains(t, err.Error(), "Metadata request failed with retries.", "Wrong message for failed request") + _, ok := err.(*HostGaCommunicatorGetVMAppInfoError) + require.True(t, ok, "expected error to be of type *HostGaCommunicatorGetVMAppInfoError") + require.Contains(t, err.Error(), MetadataRequestFailedWithRetries.ToString(), "Wrong error code") + require.Contains(t, err.Error(), "Metadata request failed after retries:", "Wrong message for failed request") } func TestGetVmAppInfo_CouldNotDecodeResponse(t *testing.T) { @@ -75,7 +81,10 @@ func TestGetVmAppInfo_CouldNotDecodeResponse(t *testing.T) { hgc := &HostGaCommunicator{} _, err := hgc.GetVMAppInfo(nopLog(), myAppName) require.NotNil(t, err, "did not fail") - require.Contains(t, err.Error(), "failed to decode response body", "Wrong message for invalid response") + _, ok := err.(*HostGaCommunicatorGetVMAppInfoError) + require.True(t, ok, "expected error to be of type *HostGaCommunicatorGetVMAppInfoError") + require.Contains(t, err.Error(), MetadataRequestFailedInvalidResponseBody.ToString(), "Wrong error code") + require.Contains(t, err.Error(), "Failed to decode response body:", "Wrong message for invalid response") } func TestGetVmAppInfo_MissingProperties(t *testing.T) { @@ -155,9 +164,22 @@ func TestDownloadPackage_CannotRemoveExistingFile(t *testing.T) { hgc := &HostGaCommunicator{} err = hgc.DownloadPackage(nopLog(), myAppName, filePath) require.NotNil(t, err, "did not fail") + _, ok := err.(*DownloadPackageError) + require.True(t, ok, "expected error to be of type *DownloadPackageError") + require.Contains(t, err.Error(), "DownloadPackageFileError", "Wrong error code") require.Contains(t, err.Error(), "Could not remove the existing file", "Wrong message for failing to remove locked file") } +func TestDownloadPackage_InvalidUri(t *testing.T) { + os.Setenv(WireProtocolAddress, "htt!p:notgoingtohappen!") + hgc := &HostGaCommunicator{} + err := hgc.DownloadPackage(nopLog(), myAppName, "somepath") + require.NotNil(t, err, "did not fail") + _, ok := err.(*DownloadPackageError) + require.True(t, ok, "expected error to be of type *DownloadPackageError") + require.Contains(t, err.Error(), DownloadPackageRequestFactoryError.ToString(), "Wrong error type") +} + func TestDownloadPackage_InvalidPath(t *testing.T) { filePath := string(make([]byte, 5)) // null characters in file names are invalid in both windows and linux @@ -170,6 +192,9 @@ func TestDownloadPackage_InvalidPath(t *testing.T) { hgc := &HostGaCommunicator{} err := hgc.DownloadPackage(nopLog(), myAppName, filePath) require.NotNil(t, err, "did not fail") + _, ok := err.(*DownloadPackageError) + require.True(t, ok, "expected error to be of type *DownloadPackageError") + require.Contains(t, err.Error(), "DownloadPackageFileError", "Wrong error code") require.Contains(t, err.Error(), "Cannot retrieve file information", "Wrong message for invalid file path") } @@ -220,6 +245,9 @@ func TestDownloadPackage_TooManyTries(t *testing.T) { hgc := &HostGaCommunicator{} err := hgc.DownloadPackage(nopLog(), myAppName, filePath) require.NotNil(t, err, "did not fail") + _, ok := err.(*DownloadPackageError) + require.True(t, ok, "expected error to be of type *DownloadPackageError") + require.Contains(t, err.Error(), "DownloadPackageFileError", "Wrong error code") require.Contains(t, err.Error(), "Failed to completely download the file", "Wrong message for incomplete file") } @@ -249,6 +277,9 @@ func TestDownloadPackage_IntermediateCallFails(t *testing.T) { hgc := &HostGaCommunicator{} err := hgc.DownloadPackage(nopLog(), myAppName, filePath) require.NotNil(t, err, "did not fail") + _, ok := err.(*DownloadPackageError) + require.True(t, ok, "expected error to be of type *DownloadPackageError") + require.Contains(t, err.Error(), "DownloadPackageFileError", "Wrong error code") require.Contains(t, err.Error(), "Unrecoverable error while downloading the file", "Wrong message for failure mid-retries") } @@ -288,6 +319,34 @@ func TestDownloadPackage_MultipleCallDownload(t *testing.T) { verifyFileContents(t, filePath, expected) } +func TestDownloadConfig_InvalidUri(t *testing.T) { + os.Setenv(WireProtocolAddress, "htt!p:notgoingtohappen!") + hgc := &HostGaCommunicator{} + err := hgc.DownloadConfig(nopLog(), myAppName, "somepath") + require.NotNil(t, err, "did not fail") + _, ok := err.(*DownloadConfigError) + require.True(t, ok, "expected error to be of type *DownloadConfigError") + require.Contains(t, err.Error(), DownloadConfigRequestFactoryError.ToString(), "Wrong error code") +} + +func TestDownloadConfig_InvalidPath(t *testing.T) { + filePath := string(make([]byte, 5)) // null characters in file names are invalid in both windows and linux + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + os.Setenv(WireProtocolAddress, srv.URL) + hgc := &HostGaCommunicator{} + err := hgc.DownloadConfig(nopLog(), myAppName, filePath) + require.NotNil(t, err, "did not fail") + _, ok := err.(*DownloadConfigError) + require.True(t, ok, "expected error to be of type *DownloadConfigError") + require.Contains(t, err.Error(), DownloadConfigFileError.ToString(), "Wrong error code") + require.Contains(t, err.Error(), "Cannot retrieve file information", "Wrong message for invalid file path") +} + func TestDownloadConfig_SingeCallDownload(t *testing.T) { expected := "file contents don't matter" createTestDir(t) diff --git a/launcher/main.go b/launcher/main.go index 8f79159..9c868ae 100644 --- a/launcher/main.go +++ b/launcher/main.go @@ -85,7 +85,7 @@ func main() { if requestedSequenceNumber >= currentSequenceNumber { // attempt to write a transitioning status file if it doesn't exist - _, getStatusError := utils.GetStatusType(handlerEnv, requestedSequenceNumber) + _, getStatusError := utils.GetStatus(handlerEnv, requestedSequenceNumber) if getStatusError != nil { // either no transitioning status file was found, or the status file was malformed // either way create a new transitioning status file diff --git a/main/main.go b/main/main.go index 00d3507..fa37b0a 100644 --- a/main/main.go +++ b/main/main.go @@ -27,7 +27,6 @@ import ( var ( ExtensionName string // assign at compile time ExtensionVersion = "1.0.10" // should be assigned at compile time, do not edit in code - reportStatusFunc = utils.ReportStatus getVMExtensionFunc = getVMExtension customEnableFunc = customEnable setSequenceNumberFunc = seqno.SetSequenceNumber @@ -95,15 +94,24 @@ func getExtensionAndRun(arguments []string) error { ext.ExtensionLogger.Error(errorMessage) ext.ExtensionEvents.LogErrorEvent("Enable Failed", errorMessage) default: + if _, ok := enableError.(*hostgacommunicator.HostGaCommunicatorGetVMAppInfoError); ok { + // Preserve the last good status file if it exists and isn't already a + // HostGA network error + if statusObj, err := utils.GetStatus(ext.HandlerEnv, requestedSequenceNumber); err == nil { + msg := statusObj.FormattedMessage.Message + if !strings.HasPrefix(msg, hostgacommunicator.HostGaMetadataErrorPrefix) { + if err := utils.BackupStatusFile(ext.HandlerEnv.StatusFolder, requestedSequenceNumber); err != nil { + ext.ExtensionLogger.Warn("Failed to back up status file for sequence %d: %v", requestedSequenceNumber, err) + } + } + } + } ext.ExtensionLogger.Error(enableError.Error()) ext.ExtensionEvents.LogErrorEvent("Enable Failed", enableError.Error()) // try to save status file statusMessage := enableError.Error() - err := reportStatusFunc(ext.HandlerEnv, requestedSequenceNumber, status.StatusError, vmextensionhelper.EnableOperation.ToStatusName(), statusMessage) + err := reportStatusWrapper(ext, requestedSequenceNumber, status.StatusError, vmextensionhelper.EnableOperation.ToStatusName(), statusMessage) if err != nil { - errorMessage := fmt.Sprintf("Failed to save status file: %s", err.Error()) - ext.ExtensionLogger.Error(errorMessage) - ext.ExtensionEvents.LogErrorEvent("Save Status", errorMessage) return err } } @@ -214,37 +222,18 @@ func customEnable(ext *vmextensionhelper.VMExtension, hostgaCommunicator hostgac return errors.Wrapf(err, "Could not get package registry") } - // write success status if requested sequence number is newer - shouldReportStatus := false - - if ext.CurrentSequenceNumber == nil || requestedSequenceNumber > *ext.CurrentSequenceNumber { - shouldReportStatus = true - } else if requestedSequenceNumber == *ext.CurrentSequenceNumber { - statusType, err := utils.GetStatusType(ext.HandlerEnv, requestedSequenceNumber) - if err != nil || strings.EqualFold(string(statusType), string(status.StatusTransitioning)) { - // either something is wrong with the status file - // or its a transitioning status file - // overwrite it in either case - shouldReportStatus = true - } - } - if shouldReportStatus { - var statusResult status.StatusType - statusMessage := getStatusMessage(currentPackageRegistry.GetPackageCollection(), executeError, result) - if executeError.GetErrorIfDeploymentFailed() == nil { // treatFailureAsDeploymentFailure - statusResult = status.StatusSuccess - } else { - statusResult = status.StatusError - } - err := utils.ReportStatus(ext.HandlerEnv, requestedSequenceNumber, statusResult, vmextensionhelper.EnableOperation.ToStatusName(), statusMessage) + statusUpdated, statusResult, statusMessage := computeStatus(ext, requestedSequenceNumber, ¤tPackageRegistry, executeError, result, vmAppResults) + + if statusUpdated { + err := reportStatusWrapper(ext, requestedSequenceNumber, statusResult, vmextensionhelper.EnableOperation.ToStatusName(), statusMessage) if err != nil { - errorMessage := fmt.Sprintf("Failed to save status file: %s", err.Error()) - ext.ExtensionLogger.Error(errorMessage) - ext.ExtensionEvents.LogErrorEvent("Save Status", errorMessage) return err } + // update the sequence number that has been executed - if err := setSequenceNumberFunc(ExtensionName, ExtensionVersion, requestedSequenceNumber); err != nil { + err = setSequenceNumberFunc(ExtensionName, ExtensionVersion, requestedSequenceNumber) + if err != nil { + // log but not return the error errorMessage := fmt.Sprintf("Failed to update sequence number to %d: %s", requestedSequenceNumber, err.Error()) ext.ExtensionLogger.Error(errorMessage) ext.ExtensionEvents.LogErrorEvent("Update Sequence Number", errorMessage) @@ -258,7 +247,75 @@ func customEnable(ext *vmextensionhelper.VMExtension, hostgaCommunicator hostgac return nil } -// Callback indicating the extension is being removed +func computeStatus( + ext *vmextensionhelper.VMExtension, + requestedSequenceNumber uint, + currentPackageRegistry *packageregistry.CurrentPackageRegistry, + executeError *actionplan.ExecuteError, + customActionResult actionplan.IResult, + vmAppResults *actionplan.PackageOperationResults, +) (bool, status.StatusType, string) { + statusUpdated := false + var statusResult status.StatusType + var statusMessage string + + if vmAppResults != nil && len(*vmAppResults) > 0 { + // executeError is only meaningful if there are VM App operations, otherwise + // it is the equivalent of no error (i.e success). + statusMessage = getStatusMessage(currentPackageRegistry.GetPackageCollection(), executeError, customActionResult) + if executeError.GetErrorIfDeploymentFailed() == nil { // treatFailureAsDeploymentFailure + statusResult = status.StatusSuccess + } else { + statusResult = status.StatusError + } + statusUpdated = true + } else { + // These next cases are dependent on the existing status + statusObj, err := utils.GetStatus(ext.HandlerEnv, requestedSequenceNumber) + if err != nil { + // Existing status file maybe corrupted or missing. The existing behavior is + // to write a success status. + statusMessage = getStatusMessage(currentPackageRegistry.GetPackageCollection(), executeError, customActionResult) + statusResult = status.StatusSuccess + statusUpdated = true + } else if strings.EqualFold(string(statusObj.Status), string(status.StatusTransitioning)) { + // If status is Transitioning and there's no VM App operations, + // then record a succes status. + statusMessage = getStatusMessage(currentPackageRegistry.GetPackageCollection(), executeError, customActionResult) + statusResult = status.StatusSuccess + statusUpdated = true + } else if strings.Contains(statusObj.FormattedMessage.Message, hostgacommunicator.HostGaMetadataErrorPrefix) { + // If there is no VM App operations, but the requested sequence's status is + // a transient host GA communication error, the status should be the same as + // its last stable status. + prevStatusObj, prevStatusErr := utils.GetLastStableStatus(ext.HandlerEnv, requestedSequenceNumber) + if prevStatusErr != nil { + // No last stable status save, should record as success since the hostGA issue + // is gone + statusResult = status.StatusSuccess + statusMessage = getStatusMessage(currentPackageRegistry.GetPackageCollection(), executeError, customActionResult) + } else { + statusResult = prevStatusObj.Status + statusMessage = prevStatusObj.FormattedMessage.Message + } + statusUpdated = true + } + } + + return statusUpdated, statusResult, statusMessage +} + +// A wrapper for utils.ReportStatus to log any errors occurring in that function +func reportStatusWrapper(ext *vmextensionhelper.VMExtension, requestedSequenceNumber uint, statusType status.StatusType, operationName string, message string) error { + err := utils.ReportStatus(ext.HandlerEnv, requestedSequenceNumber, statusType, operationName, message) + if err != nil { + errorMessage := fmt.Sprintf("Failed to save status file: %s", err.Error()) + ext.ExtensionLogger.Error(errorMessage) + ext.ExtensionEvents.LogErrorEvent("Save Status", errorMessage) + } + return err +} + func vmAppUninstallCallback(ext *vmextensionhelper.VMExtension) error { ext.ExtensionEvents.LogInformationalEvent("Uninstalling", "VmApplications extension - removing all applications for uninstall") hostGaCommunicator := hostgacommunicator.HostGaCommunicator{} diff --git a/main/main_test.go b/main/main_test.go index ddf848f..314e618 100644 --- a/main/main_test.go +++ b/main/main_test.go @@ -7,7 +7,6 @@ import ( "encoding/json" "errors" "fmt" - "io/ioutil" "os" "path" "path/filepath" @@ -73,15 +72,15 @@ func nopLog() *logging.ExtensionLogger { var maintestdir string -func TestMain(m *testing.M) { - testdir, err := ioutil.TempDir("", "maintest") +func setupTest(t *testing.T) { + testdir, err := os.MkdirTemp("", "maintest") if err != nil { - return + t.Fatalf("Failed to create temp dir: %v", err) } err = os.MkdirAll(testdir, constants.FilePermissions_UserOnly_ReadWriteExecute) if err != nil { - return + t.Fatalf("Failed to create test dir: %v", err) } setSequenceNumberFunc = func(extName, extVersion string, seqNo uint) error { @@ -90,13 +89,14 @@ func TestMain(m *testing.M) { } maintestdir = testdir - exitVal := m.Run() - os.RemoveAll(maintestdir) - os.Exit(exitVal) + t.Cleanup(func() { + os.RemoveAll(maintestdir) + }) } func Test_settingsFailToInit(t *testing.T) { + setupTest(t) ExtensionVersion = "" defer resetExtensionVersion() err := getExtensionAndRun([]string{"vm-application-manager", "enable"}) @@ -104,18 +104,21 @@ func Test_settingsFailToInit(t *testing.T) { } func Test_failToCreateExtension(t *testing.T) { + setupTest(t) // This will fail automatically because Guest Agent hasn't set the required sequence numbers err := getExtensionAndRun([]string{"vm-application-manager", "enable"}) require.Error(t, err) } func Test_getVMPackageData_noSettings(t *testing.T) { + setupTest(t) ext := createTestVMExtension(t, nil) err := customEnable(ext, noopHostGaCommunicator, 0) require.Error(t, err) } func Test_getVMPackageData_cannotDeserialize(t *testing.T) { + setupTest(t) vmPackages := "yabasnarfle {}" ext := createTestVMExtension(t, vmPackages) @@ -124,6 +127,7 @@ func Test_getVMPackageData_cannotDeserialize(t *testing.T) { } func Test_getVMPackageData_noApplications(t *testing.T) { + setupTest(t) vmApplications := []extdeserialization.VmAppSetting{} ext := createTestVMExtension(t, vmApplications) @@ -132,6 +136,7 @@ func Test_getVMPackageData_noApplications(t *testing.T) { } func Test_getVMPackageData_valid(t *testing.T) { + setupTest(t) order := 1 vmApplications := []extdeserialization.VmAppSetting{ { @@ -148,6 +153,7 @@ func Test_getVMPackageData_valid(t *testing.T) { } func Test_getVMAppProtectedSettings_valid(t *testing.T) { + setupTest(t) order := 1 actions := extdeserialization.ActionSetting{ ActionName: "logging", @@ -181,6 +187,7 @@ func Test_getVMAppProtectedSettings_valid(t *testing.T) { } func Test_getVMAppProtectedSettings_valid_no_custom_actions(t *testing.T) { + setupTest(t) order := 1 appSettings := extdeserialization.VmAppSetting{ @@ -201,6 +208,7 @@ func Test_getVMAppProtectedSettings_valid_no_custom_actions(t *testing.T) { } func Test_getVMPackageData_noVersion(t *testing.T) { + setupTest(t) order := 1 vmApplications := []extdeserialization.VmAppSetting{ { @@ -217,6 +225,7 @@ func Test_getVMPackageData_noVersion(t *testing.T) { } func Test_GetApplicationMetadataWithInvalidRebootBehavior_DefaultsToNone(t *testing.T) { + setupTest(t) order := 1 vmApplications := []extdeserialization.VmAppSetting{ { @@ -247,6 +256,7 @@ func Test_GetApplicationMetadataWithInvalidRebootBehavior_DefaultsToNone(t *test } func Test_getVMPackageDataCustomAction_valid(t *testing.T) { + setupTest(t) order := 1 actions := extdeserialization.ActionSetting{ ActionName: "Action1", @@ -285,7 +295,7 @@ func Test_getVMPackageDataCustomAction_valid(t *testing.T) { require.Contains(t, currentpackages[vmApplications[0].ApplicationName].Result, actionplan.Success) // test contents of the status file statusFilePath := filepath.Join(ext.HandlerEnv.StatusFolder, fmt.Sprintf("%d.status", requestedSequenceNumber)) - fileBytes, err := ioutil.ReadFile(statusFilePath) + fileBytes, err := os.ReadFile(statusFilePath) require.NoError(t, err) statusReport := status.StatusReport{} err = json.Unmarshal(fileBytes, &statusReport) @@ -309,6 +319,7 @@ func Test_getVMPackageDataCustomAction_valid(t *testing.T) { } func Test_getVMPackageDataCustomAction_CriticalError(t *testing.T) { + setupTest(t) order := 1 actions := extdeserialization.ActionSetting{ ActionName: "Action1", @@ -333,6 +344,7 @@ func Test_getVMPackageDataCustomAction_CriticalError(t *testing.T) { } func Test_getVMPackageData_noApplicationName(t *testing.T) { + setupTest(t) order := 1 vmApplications := []extdeserialization.VmAppSetting{ { @@ -349,6 +361,7 @@ func Test_getVMPackageData_noApplicationName(t *testing.T) { } func Test_main_statusIsWrittenForCriticalErrors(t *testing.T) { + setupTest(t) order := 1 vmApplications := []extdeserialization.VmAppSetting{ { @@ -372,7 +385,7 @@ func Test_main_statusIsWrittenForCriticalErrors(t *testing.T) { err := getExtensionAndRun([]string{"vm-application-manager", vmextension.EnableOperation.ToString()}) require.NoError(t, err) statusFilePath := filepath.Join(ext.HandlerEnv.StatusFolder, fmt.Sprintf("%d.status", requestedSequenceNumber)) - fileBytes, err := ioutil.ReadFile(statusFilePath) + fileBytes, err := os.ReadFile(statusFilePath) require.NoError(t, err) fileString := string(fileBytes) require.Contains(t, fileString, vmextension.EnableOperation.ToStatusName()) @@ -384,6 +397,7 @@ func Test_main_statusIsWrittenForCriticalErrors(t *testing.T) { } func Test_main_statusIsNotWrittenForFileLockErrors(t *testing.T) { + setupTest(t) order := 1 vmApplications := []extdeserialization.VmAppSetting{ { @@ -417,6 +431,7 @@ func Test_main_statusIsNotWrittenForFileLockErrors(t *testing.T) { } func Test_main_nothingToProcess_noStatusUpdate(t *testing.T) { + setupTest(t) vmApplications := []extdeserialization.VmAppSetting{} ext := createTestVMExtension(t, vmApplications) @@ -427,13 +442,15 @@ func Test_main_nothingToProcess_noStatusUpdate(t *testing.T) { err = customEnable(ext, &hostGaCommunicator, requestedSequenceNumber) require.NoError(t, err) // ensure stautus file is not overwritten - statusType, err := utils.GetStatusType(ext.HandlerEnv, requestedSequenceNumber) + statusObj, err := utils.GetStatus(ext.HandlerEnv, requestedSequenceNumber) require.NoError(t, err) - require.Equal(t, status.StatusError, statusType) + require.NotNil(t, statusObj) + require.Equal(t, status.StatusError, statusObj.Status) require.Equal(t, requestedSequenceNumber, currentSequenceNumber) } func Test_main_transitioningStatusIsUpdated(t *testing.T) { + setupTest(t) vmApplications := []extdeserialization.VmAppSetting{} ext := createTestVMExtension(t, vmApplications) @@ -444,13 +461,15 @@ func Test_main_transitioningStatusIsUpdated(t *testing.T) { err = customEnable(ext, &hostGaCommunicator, requestedSequenceNumber) require.NoError(t, err) // ensure error stautus file is not overwritten - statusType, err := utils.GetStatusType(ext.HandlerEnv, requestedSequenceNumber) + statusObj, err := utils.GetStatus(ext.HandlerEnv, requestedSequenceNumber) require.NoError(t, err) - require.Equal(t, status.StatusSuccess, statusType) + require.NotNil(t, statusObj) + require.Equal(t, status.StatusSuccess, statusObj.Status) require.Equal(t, requestedSequenceNumber, currentSequenceNumber) } func Test_main_nothingToProcess_withStatus(t *testing.T) { + setupTest(t) vmApplications := []extdeserialization.VmAppSetting{} ext := createTestVMExtension(t, vmApplications) hostGaCommunicator := NoopHostGaCommunicator{} @@ -458,7 +477,7 @@ func Test_main_nothingToProcess_withStatus(t *testing.T) { err := customEnable(ext, &hostGaCommunicator, requestedSequenceNumber) require.NoError(t, err) statusFilePath := filepath.Join(ext.HandlerEnv.StatusFolder, fmt.Sprintf("%d.status", requestedSequenceNumber)) - fileBytes, err := ioutil.ReadFile(statusFilePath) + fileBytes, err := os.ReadFile(statusFilePath) require.NoError(t, err) fileString := string(fileBytes) require.Contains(t, fileString, vmextension.EnableOperation.ToStatusName()) @@ -467,6 +486,7 @@ func Test_main_nothingToProcess_withStatus(t *testing.T) { } func Test_uninstall_cannotCreatePackageRegistry(t *testing.T) { + setupTest(t) vmApplications := []extdeserialization.VmAppSetting{} ext := createTestVMExtension(t, vmApplications) hostGaCommunicator := NoopHostGaCommunicator{} @@ -480,13 +500,14 @@ func Test_uninstall_cannotCreatePackageRegistry(t *testing.T) { } func Test_uninstall_cannotReadPackageRegistry(t *testing.T) { + setupTest(t) vmApplications := []extdeserialization.VmAppSetting{} ext := createTestVMExtension(t, vmApplications) hostGaCommunicator := NoopHostGaCommunicator{} // Write an invalid registry so we can't create a package registry appRegistryFilePath := path.Join(ext.HandlerEnv.ConfigFolder, packageregistry.LocalApplicationRegistryFileName) - ioutil.WriteFile(appRegistryFilePath, []byte("}"), 0644) + os.WriteFile(appRegistryFilePath, []byte("}"), 0644) defer os.Remove(appRegistryFilePath) err := doVmAppUninstallCallback(ext, &hostGaCommunicator) @@ -495,6 +516,7 @@ func Test_uninstall_cannotReadPackageRegistry(t *testing.T) { } func Test_uninstall_noAppsToUninstall(t *testing.T) { + setupTest(t) vmApplications := []extdeserialization.VmAppSetting{} ext := createTestVMExtension(t, vmApplications) hostGaCommunicator := NoopHostGaCommunicator{} @@ -536,6 +558,7 @@ func Test_uninstall_noAppsToUninstall(t *testing.T) { } func Test_uninstall_uninstallApps(t *testing.T) { + setupTest(t) vmApplications := []extdeserialization.VmAppSetting{} ext := createTestVMExtension(t, vmApplications) hostGaCommunicator := NoopHostGaCommunicator{} @@ -607,3 +630,117 @@ func createTestVMExtension(t *testing.T, settings interface{}) *vmextension.VMEx ExtensionEvents: eem, } } + +// Test computeStatus when vmAppResults has items and no deployment failure +func Test_computeStatus_WithVMAppResults_Success(t *testing.T) { + setupTest(t) + ext := createTestVMExtension(t, nil) + + vmAppResults := actionplan.PackageOperationResults{ + actionplan.PackageOperationResult{ + PackageName: "testApp", + AppVersion: "1.0.0", + Operation: "install", + Result: "0", + }, + } + executeError := &actionplan.ExecuteError{} + currentPkgReg := packageregistry.CurrentPackageRegistry{} + + updated, statusType, _ := computeStatus(ext, 1, ¤tPkgReg, executeError, &vmAppResults, &vmAppResults) + + assert.True(t, updated) + assert.Equal(t, status.StatusSuccess, statusType) +} + +// Test computeStatus when no vmAppResults and status file doesn't exist +func Test_computeStatus_NoVMAppResults_NoStatusFile(t *testing.T) { + setupTest(t) + ext := createTestVMExtension(t, nil) + + executeError := &actionplan.ExecuteError{} + currentPkgReg := packageregistry.CurrentPackageRegistry{} + emptyResults := actionplan.PackageOperationResults{} + + updated, statusType, _ := computeStatus(ext, 1, ¤tPkgReg, executeError, &emptyResults, nil) + + assert.True(t, updated) + assert.Equal(t, status.StatusSuccess, statusType) +} + +// Test computeStatus when no vmAppResults and current status has HostGA error prefix but no last stable +func Test_computeStatus_NoVMAppResults_HostGAError_NoLastStable(t *testing.T) { + setupTest(t) + ext := createTestVMExtension(t, nil) + + // Write a status file with HostGA error prefix + // ReportStatus adds "Enable failed: " prefix, but Contains still finds the HostGaMetadataErrorPrefix + err := utils.ReportStatus(ext.HandlerEnv, 1, status.StatusError, "Enable", hostgacommunicator.HostGaMetadataErrorPrefix+" some error") + require.NoError(t, err) + + seqNo := uint(1) + ext.CurrentSequenceNumber = &seqNo + + executeError := &actionplan.ExecuteError{} + currentPkgReg := packageregistry.CurrentPackageRegistry{} + emptyResults := actionplan.PackageOperationResults{} + + updated, statusType, _ := computeStatus(ext, 1, ¤tPkgReg, executeError, &emptyResults, nil) + + // HostGA error detected, no last stable status exists, so update to success + assert.True(t, updated) + assert.Equal(t, status.StatusSuccess, statusType) +} + +// Test computeStatus when no vmAppResults and current status has HostGA error with last stable status +func Test_computeStatus_NoVMAppResults_HostGAError_WithLastStable(t *testing.T) { + setupTest(t) + ext := createTestVMExtension(t, nil) + + // First write a stable (lastgood) status file + err := utils.ReportStatus(ext.HandlerEnv, 1, status.StatusSuccess, "Enable", "last good message") + require.NoError(t, err) + err = utils.BackupStatusFile(ext.HandlerEnv.StatusFolder, 1) + require.NoError(t, err) + + // Now write a status file with HostGA error prefix + // ReportStatus adds "Enable failed: " prefix, but Contains still finds the HostGaMetadataErrorPrefix + err = utils.ReportStatus(ext.HandlerEnv, 1, status.StatusError, "Enable", hostgacommunicator.HostGaMetadataErrorPrefix+" some error") + require.NoError(t, err) + + seqNo := uint(1) + ext.CurrentSequenceNumber = &seqNo + + executeError := &actionplan.ExecuteError{} + currentPkgReg := packageregistry.CurrentPackageRegistry{} + emptyResults := actionplan.PackageOperationResults{} + + updated, statusType, msg := computeStatus(ext, 1, ¤tPkgReg, executeError, &emptyResults, nil) + + // HostGA error detected, last stable status restored + assert.True(t, updated) + assert.Equal(t, status.StatusSuccess, statusType) + assert.Contains(t, msg, "last good message") +} + +// Test computeStatus when status is Transitioning and CurrentSequenceNumber is nil (first Enable) +func Test_computeStatus_StatusTransitioning_FirstEnable(t *testing.T) { + setupTest(t) + ext := createTestVMExtension(t, nil) + + // Write a transitioning status file + err := utils.ReportStatus(ext.HandlerEnv, 1, status.StatusTransitioning, "Enable", "transitioning") + require.NoError(t, err) + + // Set CurrentSequenceNumber to nil (first time Enable) + ext.CurrentSequenceNumber = nil + + executeError := &actionplan.ExecuteError{} + currentPkgReg := packageregistry.CurrentPackageRegistry{} + emptyResults := actionplan.PackageOperationResults{} + + updated, statusType, _ := computeStatus(ext, 1, ¤tPkgReg, executeError, &emptyResults, nil) + + assert.True(t, updated) + assert.Equal(t, status.StatusSuccess, statusType) +} diff --git a/pkg/utils/status.go b/pkg/utils/status.go index b9a3732..b0960da 100644 --- a/pkg/utils/status.go +++ b/pkg/utils/status.go @@ -6,7 +6,8 @@ package utils import ( "encoding/json" "fmt" - "io/ioutil" + "io" + "os" "path/filepath" "github.com/Azure/azure-extension-platform/pkg/handlerenv" @@ -14,6 +15,8 @@ import ( "github.com/pkg/errors" ) +const BackupStatusFileSuffix = ".lastStableStatus" + type StatusSaveError struct { Err error } @@ -22,19 +25,29 @@ func (statusServerError *StatusSaveError) Error() string { return statusServerError.Err.Error() } -func GetStatusType(handlerEnv *handlerenv.HandlerEnvironment, sequenceNumber uint) (status.StatusType, error) { - fn := fmt.Sprintf("%d.status", sequenceNumber) - path := filepath.Join(handlerEnv.StatusFolder, fn) - statusBytes, err := ioutil.ReadFile(path) +func readStatusFileHelper(path string) (*status.Status, error) { + statusBytes, err := os.ReadFile(path) if err != nil { - return "", err + return nil, err } statusReport := make(status.StatusReport, 1) err = json.Unmarshal(statusBytes, &statusReport) if err != nil { - return "", err + return nil, err } - return statusReport[0].Status.Status, nil + return &statusReport[0].Status, nil +} + +func GetStatus(handlerEnv *handlerenv.HandlerEnvironment, sequenceNumber uint) (*status.Status, error) { + fn := fmt.Sprintf("%d.status", sequenceNumber) + path := filepath.Join(handlerEnv.StatusFolder, fn) + return readStatusFileHelper(path) +} + +func GetLastStableStatus(handlerEnv *handlerenv.HandlerEnvironment, sequenceNumber uint) (*status.Status, error) { + fn := fmt.Sprintf("%d%s", sequenceNumber, BackupStatusFileSuffix) + path := filepath.Join(handlerEnv.StatusFolder, fn) + return readStatusFileHelper(path) } func ReportStatus(handlerEnv *handlerenv.HandlerEnvironment, requestedSequenceNumber uint, statusType status.StatusType, operationName string, message string) error { @@ -47,3 +60,59 @@ func ReportStatus(handlerEnv *handlerenv.HandlerEnvironment, requestedSequenceNu } return nil } + +// copyFile copies a file from src to dst. If dst already exists, it will be overwritten. +// The file permissions of the destination file will be the same as the source file. +func copyFile(src, dst string) error { + // Open the source file + sourceFile, err := os.Open(src) + if err != nil { + return fmt.Errorf("failed to open source file: %w", err) + } + defer sourceFile.Close() // Get source file info + + sourceInfo, err := sourceFile.Stat() + if err != nil { + return fmt.Errorf("failed to stat source file: %w", err) + } + + // Create the destination file with same mode + destinationFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, sourceInfo.Mode()) + if err != nil { + return fmt.Errorf("failed to create destination file: %w", err) + } + defer destinationFile.Close() + + // Copy the content + _, err = io.Copy(destinationFile, sourceFile) + if err != nil { + return fmt.Errorf("failed to copy file: %w", err) + } + + // Flush file metadata to disk + err = destinationFile.Sync() + if err != nil { + return fmt.Errorf("failed to sync destination file: %w", err) + } + + return nil +} + +// BackupStatusFile renames the current status file so it can be restored later. +// If there is no existing status file, this function returns without error because +// there's nothing to back up. +func BackupStatusFile(statusFolder string, sequenceNumber uint) error { + current := filepath.Join(statusFolder, fmt.Sprintf("%d.status", sequenceNumber)) + backup := filepath.Join(statusFolder, fmt.Sprintf("%d%s", sequenceNumber, BackupStatusFileSuffix)) + info, err := os.Stat(current) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + if info.IsDir() { + return fmt.Errorf("expected a file but found a directory: %s", current) + } + return copyFile(current, backup) +} diff --git a/pkg/utils/status_test.go b/pkg/utils/status_test.go index b45fb4e..ad1d90f 100644 --- a/pkg/utils/status_test.go +++ b/pkg/utils/status_test.go @@ -4,7 +4,9 @@ package utils import ( + "os" "path" + "path/filepath" "strings" "testing" @@ -15,7 +17,53 @@ import ( func TestStatusParsing(t *testing.T) { handlerEnv := handlerenv.HandlerEnvironment{StatusFolder: path.Join(".", "testFiles")} - statusType, err := GetStatusType(&handlerEnv, 1) + statusObj, err := GetStatus(&handlerEnv, 1) require.NoError(t, err) - require.True(t, strings.EqualFold(string(statusType), string(platformstatus.StatusTransitioning))) + require.NotNil(t, statusObj) + require.True(t, strings.EqualFold(string(statusObj.Status), string(platformstatus.StatusTransitioning))) +} + +func TestBackupStatusFile(t *testing.T) { + t.Run("successful backup when status file exists", func(t *testing.T) { + // Create temp directory + tmpDir := t.TempDir() + statusFile := filepath.Join(tmpDir, "1.status") + backupFile := filepath.Join(tmpDir, "1"+BackupStatusFileSuffix) + + // Create a status file + err := os.WriteFile(statusFile, []byte(`[{"status":{"status":"success"}}]`), 0644) + require.NoError(t, err) + + // Backup the status file + err = BackupStatusFile(tmpDir, 1) + require.NoError(t, err) + + // Verify original file is still there (to be overwritten later) + _, err = os.Stat(statusFile) + require.NoError(t, err) + + // Verify backup file exists + _, err = os.Stat(backupFile) + require.NoError(t, err) + }) + + t.Run("successful backup when status file does not exist", func(t *testing.T) { + tmpDir := t.TempDir() + + err := BackupStatusFile(tmpDir, 999) + require.NoError(t, err) + }) + + t.Run("error when the status file path is a directory", func(t *testing.T) { + tmpDir := t.TempDir() + statusFileAsDir := filepath.Join(tmpDir, "1.status") + + // Create a directory with the same name as the status file + err := os.Mkdir(statusFileAsDir, 0755) + require.NoError(t, err) + + // Backup should fail because the path is a directory, not a file + err = BackupStatusFile(tmpDir, 1) + require.Error(t, err) + }) }