diff --git a/cmd/node/config/config.toml b/cmd/node/config/config.toml index 7a2c3c78f41..d24711ee8b8 100644 --- a/cmd/node/config/config.toml +++ b/cmd/node/config/config.toml @@ -887,9 +887,11 @@ SnapshotsEnabled = true AccountsStatePruningEnabled = false PeerStatePruningEnabled = true - MaxStateTrieLevelInMemory = 5 - MaxPeerTrieLevelInMemory = 5 StateStatisticsEnabled = false + MaxUserTrieSizeInMemory = 524288000 #500MB + MaxPeerTrieSizeInMemory = 104857600 #100MB + DataTriesSizeInMemory = 524288000 #500MB + NumLeavesToCollapseSingleRun = 100 [TrieLeavesRetrieverConfig] Enabled = false diff --git a/common/common.go b/common/common.go index cf6601328d5..2efd873b0c9 100644 --- a/common/common.go +++ b/common/common.go @@ -28,6 +28,14 @@ const ( nonceIndex = 0 ) +const ( + // TenMbSize defines the size of 10 megabytes in bytes, used as a constant for memory limits or buffer sizes + TenMbSize = uint64(10485760) + + // NumLeavesToCollapseSingleRun defines the number of leaves to collapse in a single run, used in trie collapsing operations to manage performance and resource usage + NumLeavesToCollapseSingleRun = 100 +) + type executionResultHandler interface { GetMiniBlockHeadersHandlers() []data.MiniBlockHeaderHandler } diff --git a/common/interface.go b/common/interface.go index 9a7c51536ef..98ffeb9417b 100644 --- a/common/interface.go +++ b/common/interface.go @@ -59,6 +59,7 @@ type Trie interface { VerifyProof(rootHash []byte, key []byte, proof [][]byte) (bool, error) GetStorageManager() StorageManager IsMigratedToLatestVersion() (bool, error) + SizeInMemory() int Close() error IsInterfaceNil() bool } @@ -151,13 +152,20 @@ type SnapshotDbHandler interface { // TriesHolder is used to store multiple tries type TriesHolder interface { Put([]byte, Trie) - Replace(key []byte, tr Trie) Get([]byte) Trie GetAll() []Trie + Remove([]byte) + MarkAsDirty([]byte) Reset() IsInterfaceNil() bool } +// DataTrieCreator is an adapter for the Trie interface used only for recreating tries +type DataTrieCreator interface { + Recreate(options RootHashHolder) (Trie, error) + IsInterfaceNil() bool +} + // Locker defines the operations used to lock different critical areas. Implemented by the RWMutex. type Locker interface { Lock() @@ -527,3 +535,15 @@ type AOTSelectionPreempter interface { CancelOngoingSelection() IsInterfaceNil() bool } + +// TrieCollapseManager defines the behavior of a trie collapse manager +type TrieCollapseManager interface { + MarkKeyAsAccessed(key []byte, sizeLoadedInMemory int) + RemoveKey(key []byte, sizeLoadedInMemory int) + ShouldCollapseTrie() bool + GetCollapsibleLeaves() ([][]byte, error) + AddSizeInMemory(size int) + GetSizeInMemory() int + CloneWithoutState() TrieCollapseManager + IsInterfaceNil() bool +} diff --git a/config/config.go b/config/config.go index 80c13872b34..9d757af1901 100644 --- a/config/config.go +++ b/config/config.go @@ -455,12 +455,14 @@ type FacadeConfig struct { // StateTriesConfig will hold information about state tries type StateTriesConfig struct { - SnapshotsEnabled bool - AccountsStatePruningEnabled bool - PeerStatePruningEnabled bool - MaxStateTrieLevelInMemory uint - MaxPeerTrieLevelInMemory uint - StateStatisticsEnabled bool + SnapshotsEnabled bool + AccountsStatePruningEnabled bool + PeerStatePruningEnabled bool + StateStatisticsEnabled bool + MaxUserTrieSizeInMemory uint64 + MaxPeerTrieSizeInMemory uint64 + DataTriesSizeInMemory uint64 + NumLeavesToCollapseSingleRun uint32 } // StateAccessesCollectorConfig will hold information about state accesses collector diff --git a/config/tomlConfig_test.go b/config/tomlConfig_test.go index 1450e50f5d1..5f8dfe96331 100644 --- a/config/tomlConfig_test.go +++ b/config/tomlConfig_test.go @@ -169,11 +169,10 @@ func TestTomlParser(t *testing.T) { }, }, StateTriesConfig: StateTriesConfig{ - SnapshotsEnabled: true, - AccountsStatePruningEnabled: true, - PeerStatePruningEnabled: true, - MaxStateTrieLevelInMemory: 38, - MaxPeerTrieLevelInMemory: 39, + SnapshotsEnabled: true, + AccountsStatePruningEnabled: true, + PeerStatePruningEnabled: true, + NumLeavesToCollapseSingleRun: 100, }, TxCacheBounds: TxCacheBoundsConfig{ MaxNumBytesPerSenderUpperBound: 33_554_432, @@ -632,8 +631,7 @@ func TestTomlParser(t *testing.T) { SnapshotsEnabled = true AccountsStatePruningEnabled = true PeerStatePruningEnabled = true - MaxStateTrieLevelInMemory = 38 - MaxPeerTrieLevelInMemory = 39 + NumLeavesToCollapseSingleRun = 100 ` cfg := Config{} diff --git a/dataRetriever/factory/resolverscontainer/metaResolversContainerFactory_test.go b/dataRetriever/factory/resolverscontainer/metaResolversContainerFactory_test.go index 003d20a37b6..ed8037542a2 100644 --- a/dataRetriever/factory/resolverscontainer/metaResolversContainerFactory_test.go +++ b/dataRetriever/factory/resolverscontainer/metaResolversContainerFactory_test.go @@ -14,7 +14,7 @@ import ( "github.com/multiversx/mx-chain-go/dataRetriever/mock" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/process/factory" - "github.com/multiversx/mx-chain-go/state" + "github.com/multiversx/mx-chain-go/state/triesHolder" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/cache" @@ -86,10 +86,10 @@ func createStoreForMeta() dataRetriever.StorageService { } func createTriesHolderForMeta() common.TriesHolder { - triesHolder := state.NewDataTriesHolder() - triesHolder.Put([]byte(dataRetriever.UserAccountsUnit.String()), &trieMock.TrieStub{}) - triesHolder.Put([]byte(dataRetriever.PeerAccountsUnit.String()), &trieMock.TrieStub{}) - return triesHolder + triesContainer := triesHolder.NewTriesHolder() + triesContainer.Put([]byte(dataRetriever.UserAccountsUnit.String()), &trieMock.TrieStub{}) + triesContainer.Put([]byte(dataRetriever.PeerAccountsUnit.String()), &trieMock.TrieStub{}) + return triesContainer } // ------- NewResolversContainerFactory diff --git a/dataRetriever/factory/resolverscontainer/shardResolversContainerFactory_test.go b/dataRetriever/factory/resolverscontainer/shardResolversContainerFactory_test.go index e2a26be3733..74f99260ec9 100644 --- a/dataRetriever/factory/resolverscontainer/shardResolversContainerFactory_test.go +++ b/dataRetriever/factory/resolverscontainer/shardResolversContainerFactory_test.go @@ -14,7 +14,7 @@ import ( "github.com/multiversx/mx-chain-go/dataRetriever/mock" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/process/factory" - "github.com/multiversx/mx-chain-go/state" + "github.com/multiversx/mx-chain-go/state/triesHolder" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/cache" @@ -92,10 +92,10 @@ func createStoreForShard() dataRetriever.StorageService { } func createTriesHolderForShard() common.TriesHolder { - triesHolder := state.NewDataTriesHolder() - triesHolder.Put([]byte(dataRetriever.UserAccountsUnit.String()), &trieMock.TrieStub{}) - triesHolder.Put([]byte(dataRetriever.PeerAccountsUnit.String()), &trieMock.TrieStub{}) - return triesHolder + triesContainer := triesHolder.NewTriesHolder() + triesContainer.Put([]byte(dataRetriever.UserAccountsUnit.String()), &trieMock.TrieStub{}) + triesContainer.Put([]byte(dataRetriever.PeerAccountsUnit.String()), &trieMock.TrieStub{}) + return triesContainer } // ------- NewResolversContainerFactory diff --git a/dataRetriever/factory/storageRequestersContainer/metaRequestersContainerFactory_test.go b/dataRetriever/factory/storageRequestersContainer/metaRequestersContainerFactory_test.go index 6a44b37b153..1186184fc84 100644 --- a/dataRetriever/factory/storageRequestersContainer/metaRequestersContainerFactory_test.go +++ b/dataRetriever/factory/storageRequestersContainer/metaRequestersContainerFactory_test.go @@ -221,8 +221,6 @@ func getArgumentsMeta() storagerequesterscontainer.FactoryArgs { StateTriesConfig: config.StateTriesConfig{ AccountsStatePruningEnabled: false, PeerStatePruningEnabled: false, - MaxStateTrieLevelInMemory: 5, - MaxPeerTrieLevelInMemory: 5, }, }, ShardIDForTries: 0, diff --git a/dataRetriever/factory/storageRequestersContainer/shardRequestersContainerFactory_test.go b/dataRetriever/factory/storageRequestersContainer/shardRequestersContainerFactory_test.go index ecf848b35ac..9e5e9645ac8 100644 --- a/dataRetriever/factory/storageRequestersContainer/shardRequestersContainerFactory_test.go +++ b/dataRetriever/factory/storageRequestersContainer/shardRequestersContainerFactory_test.go @@ -206,8 +206,6 @@ func getArgumentsShard() storagerequesterscontainer.FactoryArgs { StateTriesConfig: config.StateTriesConfig{ AccountsStatePruningEnabled: false, PeerStatePruningEnabled: false, - MaxStateTrieLevelInMemory: 5, - MaxPeerTrieLevelInMemory: 5, }, }, ShardIDForTries: 0, diff --git a/epochStart/bootstrap/process.go b/epochStart/bootstrap/process.go index 6fa278b77e7..70aa5ccf873 100644 --- a/epochStart/bootstrap/process.go +++ b/epochStart/bootstrap/process.go @@ -15,10 +15,9 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/typeConverters/uint64ByteSlice" + "github.com/multiversx/mx-chain-go/state/triesHolder" logger "github.com/multiversx/mx-chain-logger-go" - "github.com/multiversx/mx-chain-go/process/interceptors/processor" - "github.com/multiversx/mx-chain-go/common" disabledCommon "github.com/multiversx/mx-chain-go/common/disabled" "github.com/multiversx/mx-chain-go/common/ordering" @@ -42,11 +41,11 @@ import ( "github.com/multiversx/mx-chain-go/process/heartbeat/validator" "github.com/multiversx/mx-chain-go/process/interceptors" disabledInterceptors "github.com/multiversx/mx-chain-go/process/interceptors/disabled" + "github.com/multiversx/mx-chain-go/process/interceptors/processor" "github.com/multiversx/mx-chain-go/process/peer" "github.com/multiversx/mx-chain-go/redundancy" "github.com/multiversx/mx-chain-go/sharding" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" - "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/state/syncer" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/storage/cache" @@ -283,7 +282,7 @@ func NewEpochStartBootstrap(args ArgsEpochStartBootstrap) (*epochStartBootstrap, return nil, err } - epochStartProvider.trieContainer = state.NewDataTriesHolder() + epochStartProvider.trieContainer = triesHolder.NewTriesHolder() epochStartProvider.trieStorageManagers = make(map[string]common.StorageManager) if epochStartProvider.generalConfig.Hardfork.AfterHardFork { @@ -1573,7 +1572,6 @@ func (e *epochStartBootstrap) syncUserAccountsState(rootHash []byte) error { RequestHandler: e.requestHandler, Timeout: common.TimeoutGettingTrieNodes, Cacher: e.dataPool.TrieNodes(), - MaxTrieLevelInMemory: e.generalConfig.StateTriesConfig.MaxStateTrieLevelInMemory, MaxHardCapForMissingNodes: e.maxHardCapForMissingNodes, TrieSyncerVersion: e.trieSyncerVersion, CheckNodesOnDisk: e.checkNodesOnDisk, @@ -1647,7 +1645,6 @@ func (e *epochStartBootstrap) syncValidatorAccountsState(rootHash []byte) error RequestHandler: e.requestHandler, Timeout: common.TimeoutGettingTrieNodes, Cacher: e.dataPool.TrieNodes(), - MaxTrieLevelInMemory: e.generalConfig.StateTriesConfig.MaxPeerTrieLevelInMemory, MaxHardCapForMissingNodes: e.maxHardCapForMissingNodes, TrieSyncerVersion: e.trieSyncerVersion, CheckNodesOnDisk: e.checkNodesOnDisk, diff --git a/epochStart/bootstrap/process_test.go b/epochStart/bootstrap/process_test.go index 49843c2bc8b..7cabfa4e0bd 100644 --- a/epochStart/bootstrap/process_test.go +++ b/epochStart/bootstrap/process_test.go @@ -182,8 +182,9 @@ func createMockEpochStartBootstrapArgs( AccountsStatePruningEnabled: true, SnapshotsEnabled: true, PeerStatePruningEnabled: true, - MaxStateTrieLevelInMemory: 5, - MaxPeerTrieLevelInMemory: 5, + MaxUserTrieSizeInMemory: generalCfg.StateTriesConfig.MaxUserTrieSizeInMemory, + MaxPeerTrieSizeInMemory: generalCfg.StateTriesConfig.MaxPeerTrieSizeInMemory, + DataTriesSizeInMemory: generalCfg.StateTriesConfig.DataTriesSizeInMemory, }, TrieStorageManagerConfig: config.TrieStorageManagerConfig{ PruningBufferLen: 1000, diff --git a/epochStart/metachain/baseRewards_test.go b/epochStart/metachain/baseRewards_test.go index b0e1118b84e..984ff32c6ca 100644 --- a/epochStart/metachain/baseRewards_test.go +++ b/epochStart/metachain/baseRewards_test.go @@ -14,6 +14,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data/rewardTx" "github.com/multiversx/mx-chain-core-go/hashing/sha256" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/state" vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -21,7 +22,6 @@ import ( "github.com/multiversx/mx-chain-go/epochStart" "github.com/multiversx/mx-chain-go/epochStart/mock" "github.com/multiversx/mx-chain-go/process" - "github.com/multiversx/mx-chain-go/state/factory" "github.com/multiversx/mx-chain-go/testscommon" txExecOrderStub "github.com/multiversx/mx-chain-go/testscommon/common" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" @@ -29,7 +29,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" - "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/trie" ) @@ -1140,20 +1139,14 @@ func getBaseRewardsArguments() BaseRewardsCreatorArgs { hasher := sha256.NewSha256() marshalizer := &marshal.GogoProtoMarshalizer{} - storageManagerArgs := storage.GetStorageManagerArgs() + storageManagerArgs := txExecOrderStub.GetStorageManagerArgs() storageManagerArgs.Marshalizer = marshalizer storageManagerArgs.Hasher = hasher - trieFactoryManager, _ := trie.CreateTrieStorageManager(storageManagerArgs, storage.GetStorageManagerOptions()) - argsAccCreator := factory.ArgsAccountCreator{ - Hasher: hasher, - Marshaller: marshalizer, - EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - StateAccessesCollector: &stateMock.StateAccessesCollectorStub{}, - } - accCreator, _ := factory.NewAccountCreator(argsAccCreator) + trieFactoryManager, _ := trie.CreateTrieStorageManager(storageManagerArgs, txExecOrderStub.GetStorageManagerOptions()) enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{} - userAccountsDB := createAccountsDB(hasher, marshalizer, accCreator, trieFactoryManager, enableEpochsHandler) + userAccountsDBArgs := createAccountsDBArgs(hasher, marshalizer, trieFactoryManager, enableEpochsHandler) + userAccountsDB, _ := state.NewAccountsDB(userAccountsDBArgs) shardCoordinator := mock.NewMultiShardsCoordinatorMock(2) shardCoordinator.CurrentShard = core.MetachainShardId shardCoordinator.ComputeIdCalled = func(address []byte) uint32 { diff --git a/epochStart/metachain/systemSCs_test.go b/epochStart/metachain/systemSCs_test.go index aac10d69d01..7adeb57f170 100644 --- a/epochStart/metachain/systemSCs_test.go +++ b/epochStart/metachain/systemSCs_test.go @@ -41,10 +41,12 @@ import ( "github.com/multiversx/mx-chain-go/state/factory" "github.com/multiversx/mx-chain-go/state/storagePruningManager" "github.com/multiversx/mx-chain-go/state/storagePruningManager/evictionWaitingList" + "github.com/multiversx/mx-chain-go/state/triesHolder" "github.com/multiversx/mx-chain-go/storage" storageFactory "github.com/multiversx/mx-chain-go/storage/factory" "github.com/multiversx/mx-chain-go/storage/storageunit" "github.com/multiversx/mx-chain-go/testscommon" + testCommon "github.com/multiversx/mx-chain-go/testscommon/common" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" @@ -57,6 +59,7 @@ import ( statusHandlerMock "github.com/multiversx/mx-chain-go/testscommon/statusHandler" storageMock "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/trie" + "github.com/multiversx/mx-chain-go/trie/collapseManager" "github.com/multiversx/mx-chain-go/vm" "github.com/multiversx/mx-chain-go/vm/systemSmartContracts" "github.com/multiversx/mx-chain-go/vm/systemSmartContracts/defaults" @@ -741,14 +744,23 @@ func prepareStakingContractWithData( log.LogIfError(err) } -func createAccountsDB( +func createAccountsDBArgs( hasher hashing.Hasher, marshaller marshal.Marshalizer, - accountFactory state.AccountFactory, trieStorageManager common.StorageManager, enableEpochsHandler common.EnableEpochsHandler, -) *state.AccountsDB { - tr, _ := trie.NewTrie(trieStorageManager, marshaller, hasher, enableEpochsHandler, 5) +) state.ArgsAccountsDB { + dth, _ := triesHolder.NewDataTriesHolder(common.TenMbSize) + tr, _ := trie.NewTrie(trieStorageManager, marshaller, hasher, enableEpochsHandler, collapseManager.NewDisabledCollapseManager()) + argsAccCreator := factory.ArgsAccountCreator{ + Hasher: hasher, + Marshaller: marshaller, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + StateAccessesCollector: &stateMock.StateAccessesCollectorStub{}, + DataTriesHolder: dth, + DataTrieCreator: tr, + } + accCreator, _ := factory.NewAccountCreator(argsAccCreator) ewlArgs := evictionWaitingList.MemoryEvictionWaitingListArgs{ RootHashesSize: 100, HashesSize: 10000, @@ -756,37 +768,28 @@ func createAccountsDB( ewl, _ := evictionWaitingList.NewMemoryEvictionWaitingList(ewlArgs) spm, _ := storagePruningManager.NewStoragePruningManager(ewl, 10) - args := state.ArgsAccountsDB{ + return state.ArgsAccountsDB{ Trie: tr, Hasher: hasher, Marshaller: marshaller, - AccountFactory: accountFactory, + AccountFactory: accCreator, StoragePruningManager: spm, AddressConverter: &testscommon.PubkeyConverterMock{}, SnapshotsManager: disabledState.NewDisabledSnapshotsManager(), StateAccessesCollector: disabledState.NewDisabledStateAccessesCollector(), + DataTriesHolder: dth, } - adb, _ := state.NewAccountsDB(args) - return adb } func createFullArgumentsForSystemSCProcessing(enableEpochsConfig config.EnableEpochs, trieStorer storage.Storer) (ArgsNewEpochStartSystemSCProcessing, vm.SystemSCContainer) { hasher := sha256.NewSha256() marshalizer := &marshal.GogoProtoMarshalizer{} - storageManagerArgs := storageMock.GetStorageManagerArgs() + storageManagerArgs := testCommon.GetStorageManagerArgs() storageManagerArgs.Marshalizer = marshalizer storageManagerArgs.Hasher = hasher storageManagerArgs.MainStorer = trieStorer - trieFactoryManager, _ := trie.CreateTrieStorageManager(storageManagerArgs, storageMock.GetStorageManagerOptions()) - argsAccCreator := factory.ArgsAccountCreator{ - Hasher: hasher, - Marshaller: marshalizer, - EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - StateAccessesCollector: disabledState.NewDisabledStateAccessesCollector(), - } - accCreator, _ := factory.NewAccountCreator(argsAccCreator) - peerAccCreator := factory.NewPeerAccountCreator() + trieFactoryManager, _ := trie.CreateTrieStorageManager(storageManagerArgs, testCommon.GetStorageManagerOptions()) en := forking.NewGenericEpochNotifier() enableEpochsConfig.StakeLimitsEnableEpoch = 10 enableEpochsConfig.StakingV4Step1EnableEpoch = 444 @@ -795,8 +798,12 @@ func createFullArgumentsForSystemSCProcessing(enableEpochsConfig config.EnableEp EnableEpochs: enableEpochsConfig, } enableEpochsHandler, _ := enablers.NewEnableEpochsHandler(epochsConfig.EnableEpochs, en) - userAccountsDB := createAccountsDB(hasher, marshalizer, accCreator, trieFactoryManager, enableEpochsHandler) - peerAccountsDB := createAccountsDB(hasher, marshalizer, peerAccCreator, trieFactoryManager, enableEpochsHandler) + userAccountsDBArgs := createAccountsDBArgs(hasher, marshalizer, trieFactoryManager, enableEpochsHandler) + userAccountsDB, _ := state.NewAccountsDB(userAccountsDBArgs) + + peerAccountsDBArgs := createAccountsDBArgs(hasher, marshalizer, trieFactoryManager, enableEpochsHandler) + peerAccountsDBArgs.AccountFactory = factory.NewPeerAccountCreator() + peerAccountsDB, _ := state.NewAccountsDB(peerAccountsDBArgs) argsValidatorsProcessor := peer.ArgValidatorStatisticsProcessor{ Marshalizer: marshalizer, diff --git a/errors/errors.go b/errors/errors.go index 7f9b5e85438..15ba8d1488c 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -616,3 +616,9 @@ var ErrNilTrieLeavesRetriever = errors.New("nil trie leaves retriever") // ErrNilAOTSelector signals that a nil AOT selector has been provided var ErrNilAOTSelector = errors.New("nil AOT selector") + +// ErrNilDataTriesHolder signals that a nil data tries holder has been provided +var ErrNilDataTriesHolder = errors.New("nil data tries holder") + +// ErrNilDataTrieCreator signals that a nil data trie creator has been provided +var ErrNilDataTrieCreator = errors.New("nil data trie creator") diff --git a/factory/api/apiResolverFactory.go b/factory/api/apiResolverFactory.go index 170d05ac839..1a2efb8fd95 100644 --- a/factory/api/apiResolverFactory.go +++ b/factory/api/apiResolverFactory.go @@ -9,6 +9,8 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/marshal" + factoryState "github.com/multiversx/mx-chain-go/state/factory" + "github.com/multiversx/mx-chain-go/state/triesHolder" logger "github.com/multiversx/mx-chain-logger-go" vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/multiversx/mx-chain-vm-common-go/parsers" @@ -43,7 +45,6 @@ import ( "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/state/blockInfoProviders" disabledState "github.com/multiversx/mx-chain-go/state/disabled" - factoryState "github.com/multiversx/mx-chain-go/state/factory" "github.com/multiversx/mx-chain-go/state/storagePruningManager" "github.com/multiversx/mx-chain-go/state/storagePruningManager/evictionWaitingList" "github.com/multiversx/mx-chain-go/state/syncer" @@ -554,17 +555,6 @@ func createShardVmContainerFactory(args scQueryElementArgs, argsHook hooks.ArgBl } func createNewAccountsAdapterApi(args scQueryElementArgs, chainHandler data.ChainHandler) (state.AccountsAdapterAPI, common.StorageManager, error) { - argsAccCreator := factoryState.ArgsAccountCreator{ - Hasher: args.coreComponents.Hasher(), - Marshaller: args.coreComponents.InternalMarshalizer(), - EnableEpochsHandler: args.coreComponents.EnableEpochsHandler(), - StateAccessesCollector: args.stateComponents.StateAccessesCollector(), - } - accountFactory, err := factoryState.NewAccountCreator(argsAccCreator) - if err != nil { - return nil, nil, err - } - storagePruning, err := newStoragePruningManager(args) if err != nil { return nil, nil, err @@ -587,20 +577,40 @@ func createNewAccountsAdapterApi(args scQueryElementArgs, chainHandler data.Chai } trieCreatorArgs := trieFactory.TrieCreateArgs{ - MainStorer: trieStorer, - PruningEnabled: args.generalConfig.StateTriesConfig.AccountsStatePruningEnabled, - MaxTrieLevelInMem: args.generalConfig.StateTriesConfig.MaxStateTrieLevelInMemory, - SnapshotsEnabled: args.generalConfig.StateTriesConfig.SnapshotsEnabled, - IdleProvider: args.coreComponents.ProcessStatusHandler(), - Identifier: dataRetriever.UserAccountsUnit.String(), - EnableEpochsHandler: args.coreComponents.EnableEpochsHandler(), - StatsCollector: args.statusCoreComponents.StateStatsHandler(), + MainStorer: trieStorer, + PruningEnabled: args.generalConfig.StateTriesConfig.AccountsStatePruningEnabled, + SnapshotsEnabled: args.generalConfig.StateTriesConfig.SnapshotsEnabled, + IdleProvider: args.coreComponents.ProcessStatusHandler(), + Identifier: dataRetriever.UserAccountsUnit.String(), + EnableEpochsHandler: args.coreComponents.EnableEpochsHandler(), + StatsCollector: args.statusCoreComponents.StateStatsHandler(), + MaxSizeInMemory: args.generalConfig.StateTriesConfig.MaxUserTrieSizeInMemory, + NumLeavesToCollapseSingleRun: args.generalConfig.StateTriesConfig.NumLeavesToCollapseSingleRun, } trieStorageManager, merkleTrie, err := trFactory.Create(trieCreatorArgs) if err != nil { return nil, nil, err } + // TODO use different dataTriesSizeInMem for accountsDbApi + dth, err := triesHolder.NewDataTriesHolder(args.generalConfig.StateTriesConfig.DataTriesSizeInMemory) + if err != nil { + return nil, nil, err + } + + argsAccCreator := factoryState.ArgsAccountCreator{ + Hasher: args.coreComponents.Hasher(), + Marshaller: args.coreComponents.InternalMarshalizer(), + EnableEpochsHandler: args.coreComponents.EnableEpochsHandler(), + StateAccessesCollector: args.stateComponents.StateAccessesCollector(), + DataTriesHolder: dth, + DataTrieCreator: merkleTrie, + } + accountFactory, err := factoryState.NewAccountCreator(argsAccCreator) + if err != nil { + return nil, nil, err + } + argsAPIAccountsDB := state.ArgsAccountsDB{ Trie: merkleTrie, Hasher: args.coreComponents.Hasher(), @@ -610,6 +620,7 @@ func createNewAccountsAdapterApi(args scQueryElementArgs, chainHandler data.Chai AddressConverter: args.coreComponents.AddressPubKeyConverter(), SnapshotsManager: disabledState.NewDisabledSnapshotsManager(), StateAccessesCollector: disabledState.NewDisabledStateAccessesCollector(), + DataTriesHolder: dth, } provider, err := blockInfoProviders.NewCurrentBlockInfo(chainHandler) diff --git a/factory/api/apiResolverFactory_test.go b/factory/api/apiResolverFactory_test.go index 79329267cf5..b783f98726a 100644 --- a/factory/api/apiResolverFactory_test.go +++ b/factory/api/apiResolverFactory_test.go @@ -37,7 +37,9 @@ import ( "github.com/stretchr/testify/require" ) -const unreachableStep = 10000 +const ( + unreachableStep = 10000 +) type failingSteps struct { marshallerStepCounter int @@ -306,7 +308,9 @@ func createMockSCQueryElementArgs() api.SCQueryElementArgs { SnapshotsGoroutineNum: 1, }, StateTriesConfig: config.StateTriesConfig{ - MaxStateTrieLevelInMemory: 5, + MaxUserTrieSizeInMemory: common.TenMbSize, + MaxPeerTrieSizeInMemory: common.TenMbSize, + DataTriesSizeInMemory: common.TenMbSize, }, VirtualMachine: config.VirtualMachineServicesConfig{ Querying: config.QueryVirtualMachineConfig{ diff --git a/factory/consensus/consensusComponents.go b/factory/consensus/consensusComponents.go index 8862fe21de6..29875f8b118 100644 --- a/factory/consensus/consensusComponents.go +++ b/factory/consensus/consensusComponents.go @@ -547,7 +547,6 @@ func (ccf *consensusComponentsFactory) createArgsBaseAccountsSyncer(trieStorageM RequestHandler: ccf.processComponents.RequestHandler(), Timeout: common.TimeoutGettingTrieNodes, Cacher: ccf.dataComponents.Datapool().TrieNodes(), - MaxTrieLevelInMemory: ccf.config.StateTriesConfig.MaxStateTrieLevelInMemory, MaxHardCapForMissingNodes: ccf.config.TrieSync.MaxHardCapForMissingNodes, TrieSyncerVersion: ccf.config.TrieSync.TrieSyncerVersion, CheckNodesOnDisk: ccf.config.TrieSync.CheckNodesOnDisk, diff --git a/factory/processing/blockProcessorCreator_test.go b/factory/processing/blockProcessorCreator_test.go index 008f5e5a4e7..28dce04a36f 100644 --- a/factory/processing/blockProcessorCreator_test.go +++ b/factory/processing/blockProcessorCreator_test.go @@ -6,8 +6,11 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/core" - "github.com/multiversx/mx-chain-core-go/hashing" - "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/factory" + factoryState "github.com/multiversx/mx-chain-go/state/factory" + "github.com/multiversx/mx-chain-go/state/triesHolder" + common2 "github.com/multiversx/mx-chain-go/testscommon/common" + "github.com/multiversx/mx-chain-go/trie/collapseManager" vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/stretchr/testify/require" @@ -19,14 +22,11 @@ import ( "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/state/accounts" disabledState "github.com/multiversx/mx-chain-go/state/disabled" - factoryState "github.com/multiversx/mx-chain-go/state/factory" "github.com/multiversx/mx-chain-go/state/storagePruningManager/disabled" "github.com/multiversx/mx-chain-go/testscommon" componentsMock "github.com/multiversx/mx-chain-go/testscommon/components" - "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/processMocks" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" - storageManager "github.com/multiversx/mx-chain-go/testscommon/storage" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" "github.com/multiversx/mx-chain-go/trie" ) @@ -91,30 +91,20 @@ func Test_newBlockProcessorCreatorForMeta(t *testing.T) { cryptoComponents := componentsMock.GetCryptoComponents(coreComponents) networkComponents := componentsMock.GetNetworkComponents(cryptoComponents) - storageManagerArgs := storageManager.GetStorageManagerArgs() + storageManagerArgs := common2.GetStorageManagerArgs() storageManagerArgs.Marshalizer = coreComponents.InternalMarshalizer() storageManagerArgs.Hasher = coreComponents.Hasher() - storageManagerUser, _ := trie.CreateTrieStorageManager(storageManagerArgs, storageManager.GetStorageManagerOptions()) + storageManagerUser, _ := trie.CreateTrieStorageManager(storageManagerArgs, common2.GetStorageManagerOptions()) storageManagerArgs.MainStorer = mock.NewMemDbMock() - storageManagerPeer, _ := trie.CreateTrieStorageManager(storageManagerArgs, storageManager.GetStorageManagerOptions()) + storageManagerPeer, _ := trie.CreateTrieStorageManager(storageManagerArgs, common2.GetStorageManagerOptions()) trieStorageManagers := make(map[string]common.StorageManager) trieStorageManagers[dataRetriever.UserAccountsUnit.String()] = storageManagerUser trieStorageManagers[dataRetriever.PeerAccountsUnit.String()] = storageManagerPeer - argsAccCreator := factoryState.ArgsAccountCreator{ - Hasher: coreComponents.Hasher(), - Marshaller: coreComponents.InternalMarshalizer(), - EnableEpochsHandler: coreComponents.EnableEpochsHandler(), - StateAccessesCollector: &stateMock.StateAccessesCollectorStub{}, - } - accCreator, _ := factoryState.NewAccountCreator(argsAccCreator) - adb, err := createAccountAdapter( - &mock.MarshalizerMock{}, - &hashingMocks.HasherMock{}, - accCreator, + coreComponents, trieStorageManagers[dataRetriever.UserAccountsUnit.String()], coreComponents.EnableEpochsHandler(), ) @@ -203,26 +193,37 @@ func Test_newBlockProcessorCreatorForMeta(t *testing.T) { } func createAccountAdapter( - marshaller marshal.Marshalizer, - hasher hashing.Hasher, - accountFactory state.AccountFactory, + coreComponents factory.CoreComponentsHolder, trieStorage common.StorageManager, handler common.EnableEpochsHandler, ) (state.AccountsAdapter, error) { - tr, err := trie.NewTrie(trieStorage, marshaller, hasher, handler, 5) + tr, err := trie.NewTrie(trieStorage, coreComponents.InternalMarshalizer(), coreComponents.Hasher(), handler, collapseManager.NewDisabledCollapseManager()) if err != nil { return nil, err } + dth, _ := triesHolder.NewDataTriesHolder(common.TenMbSize) + + argsAccCreator := factoryState.ArgsAccountCreator{ + Hasher: coreComponents.Hasher(), + Marshaller: coreComponents.InternalMarshalizer(), + EnableEpochsHandler: coreComponents.EnableEpochsHandler(), + StateAccessesCollector: &stateMock.StateAccessesCollectorStub{}, + DataTriesHolder: dth, + DataTrieCreator: tr, + } + accCreator, _ := factoryState.NewAccountCreator(argsAccCreator) + args := state.ArgsAccountsDB{ Trie: tr, - Hasher: hasher, - Marshaller: marshaller, - AccountFactory: accountFactory, + Hasher: coreComponents.Hasher(), + Marshaller: coreComponents.InternalMarshalizer(), + AccountFactory: accCreator, StoragePruningManager: disabled.NewDisabledStoragePruningManager(), AddressConverter: &testscommon.PubkeyConverterMock{}, SnapshotsManager: disabledState.NewDisabledSnapshotsManager(), StateAccessesCollector: disabledState.NewDisabledStateAccessesCollector(), + DataTriesHolder: dth, } adb, err := state.NewAccountsDB(args) if err != nil { diff --git a/factory/processing/processComponents.go b/factory/processing/processComponents.go index f5050230e1c..d4e30c29e0b 100644 --- a/factory/processing/processComponents.go +++ b/factory/processing/processComponents.go @@ -1952,7 +1952,6 @@ func (pcf *processComponentsFactory) createExportFactoryHandler( ExportTriesStorageConfig: hardforkConfig.ExportTriesStorageConfig, ExportStateStorageConfig: hardforkConfig.ExportStateStorageConfig, ExportStateKeysConfig: hardforkConfig.ExportKeysStorageConfig, - MaxTrieLevelInMemory: pcf.config.StateTriesConfig.MaxStateTrieLevelInMemory, WhiteListHandler: pcf.whiteListHandler, WhiteListerVerifiedTxs: pcf.whiteListerVerifiedTxs, MainInterceptorsContainer: mainInterceptorsContainer, diff --git a/factory/state/stateComponents.go b/factory/state/stateComponents.go index 0a2e4f278df..e16e7e4c68b 100644 --- a/factory/state/stateComponents.go +++ b/factory/state/stateComponents.go @@ -7,6 +7,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" chainData "github.com/multiversx/mx-chain-core-go/data" data "github.com/multiversx/mx-chain-core-go/data/stateChange" + "github.com/multiversx/mx-chain-go/state/triesHolder" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" @@ -220,23 +221,30 @@ func (scf *stateComponentsFactory) createSnapshotManager( } func (scf *stateComponentsFactory) createAccountsAdapters(triesContainer common.TriesHolder, StateAccessesCollector state.StateAccessesCollector) (*accountsAdapterCreationResult, error) { + merkleTrie := triesContainer.Get([]byte(dataRetriever.UserAccountsUnit.String())) + storagePruning, err := scf.newStoragePruningManager() + if err != nil { + return nil, err + } + + dth, err := triesHolder.NewDataTriesHolder(scf.config.StateTriesConfig.DataTriesSizeInMemory) + if err != nil { + return nil, fmt.Errorf("failed to create data tries holder: %w", err) + } + argsAccCreator := factoryState.ArgsAccountCreator{ Hasher: scf.core.Hasher(), Marshaller: scf.core.InternalMarshalizer(), EnableEpochsHandler: scf.core.EnableEpochsHandler(), StateAccessesCollector: StateAccessesCollector, + DataTriesHolder: dth, + DataTrieCreator: merkleTrie, } accountFactory, err := factoryState.NewAccountCreator(argsAccCreator) if err != nil { return nil, err } - merkleTrie := triesContainer.Get([]byte(dataRetriever.UserAccountsUnit.String())) - storagePruning, err := scf.newStoragePruningManager() - if err != nil { - return nil, err - } - argStateMetrics := stateMetrics.ArgsStateMetrics{ SnapshotInProgressKey: common.MetricAccountsSnapshotInProgress, LastSnapshotDurationKey: common.MetricLastAccountsSnapshotDurationSec, @@ -262,17 +270,26 @@ func (scf *stateComponentsFactory) createAccountsAdapters(triesContainer common. SnapshotsManager: snapshotsManager, StateAccessesCollector: StateAccessesCollector, PruningEnabled: scf.config.StateTriesConfig.AccountsStatePruningEnabled, + DataTriesHolder: dth, } accountsAdapter, err := state.NewAccountsDB(argsProcessingAccountsDB) if err != nil { return nil, fmt.Errorf("%w: %s", errors.ErrAccountsAdapterCreation, err.Error()) } + // TODO use different size in apiAccountsDB for DataTriesSizeInMemory + apiDataTriesHolder, err := triesHolder.NewDataTriesHolder(scf.config.StateTriesConfig.DataTriesSizeInMemory) + if err != nil { + return nil, fmt.Errorf("%w: %s", errors.ErrAccountsAdapterCreation, err.Error()) + } + argsAPIAccCreator := factoryState.ArgsAccountCreator{ Hasher: scf.core.Hasher(), Marshaller: scf.core.InternalMarshalizer(), EnableEpochsHandler: scf.core.EnableEpochsHandler(), StateAccessesCollector: disabled.NewDisabledStateAccessesCollector(), + DataTriesHolder: apiDataTriesHolder, + DataTrieCreator: merkleTrie, } accountFactoryAPI, err := factoryState.NewAccountCreator(argsAPIAccCreator) if err != nil { @@ -288,6 +305,7 @@ func (scf *stateComponentsFactory) createAccountsAdapters(triesContainer common. AddressConverter: scf.core.AddressPubKeyConverter(), SnapshotsManager: disabled.NewDisabledSnapshotsManager(), StateAccessesCollector: disabled.NewDisabledStateAccessesCollector(), + DataTriesHolder: apiDataTriesHolder, } accountsAdapterApiOnFinal, err := factoryState.CreateAccountsAdapterAPIOnFinal(argsAPIAccountsDB, scf.chainHandler) @@ -340,6 +358,12 @@ func (scf *stateComponentsFactory) createPeerAdapter(triesContainer common.Tries return nil, err } + // TODO check if this is needed for peer accounts holder + dataTriesHolder, err := triesHolder.NewDataTriesHolder(scf.config.StateTriesConfig.DataTriesSizeInMemory) + if err != nil { + return nil, fmt.Errorf("failed to create data tries holder: %w", err) + } + argStateMetrics := stateMetrics.ArgsStateMetrics{ SnapshotInProgressKey: common.MetricPeersSnapshotInProgress, LastSnapshotDurationKey: common.MetricLastPeersSnapshotDurationSec, @@ -366,6 +390,7 @@ func (scf *stateComponentsFactory) createPeerAdapter(triesContainer common.Tries SnapshotsManager: snapshotManager, StateAccessesCollector: disabled.NewDisabledStateAccessesCollector(), PruningEnabled: scf.config.StateTriesConfig.PeerStatePruningEnabled, + DataTriesHolder: dataTriesHolder, } peerAdapter, err := state.NewPeerAccountsDB(argsProcessingPeerAccountsDB) if err != nil { diff --git a/genesis/mock/userAccountMock.go b/genesis/mock/userAccountMock.go index 28ef9e7c966..027423e527e 100644 --- a/genesis/mock/userAccountMock.go +++ b/genesis/mock/userAccountMock.go @@ -64,15 +64,15 @@ func (uam *UserAccountMock) SetRootHash(bytes []byte) { uam.rootHash = bytes } +// SetDataTrieRootHash - +func (uam *UserAccountMock) SetDataTrieRootHash() { +} + // GetRootHash - func (uam *UserAccountMock) GetRootHash() []byte { return uam.rootHash } -// SetDataTrie - -func (uam *UserAccountMock) SetDataTrie(_ common.Trie) { -} - // DataTrie - func (uam *UserAccountMock) DataTrie() common.DataTrieHandler { return nil diff --git a/genesis/process/genesisBlockCreator.go b/genesis/process/genesisBlockCreator.go index ac1ddd405c5..b116b0ffb14 100644 --- a/genesis/process/genesisBlockCreator.go +++ b/genesis/process/genesisBlockCreator.go @@ -27,8 +27,6 @@ import ( "github.com/multiversx/mx-chain-go/process/smartContract/hooks" "github.com/multiversx/mx-chain-go/process/smartContract/hooks/counters" "github.com/multiversx/mx-chain-go/sharding" - disabledState "github.com/multiversx/mx-chain-go/state/disabled" - factoryState "github.com/multiversx/mx-chain-go/state/factory" "github.com/multiversx/mx-chain-go/state/syncer" "github.com/multiversx/mx-chain-go/statusHandler" "github.com/multiversx/mx-chain-go/storage" @@ -502,24 +500,10 @@ func (gbc *genesisBlockCreator) getNewArgForShard(shardID uint32) (ArgsGenesisBl return newArgument, nil } - argsAccCreator := factoryState.ArgsAccountCreator{ - Hasher: newArgument.Core.Hasher(), - Marshaller: newArgument.Core.InternalMarshalizer(), - EnableEpochsHandler: newArgument.Core.EnableEpochsHandler(), - StateAccessesCollector: disabledState.NewDisabledStateAccessesCollector(), - } - accCreator, err := factoryState.NewAccountCreator(argsAccCreator) - if err != nil { - return ArgsGenesisBlockCreator{}, err - } - newArgument.Accounts, err = createAccountAdapter( - newArgument.Core.InternalMarshalizer(), - newArgument.Core.Hasher(), - accCreator, + newArgument.Core, gbc.arg.TrieStorageManagers[dataRetriever.UserAccountsUnit.String()], gbc.arg.Core.AddressPubKeyConverter(), - newArgument.Core.EnableEpochsHandler(), ) if err != nil { return ArgsGenesisBlockCreator{}, fmt.Errorf("'%w' while generating an in-memory accounts adapter for shard %d", diff --git a/genesis/process/genesisBlockCreator_test.go b/genesis/process/genesisBlockCreator_test.go index d5dada9caec..a725f6d67bf 100644 --- a/genesis/process/genesisBlockCreator_test.go +++ b/genesis/process/genesisBlockCreator_test.go @@ -29,7 +29,6 @@ import ( "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/state/accounts" - factoryState "github.com/multiversx/mx-chain-go/state/factory" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" commonMocks "github.com/multiversx/mx-chain-go/testscommon/common" @@ -57,8 +56,8 @@ func createMockArgument( entireSupply *big.Int, ) ArgsGenesisBlockCreator { - storageManagerArgs := storageCommon.GetStorageManagerArgs() - storageManager, _ := trie.CreateTrieStorageManager(storageManagerArgs, storageCommon.GetStorageManagerOptions()) + storageManagerArgs := commonMocks.GetStorageManagerArgs() + storageManager, _ := trie.CreateTrieStorageManager(storageManagerArgs, commonMocks.GetStorageManagerOptions()) trieStorageManagers := make(map[string]common.StorageManager) trieStorageManagers[dataRetriever.UserAccountsUnit.String()] = storageManager @@ -213,25 +212,14 @@ func createMockArgument( SelfShardId: 0, } - argsAccCreator := factoryState.ArgsAccountCreator{ - Hasher: &hashingMocks.HasherMock{}, - Marshaller: &mock.MarshalizerMock{}, - EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - StateAccessesCollector: &stateMock.StateAccessesCollectorStub{}, - } - accCreator, err := factoryState.NewAccountCreator(argsAccCreator) - require.Nil(t, err) - - arg.Accounts, err = createAccountAdapter( - &mock.MarshalizerMock{}, - &hashingMocks.HasherMock{}, - accCreator, + acc, err := createAccountAdapter( + arg.Core, trieStorageManagers[dataRetriever.UserAccountsUnit.String()], &testscommon.PubkeyConverterMock{}, - &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) require.Nil(t, err) - arg.AccountsProposal = arg.Accounts + arg.Accounts = acc + arg.AccountsProposal = acc arg.ValidatorAccounts = &stateMock.AccountsStub{ RootHashCalled: func() ([]byte, error) { diff --git a/genesis/process/memoryComponents.go b/genesis/process/memoryComponents.go index 91c32211740..91f2f641f55 100644 --- a/genesis/process/memoryComponents.go +++ b/genesis/process/memoryComponents.go @@ -2,39 +2,54 @@ package process import ( "github.com/multiversx/mx-chain-core-go/core" - "github.com/multiversx/mx-chain-core-go/hashing" - "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/state" disabledState "github.com/multiversx/mx-chain-go/state/disabled" + factoryState "github.com/multiversx/mx-chain-go/state/factory" "github.com/multiversx/mx-chain-go/state/storagePruningManager/disabled" + "github.com/multiversx/mx-chain-go/state/triesHolder" "github.com/multiversx/mx-chain-go/trie" + "github.com/multiversx/mx-chain-go/trie/collapseManager" ) -const maxTrieLevelInMemory = uint(5) - func createAccountAdapter( - marshaller marshal.Marshalizer, - hasher hashing.Hasher, - accountFactory state.AccountFactory, + coreComp coreComponentsHandler, trieStorage common.StorageManager, addressConverter core.PubkeyConverter, - enableEpochsHandler common.EnableEpochsHandler, ) (state.AccountsAdapter, error) { - tr, err := trie.NewTrie(trieStorage, marshaller, hasher, enableEpochsHandler, maxTrieLevelInMemory) + tr, err := trie.NewTrie(trieStorage, coreComp.InternalMarshalizer(), coreComp.Hasher(), coreComp.EnableEpochsHandler(), collapseManager.NewDisabledCollapseManager()) + if err != nil { + return nil, err + } + + dth, err := triesHolder.NewDataTriesHolder(common.TenMbSize) + if err != nil { + return nil, err + } + + argsAccCreator := factoryState.ArgsAccountCreator{ + Hasher: coreComp.Hasher(), + Marshaller: coreComp.InternalMarshalizer(), + EnableEpochsHandler: coreComp.EnableEpochsHandler(), + StateAccessesCollector: disabledState.NewDisabledStateAccessesCollector(), + DataTriesHolder: dth, + DataTrieCreator: tr, + } + accCreator, err := factoryState.NewAccountCreator(argsAccCreator) if err != nil { return nil, err } args := state.ArgsAccountsDB{ Trie: tr, - Hasher: hasher, - Marshaller: marshaller, - AccountFactory: accountFactory, + Hasher: coreComp.Hasher(), + Marshaller: coreComp.InternalMarshalizer(), + AccountFactory: accCreator, StoragePruningManager: disabled.NewDisabledStoragePruningManager(), AddressConverter: addressConverter, SnapshotsManager: disabledState.NewDisabledSnapshotsManager(), StateAccessesCollector: disabledState.NewDisabledStateAccessesCollector(), + DataTriesHolder: dth, } adb, err := state.NewAccountsDB(args) diff --git a/integrationTests/benchmarks/loadFromTrie_test.go b/integrationTests/benchmarks/loadFromTrie_test.go index 8b2d2736b1a..2adbdb85dda 100644 --- a/integrationTests/benchmarks/loadFromTrie_test.go +++ b/integrationTests/benchmarks/loadFromTrie_test.go @@ -15,9 +15,10 @@ import ( "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/storage/database" "github.com/multiversx/mx-chain-go/storage/storageunit" + testStorage "github.com/multiversx/mx-chain-go/testscommon/common" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" - testStorage "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/trie" + "github.com/multiversx/mx-chain-go/trie/collapseManager" "github.com/stretchr/testify/require" ) @@ -101,7 +102,7 @@ func generateTriesWithMaxDepth( ) []*keyForTrie { tries := make([]*keyForTrie, numTries) for i := 0; i < numTries; i++ { - tr, _ := trie.NewTrie(storage, marshaller, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 2) + tr, _ := trie.NewTrie(storage, marshaller, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, collapseManager.NewDisabledCollapseManager()) key := insertKeysIntoTrie(t, tr, numTrieLevels, numChildrenPerBranch) rootHash, _ := tr.RootHash() diff --git a/integrationTests/longTests/storage/storage_test.go b/integrationTests/longTests/storage/storage_test.go index bea274856d8..3da4b1060c9 100644 --- a/integrationTests/longTests/storage/storage_test.go +++ b/integrationTests/longTests/storage/storage_test.go @@ -9,9 +9,10 @@ import ( "github.com/multiversx/mx-chain-core-go/hashing/blake2b" "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/integrationTests" + testCommon "github.com/multiversx/mx-chain-go/testscommon/common" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" - "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/trie" + "github.com/multiversx/mx-chain-go/trie/collapseManager" "github.com/stretchr/testify/assert" ) @@ -106,18 +107,17 @@ func TestWriteContinuouslyInTree(t *testing.T) { nbTxsWrite := 1000000 testStorage := integrationTests.NewTestStorage() store := testStorage.CreateStorageLevelDB() - storageManagerArgs := storage.GetStorageManagerArgs() + storageManagerArgs := testCommon.GetStorageManagerArgs() storageManagerArgs.MainStorer = store storageManagerArgs.Marshalizer = &marshal.JsonMarshalizer{} storageManagerArgs.Hasher = blake2b.NewBlake2b() - options := storage.GetStorageManagerOptions() + options := testCommon.GetStorageManagerOptions() options.PruningEnabled = false trieStorage, _ := trie.CreateTrieStorageManager(storageManagerArgs, options) - maxTrieLevelInMemory := uint(5) - tr, _ := trie.NewTrie(trieStorage, &marshal.JsonMarshalizer{}, blake2b.NewBlake2b(), &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, maxTrieLevelInMemory) + tr, _ := trie.NewTrie(trieStorage, &marshal.JsonMarshalizer{}, blake2b.NewBlake2b(), &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, collapseManager.NewDisabledCollapseManager()) defer func() { _ = store.DestroyUnit() diff --git a/integrationTests/multiShard/hardFork/hardFork_test.go b/integrationTests/multiShard/hardFork/hardFork_test.go index 234dd81f063..06477abd3a2 100644 --- a/integrationTests/multiShard/hardFork/hardFork_test.go +++ b/integrationTests/multiShard/hardFork/hardFork_test.go @@ -656,7 +656,6 @@ func createHardForkExporter( }, ExportStateStorageConfig: exportConfig, ExportStateKeysConfig: keysConfig, - MaxTrieLevelInMemory: uint(5), WhiteListHandler: node.WhiteListHandler, WhiteListerVerifiedTxs: node.WhiteListerVerifiedTxs, MainInterceptorsContainer: node.MainInterceptorsContainer, diff --git a/integrationTests/state/stateTrie/stateTrie_test.go b/integrationTests/state/stateTrie/stateTrie_test.go index befb38a5d57..b3108e8b405 100644 --- a/integrationTests/state/stateTrie/stateTrie_test.go +++ b/integrationTests/state/stateTrie/stateTrie_test.go @@ -25,7 +25,10 @@ import ( "github.com/multiversx/mx-chain-core-go/hashing/sha256" crypto "github.com/multiversx/mx-chain-crypto-go" "github.com/multiversx/mx-chain-go/epochStart/notifier" + "github.com/multiversx/mx-chain-go/state/triesHolder" + testCommon "github.com/multiversx/mx-chain-go/testscommon/common" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" + "github.com/multiversx/mx-chain-go/trie/collapseManager" "github.com/multiversx/mx-chain-storage-go/types" vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/stretchr/testify/assert" @@ -54,7 +57,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" - testStorage "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/trie" ) @@ -271,13 +273,12 @@ func TestAccountsDB_CommitTwoOkAccountsShouldWork(t *testing.T) { func TestTrieDB_RecreateFromStorageShouldWork(t *testing.T) { hasher := integrationTests.TestHasher store := integrationTests.CreateMemUnit() - args := testStorage.GetStorageManagerArgs() + args := testCommon.GetStorageManagerArgs() args.MainStorer = store args.Hasher = hasher trieStorage, _ := trie.NewTrieStorageManager(args) - maxTrieLevelInMemory := uint(5) - tr1, _ := trie.NewTrie(trieStorage, integrationTests.TestMarshalizer, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, maxTrieLevelInMemory) + tr1, _ := trie.NewTrie(trieStorage, integrationTests.TestMarshalizer, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, collapseManager.NewDisabledCollapseManager()) key := hasher.Compute("key") value := hasher.Compute("value") @@ -1056,17 +1057,19 @@ func createAccounts( HashesSize: evictionWaitListSize * 100, } ewl, _ := evictionWaitingList.NewMemoryEvictionWaitingList(ewlArgs) - args := testStorage.GetStorageManagerArgs() + args := testCommon.GetStorageManagerArgs() args.MainStorer = store trieStorage, _ := trie.NewTrieStorageManager(args) - maxTrieLevelInMemory := uint(5) - tr, _ := trie.NewTrie(trieStorage, integrationTests.TestMarshalizer, integrationTests.TestHasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, maxTrieLevelInMemory) + tr, _ := trie.NewTrie(trieStorage, integrationTests.TestMarshalizer, integrationTests.TestHasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, collapseManager.NewDisabledCollapseManager()) spm, _ := storagePruningManager.NewStoragePruningManager(ewl, 10) + dth, _ := triesHolder.NewDataTriesHolder(integrationTests.TenMbSize) argsAccCreator := factory.ArgsAccountCreator{ Hasher: integrationTests.TestHasher, Marshaller: integrationTests.TestMarshalizer, EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, StateAccessesCollector: &stateMock.StateAccessesCollectorStub{}, + DataTriesHolder: dth, + DataTrieCreator: tr, } accCreator, _ := factory.NewAccountCreator(argsAccCreator) snapshotsManager, _ := state.NewSnapshotsManager(state.ArgsNewSnapshotsManager{ @@ -1089,6 +1092,7 @@ func createAccounts( AddressConverter: &testscommon.PubkeyConverterMock{}, SnapshotsManager: snapshotsManager, StateAccessesCollector: &stateMock.StateAccessesCollectorStub{}, + DataTriesHolder: dth, } adb, _ := state.NewAccountsDB(argsAccountsDB) @@ -1276,8 +1280,11 @@ func TestTrieDbPruning_GetDataTrieTrackerAfterPruning(t *testing.T) { require.NotNil(t, acc) require.Nil(t, err) - collapseTrie(state1, t) - collapseTrie(userAccount, t) + acc1, _ = adb.LoadAccount(address1) + state1 = acc1.(state.UserAccountHandler) + + acc2, _ = adb.LoadAccount(address2) + userAccount = acc2.(state.UserAccountHandler) val, _, err := state1.RetrieveValue(key1) require.Nil(t, err) @@ -1288,15 +1295,6 @@ func TestTrieDbPruning_GetDataTrieTrackerAfterPruning(t *testing.T) { require.Equal(t, value1, val) } -func collapseTrie(state state.UserAccountHandler, t *testing.T) { - stateRootHash := state.GetRootHash() - stateTrie := state.DataTrie().(common.Trie) - stateNewTrie, _ := stateTrie.Recreate(holders.NewDefaultRootHashesHolder(stateRootHash)) - require.NotNil(t, stateNewTrie) - - state.SetDataTrie(stateNewTrie) -} - func TestRollbackBlockAndCheckThatPruningIsCancelledOnAccountsTrie(t *testing.T) { if testing.Short() { t.Skip("this is not a short test") @@ -2733,17 +2731,19 @@ func createAccountsDBTestSetup() *state.AccountsDB { HashesSize: evictionWaitListSize * 100, } ewl, _ := evictionWaitingList.NewMemoryEvictionWaitingList(ewlArgs) - args := testStorage.GetStorageManagerArgs() + args := testCommon.GetStorageManagerArgs() args.GeneralConfig = generalCfg trieStorage, _ := trie.NewTrieStorageManager(args) - maxTrieLevelInMemory := uint(5) - tr, _ := trie.NewTrie(trieStorage, integrationTests.TestMarshalizer, integrationTests.TestHasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, maxTrieLevelInMemory) + tr, _ := trie.NewTrie(trieStorage, integrationTests.TestMarshalizer, integrationTests.TestHasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, collapseManager.NewDisabledCollapseManager()) spm, _ := storagePruningManager.NewStoragePruningManager(ewl, 10) + dth, _ := triesHolder.NewDataTriesHolder(integrationTests.TenMbSize) argsAccCreator := factory.ArgsAccountCreator{ Hasher: integrationTests.TestHasher, Marshaller: integrationTests.TestMarshalizer, EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, StateAccessesCollector: &stateMock.StateAccessesCollectorStub{}, + DataTriesHolder: dth, + DataTrieCreator: tr, } accCreator, _ := factory.NewAccountCreator(argsAccCreator) @@ -2768,6 +2768,7 @@ func createAccountsDBTestSetup() *state.AccountsDB { AddressConverter: &testscommon.PubkeyConverterMock{}, SnapshotsManager: snapshotsManager, StateAccessesCollector: &stateMock.StateAccessesCollectorStub{}, + DataTriesHolder: dth, } adb, _ := state.NewAccountsDB(argsAccountsDB) @@ -2789,11 +2790,10 @@ func TestStateSnapshot_MultipleEpochsWithoutCompleteSnapshot(t *testing.T) { numEpochs := 5 getCounters := make([]int, numEpochs) - args := testStorage.GetStorageManagerArgs() + args := testCommon.GetStorageManagerArgs() args.MainStorer = mainStorer trieStorage, _ := trie.NewTrieStorageManager(args) - maxTrieLevelInMemory := uint(5) - tr, _ := trie.NewTrie(trieStorage, args.Marshalizer, args.Hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, maxTrieLevelInMemory) + tr, _ := trie.NewTrie(trieStorage, args.Marshalizer, args.Hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, collapseManager.NewDisabledCollapseManager()) defer func() { _ = trieStorage.Close() }() diff --git a/integrationTests/state/stateTrieClose/stateTrieClose_test.go b/integrationTests/state/stateTrieClose/stateTrieClose_test.go index 9d99a178484..2b062184d51 100644 --- a/integrationTests/state/stateTrieClose/stateTrieClose_test.go +++ b/integrationTests/state/stateTrieClose/stateTrieClose_test.go @@ -12,10 +12,11 @@ import ( "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/state/parsers" + testCommon "github.com/multiversx/mx-chain-go/testscommon/common" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/goroutines" - "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/trie" + "github.com/multiversx/mx-chain-go/trie/collapseManager" "github.com/multiversx/mx-chain-go/trie/keyBuilder" "github.com/stretchr/testify/assert" ) @@ -23,7 +24,7 @@ import ( func TestPatriciaMerkleTrie_Close(t *testing.T) { numLeavesToAdd := 200 trieStorage, _ := integrationTests.CreateTrieStorageManager(integrationTests.CreateMemUnit()) - tr, _ := trie.NewTrie(trieStorage, integrationTests.TestMarshalizer, integrationTests.TestHasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 5) + tr, _ := trie.NewTrie(trieStorage, integrationTests.TestMarshalizer, integrationTests.TestHasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, collapseManager.NewDisabledCollapseManager()) for i := 0; i < numLeavesToAdd; i++ { _ = tr.Update([]byte(strconv.Itoa(i)), []byte(strconv.Itoa(i))) @@ -139,7 +140,7 @@ func TestPatriciaMerkleTrie_Close(t *testing.T) { } func TestTrieStorageManager_Close(t *testing.T) { - args := storage.GetStorageManagerArgs() + args := testCommon.GetStorageManagerArgs() gc := goroutines.NewGoCounter(goroutines.TestsRelevantGoRoutines) idxInitial, _ := gc.Snapshot() diff --git a/integrationTests/state/stateTrieSync/stateTrieSync_test.go b/integrationTests/state/stateTrieSync/stateTrieSync_test.go index cd416f663c5..2c7729f9058 100644 --- a/integrationTests/state/stateTrieSync/stateTrieSync_test.go +++ b/integrationTests/state/stateTrieSync/stateTrieSync_test.go @@ -609,7 +609,6 @@ func getUserAccountSyncerArgs(node *integrationTests.TestProcessorNode, version RequestHandler: node.RequestHandler, Timeout: common.TimeoutGettingTrieNodes, Cacher: node.DataPool.TrieNodes(), - MaxTrieLevelInMemory: 200, MaxHardCapForMissingNodes: 5000, TrieSyncerVersion: version, UserAccountsSyncStatisticsHandler: statistics.NewTrieSyncStatistics(), diff --git a/integrationTests/testInitializer.go b/integrationTests/testInitializer.go index 8e1d8969632..67c8b0f013d 100644 --- a/integrationTests/testInitializer.go +++ b/integrationTests/testInitializer.go @@ -29,6 +29,9 @@ import ( "github.com/multiversx/mx-chain-crypto-go/signing/ed25519" "github.com/multiversx/mx-chain-crypto-go/signing/mcl" "github.com/multiversx/mx-chain-crypto-go/signing/secp256k1" + "github.com/multiversx/mx-chain-go/state/triesHolder" + trieTestComponents "github.com/multiversx/mx-chain-go/testscommon/trie" + "github.com/multiversx/mx-chain-go/trie/collapseManager" logger "github.com/multiversx/mx-chain-logger-go" wasmConfig "github.com/multiversx/mx-chain-vm-go/config" "github.com/pkg/errors" @@ -81,7 +84,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/stakingcommon" testStorage "github.com/multiversx/mx-chain-go/testscommon/state" statusHandlerMock "github.com/multiversx/mx-chain-go/testscommon/statusHandler" - testcommonStorage "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/testscommon/txDataBuilder" "github.com/multiversx/mx-chain-go/trie" "github.com/multiversx/mx-chain-go/vm" @@ -104,6 +106,9 @@ var InitialRating = uint32(50) // AdditionalGasLimit is the value that can be added on a transaction in the GasLimit var AdditionalGasLimit = uint64(999000) +// TenMbSize represents 10 MB in bytes +const TenMbSize = uint64(10485760) + // GasSchedulePath -- const GasSchedulePath = "../../../../cmd/node/config/gasSchedules/gasScheduleV4.toml" @@ -114,7 +119,6 @@ const ( shuffleBetweenShards = false adaptivity = false hysteresis = float32(0.2) - maxTrieLevelInMemory = uint(5) delegationContractsList = "delegationContracts" ) @@ -431,7 +435,7 @@ func CreateTrieStorageManagerWithPruningStorer(coordinator sharding.Coordinator, fmt.Println("err creating main storer" + err.Error()) } - args := testcommonStorage.GetStorageManagerArgs() + args := commonMocks.GetStorageManagerArgs() args.MainStorer = mainStorer args.Marshalizer = TestMarshalizer args.Hasher = TestHasher @@ -443,7 +447,7 @@ func CreateTrieStorageManagerWithPruningStorer(coordinator sharding.Coordinator, // CreateTrieStorageManager creates the trie storage manager for the tests func CreateTrieStorageManager(store storage.Storer) (common.StorageManager, storage.Storer) { - args := testcommonStorage.GetStorageManagerArgs() + args := commonMocks.GetStorageManagerArgs() args.MainStorer = store args.Marshalizer = TestMarshalizer args.Hasher = TestHasher @@ -467,14 +471,14 @@ func CreateAccountsDBWithEnableEpochsHandler( trieStorageManager common.StorageManager, enableEpochsHandler common.EnableEpochsHandler, ) (*state.AccountsDB, common.Trie) { - tr, _ := trie.NewTrie(trieStorageManager, TestMarshalizer, TestHasher, enableEpochsHandler, maxTrieLevelInMemory) + tr, _ := trie.NewTrie(trieStorageManager, TestMarshalizer, TestHasher, enableEpochsHandler, collapseManager.NewDisabledCollapseManager()) ewlArgs := evictionWaitingList.MemoryEvictionWaitingListArgs{ RootHashesSize: 100, HashesSize: 10000, } ewl, _ := evictionWaitingList.NewMemoryEvictionWaitingList(ewlArgs) - accountFactory, _ := getAccountFactory(accountType, enableEpochsHandler) + accountFactory, dth, _ := getAccountFactory(accountType, enableEpochsHandler, tr) spm, _ := storagePruningManager.NewStoragePruningManager(ewl, 10) snapshotsManager, _ := state.NewSnapshotsManager(state.ArgsNewSnapshotsManager{ @@ -500,26 +504,43 @@ func CreateAccountsDBWithEnableEpochsHandler( SnapshotsManager: snapshotsManager, StateAccessesCollector: disabled.NewDisabledStateAccessesCollector(), PruningEnabled: trieStorageManager.IsPruningEnabled(), + DataTriesHolder: dth, } adb, _ := state.NewAccountsDB(args) return adb, tr } -func getAccountFactory(accountType Type, enableEpochsHandler common.EnableEpochsHandler) (state.AccountFactory, error) { +func getAccountFactory( + accountType Type, + enableEpochsHandler common.EnableEpochsHandler, + tr common.Trie, +) (state.AccountFactory, common.TriesHolder, error) { switch accountType { case UserAccount: + dth, err := triesHolder.NewDataTriesHolder(TenMbSize) + if err != nil { + return nil, nil, err + } + argsAccCreator := factory.ArgsAccountCreator{ Hasher: TestHasher, Marshaller: TestMarshalizer, EnableEpochsHandler: enableEpochsHandler, StateAccessesCollector: disabled.NewDisabledStateAccessesCollector(), + DataTriesHolder: dth, + DataTrieCreator: tr, } - return factory.NewAccountCreator(argsAccCreator) + accCreator, err := factory.NewAccountCreator(argsAccCreator) + if err != nil { + return nil, nil, err + } + + return accCreator, dth, nil case ValidatorAccount: - return factory.NewPeerAccountCreator(), nil + return factory.NewPeerAccountCreator(), &trieTestComponents.TriesHolderStub{}, nil default: - return nil, fmt.Errorf("invalid account type provided") + return nil, nil, fmt.Errorf("invalid account type provided") } } @@ -1056,7 +1077,16 @@ func GenerateAddressJournalAccountAccountsDB() ([]byte, state.UserAccountHandler adb, _ := CreateAccountsDB(UserAccount, trieStorage) dtlp, _ := parsers.NewDataTrieLeafParser(adr, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) - dtt, _ := trackableDataTrie.NewTrackableDataTrie(adr, &testscommon.HasherStub{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, disabled.NewDisabledStateAccessesCollector()) + args := trackableDataTrie.TrackableDataTrieArgs{ + Identifier: adr, + Hasher: &testscommon.HasherStub{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + StateAccessesCollector: disabled.NewDisabledStateAccessesCollector(), + DataTriesHolder: &trieTestComponents.TriesHolderStub{}, + DataTrieCreator: &trieTestComponents.TrieStub{}, + } + dtt, _ := trackableDataTrie.NewTrackableDataTrie(args) account, _ := accounts.NewUserAccount(adr, dtt, dtlp) @@ -1158,13 +1188,13 @@ func CreateSimpleTxProcessor(accnts state.AccountsAdapter) process.TransactionPr // CreateNewDefaultTrie returns a new trie with test hasher and marsahalizer func CreateNewDefaultTrie() common.Trie { - args := testcommonStorage.GetStorageManagerArgs() + args := commonMocks.GetStorageManagerArgs() args.Marshalizer = TestMarshalizer args.Hasher = TestHasher trieStorage, _ := trie.NewTrieStorageManager(args) - tr, _ := trie.NewTrie(trieStorage, TestMarshalizer, TestHasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, maxTrieLevelInMemory) + tr, _ := trie.NewTrie(trieStorage, TestMarshalizer, TestHasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, collapseManager.NewDisabledCollapseManager()) return tr } diff --git a/integrationTests/testProcessorNode.go b/integrationTests/testProcessorNode.go index b63908e027a..852192f185d 100644 --- a/integrationTests/testProcessorNode.go +++ b/integrationTests/testProcessorNode.go @@ -31,6 +31,7 @@ import ( ed25519SingleSig "github.com/multiversx/mx-chain-crypto-go/signing/ed25519/singlesig" "github.com/multiversx/mx-chain-crypto-go/signing/mcl" mclsig "github.com/multiversx/mx-chain-crypto-go/signing/mcl/singlesig" + "github.com/multiversx/mx-chain-go/state/triesHolder" vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/multiversx/mx-chain-vm-common-go/parsers" wasmConfig "github.com/multiversx/mx-chain-vm-go/config" @@ -705,7 +706,7 @@ func (tpn *TestProcessorNode) initAccountDBsWithPruningStorer() { tpn.EpochStartNotifier = notifier.NewEpochStartSubscriptionHandler() } trieStorageManager := CreateTrieStorageManagerWithPruningStorer(tpn.ShardCoordinator, tpn.EpochStartNotifier) - tpn.TrieContainer = state.NewDataTriesHolder() + tpn.TrieContainer = triesHolder.NewTriesHolder() var stateTrie common.Trie tpn.AccntState, stateTrie = CreateAccountsDBWithEnableEpochsHandler(UserAccount, trieStorageManager, tpn.EnableEpochsHandler) tpn.AccntStateProposal, _ = CreateAccountsDBWithEnableEpochsHandler(UserAccount, trieStorageManager, tpn.EnableEpochsHandler) @@ -722,7 +723,7 @@ func (tpn *TestProcessorNode) initAccountDBsWithPruningStorer() { func (tpn *TestProcessorNode) initAccountDBs(store storage.Storer) { trieStorageManager, _ := CreateTrieStorageManager(store) - tpn.TrieContainer = state.NewDataTriesHolder() + tpn.TrieContainer = triesHolder.NewTriesHolder() var stateTrie common.Trie tpn.AccntState, stateTrie = CreateAccountsDBWithEnableEpochsHandler(UserAccount, trieStorageManager, tpn.EnableEpochsHandler) tpn.AccntStateProposal, _ = CreateAccountsDBWithEnableEpochsHandler(UserAccount, trieStorageManager, tpn.EnableEpochsHandler) diff --git a/integrationTests/vm/staking/componentsHolderCreator.go b/integrationTests/vm/staking/componentsHolderCreator.go index d46283b8d8a..3897e27bbe6 100644 --- a/integrationTests/vm/staking/componentsHolderCreator.go +++ b/integrationTests/vm/staking/componentsHolderCreator.go @@ -11,6 +11,9 @@ import ( "github.com/multiversx/mx-chain-core-go/data/typeConverters/uint64ByteSlice" "github.com/multiversx/mx-chain-core-go/hashing/sha256" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/state/triesHolder" + trieTestComponents "github.com/multiversx/mx-chain-go/testscommon/trie" + "github.com/multiversx/mx-chain-go/trie/collapseManager" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/enablers" @@ -165,17 +168,8 @@ func createStateComponents(coreComponents factory.CoreComponentsHolder) factory. tsm, _ := trie.CreateTrieStorageManager(tsmArgs, trie.StorageManagerOptions{}) trieFactoryManager, _ := trie.NewTrieStorageManagerWithoutPruning(tsm) - argsAccCreator := stateFactory.ArgsAccountCreator{ - Hasher: coreComponents.Hasher(), - Marshaller: coreComponents.InternalMarshalizer(), - EnableEpochsHandler: coreComponents.EnableEpochsHandler(), - StateAccessesCollector: disabledState.NewDisabledStateAccessesCollector(), - } - - accCreator, _ := stateFactory.NewAccountCreator(argsAccCreator) - - userAccountsDB := createAccountsDB(coreComponents, accCreator, trieFactoryManager) - peerAccountsDB := createAccountsDB(coreComponents, stateFactory.NewPeerAccountCreator(), trieFactoryManager) + userAccountsDB := createAccountsDB(coreComponents, userAdb, trieFactoryManager) + peerAccountsDB := createAccountsDB(coreComponents, peerAdb, trieFactoryManager) _ = userAccountsDB.SetSyncer(&mock.AccountsDBSyncerStub{}) _ = peerAccountsDB.SetSyncer(&mock.AccountsDBSyncerStub{}) @@ -199,9 +193,35 @@ func getNewTrieStorageManagerArgs(coreComponents factory.CoreComponentsHolder) t } } +type adbType string + +const ( + userAdb = "userAdb" + peerAdb = "peerAdb" +) + +func getAccountsCreator(t adbType, tr common.Trie, coreComponents factory.CoreComponentsHolder) (state.AccountFactory, common.TriesHolder) { + if t == peerAdb { + return stateFactory.NewPeerAccountCreator(), &trieTestComponents.TriesHolderStub{} + } + + dth, _ := triesHolder.NewDataTriesHolder(integrationTests.TenMbSize) + + argsAccCreator := stateFactory.ArgsAccountCreator{ + Hasher: coreComponents.Hasher(), + Marshaller: coreComponents.InternalMarshalizer(), + EnableEpochsHandler: coreComponents.EnableEpochsHandler(), + StateAccessesCollector: disabledState.NewDisabledStateAccessesCollector(), + DataTriesHolder: dth, + DataTrieCreator: tr, + } + accCreator, _ := stateFactory.NewAccountCreator(argsAccCreator) + return accCreator, dth +} + func createAccountsDB( coreComponents factory.CoreComponentsHolder, - accountFactory state.AccountFactory, + adbType adbType, trieStorageManager common.StorageManager, ) *state.AccountsDB { tr, _ := trie.NewTrie( @@ -209,9 +229,11 @@ func createAccountsDB( coreComponents.InternalMarshalizer(), coreComponents.Hasher(), coreComponents.EnableEpochsHandler(), - 5, + collapseManager.NewDisabledCollapseManager(), ) + accCreator, dth := getAccountsCreator(adbType, tr, coreComponents) + argsEvictionWaitingList := evictionWaitingList.MemoryEvictionWaitingListArgs{ RootHashesSize: 10, HashesSize: hashSize, @@ -222,11 +244,12 @@ func createAccountsDB( Trie: tr, Hasher: coreComponents.Hasher(), Marshaller: coreComponents.InternalMarshalizer(), - AccountFactory: accountFactory, + AccountFactory: accCreator, StoragePruningManager: spm, AddressConverter: coreComponents.AddressPubKeyConverter(), SnapshotsManager: &stateTests.SnapshotsManagerStub{}, StateAccessesCollector: disabledState.NewDisabledStateAccessesCollector(), + DataTriesHolder: dth, } adb, _ := state.NewAccountsDB(argsAccountsDb) return adb diff --git a/integrationTests/vm/wasm/wasmvm/mockContracts.go b/integrationTests/vm/wasm/wasmvm/mockContracts.go index fba56804747..94da6cb71b2 100644 --- a/integrationTests/vm/wasm/wasmvm/mockContracts.go +++ b/integrationTests/vm/wasm/wasmvm/mockContracts.go @@ -19,11 +19,6 @@ import ( "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/factory" "github.com/multiversx/mx-chain-go/state" - stateFactory "github.com/multiversx/mx-chain-go/state/factory" - "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" - "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" - "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" - stateMock "github.com/multiversx/mx-chain-go/testscommon/state" "github.com/multiversx/mx-chain-go/testscommon/txDataBuilder" ) @@ -39,6 +34,12 @@ var InitialEsdt = uint64(100) // EsdtTokenIdentifier is the token identifier in tests var EsdtTokenIdentifier = []byte("TTT-010101") +// AdbWithAccountsFactory - +type AdbWithAccountsFactory interface { + state.AccountsAdapter + GetAccountsFactory() state.AccountFactory +} + // InitializeMockContracts - func InitializeMockContracts( t *testing.T, @@ -107,15 +108,11 @@ func GetAddressForNewAccountOnWalletAndNodeWithVM( require.Nil(t, err) address := net.NewAddressWithVM(wallet, vmType) - argsAccCreation := stateFactory.ArgsAccountCreator{ - Hasher: &hashingMocks.HasherMock{}, - Marshaller: &marshallerMock.MarshalizerMock{}, - EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - StateAccessesCollector: &stateMock.StateAccessesCollectorStub{}, - } - accountFactory, _ := stateFactory.NewAccountCreator(argsAccCreation) - account, _ := accountFactory.CreateAccount(address) + adb, ok := node.AccntState.(AdbWithAccountsFactory) + require.True(t, ok) + + account, _ := adb.GetAccountsFactory().CreateAccount(address) userAccount := account.(state.UserAccountHandler) _ = userAccount.AddToBalance(MockInitialBalance) userAccount.SetCode(address) diff --git a/node/chainSimulator/components/testOnlyProcessingNode.go b/node/chainSimulator/components/testOnlyProcessingNode.go index 3360e56a6bc..b0bd2b87887 100644 --- a/node/chainSimulator/components/testOnlyProcessingNode.go +++ b/node/chainSimulator/components/testOnlyProcessingNode.go @@ -1,6 +1,7 @@ package components import ( + "bytes" "encoding/base64" "encoding/hex" "errors" @@ -528,19 +529,19 @@ func (node *testOnlyProcessingNode) SetStateForAddress(address []byte, addressSt return err } - rootHash, err := base64.StdEncoding.DecodeString(addressState.RootHash) + accountsAdapter := node.StateComponentsHolder.AccountsAdapter() + err = accountsAdapter.SaveAccount(userAccount) if err != nil { return err } - if len(rootHash) != 0 { - userAccount.SetRootHash(rootHash) - } - accountsAdapter := node.StateComponentsHolder.AccountsAdapter() - err = accountsAdapter.SaveAccount(userAccount) + rootHash, err := base64.StdEncoding.DecodeString(addressState.RootHash) if err != nil { return err } + if !bytes.Equal(userAccount.GetRootHash(), rootHash) && len(rootHash) != 0 { + return fmt.Errorf("account root hash does not match the provided one, expected: %s, got: %s", rootHash, userAccount.GetRootHash()) + } newRootHash, err := accountsAdapter.Commit() node.setBlockchainRootHashIfSupernovaIsActive(newRootHash) diff --git a/node/chainSimulator/components/testOnlyProcessingNode_test.go b/node/chainSimulator/components/testOnlyProcessingNode_test.go index 65f93c45c18..9c63b210613 100644 --- a/node/chainSimulator/components/testOnlyProcessingNode_test.go +++ b/node/chainSimulator/components/testOnlyProcessingNode_test.go @@ -150,7 +150,7 @@ func TestNewTestOnlyProcessingNode(t *testing.T) { t.Run("CreateStateComponents failure should error", func(t *testing.T) { args := createMockArgsTestOnlyProcessingNode(t) args.ShardIDStr = common.MetachainShardName // coverage only - args.Configs.GeneralConfig.StateTriesConfig.MaxStateTrieLevelInMemory = 0 + args.Configs.GeneralConfig.TrieStorageManagerConfig = config.TrieStorageManagerConfig{} node, err := NewTestOnlyProcessingNode(args) require.Error(t, err) require.Nil(t, node) diff --git a/node/nodeRunner.go b/node/nodeRunner.go index dabbe529bac..b4793cd632b 100644 --- a/node/nodeRunner.go +++ b/node/nodeRunner.go @@ -653,7 +653,6 @@ func getUserAccountSyncer( bootstrapComponents mainFactory.BootstrapComponentsHolder, processComponents mainFactory.ProcessComponentsHolder, ) (process.AccountsDBSyncer, error) { - maxTrieLevelInMemory := config.StateTriesConfig.MaxStateTrieLevelInMemory userTrie := stateComponents.TriesContainer().Get([]byte(dataRetriever.UserAccountsUnit.String())) storageManager := userTrie.GetStorageManager() @@ -669,7 +668,6 @@ func getUserAccountSyncer( dataComponents, processComponents, storageManager, - maxTrieLevelInMemory, ), ShardId: bootstrapComponents.ShardCoordinator().SelfId(), Throttler: thr, @@ -686,7 +684,6 @@ func getValidatorAccountSyncer( stateComponents mainFactory.StateComponentsHolder, processComponents mainFactory.ProcessComponentsHolder, ) (process.AccountsDBSyncer, error) { - maxTrieLevelInMemory := config.StateTriesConfig.MaxPeerTrieLevelInMemory peerTrie := stateComponents.TriesContainer().Get([]byte(dataRetriever.PeerAccountsUnit.String())) storageManager := peerTrie.GetStorageManager() @@ -697,7 +694,6 @@ func getValidatorAccountSyncer( dataComponents, processComponents, storageManager, - maxTrieLevelInMemory, ), } @@ -710,7 +706,6 @@ func getBaseAccountSyncerArgs( dataComponents mainFactory.DataComponentsHolder, processComponents mainFactory.ProcessComponentsHolder, storageManager common.StorageManager, - maxTrieLevelInMemory uint, ) syncer.ArgsNewBaseAccountsSyncer { return syncer.ArgsNewBaseAccountsSyncer{ Hasher: coreComponents.Hasher(), @@ -719,7 +714,6 @@ func getBaseAccountSyncerArgs( RequestHandler: processComponents.RequestHandler(), Timeout: common.TimeoutGettingTrieNodes, Cacher: dataComponents.Datapool().TrieNodes(), - MaxTrieLevelInMemory: maxTrieLevelInMemory, MaxHardCapForMissingNodes: config.TrieSync.MaxHardCapForMissingNodes, TrieSyncerVersion: config.TrieSync.TrieSyncerVersion, CheckNodesOnDisk: true, diff --git a/node/node_test.go b/node/node_test.go index 762dbe08c70..38d9b0e8227 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -102,7 +102,33 @@ func createMockPubkeyConverter() *testscommon.PubkeyConverterMock { func createAcc(address []byte) state.UserAccountHandler { dtlp, _ := parsers.NewDataTrieLeafParser(address, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) - dtt, _ := trackableDataTrie.NewTrackableDataTrie(address, &testscommon.HasherStub{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, disabled.NewDisabledStateAccessesCollector()) + args := trackableDataTrie.TrackableDataTrieArgs{ + Identifier: address, + Hasher: &testscommon.HasherStub{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + StateAccessesCollector: disabled.NewDisabledStateAccessesCollector(), + DataTriesHolder: &trieMock.TriesHolderStub{}, + DataTrieCreator: &trieMock.TrieStub{}, + } + dtt, _ := trackableDataTrie.NewTrackableDataTrie(args) + acc, _ := accounts.NewUserAccount(address, dtt, dtlp) + + return acc +} + +func createAccWithDth(address []byte, dth common.TriesHolder) state.UserAccountHandler { + dtlp, _ := parsers.NewDataTrieLeafParser(address, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) + args := trackableDataTrie.TrackableDataTrieArgs{ + Identifier: address, + Hasher: &testscommon.HasherStub{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + StateAccessesCollector: disabled.NewDisabledStateAccessesCollector(), + DataTriesHolder: dth, + DataTrieCreator: &trieMock.TrieStub{}, + } + dtt, _ := trackableDataTrie.NewTrackableDataTrie(args) acc, _ := accounts.NewUserAccount(address, dtt, dtlp) return acc @@ -448,33 +474,37 @@ func TestNode_GetKeyValuePairsAccNotFoundShouldReturnEmpty(t *testing.T) { func TestNode_GetKeyValuePairs(t *testing.T) { t.Parallel() - acc := createAcc([]byte("newaddress")) - + address := []byte("newaddress") k1, v1 := []byte("key1"), []byte("value1") k2, v2 := []byte("key2"), []byte("value2") - accDB := &stateMock.AccountsStub{} - acc.SetDataTrie( - &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, tlp common.TrieLeafParser) error { - go func() { - suffix := append(k1, acc.AddressBytes()...) - trieLeaf, _ := tlp.ParseLeaf(k1, append(v1, suffix...), core.NotSpecified) - leavesChannels.LeavesChan <- trieLeaf - - suffix = append(k2, acc.AddressBytes()...) - trieLeaf2, _ := tlp.ParseLeaf(k2, append(v2, suffix...), core.NotSpecified) - leavesChannels.LeavesChan <- trieLeaf2 - close(leavesChannels.LeavesChan) - leavesChannels.ErrChan.Close() - }() + dth := &trieMock.TriesHolderStub{ + GetCalled: func(_ []byte) common.Trie { + return &trieMock.TrieStub{ + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, tlp common.TrieLeafParser) error { + go func() { + suffix := append(k1, address...) + trieLeaf, _ := tlp.ParseLeaf(k1, append(v1, suffix...), core.NotSpecified) + leavesChannels.LeavesChan <- trieLeaf + + suffix = append(k2, address...) + trieLeaf2, _ := tlp.ParseLeaf(k2, append(v2, suffix...), core.NotSpecified) + leavesChannels.LeavesChan <- trieLeaf2 + close(leavesChannels.LeavesChan) + leavesChannels.ErrChan.Close() + }() + + return nil + }, + RootCalled: func() ([]byte, error) { + return nil, nil + }, + } + }, + } + acc := createAccWithDth(address, dth) - return nil - }, - RootCalled: func() ([]byte, error) { - return nil, nil - }, - }) + accDB := &stateMock.AccountsStub{} accDB.GetAccountWithBlockInfoCalled = func(address []byte, options common.RootHashHolder) (vmcommon.AccountHandler, common.BlockInfo, error) { return acc, nil, nil @@ -516,25 +546,27 @@ func TestNode_GetKeyValuePairs(t *testing.T) { func TestNode_GetKeyValuePairs_GetAllLeavesShouldFail(t *testing.T) { t.Parallel() - acc := createAcc([]byte("newaddress")) - - accDB := &stateMock.AccountsStub{} - + address := []byte("newaddress") expectedErr := errors.New("expected err") - acc.SetDataTrie( - &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, _ common.TrieLeafParser) error { - go func() { - leavesChannels.ErrChan.WriteInChanNonBlocking(expectedErr) - close(leavesChannels.LeavesChan) - }() + dth := &trieMock.TriesHolderStub{ + GetCalled: func(_ []byte) common.Trie { + return &trieMock.TrieStub{ + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, _ common.TrieLeafParser) error { + go func() { + leavesChannels.ErrChan.WriteInChanNonBlocking(expectedErr) + close(leavesChannels.LeavesChan) + }() - return nil - }, - RootCalled: func() ([]byte, error) { - return nil, nil - }, - }) + return nil + }, + RootCalled: func() ([]byte, error) { + return nil, nil + }, + } + }, + } + acc := createAccWithDth(address, dth) + accDB := &stateMock.AccountsStub{} accDB.GetAccountWithBlockInfoCalled = func(address []byte, options common.RootHashHolder) (vmcommon.AccountHandler, common.BlockInfo, error) { return acc, nil, nil @@ -571,24 +603,27 @@ func TestNode_GetKeyValuePairs_GetAllLeavesShouldFail(t *testing.T) { func TestNode_GetKeyValuePairsContextShouldTimeout(t *testing.T) { t.Parallel() - acc := createAcc([]byte("newaddress")) - + address := []byte("newaddress") + dth := &trieMock.TriesHolderStub{ + GetCalled: func(_ []byte) common.Trie { + return &trieMock.TrieStub{ + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, _ common.TrieLeafParser) error { + go func() { + time.Sleep(time.Second) + close(leavesChannels.LeavesChan) + leavesChannels.ErrChan.Close() + }() + + return nil + }, + RootCalled: func() ([]byte, error) { + return nil, nil + }, + } + }, + } + acc := createAccWithDth(address, dth) accDB := &stateMock.AccountsStub{} - acc.SetDataTrie( - &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, _ common.TrieLeafParser) error { - go func() { - time.Sleep(time.Second) - close(leavesChannels.LeavesChan) - leavesChannels.ErrChan.Close() - }() - - return nil - }, - RootCalled: func() ([]byte, error) { - return nil, nil - }, - }) accDB.GetAccountWithBlockInfoCalled = func(address []byte, options common.RootHashHolder) (vmcommon.AccountHandler, common.BlockInfo, error) { return acc, nil, nil @@ -842,9 +877,28 @@ func TestNode_GetESDTDataForNFT(t *testing.T) { func TestNode_GetAllESDTTokens(t *testing.T) { t.Parallel() - acc := createAcc(testscommon.TestPubKeyAlice) esdtToken := "newToken" esdtKey := []byte(core.ProtectedKeyPrefix + core.ESDTKeyIdentifier + esdtToken) + dth := &trieMock.TriesHolderStub{ + GetCalled: func(_ []byte) common.Trie { + return &trieMock.TrieStub{ + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, _ common.TrieLeafParser) error { + go func() { + trieLeaf := keyValStorage.NewKeyValStorage(esdtKey, nil) + leavesChannels.LeavesChan <- trieLeaf + close(leavesChannels.LeavesChan) + leavesChannels.ErrChan.Close() + }() + + return nil + }, + RootCalled: func() ([]byte, error) { + return nil, nil + }, + } + }, + } + acc := createAccWithDth(testscommon.TestPubKeyAlice, dth) esdtData := &esdt.ESDigitalToken{Value: big.NewInt(10)} @@ -854,23 +908,6 @@ func TestNode_GetAllESDTTokens(t *testing.T) { }, } - acc.SetDataTrie( - &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, _ common.TrieLeafParser) error { - go func() { - trieLeaf := keyValStorage.NewKeyValStorage(esdtKey, nil) - leavesChannels.LeavesChan <- trieLeaf - close(leavesChannels.LeavesChan) - leavesChannels.ErrChan.Close() - }() - - return nil - }, - RootCalled: func() ([]byte, error) { - return nil, nil - }, - }) - accDB := &stateMock.AccountsStub{ RecreateTrieCalled: func(rootHash common.RootHashHolder) error { return nil @@ -910,23 +947,25 @@ func TestNode_GetAllESDTTokens(t *testing.T) { func TestNode_GetAllESDTTokens_GetAllLeavesShouldFail(t *testing.T) { t.Parallel() - acc := createAcc(testscommon.TestPubKeyAlice) - expectedErr := errors.New("expected error") - acc.SetDataTrie( - &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, _ common.TrieLeafParser) error { - go func() { - leavesChannels.ErrChan.WriteInChanNonBlocking(expectedErr) - close(leavesChannels.LeavesChan) - }() + dth := &trieMock.TriesHolderStub{ + GetCalled: func(_ []byte) common.Trie { + return &trieMock.TrieStub{ + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, _ common.TrieLeafParser) error { + go func() { + leavesChannels.ErrChan.WriteInChanNonBlocking(expectedErr) + close(leavesChannels.LeavesChan) + }() - return nil - }, - RootCalled: func() ([]byte, error) { - return nil, nil - }, - }) + return nil + }, + RootCalled: func() ([]byte, error) { + return nil, nil + }, + } + }, + } + acc := createAccWithDth(testscommon.TestPubKeyAlice, dth) accDB := &stateMock.AccountsStub{ RecreateTrieCalled: func(rootHash common.RootHashHolder) error { @@ -966,23 +1005,25 @@ func TestNode_GetAllESDTTokens_GetAllLeavesShouldFail(t *testing.T) { func TestNode_GetAllESDTTokensContextShouldTimeout(t *testing.T) { t.Parallel() - acc := createAcc(testscommon.TestPubKeyAlice) - - acc.SetDataTrie( - &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, _ common.TrieLeafParser) error { - go func() { - time.Sleep(time.Second) - close(leavesChannels.LeavesChan) - leavesChannels.ErrChan.Close() - }() - - return nil - }, - RootCalled: func() ([]byte, error) { - return nil, nil - }, - }) + dth := &trieMock.TriesHolderStub{ + GetCalled: func(_ []byte) common.Trie { + return &trieMock.TrieStub{ + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, _ common.TrieLeafParser) error { + go func() { + time.Sleep(time.Second) + close(leavesChannels.LeavesChan) + leavesChannels.ErrChan.Close() + }() + + return nil + }, + RootCalled: func() ([]byte, error) { + return nil, nil + }, + } + }, + } + acc := createAccWithDth(testscommon.TestPubKeyAlice, dth) accDB := &stateMock.AccountsStub{ RecreateTrieCalled: func(rootHash common.RootHashHolder) error { @@ -1062,20 +1103,19 @@ func TestNode_GetAllESDTsAccNotFoundShouldReturnEmpty(t *testing.T) { func TestNode_GetAllESDTTokensShouldReturnEsdtAndFormattedNft(t *testing.T) { t.Parallel() - acc := createAcc(testscommon.TestPubKeyAlice) - + address := testscommon.TestPubKeyAlice esdtToken := "TKKR-7q8w9e" esdtKey := []byte(core.ProtectedKeyPrefix + core.ESDTKeyIdentifier + esdtToken) esdtData := &esdt.ESDigitalToken{Value: big.NewInt(10)} marshalledData, _ := getMarshalizer().Marshal(esdtData) - suffix := append(esdtKey, acc.AddressBytes()...) + suffix := append(esdtKey, address...) nftToken := "TCKR-67tgv3" nftNonce := big.NewInt(1) nftKey := []byte(core.ProtectedKeyPrefix + core.ESDTKeyIdentifier + nftToken) nftKeyWithBytes := append(nftKey, nftNonce.Bytes()...) - nftSuffix := append(nftKeyWithBytes, acc.AddressBytes()...) + nftSuffix := append(nftKeyWithBytes, address...) nftMetaData := &esdt.MetaData{Nonce: nftNonce.Uint64(), Creator: []byte("12345678901234567890123456789012")} nftData := &esdt.ESDigitalToken{Type: uint32(core.NonFungible), Value: big.NewInt(10), TokenMetaData: nftMetaData} @@ -1085,7 +1125,7 @@ func TestNode_GetAllESDTTokensShouldReturnEsdtAndFormattedNft(t *testing.T) { dynamicNftNonce := big.NewInt(100) dynamicNftKey := []byte(core.ProtectedKeyPrefix + core.ESDTKeyIdentifier + dynamicNft) dynamicNftKeyWithBytes := append(dynamicNftKey, dynamicNftNonce.Bytes()...) - dynamicNftSuffix := append(dynamicNftKeyWithBytes, acc.AddressBytes()...) + dynamicNftSuffix := append(dynamicNftKeyWithBytes, address...) dynamicNftData := &esdt.ESDigitalToken{ Type: uint32(core.DynamicNFT), Value: big.NewInt(0), @@ -1110,34 +1150,39 @@ func TestNode_GetAllESDTTokensShouldReturnEsdtAndFormattedNft(t *testing.T) { } }, } - acc.SetDataTrie( - &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, _ common.TrieLeafParser) error { - wg := &sync.WaitGroup{} - wg.Add(1) - go func() { - trieLeaf := keyValStorage.NewKeyValStorage(esdtKey, append(marshalledData, suffix...)) - leavesChannels.LeavesChan <- trieLeaf - trieLeaf = keyValStorage.NewKeyValStorage(nftKey, append(marshalledNftData, nftSuffix...)) - leavesChannels.LeavesChan <- trieLeaf + dth := &trieMock.TriesHolderStub{ + GetCalled: func(_ []byte) common.Trie { + return &trieMock.TrieStub{ + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, _ common.TrieLeafParser) error { + wg := &sync.WaitGroup{} + wg.Add(1) + go func() { + trieLeaf := keyValStorage.NewKeyValStorage(esdtKey, append(marshalledData, suffix...)) + leavesChannels.LeavesChan <- trieLeaf - trieLeaf = keyValStorage.NewKeyValStorage(dynamicNftKey, append(marshalledDynamicNftData, dynamicNftSuffix...)) - leavesChannels.LeavesChan <- trieLeaf + trieLeaf = keyValStorage.NewKeyValStorage(nftKey, append(marshalledNftData, nftSuffix...)) + leavesChannels.LeavesChan <- trieLeaf - wg.Done() - close(leavesChannels.LeavesChan) - leavesChannels.ErrChan.Close() - }() + trieLeaf = keyValStorage.NewKeyValStorage(dynamicNftKey, append(marshalledDynamicNftData, dynamicNftSuffix...)) + leavesChannels.LeavesChan <- trieLeaf - wg.Wait() + wg.Done() + close(leavesChannels.LeavesChan) + leavesChannels.ErrChan.Close() + }() - return nil - }, - RootCalled: func() ([]byte, error) { - return nil, nil - }, - }) + wg.Wait() + + return nil + }, + RootCalled: func() ([]byte, error) { + return nil, nil + }, + } + }, + } + acc := createAccWithDth(address, dth) accDB := &stateMock.AccountsStub{ RecreateTrieCalled: func(rootHash common.RootHashHolder) error { @@ -1179,7 +1224,7 @@ func TestNode_GetAllESDTTokensShouldReturnEsdtAndFormattedNft(t *testing.T) { func TestNode_GetAllIssuedESDTs(t *testing.T) { t.Parallel() - acc := createAcc([]byte("newaddress")) + address := []byte("newaddress") esdtToken := []byte("TCK-RANDOM") sftToken := []byte("SFT-RANDOM") sftTokenDynamic := []byte("SFT-Dynamic") @@ -1187,66 +1232,71 @@ func TestNode_GetAllIssuedESDTs(t *testing.T) { nftTokenV2 := []byte("NFT-RANDOM-V2") nftTokenDynamic := []byte("NFT-Dynamic") + esdtSuffix := append(esdtToken, address...) + nftSuffix := append(nftToken, address...) + nftSuffix2 := append(nftTokenV2, address...) + nftDynamicSuffix := append(nftTokenDynamic, address...) + sftSuffix := append(sftToken, address...) + sftDynamicSuffix := append(sftTokenDynamic, address...) + esdtData := &systemSmartContracts.ESDTDataV2{TokenName: []byte("fungible"), TokenType: []byte(core.FungibleESDT)} marshalledData, _ := getMarshalizer().Marshal(esdtData) - _ = acc.SaveKeyValue(esdtToken, marshalledData) sftData := &systemSmartContracts.ESDTDataV2{TokenName: []byte("semi fungible"), TokenType: []byte(core.SemiFungibleESDT)} sftMarshalledData, _ := getMarshalizer().Marshal(sftData) - _ = acc.SaveKeyValue(sftToken, sftMarshalledData) sftDataDynamic := &systemSmartContracts.ESDTDataV2{TokenName: []byte("semi fungible dynamic"), TokenType: []byte(core.DynamicSFTESDT)} sftMarshalledDataDynamic, _ := getMarshalizer().Marshal(sftDataDynamic) - _ = acc.SaveKeyValue(sftTokenDynamic, sftMarshalledDataDynamic) nftData := &systemSmartContracts.ESDTDataV2{TokenName: []byte("non fungible"), TokenType: []byte(core.NonFungibleESDT)} nftMarshalledData, _ := getMarshalizer().Marshal(nftData) - _ = acc.SaveKeyValue(nftToken, nftMarshalledData) nftData2 := &systemSmartContracts.ESDTDataV2{TokenName: []byte("non fungible v2"), TokenType: []byte(core.NonFungibleESDT)} nftMarshalledData2, _ := getMarshalizer().Marshal(nftData2) - _ = acc.SaveKeyValue(nftTokenV2, nftMarshalledData2) nftDataDynamic := &systemSmartContracts.ESDTDataV2{TokenName: []byte("non fungible dynamic"), TokenType: []byte(core.DynamicNFTESDT)} nftMarshalledDataDyamic, _ := getMarshalizer().Marshal(nftDataDynamic) - _ = acc.SaveKeyValue(nftTokenDynamic, nftMarshalledDataDyamic) - esdtSuffix := append(esdtToken, acc.AddressBytes()...) - nftSuffix := append(nftToken, acc.AddressBytes()...) - nftSuffix2 := append(nftTokenV2, acc.AddressBytes()...) - nftDynamicSuffix := append(nftTokenDynamic, acc.AddressBytes()...) - sftSuffix := append(sftToken, acc.AddressBytes()...) - sftDynamicSuffix := append(sftTokenDynamic, acc.AddressBytes()...) - - acc.SetDataTrie( - &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, tlp common.TrieLeafParser) error { - go func() { - trieLeaf, _ := tlp.ParseLeaf(esdtToken, append(marshalledData, esdtSuffix...), core.NotSpecified) - leavesChannels.LeavesChan <- trieLeaf - - trieLeaf, _ = tlp.ParseLeaf(sftToken, append(sftMarshalledData, sftSuffix...), core.NotSpecified) - leavesChannels.LeavesChan <- trieLeaf - trieLeaf, _ = tlp.ParseLeaf(sftTokenDynamic, append(sftMarshalledDataDynamic, sftDynamicSuffix...), core.NotSpecified) - leavesChannels.LeavesChan <- trieLeaf - - trieLeaf, _ = tlp.ParseLeaf(nftToken, append(nftMarshalledData, nftSuffix...), core.NotSpecified) - leavesChannels.LeavesChan <- trieLeaf - trieLeaf, _ = tlp.ParseLeaf(nftTokenV2, append(nftMarshalledData2, nftSuffix2...), core.NotSpecified) - leavesChannels.LeavesChan <- trieLeaf - trieLeaf, _ = tlp.ParseLeaf(nftTokenDynamic, append(nftMarshalledDataDyamic, nftDynamicSuffix...), core.NotSpecified) - leavesChannels.LeavesChan <- trieLeaf - - close(leavesChannels.LeavesChan) - leavesChannels.ErrChan.Close() - }() + dth := &trieMock.TriesHolderStub{ + GetCalled: func(_ []byte) common.Trie { + return &trieMock.TrieStub{ + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, tlp common.TrieLeafParser) error { + go func() { + trieLeaf, _ := tlp.ParseLeaf(esdtToken, append(marshalledData, esdtSuffix...), core.NotSpecified) + leavesChannels.LeavesChan <- trieLeaf + + trieLeaf, _ = tlp.ParseLeaf(sftToken, append(sftMarshalledData, sftSuffix...), core.NotSpecified) + leavesChannels.LeavesChan <- trieLeaf + trieLeaf, _ = tlp.ParseLeaf(sftTokenDynamic, append(sftMarshalledDataDynamic, sftDynamicSuffix...), core.NotSpecified) + leavesChannels.LeavesChan <- trieLeaf + + trieLeaf, _ = tlp.ParseLeaf(nftToken, append(nftMarshalledData, nftSuffix...), core.NotSpecified) + leavesChannels.LeavesChan <- trieLeaf + trieLeaf, _ = tlp.ParseLeaf(nftTokenV2, append(nftMarshalledData2, nftSuffix2...), core.NotSpecified) + leavesChannels.LeavesChan <- trieLeaf + trieLeaf, _ = tlp.ParseLeaf(nftTokenDynamic, append(nftMarshalledDataDyamic, nftDynamicSuffix...), core.NotSpecified) + leavesChannels.LeavesChan <- trieLeaf + + close(leavesChannels.LeavesChan) + leavesChannels.ErrChan.Close() + }() + + return nil + }, + RootCalled: func() ([]byte, error) { + return nil, nil + }, + } + }, + } + acc := createAccWithDth(address, dth) - return nil - }, - RootCalled: func() ([]byte, error) { - return nil, nil - }, - }) + _ = acc.SaveKeyValue(esdtToken, marshalledData) + _ = acc.SaveKeyValue(sftToken, sftMarshalledData) + _ = acc.SaveKeyValue(sftTokenDynamic, sftMarshalledDataDynamic) + _ = acc.SaveKeyValue(nftToken, nftMarshalledData) + _ = acc.SaveKeyValue(nftTokenV2, nftMarshalledData2) + _ = acc.SaveKeyValue(nftTokenDynamic, nftMarshalledDataDyamic) accDB := &stateMock.AccountsStub{ RecreateTrieCalled: func(rootHash common.RootHashHolder) error { @@ -1311,38 +1361,39 @@ func TestNode_GetESDTsWithRole(t *testing.T) { t.Parallel() addrBytes := testscommon.TestPubKeyAlice - acc := createAcc(addrBytes) - esdtToken := []byte("TCK-RANDOM") - specialRoles := []*systemSmartContracts.ESDTRoles{ { Address: addrBytes, Roles: [][]byte{[]byte(core.ESDTRoleNFTAddQuantity), []byte(core.ESDTRoleLocalMint)}, }, } - + esdtToken := []byte("TCK-RANDOM") esdtData := &systemSmartContracts.ESDTDataV2{TokenName: []byte("fungible"), TokenType: []byte(core.FungibleESDT), SpecialRoles: specialRoles} marshalledData, _ := getMarshalizer().Marshal(esdtData) - _ = acc.SaveKeyValue(esdtToken, marshalledData) + esdtSuffix := append(esdtToken, addrBytes...) - esdtSuffix := append(esdtToken, acc.AddressBytes()...) - - acc.SetDataTrie( - &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, tlp common.TrieLeafParser) error { - go func() { - trieLeaf, _ := tlp.ParseLeaf(esdtToken, append(marshalledData, esdtSuffix...), core.NotSpecified) - leavesChannels.LeavesChan <- trieLeaf - close(leavesChannels.LeavesChan) - leavesChannels.ErrChan.Close() - }() + dth := &trieMock.TriesHolderStub{ + GetCalled: func(_ []byte) common.Trie { + return &trieMock.TrieStub{ + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, tlp common.TrieLeafParser) error { + go func() { + trieLeaf, _ := tlp.ParseLeaf(esdtToken, append(marshalledData, esdtSuffix...), core.NotSpecified) + leavesChannels.LeavesChan <- trieLeaf + close(leavesChannels.LeavesChan) + leavesChannels.ErrChan.Close() + }() + + return nil + }, + RootCalled: func() ([]byte, error) { + return nil, nil + }, + } + }, + } + acc := createAccWithDth(addrBytes, dth) - return nil - }, - RootCalled: func() ([]byte, error) { - return nil, nil - }, - }) + _ = acc.SaveKeyValue(esdtToken, marshalledData) accDB := &stateMock.AccountsStub{ RecreateTrieCalled: func(rootHash common.RootHashHolder) error { @@ -1391,37 +1442,36 @@ func TestNode_GetESDTsRoles(t *testing.T) { t.Parallel() addrBytes := testscommon.TestPubKeyAlice - acc := createAcc(addrBytes) - esdtToken := []byte("TCK-RANDOM") - specialRoles := []*systemSmartContracts.ESDTRoles{ { Address: addrBytes, Roles: [][]byte{[]byte(core.ESDTRoleNFTAddQuantity), []byte(core.ESDTRoleLocalMint)}, }, } - + esdtToken := []byte("TCK-RANDOM") esdtData := &systemSmartContracts.ESDTDataV2{TokenName: []byte("fungible"), TokenType: []byte(core.FungibleESDT), SpecialRoles: specialRoles} marshalledData, _ := getMarshalizer().Marshal(esdtData) - - esdtSuffix := append(esdtToken, acc.AddressBytes()...) - - acc.SetDataTrie( - &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, tlp common.TrieLeafParser) error { - go func() { - trieLeaf, _ := tlp.ParseLeaf(esdtToken, append(marshalledData, esdtSuffix...), core.NotSpecified) - leavesChannels.LeavesChan <- trieLeaf - close(leavesChannels.LeavesChan) - leavesChannels.ErrChan.Close() - }() - - return nil - }, - RootCalled: func() ([]byte, error) { - return nil, nil - }, - }) + esdtSuffix := append(esdtToken, addrBytes...) + dth := &trieMock.TriesHolderStub{ + GetCalled: func(_ []byte) common.Trie { + return &trieMock.TrieStub{ + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, tlp common.TrieLeafParser) error { + go func() { + trieLeaf, _ := tlp.ParseLeaf(esdtToken, append(marshalledData, esdtSuffix...), core.NotSpecified) + leavesChannels.LeavesChan <- trieLeaf + close(leavesChannels.LeavesChan) + leavesChannels.ErrChan.Close() + }() + + return nil + }, + RootCalled: func() ([]byte, error) { + return nil, nil + }, + } + }, + } + acc := createAccWithDth(addrBytes, dth) accDB := &stateMock.AccountsStub{ RecreateTrieCalled: func(rootHash common.RootHashHolder) error { @@ -1462,32 +1512,32 @@ func TestNode_GetNFTTokenIDsRegisteredByAddress(t *testing.T) { t.Parallel() addrBytes := testscommon.TestPubKeyAlice - acc := createAcc(addrBytes) esdtToken := []byte("TCK-RANDOM") - esdtData := &systemSmartContracts.ESDTDataV2{TokenName: []byte("fungible"), TokenType: []byte(core.SemiFungibleESDT), OwnerAddress: addrBytes} marshalledData, _ := getMarshalizer().Marshal(esdtData) - _ = acc.SaveKeyValue(esdtToken, marshalledData) - - esdtSuffix := append(esdtToken, acc.AddressBytes()...) - - acc.SetDataTrie( - &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, tlp common.TrieLeafParser) error { - go func() { - trieLeaf, _ := tlp.ParseLeaf(esdtToken, append(marshalledData, esdtSuffix...), core.NotSpecified) - leavesChannels.LeavesChan <- trieLeaf - close(leavesChannels.LeavesChan) - leavesChannels.ErrChan.Close() - }() - - return nil - }, - RootCalled: func() ([]byte, error) { - return nil, nil - }, + esdtSuffix := append(esdtToken, addrBytes...) + dth := &trieMock.TriesHolderStub{ + GetCalled: func(_ []byte) common.Trie { + return &trieMock.TrieStub{ + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, tlp common.TrieLeafParser) error { + go func() { + trieLeaf, _ := tlp.ParseLeaf(esdtToken, append(marshalledData, esdtSuffix...), core.NotSpecified) + leavesChannels.LeavesChan <- trieLeaf + close(leavesChannels.LeavesChan) + leavesChannels.ErrChan.Close() + }() + + return nil + }, + RootCalled: func() ([]byte, error) { + return nil, nil + }, + } }, - ) + } + acc := createAccWithDth(addrBytes, dth) + + _ = acc.SaveKeyValue(esdtToken, marshalledData) accDB := &stateMock.AccountsStub{ RecreateTrieCalled: func(rootHash common.RootHashHolder) error { @@ -1527,24 +1577,25 @@ func TestNode_GetNFTTokenIDsRegisteredByAddressContextShouldTimeout(t *testing.T t.Parallel() addrBytes := testscommon.TestPubKeyAlice - acc := createAcc(addrBytes) - - acc.SetDataTrie( - &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, _ common.TrieLeafParser) error { - go func() { - time.Sleep(time.Second) - close(leavesChannels.LeavesChan) - leavesChannels.ErrChan.Close() - }() - - return nil - }, - RootCalled: func() ([]byte, error) { - return nil, nil - }, + dth := &trieMock.TriesHolderStub{ + GetCalled: func(_ []byte) common.Trie { + return &trieMock.TrieStub{ + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, _ common.TrieLeafParser) error { + go func() { + time.Sleep(time.Second) + close(leavesChannels.LeavesChan) + leavesChannels.ErrChan.Close() + }() + + return nil + }, + RootCalled: func() ([]byte, error) { + return nil, nil + }, + } }, - ) + } + acc := createAccWithDth(addrBytes, dth) accDB := &stateMock.AccountsStub{ RecreateTrieCalled: func(rootHash common.RootHashHolder) error { @@ -3554,18 +3605,24 @@ func TestNode_GetAccountAccountExistsShouldReturn(t *testing.T) { } func TestNode_GetAccountAccountWithKeysErrorShouldFail(t *testing.T) { - accnt := createAcc(testscommon.TestPubKeyBob) - _ = accnt.AddToBalance(big.NewInt(1)) + t.Parallel() + expectedErr := errors.New("expected error") - accnt.SetDataTrie( - &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, tlp common.TrieLeafParser) error { - return expectedErr - }, - RootCalled: func() ([]byte, error) { - return nil, nil - }, - }) + + dth := &trieMock.TriesHolderStub{ + GetCalled: func(_ []byte) common.Trie { + return &trieMock.TrieStub{ + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, tlp common.TrieLeafParser) error { + return expectedErr + }, + RootCalled: func() ([]byte, error) { + return nil, nil + }, + } + }, + } + accnt := createAccWithDth(testscommon.TestPubKeyBob, dth) + _ = accnt.AddToBalance(big.NewInt(1)) accDB := &stateMock.AccountsStub{ GetAccountWithBlockInfoCalled: func(address []byte, options common.RootHashHolder) (vmcommon.AccountHandler, common.BlockInfo, error) { @@ -3588,34 +3645,37 @@ func TestNode_GetAccountAccountWithKeysErrorShouldFail(t *testing.T) { func TestNode_GetAccountAccountWithKeysShouldWork(t *testing.T) { t.Parallel() - accnt := createAcc(testscommon.TestPubKeyBob) - _ = accnt.AddToBalance(big.NewInt(1)) - + address := testscommon.TestPubKeyBob k1, v1 := []byte("key1"), []byte("value1") k2, v2 := []byte("key2"), []byte("value2") - accnt.SetDataTrie( - &trieMock.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, tlp common.TrieLeafParser) error { - go func() { - suffix := append(k1, accnt.AddressBytes()...) - trieLeaf, _ := tlp.ParseLeaf(k1, append(v1, suffix...), core.NotSpecified) - leavesChannels.LeavesChan <- trieLeaf + dth := &trieMock.TriesHolderStub{ + GetCalled: func(_ []byte) common.Trie { + return &trieMock.TrieStub{ + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, tlp common.TrieLeafParser) error { + go func() { + suffix := append(k1, address...) + trieLeaf, _ := tlp.ParseLeaf(k1, append(v1, suffix...), core.NotSpecified) + leavesChannels.LeavesChan <- trieLeaf - suffix = append(k2, accnt.AddressBytes()...) - trieLeaf2, _ := tlp.ParseLeaf(k2, append(v2, suffix...), core.NotSpecified) - leavesChannels.LeavesChan <- trieLeaf2 + suffix = append(k2, address...) + trieLeaf2, _ := tlp.ParseLeaf(k2, append(v2, suffix...), core.NotSpecified) + leavesChannels.LeavesChan <- trieLeaf2 - close(leavesChannels.LeavesChan) - leavesChannels.ErrChan.Close() - }() + close(leavesChannels.LeavesChan) + leavesChannels.ErrChan.Close() + }() - return nil - }, - RootCalled: func() ([]byte, error) { - return nil, nil - }, - }) + return nil + }, + RootCalled: func() ([]byte, error) { + return nil, nil + }, + } + }, + } + accnt := createAccWithDth(address, dth) + _ = accnt.AddToBalance(big.NewInt(1)) accDB := &stateMock.AccountsStub{ GetAccountWithBlockInfoCalled: func(address []byte, options common.RootHashHolder) (vmcommon.AccountHandler, common.BlockInfo, error) { @@ -3668,7 +3728,6 @@ func TestNode_GetAccountAccountWithKeysNilDataTrieShouldWork(t *testing.T) { t.Parallel() accnt := createAcc(testscommon.TestPubKeyBob) - accnt.SetDataTrie(nil) _ = accnt.AddToBalance(big.NewInt(1)) accDB := &stateMock.AccountsStub{ @@ -4542,12 +4601,17 @@ func TestNode_IsDataTrieMigrated(t *testing.T) { t.Run("should work and return false", func(t *testing.T) { t.Parallel() - acc := createAcc([]byte("000000000000000000010000000000000000000000000000000000000001ffff")) - acc.SetDataTrie(&trieMock.TrieStub{ - IsMigratedToLatestVersionCalled: func() (bool, error) { - return false, nil + address := []byte("000000000000000000010000000000000000000000000000000000000001ffff") + dth := &trieMock.TriesHolderStub{ + GetCalled: func(_ []byte) common.Trie { + return &trieMock.TrieStub{ + IsMigratedToLatestVersionCalled: func() (bool, error) { + return false, nil + }, + } }, - }) + } + acc := createAccWithDth(address, dth) stateComponents := getDefaultStateComponents() stateComponents.AccountsRepo = &stateMock.AccountsRepositoryStub{ @@ -4569,12 +4633,17 @@ func TestNode_IsDataTrieMigrated(t *testing.T) { t.Run("should work and return true", func(t *testing.T) { t.Parallel() - acc := createAcc([]byte("000000000000000000010000000000000000000000000000000000000001ffff")) - acc.SetDataTrie(&trieMock.TrieStub{ - IsMigratedToLatestVersionCalled: func() (bool, error) { - return true, nil + address := []byte("000000000000000000010000000000000000000000000000000000000001ffff") + dth := &trieMock.TriesHolderStub{ + GetCalled: func(_ []byte) common.Trie { + return &trieMock.TrieStub{ + IsMigratedToLatestVersionCalled: func() (bool, error) { + return true, nil + }, + } }, - }) + } + acc := createAccWithDth(address, dth) stateComponents := getDefaultStateComponents() stateComponents.AccountsRepo = &stateMock.AccountsRepositoryStub{ diff --git a/process/block/baseProcess.go b/process/block/baseProcess.go index 506e1e2a0cc..5aac9420ede 100644 --- a/process/block/baseProcess.go +++ b/process/block/baseProcess.go @@ -43,7 +43,6 @@ import ( "github.com/multiversx/mx-chain-go/sharding" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/state" - "github.com/multiversx/mx-chain-go/state/factory" "github.com/multiversx/mx-chain-go/state/parsers" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/storage/storageunit" @@ -2507,16 +2506,11 @@ func (bp *baseProcessor) commitTrieEpochRootHashIfNeeded(metaBlock data.MetaHead totalSizeAccountsDataTries := 0 totalSizeCodeLeaves := 0 - argsAccCreator := factory.ArgsAccountCreator{ - Hasher: bp.hasher, - Marshaller: bp.marshalizer, - EnableEpochsHandler: bp.enableEpochsHandler, - StateAccessesCollector: bp.stateAccessesCollector, - } - accountCreator, err := factory.NewAccountCreator(argsAccCreator) - if err != nil { - return err + adb, ok := userAccountsDb.(adbWithAccountsFactory) + if !ok { + return fmt.Errorf("%w: cannot assert accountsDB to have account factory: %T", process.ErrWrongTypeAssertion, userAccountsDb) } + accountCreator := adb.GetAccountsFactory() for leaf := range iteratorChannels.LeavesChan { userAccount, errUnmarshal := bp.unmarshalUserAccount(accountCreator, leaf.Key(), leaf.Value()) diff --git a/process/block/baseProcess_test.go b/process/block/baseProcess_test.go index b6a1ea05464..adee5578f92 100644 --- a/process/block/baseProcess_test.go +++ b/process/block/baseProcess_test.go @@ -27,6 +27,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data/typeConverters/uint64ByteSlice" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -2561,6 +2562,13 @@ func TestBaseProcessor_commitTrieEpochRootHashIfNeededShouldUseDataTrieIfNeededW return nil }, + GetAccountsFactoryCalled: func() state.AccountFactory { + return &stateMock.AccountsFactoryStub{ + CreateAccountCalled: func(address []byte) (vmcommon.AccountHandler, error) { + return stateMock.NewAccountWrapMock([]byte("address")), nil + }, + } + }, }, } diff --git a/process/block/interface.go b/process/block/interface.go index 53ed9422f9a..8560f198035 100644 --- a/process/block/interface.go +++ b/process/block/interface.go @@ -5,6 +5,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process/asyncExecution/executionTrack" @@ -87,3 +88,8 @@ type MissingDataResolver interface { Reset() IsInterfaceNil() bool } + +type adbWithAccountsFactory interface { + state.AccountsAdapter + GetAccountsFactory() state.AccountFactory +} diff --git a/process/rewardTransaction/process_test.go b/process/rewardTransaction/process_test.go index a26c0cf9441..872831d2546 100644 --- a/process/rewardTransaction/process_test.go +++ b/process/rewardTransaction/process_test.go @@ -8,6 +8,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/rewardTx" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/process/rewardTransaction" @@ -273,13 +274,16 @@ func TestRewardTxProcessor_ProcessRewardTransactionMissingTrieNode(t *testing.T) missingNodeErr := fmt.Errorf(core.GetNodeFromDBErrorString) accountsDb := &stateMock.AccountsStub{ LoadAccountCalled: func(address []byte) (vmcommon.AccountHandler, error) { - acc := stateMock.NewAccountWrapMock(address) - acc.SetDataTrie(&trie.TrieStub{ - GetCalled: func(key []byte) ([]byte, uint32, error) { - return nil, 0, missingNodeErr + dth := &trie.TriesHolderStub{ + GetCalled: func(_ []byte) common.Trie { + return &trie.TrieStub{ + GetCalled: func(key []byte) ([]byte, uint32, error) { + return nil, 0, missingNodeErr + }, + } }, - }, - ) + } + acc := stateMock.NewAccountWrapMockWithDataTrieHolder(dth) return acc, nil }, @@ -311,13 +315,16 @@ func TestRewardTxProcessor_ProcessRewardTransactionToASmartContractShouldWork(t address := []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6} - dtt, _ := trackableDataTrie.NewTrackableDataTrie( - address, - &hashingMocks.HasherMock{}, - &marshallerMock.MarshalizerMock{}, - enableEpochsHandlerMock.NewEnableEpochsHandlerStub(), - &stateMock.StateAccessesCollectorStub{}, - ) + args := trackableDataTrie.TrackableDataTrieArgs{ + Identifier: address, + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + StateAccessesCollector: &stateMock.StateAccessesCollectorStub{}, + DataTriesHolder: &trie.TriesHolderStub{}, + DataTrieCreator: &trie.TrieStub{}, + } + dtt, _ := trackableDataTrie.NewTrackableDataTrie(args) userAccount, _ := accounts.NewUserAccount(address, dtt, &trie.TrieLeafParserStub{}) accountsDb := &stateMock.AccountsStub{ LoadAccountCalled: func(address []byte) (vmcommon.AccountHandler, error) { diff --git a/process/scToProtocol/stakingToPeer_test.go b/process/scToProtocol/stakingToPeer_test.go index db8efdac667..97890d10be2 100644 --- a/process/scToProtocol/stakingToPeer_test.go +++ b/process/scToProtocol/stakingToPeer_test.go @@ -61,13 +61,16 @@ func createBlockBody() *block.Body { } func createStakingScAccount() state.UserAccountHandler { - dtt, _ := trackableDataTrie.NewTrackableDataTrie( - vm.StakingSCAddress, - &hashingMocks.HasherMock{}, - &marshallerMock.MarshalizerMock{}, - enableEpochsHandlerMock.NewEnableEpochsHandlerStub(), - &stateMock.StateAccessesCollectorStub{}, - ) + args := trackableDataTrie.TrackableDataTrieArgs{ + Identifier: vm.StakingSCAddress, + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + StateAccessesCollector: &stateMock.StateAccessesCollectorStub{}, + DataTriesHolder: &trie.TriesHolderStub{}, + DataTrieCreator: &trie.TrieStub{}, + } + dtt, _ := trackableDataTrie.NewTrackableDataTrie(args) userAcc, _ := accounts.NewUserAccount(vm.StakingSCAddress, dtt, &trie.TrieLeafParserStub{}) return userAcc diff --git a/process/smartContract/hooks/blockChainHook_test.go b/process/smartContract/hooks/blockChainHook_test.go index 803b816182d..9a0890375fb 100644 --- a/process/smartContract/hooks/blockChainHook_test.go +++ b/process/smartContract/hooks/blockChainHook_test.go @@ -2333,7 +2333,6 @@ func TestBlockChainHookImpl_GetESDTToken(t *testing.T) { args.Accounts = &stateMock.AccountsStub{ GetExistingAccountCalled: func(addressContainer []byte) (vmcommon.AccountHandler, error) { addressHandler := stateMock.NewAccountWrapMock(address) - addressHandler.SetDataTrie(nil) return addressHandler, nil }, @@ -2354,12 +2353,16 @@ func TestBlockChainHookImpl_GetESDTToken(t *testing.T) { args := createMockBlockChainHookArgs() args.Accounts = &stateMock.AccountsStub{ GetExistingAccountCalled: func(addressContainer []byte) (vmcommon.AccountHandler, error) { - addressHandler := stateMock.NewAccountWrapMock(address) - addressHandler.SetDataTrie(&trie.TrieStub{ - GetCalled: func(_ []byte) ([]byte, uint32, error) { - return make([]byte, 0), 0, nil + dth := &trie.TriesHolderStub{ + GetCalled: func(_ []byte) common.Trie { + return &trie.TrieStub{ + GetCalled: func(_ []byte) ([]byte, uint32, error) { + return make([]byte, 0), 0, nil + }, + } }, - }) + } + addressHandler := stateMock.NewAccountWrapMockWithDataTrieHolder(dth) return addressHandler, nil }, diff --git a/process/smartContract/processorV2/process_test.go b/process/smartContract/processorV2/process_test.go index 28e644e86a8..5bab320366a 100644 --- a/process/smartContract/processorV2/process_test.go +++ b/process/smartContract/processorV2/process_test.go @@ -37,6 +37,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" testsCommonStorage "github.com/multiversx/mx-chain-go/testscommon/storage" + "github.com/multiversx/mx-chain-go/testscommon/trie" "github.com/multiversx/mx-chain-go/testscommon/vmcommonMocks" "github.com/multiversx/mx-chain-go/txcache" vmcommon "github.com/multiversx/mx-chain-vm-common-go" @@ -66,6 +67,8 @@ func createAccount(address []byte) state.UserAccountHandler { Marshaller: &marshallerMock.MarshalizerMock{}, EnableEpochsHandler: enableEpochsHandlerMock.NewEnableEpochsHandlerStub(), StateAccessesCollector: &stateMock.StateAccessesCollectorStub{}, + DataTriesHolder: &trie.TriesHolderStub{}, + DataTrieCreator: &trie.TrieStub{}, } accountFactory, _ := stateFactory.NewAccountCreator(argsAccCreation) account, _ := accountFactory.CreateAccount(address) diff --git a/process/sync/baseForkDetector_test.go b/process/sync/baseForkDetector_test.go index 2fa476de1df..c140471b48f 100644 --- a/process/sync/baseForkDetector_test.go +++ b/process/sync/baseForkDetector_test.go @@ -5,9 +5,6 @@ import ( "testing" "time" - "github.com/multiversx/mx-chain-go/testscommon/chainParameters" - "github.com/multiversx/mx-chain-go/testscommon/processMocks" - "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" @@ -17,8 +14,10 @@ import ( "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/process/sync" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/chainParameters" "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/processMocks" "github.com/stretchr/testify/assert" ) diff --git a/process/sync/trieIterators/tokensSuppliesProcessor_test.go b/process/sync/trieIterators/tokensSuppliesProcessor_test.go index 8bd3949d4bc..a2e2217d9e5 100644 --- a/process/sync/trieIterators/tokensSuppliesProcessor_test.go +++ b/process/sync/trieIterators/tokensSuppliesProcessor_test.go @@ -121,16 +121,19 @@ func TestTokensSuppliesProcessor_HandleTrieAccountIteration(t *testing.T) { expectedErr := errors.New("error") - userAcc, _ := accounts.NewUserAccount([]byte("addr"), &trie.DataTrieTrackerStub{}, &trie.TrieLeafParserStub{}) - userAcc.SetRootHash([]byte("rootHash")) - userAcc.SetDataTrie(&trie.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder, leafParser common.TrieLeafParser) error { - return expectedErr - }, - RootCalled: func() ([]byte, error) { - return []byte("rootHash"), nil + userAcc, _ := accounts.NewUserAccount([]byte("addr"), &trie.DataTrieTrackerStub{ + DataTrieCalled: func() common.Trie { + return &trie.TrieStub{ + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder, leafParser common.TrieLeafParser) error { + return expectedErr + }, + RootCalled: func() ([]byte, error) { + return []byte("rootHash"), nil + }, + } }, - }) + }, &trie.TrieLeafParserStub{}) + userAcc.SetRootHash([]byte("rootHash")) err := tsp.HandleTrieAccountIteration(userAcc) require.ErrorIs(t, err, expectedErr) @@ -143,19 +146,22 @@ func TestTokensSuppliesProcessor_HandleTrieAccountIteration(t *testing.T) { args := getTokensSuppliesProcessorArgs() tsp, _ := NewTokensSuppliesProcessor(args) - userAcc, _ := accounts.NewUserAccount([]byte("addr"), &trie.DataTrieTrackerStub{}, &trie.TrieLeafParserStub{}) - userAcc.SetRootHash([]byte("rootHash")) - userAcc.SetDataTrie(&trie.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder, leafParser common.TrieLeafParser) error { - leavesChannels.LeavesChan <- keyValStorage.NewKeyValStorage([]byte("not a token key"), []byte("not a token value")) + userAcc, _ := accounts.NewUserAccount([]byte("addr"), &trie.DataTrieTrackerStub{ + DataTrieCalled: func() common.Trie { + return &trie.TrieStub{ + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder, leafParser common.TrieLeafParser) error { + leavesChannels.LeavesChan <- keyValStorage.NewKeyValStorage([]byte("not a token key"), []byte("not a token value")) - close(leavesChannels.LeavesChan) - return nil - }, - RootCalled: func() ([]byte, error) { - return []byte("rootHash"), nil + close(leavesChannels.LeavesChan) + return nil + }, + RootCalled: func() ([]byte, error) { + return []byte("rootHash"), nil + }, + } }, - }) + }, &trie.TrieLeafParserStub{}) + userAcc.SetRootHash([]byte("rootHash")) err := tsp.HandleTrieAccountIteration(userAcc) require.NoError(t, err) @@ -168,23 +174,26 @@ func TestTokensSuppliesProcessor_HandleTrieAccountIteration(t *testing.T) { args := getTokensSuppliesProcessorArgs() tsp, _ := NewTokensSuppliesProcessor(args) - userAcc, _ := accounts.NewUserAccount(vmcommon.SystemAccountAddress, &trie.DataTrieTrackerStub{}, &trie.TrieLeafParserStub{}) - userAcc.SetRootHash([]byte("rootHash")) - userAcc.SetDataTrie(&trie.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder, leafParser common.TrieLeafParser) error { - esToken := &esdt.ESDigitalToken{ - Value: big.NewInt(37), + userAcc, _ := accounts.NewUserAccount(vmcommon.SystemAccountAddress, &trie.DataTrieTrackerStub{ + DataTrieCalled: func() common.Trie { + return &trie.TrieStub{ + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder, leafParser common.TrieLeafParser) error { + esToken := &esdt.ESDigitalToken{ + Value: big.NewInt(37), + } + esBytes, _ := args.Marshaller.Marshal(esToken) + tknKey := []byte("ELRONDesdtTKN-00aacc") + value := append(esBytes, tknKey...) + value = append(value, []byte("addr")...) + leavesChannels.LeavesChan <- keyValStorage.NewKeyValStorage(tknKey, value) + + close(leavesChannels.LeavesChan) + return nil + }, } - esBytes, _ := args.Marshaller.Marshal(esToken) - tknKey := []byte("ELRONDesdtTKN-00aacc") - value := append(esBytes, tknKey...) - value = append(value, []byte("addr")...) - leavesChannels.LeavesChan <- keyValStorage.NewKeyValStorage(tknKey, value) - - close(leavesChannels.LeavesChan) - return nil }, - }) + }, &trie.TrieLeafParserStub{}) + userAcc.SetRootHash([]byte("rootHash")) err := tsp.HandleTrieAccountIteration(userAcc) require.NoError(t, err) @@ -196,49 +205,54 @@ func TestTokensSuppliesProcessor_HandleTrieAccountIteration(t *testing.T) { args := getTokensSuppliesProcessorArgs() tsp, _ := NewTokensSuppliesProcessor(args) - - dtt, _ := trackableDataTrie.NewTrackableDataTrie( - []byte("addr"), - &hashingMocks.HasherMock{}, - &marshallerMock.MarshalizerMock{}, - &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - &stateMock.StateAccessesCollectorStub{}, - ) + argsTdt := trackableDataTrie.TrackableDataTrieArgs{ + Identifier: []byte("addr"), + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + StateAccessesCollector: &stateMock.StateAccessesCollectorStub{}, + DataTriesHolder: &trie.TriesHolderStub{ + GetCalled: func(i []byte) common.Trie { + return &trie.TrieStub{ + GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder, leafParser common.TrieLeafParser) error { + esToken := &esdt.ESDigitalToken{ + Value: big.NewInt(37), + } + esBytes, _ := args.Marshaller.Marshal(esToken) + tknKey := []byte("ELRONDesdtTKN-00aacc") + value := append(esBytes, tknKey...) + value = append(value, []byte("addr")...) + leaf, err := leafParser.ParseLeaf(tknKey, value, 0) + require.Nil(t, err) + leavesChannels.LeavesChan <- leaf + + sft := &esdt.ESDigitalToken{ + Value: big.NewInt(1), + } + sftBytes, _ := args.Marshaller.Marshal(sft) + sftKey := []byte("ELRONDesdtSFT-00aabb") + sftKey = append(sftKey, big.NewInt(37).Bytes()...) + value = append(sftBytes, sftKey...) + value = append(value, []byte("addr")...) + leaf, err = leafParser.ParseLeaf(sftKey, value, 0) + require.Nil(t, err) + leavesChannels.LeavesChan <- leaf + + close(leavesChannels.LeavesChan) + return nil + }, + RootCalled: func() ([]byte, error) { + return []byte("rootHash"), nil + }, + } + }, + }, + DataTrieCreator: &trie.TrieStub{}, + } + dtt, _ := trackableDataTrie.NewTrackableDataTrie(argsTdt) dtlp, _ := parsers.NewDataTrieLeafParser([]byte("addr"), &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) userAcc, _ := accounts.NewUserAccount([]byte("addr"), dtt, dtlp) userAcc.SetRootHash([]byte("rootHash")) - userAcc.SetDataTrie(&trie.TrieStub{ - GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, keyBuilder common.KeyBuilder, leafParser common.TrieLeafParser) error { - esToken := &esdt.ESDigitalToken{ - Value: big.NewInt(37), - } - esBytes, _ := args.Marshaller.Marshal(esToken) - tknKey := []byte("ELRONDesdtTKN-00aacc") - value := append(esBytes, tknKey...) - value = append(value, []byte("addr")...) - leaf, err := leafParser.ParseLeaf(tknKey, value, 0) - require.Nil(t, err) - leavesChannels.LeavesChan <- leaf - - sft := &esdt.ESDigitalToken{ - Value: big.NewInt(1), - } - sftBytes, _ := args.Marshaller.Marshal(sft) - sftKey := []byte("ELRONDesdtSFT-00aabb") - sftKey = append(sftKey, big.NewInt(37).Bytes()...) - value = append(sftBytes, sftKey...) - value = append(value, []byte("addr")...) - leaf, err = leafParser.ParseLeaf(sftKey, value, 0) - require.Nil(t, err) - leavesChannels.LeavesChan <- leaf - - close(leavesChannels.LeavesChan) - return nil - }, - RootCalled: func() ([]byte, error) { - return []byte("rootHash"), nil - }, - }) err := tsp.HandleTrieAccountIteration(userAcc) require.NoError(t, err) diff --git a/state/accounts/userAccount.go b/state/accounts/userAccount.go index d7bb1055e08..0af05b86e6f 100644 --- a/state/accounts/userAccount.go +++ b/state/accounts/userAccount.go @@ -147,6 +147,12 @@ func (a *userAccount) SetCodeHash(codeHash []byte) { // SetRootHash sets the root hash associated with the account func (a *userAccount) SetRootHash(roothash []byte) { a.RootHash = roothash + a.dataTrieInteractor.SetRootHash(roothash) +} + +// SetDataTrieRootHash sets the root hash of the data trie to the one stored in the account +func (a *userAccount) SetDataTrieRootHash() { + a.dataTrieInteractor.SetRootHash(a.RootHash) } // SetCodeMetadata sets the code metadata diff --git a/state/accountsDB.go b/state/accountsDB.go index 9d57fcc06a5..a450fe057ae 100644 --- a/state/accountsDB.go +++ b/state/accountsDB.go @@ -16,6 +16,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data/stateChange" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/errors" logger "github.com/multiversx/mx-chain-logger-go" vmcommon "github.com/multiversx/mx-chain-vm-common-go" @@ -107,6 +108,7 @@ type ArgsAccountsDB struct { SnapshotsManager SnapshotsManager StateAccessesCollector StateAccessesCollector PruningEnabled bool + DataTriesHolder common.TriesHolder } // NewAccountsDB creates a new account manager @@ -128,7 +130,7 @@ func createAccountsDb(args ArgsAccountsDB) *AccountsDB { storagePruningManager: args.StoragePruningManager, entries: make([]JournalEntry, 0), mutOp: sync.RWMutex{}, - dataTries: NewDataTriesHolder(), + dataTries: args.DataTriesHolder, obsoleteDataTrieHashes: make(map[string][][]byte), loadCodeMeasurements: &loadingMeasurements{ identifier: "load code", @@ -165,6 +167,9 @@ func checkArgsAccountsDB(args ArgsAccountsDB) error { if check.IfNil(args.StateAccessesCollector) { return ErrNilStateAccessesCollector } + if check.IfNil(args.DataTriesHolder) { + return errors.ErrNilDataTriesHolder + } return nil } @@ -502,33 +507,6 @@ func saveCodeEntry(codeHash []byte, entry *CodeEntry, trie Updater, marshalizer return codeEntry, nil } -// loadDataTrieConcurrentSafe retrieves and saves the SC data inside accountHandler object. -// Errors if something went wrong -func (adb *AccountsDB) loadDataTrieConcurrentSafe(accountHandler baseAccountHandler, mainTrie common.Trie) error { - adb.mutOp.Lock() - defer adb.mutOp.Unlock() - - dataTrie := adb.dataTries.Get(accountHandler.AddressBytes()) - if dataTrie != nil { - accountHandler.SetDataTrie(dataTrie) - return nil - } - - if len(accountHandler.GetRootHash()) == 0 { - return nil - } - - rootHashHolder := holders.NewDefaultRootHashesHolder(accountHandler.GetRootHash()) - dataTrie, err := mainTrie.Recreate(rootHashHolder) - if err != nil { - return fmt.Errorf("trie was not found for hash, rootHash = %s, err = %w", hex.EncodeToString(accountHandler.GetRootHash()), err) - } - - accountHandler.SetDataTrie(dataTrie) - adb.dataTries.Put(accountHandler.AddressBytes(), dataTrie) - return nil -} - // saveDataTrie is used to save the data trie (not committing it) and to recompute the new Root value // If data is not dirtied, method will not create its JournalEntries to keep track of data modification func (adb *AccountsDB) saveDataTrie(accountHandler baseAccountHandler) ([]*stateChange.DataTrieChange, error) { @@ -553,16 +531,17 @@ func (adb *AccountsDB) saveDataTrie(accountHandler baseAccountHandler) ([]*state accountHandler.SetRootHash(rootHash) log.Trace("saveDataTrie: rootHash changed", "address", accountHandler.AddressBytes(), "rootHash", rootHash) - if check.IfNil(adb.dataTries.Get(accountHandler.AddressBytes())) { - trie, ok := accountHandler.DataTrie().(common.Trie) - if !ok { - log.Warn("wrong type conversion", "trie type", fmt.Sprintf("%T", accountHandler.DataTrie())) - return nil, nil - } + if !check.IfNil(adb.dataTries.Get(accountHandler.AddressBytes())) { + adb.dataTries.MarkAsDirty(accountHandler.AddressBytes()) + return newValues, nil + } - adb.dataTries.Put(accountHandler.AddressBytes(), trie) + trie, ok := accountHandler.DataTrie().(common.Trie) + if !ok { + return nil, fmt.Errorf("wrong type conversion, trie type %T", accountHandler.DataTrie()) } + adb.dataTries.Put(accountHandler.AddressBytes(), trie) return newValues, nil } @@ -671,6 +650,11 @@ func (adb *AccountsDB) removeDataTrie(baseAcc baseAccountHandler) error { } adb.journalize(entry) + // Evict the cached trie for this address so that a subsequent recreation of + // the account at the same address cannot inherit the stale data trie from + // this deleted incarnation (see loadDataTrieConcurrentSafe / saveDataTrie). + adb.dataTries.Remove(baseAcc.AddressBytes()) + return nil } @@ -707,14 +691,6 @@ func (adb *AccountsDB) LoadAccount(address []byte) (vmcommon.AccountHandler, err return adb.accountFactory.CreateAccount(address) } - baseAcc, ok := acnt.(baseAccountHandler) - if ok { - err = adb.loadDataTrieConcurrentSafe(baseAcc, mainTrie) - if err != nil { - return nil, err - } - } - return acnt, nil } @@ -746,7 +722,13 @@ func (adb *AccountsDB) getAccount(address []byte, mainTrie common.Trie) (vmcommo return nil, err } - return acnt, nil + baseAcc, ok := acnt.(baseAccountHandler) + if !ok { + return acnt, nil + } + baseAcc.SetDataTrieRootHash() + + return baseAcc, nil } // GetExistingAccount returns an existing account if exists or nil if missing @@ -766,14 +748,6 @@ func (adb *AccountsDB) GetExistingAccount(address []byte) (vmcommon.AccountHandl return nil, ErrAccNotFound } - baseAcc, ok := acnt.(baseAccountHandler) - if ok { - err = adb.loadDataTrieConcurrentSafe(baseAcc, mainTrie) - if err != nil { - return nil, err - } - } - return acnt, nil } @@ -798,12 +772,8 @@ func (adb *AccountsDB) GetAccountFromBytes(address []byte, accountBytes []byte) return acnt, nil } - err = adb.loadDataTrieConcurrentSafe(baseAcc, adb.getMainTrie()) - if err != nil { - return nil, err - } - - return acnt, nil + baseAcc.SetDataTrieRootHash() + return baseAcc, nil } // loadCode retrieves and saves the SC code inside AccountState object. Errors if something went wrong @@ -924,7 +894,7 @@ func (adb *AccountsDB) commit() ([]byte, error) { oldHashes := make(common.ModifiedHashes) newHashes := make(common.ModifiedHashes) - // Step 1. commit all data tries + // Step 1. commit all data tries. GetAll returns only the dirty tries for the dataTriesHolder implementation dataTries := adb.dataTries.GetAll() for i := 0; i < len(dataTries); i++ { err := adb.commitTrie(dataTries[i], oldHashes, newHashes) @@ -932,7 +902,6 @@ func (adb *AccountsDB) commit() ([]byte, error) { return nil, err } } - adb.dataTries.Reset() oldRoot := adb.mainTrie.GetOldRoot() @@ -1374,6 +1343,11 @@ func (adb *AccountsDB) IsSnapshotInProgress() bool { return adb.snapshotsManger.IsSnapshotInProgress() } +// GetAccountsFactory returns the accounts factory used by the accountsDB +func (adb *AccountsDB) GetAccountsFactory() AccountFactory { + return adb.accountFactory +} + // IsInterfaceNil returns true if there is no value under the interface func (adb *AccountsDB) IsInterfaceNil() bool { return adb == nil diff --git a/state/accountsDB_test.go b/state/accountsDB_test.go index 2f2125c7737..31cf47b1d97 100644 --- a/state/accountsDB_test.go +++ b/state/accountsDB_test.go @@ -18,6 +18,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core/keyValStorage" data "github.com/multiversx/mx-chain-core-go/data/stateChange" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/state/triesHolder" vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/multiversx/mx-chain-vm-common-go/dataTrieMigrator" "github.com/stretchr/testify/assert" @@ -41,17 +42,20 @@ import ( "github.com/multiversx/mx-chain-go/state/storagePruningManager/disabled" "github.com/multiversx/mx-chain-go/state/storagePruningManager/evictionWaitingList" "github.com/multiversx/mx-chain-go/testscommon" + testCommon "github.com/multiversx/mx-chain-go/testscommon/common" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" - "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/testscommon/storageManager" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" "github.com/multiversx/mx-chain-go/trie" + "github.com/multiversx/mx-chain-go/trie/collapseManager" ) -const trieDbOperationDelay = time.Second +const ( + trieDbOperationDelay = time.Second +) func createMockAccountsDBArgs() state.ArgsAccountsDB { accCreator := &stateMock.AccountsFactoryStub{ @@ -85,6 +89,7 @@ func createMockAccountsDBArgs() state.ArgsAccountsDB { AddressConverter: &testscommon.PubkeyConverterMock{}, SnapshotsManager: snapshotsManager, StateAccessesCollector: &stateMock.StateAccessesCollectorStub{}, + DataTriesHolder: &trieMock.TriesHolderStub{}, } } @@ -141,21 +146,24 @@ func getDefaultStateComponents( marshaller := &marshallerMock.MarshalizerMock{} hasher := &hashingMocks.HasherMock{} - args := storage.GetStorageManagerArgs() + args := testCommon.GetStorageManagerArgs() args.MainStorer = db trieStorage, _ := trie.NewTrieStorageManager(args) - tr, _ := trie.NewTrie(trieStorage, marshaller, hasher, enableEpochsHandler, 5) + tr, _ := trie.NewTrie(trieStorage, marshaller, hasher, enableEpochsHandler, collapseManager.NewDisabledCollapseManager()) ewlArgs := evictionWaitingList.MemoryEvictionWaitingListArgs{ RootHashesSize: 100, HashesSize: 10000, } ewl, _ := evictionWaitingList.NewMemoryEvictionWaitingList(ewlArgs) spm, _ := storagePruningManager.NewStoragePruningManager(ewl, generalCfg.PruningBufferLen) + dth, _ := triesHolder.NewDataTriesHolder(common.TenMbSize) argsAccCreator := factory.ArgsAccountCreator{ Hasher: hasher, Marshaller: marshaller, EnableEpochsHandler: enableEpochsHandler, StateAccessesCollector: &stateMock.StateAccessesCollectorStub{}, + DataTriesHolder: dth, + DataTrieCreator: tr, } accCreator, _ := factory.NewAccountCreator(argsAccCreator) @@ -182,6 +190,7 @@ func getDefaultStateComponents( SnapshotsManager: snapshotsManager, StateAccessesCollector: collector, PruningEnabled: true, + DataTriesHolder: dth, } adb, _ := state.NewAccountsDB(argsAccountsDB) @@ -741,6 +750,55 @@ func TestAccountsDB_RemoveAccountShouldWork(t *testing.T) { assert.Equal(t, 2, adb.JournalLen()) } +func TestAccountsDB_RemoveAccountCommitAndRecreateSameAddressShouldNotReuseOldDataTrie(t *testing.T) { + t.Parallel() + + _, adb := getDefaultTrieAndAccountsDb() + address := generateRandomByteArray(32) + oldKey := []byte("old-key") + oldValue := []byte("old-value") + newKey := []byte("new-key") + newValue := []byte("new-value") + + acc, err := adb.LoadAccount(address) + require.NoError(t, err) + userAcc := acc.(state.UserAccountHandler) + + err = userAcc.SaveKeyValue(oldKey, oldValue) + require.NoError(t, err) + err = adb.SaveAccount(userAcc) + require.NoError(t, err) + _, err = adb.Commit() + require.NoError(t, err) + + err = adb.RemoveAccount(address) + require.NoError(t, err) + _, err = adb.Commit() + require.NoError(t, err) + + acc, err = adb.LoadAccount(address) + require.NoError(t, err) + userAcc = acc.(state.UserAccountHandler) + err = userAcc.SaveKeyValue(newKey, newValue) + require.NoError(t, err) + err = adb.SaveAccount(userAcc) + require.NoError(t, err) + _, err = adb.Commit() + require.NoError(t, err) + + acc, err = adb.LoadAccount(address) + require.NoError(t, err) + userAcc = acc.(state.UserAccountHandler) + + val, _, err := userAcc.RetrieveValue(oldKey) + require.NoError(t, err) + assert.Len(t, val, 0) + + val, _, err = userAcc.RetrieveValue(newKey) + require.NoError(t, err) + assert.Equal(t, newValue, val) +} + // ------- LoadAccount func TestAccountsDB_LoadAccountMalfunctionTrieShouldErr(t *testing.T) { @@ -814,7 +872,19 @@ func TestAccountsDB_LoadAccountExistingShouldLoadDataTrie(t *testing.T) { }, } - adb := generateAccountDBFromTrie(trieStub) + dth := &trieMock.TriesHolderStub{ + GetCalled: func(_ []byte) common.Trie { + return dataTrie + }, + } + args := createMockAccountsDBArgs() + args.Trie = trieStub + args.AccountFactory = &stateMock.AccountsFactoryStub{ + CreateAccountCalled: func(_ []byte) (vmcommon.AccountHandler, error) { + return stateMock.NewAccountWrapMockWithDataTrieHolder(dth), nil + }, + } + adb, _ := state.NewAccountsDB(args) retrievedAccount, err := adb.LoadAccount(acc.AddressBytes()) assert.Nil(t, err) @@ -892,7 +962,19 @@ func TestAccountsDB_GetExistingAccountFoundShouldRetAccount(t *testing.T) { }, } - adb := generateAccountDBFromTrie(trieStub) + dth := &trieMock.TriesHolderStub{ + GetCalled: func(_ []byte) common.Trie { + return dataTrie + }, + } + args := createMockAccountsDBArgs() + args.Trie = trieStub + args.AccountFactory = &stateMock.AccountsFactoryStub{ + CreateAccountCalled: func(_ []byte) (vmcommon.AccountHandler, error) { + return stateMock.NewAccountWrapMockWithDataTrieHolder(dth), nil + }, + } + adb, _ := state.NewAccountsDB(args) retrievedAccount, err := adb.GetExistingAccount(acc.AddressBytes()) assert.Nil(t, err) @@ -1013,134 +1095,31 @@ func TestAccountsDB_LoadCodeOkValsShouldWork(t *testing.T) { assert.Equal(t, adr, state.GetCode(account)) } -// ------- RetrieveData - -func TestAccountsDB_LoadDataNilRootShouldRetNil(t *testing.T) { - t.Parallel() - - tr := &trieMock.TrieStub{ - GetStorageManagerCalled: func() common.StorageManager { - return &storageManager.StorageManagerStub{} - }, - } - _, account, adb := generateAddressAccountAccountsDB(tr) - - // since root is nil, result should be nil and data trie should be nil - err := adb.LoadDataTrieConcurrentSafe(account) - assert.Nil(t, err) - assert.Nil(t, account.DataTrie()) -} - -func TestAccountsDB_LoadDataBadLengthShouldErr(t *testing.T) { - t.Parallel() - - _, account, adb := generateAddressAccountAccountsDB(&trieMock.TrieStub{ - GetStorageManagerCalled: func() common.StorageManager { - return &storageManager.StorageManagerStub{} - }, - }) - - account.SetRootHash([]byte("12345")) - - // should return error - err := adb.LoadDataTrieConcurrentSafe(account) - assert.NotNil(t, err) -} - -func TestAccountsDB_LoadDataMalfunctionTrieShouldErr(t *testing.T) { - t.Parallel() - - account := generateAccount() - account.SetRootHash([]byte("12345")) - - mockTrie := &trieMock.TrieStub{ - GetStorageManagerCalled: func() common.StorageManager { - return &storageManager.StorageManagerStub{} - }, - } - adb := generateAccountDBFromTrie(mockTrie) - - // should return error - err := adb.LoadDataTrieConcurrentSafe(account) - assert.NotNil(t, err) -} - -func TestAccountsDB_LoadDataNotFoundRootShouldReturnErr(t *testing.T) { - t.Parallel() - - _, account, adb := generateAddressAccountAccountsDB(&trieMock.TrieStub{ - GetStorageManagerCalled: func() common.StorageManager { - return &storageManager.StorageManagerStub{} - }, - }) - - rootHash := make([]byte, (&hashingMocks.HasherMock{}).Size()) - rootHash[0] = 1 - account.SetRootHash(rootHash) - - // should return error - err := adb.LoadDataTrieConcurrentSafe(account) - assert.NotNil(t, err) - fmt.Println(err.Error()) -} +// ------- Commit -func TestAccountsDB_LoadDataWithSomeValuesShouldWork(t *testing.T) { +func TestAccountsDB_CommitShouldCallCommitFromTrie(t *testing.T) { t.Parallel() - rootHash := make([]byte, (&hashingMocks.HasherMock{}).Size()) - rootHash[0] = 1 - keyRequired := []byte{65, 66, 67} - val := []byte{32, 33, 34} - - trieVal := append(val, keyRequired...) - trieVal = append(trieVal, []byte("identifier")...) - + commitCalled := 0 + marshaller := &marshallerMock.MarshalizerMock{} + serializedAccount, _ := marshaller.Marshal(stateMock.AccountWrapMock{}) dataTrie := &trieMock.TrieStub{ - GetCalled: func(key []byte) ([]byte, uint32, error) { - if bytes.Equal(key, keyRequired) { - return trieVal, 0, nil - } - - return nil, 0, nil + GetCalled: func(_ []byte) ([]byte, uint32, error) { + return []byte("doge"), 0, nil }, - } - - account := generateAccount() - mockTrie := &trieMock.TrieStub{ - RecreateCalled: func(root common.RootHashHolder) (trie common.Trie, e error) { - if !bytes.Equal(root.GetRootHash(), rootHash) { - return nil, errors.New("bad root hash") - } + UpdateWithVersionCalled: func(key, value []byte, version core.TrieNodeVersion) error { + return nil + }, + CommitCalled: func() error { + commitCalled++ - return dataTrie, nil + return nil }, - GetStorageManagerCalled: func() common.StorageManager { - return &storageManager.StorageManagerStub{} + RootCalled: func() ([]byte, error) { + return nil, nil }, } - adb := generateAccountDBFromTrie(mockTrie) - - account.SetRootHash(rootHash) - - // should not return error - err := adb.LoadDataTrieConcurrentSafe(account) - assert.Nil(t, err) - - // verify data - dataRecov, _, err := account.RetrieveValue(keyRequired) - assert.Nil(t, err) - assert.Equal(t, val, dataRecov) -} - -// ------- Commit - -func TestAccountsDB_CommitShouldCallCommitFromTrie(t *testing.T) { - t.Parallel() - - commitCalled := 0 - marshaller := &marshallerMock.MarshalizerMock{} - serializedAccount, _ := marshaller.Marshal(stateMock.AccountWrapMock{}) - trieStub := trieMock.TrieStub{ + trieStub := &trieMock.TrieStub{ CommitCalled: func() error { commitCalled++ @@ -1152,30 +1131,20 @@ func TestAccountsDB_CommitShouldCallCommitFromTrie(t *testing.T) { GetCalled: func(_ []byte) ([]byte, uint32, error) { return serializedAccount, 0, nil }, - RecreateCalled: func(root common.RootHashHolder) (trie common.Trie, err error) { - return &trieMock.TrieStub{ - GetCalled: func(_ []byte) ([]byte, uint32, error) { - return []byte("doge"), 0, nil - }, - UpdateWithVersionCalled: func(key, value []byte, version core.TrieNodeVersion) error { - return nil - }, - CommitCalled: func() error { - commitCalled++ - - return nil - }, - RootCalled: func() ([]byte, error) { - return nil, nil - }, - }, nil - }, GetStorageManagerCalled: func() common.StorageManager { return &storageManager.StorageManagerStub{} }, } - adb := generateAccountDBFromTrie(&trieStub) + args := createMockAccountsDBArgs() + args.Trie = trieStub + dth := &trieMock.TriesHolderStub{ + GetAllCalled: func() []common.Trie { + return []common.Trie{dataTrie} + }, + } + args.DataTriesHolder = dth + adb, _ := state.NewAccountsDB(args) accnt, _ := adb.LoadAccount(make([]byte, 32)) _ = accnt.(state.UserAccountHandler).SaveKeyValue([]byte("dog"), []byte("puppy")) @@ -2126,10 +2095,9 @@ func TestAccountsDB_MainTrieAutomaticallyMarksCodeUpdatesForEviction(t *testing. marshaller := &marshallerMock.MarshalizerMock{} hasher := &hashingMocks.HasherMock{} ewl := stateMock.NewEvictionWaitingListMock(100) - args := storage.GetStorageManagerArgs() + args := testCommon.GetStorageManagerArgs() tsm, _ := trie.NewTrieStorageManager(args) - maxTrieLevelInMemory := uint(5) - tr, _ := trie.NewTrie(tsm, marshaller, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, maxTrieLevelInMemory) + tr, _ := trie.NewTrie(tsm, marshaller, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, collapseManager.NewDisabledCollapseManager()) spm, _ := storagePruningManager.NewStoragePruningManager(ewl, 5) argsAccountsDB := createMockAccountsDBArgs() @@ -2137,13 +2105,9 @@ func TestAccountsDB_MainTrieAutomaticallyMarksCodeUpdatesForEviction(t *testing. argsAccountsDB.Trie = tr argsAccountsDB.Hasher = hasher argsAccountsDB.Marshaller = marshaller - argsAccCreator := factory.ArgsAccountCreator{ - Hasher: hasher, - Marshaller: marshaller, - EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - StateAccessesCollector: &stateMock.StateAccessesCollectorStub{}, - } - argsAccountsDB.AccountFactory, _ = factory.NewAccountCreator(argsAccCreator) + accFactory, dth := getAccountsCreator(argsAccountsDB) + argsAccountsDB.AccountFactory = accFactory + argsAccountsDB.DataTriesHolder = dth argsAccountsDB.StoragePruningManager = spm adb, _ := state.NewAccountsDB(argsAccountsDB) @@ -2205,14 +2169,13 @@ func TestAccountsDB_RemoveAccountSetsObsoleteHashes(t *testing.T) { func TestAccountsDB_RemoveAccountMarksObsoleteHashesForEviction(t *testing.T) { t.Parallel() - maxTrieLevelInMemory := uint(5) marshaller := &marshallerMock.MarshalizerMock{} hasher := &hashingMocks.HasherMock{} ewl := stateMock.NewEvictionWaitingListMock(100) - args := storage.GetStorageManagerArgs() + args := testCommon.GetStorageManagerArgs() tsm, _ := trie.NewTrieStorageManager(args) - tr, _ := trie.NewTrie(tsm, marshaller, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, maxTrieLevelInMemory) + tr, _ := trie.NewTrie(tsm, marshaller, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, collapseManager.NewDisabledCollapseManager()) spm, _ := storagePruningManager.NewStoragePruningManager(ewl, 5) argsAccountsDB := createMockAccountsDBArgs() @@ -2220,13 +2183,9 @@ func TestAccountsDB_RemoveAccountMarksObsoleteHashesForEviction(t *testing.T) { argsAccountsDB.Trie = tr argsAccountsDB.Hasher = hasher argsAccountsDB.Marshaller = marshaller - argsAccCreator := factory.ArgsAccountCreator{ - Hasher: hasher, - Marshaller: marshaller, - EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - StateAccessesCollector: &stateMock.StateAccessesCollectorStub{}, - } - argsAccountsDB.AccountFactory, _ = factory.NewAccountCreator(argsAccCreator) + accFactory, dth := getAccountsCreator(argsAccountsDB) + argsAccountsDB.AccountFactory = accFactory + argsAccountsDB.DataTriesHolder = dth argsAccountsDB.StoragePruningManager = spm adb, _ := state.NewAccountsDB(argsAccountsDB) @@ -2411,25 +2370,20 @@ func modifyDataTries(t *testing.T, accountsAddresses [][]byte, adb *state.Accoun func TestAccountsDB_GetCode(t *testing.T) { t.Parallel() - maxTrieLevelInMemory := uint(5) marshaller := &marshallerMock.MarshalizerMock{} hasher := &hashingMocks.HasherMock{} - args := storage.GetStorageManagerArgs() + args := testCommon.GetStorageManagerArgs() tsm, _ := trie.NewTrieStorageManager(args) - tr, _ := trie.NewTrie(tsm, marshaller, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, maxTrieLevelInMemory) + tr, _ := trie.NewTrie(tsm, marshaller, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, collapseManager.NewDisabledCollapseManager()) spm := disabled.NewDisabledStoragePruningManager() argsAccountsDB := createMockAccountsDBArgs() argsAccountsDB.Trie = tr argsAccountsDB.Hasher = hasher argsAccountsDB.Marshaller = marshaller - argsAccCreator := factory.ArgsAccountCreator{ - Hasher: hasher, - Marshaller: marshaller, - EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - StateAccessesCollector: &stateMock.StateAccessesCollectorStub{}, - } - argsAccountsDB.AccountFactory, _ = factory.NewAccountCreator(argsAccCreator) + accFactory, dth := getAccountsCreator(argsAccountsDB) + argsAccountsDB.AccountFactory = accFactory + argsAccountsDB.DataTriesHolder = dth argsAccountsDB.StoragePruningManager = spm adb, _ := state.NewAccountsDB(argsAccountsDB) @@ -2591,13 +2545,9 @@ func TestAccountsDB_Close(t *testing.T) { argsAccountsDB.Trie = tr argsAccountsDB.Hasher = hasher argsAccountsDB.Marshaller = marshaller - argsAccCreator := factory.ArgsAccountCreator{ - Hasher: hasher, - Marshaller: marshaller, - EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - StateAccessesCollector: &stateMock.StateAccessesCollectorStub{}, - } - argsAccountsDB.AccountFactory, _ = factory.NewAccountCreator(argsAccCreator) + accFactory, dth := getAccountsCreator(argsAccountsDB) + argsAccountsDB.AccountFactory = accFactory + argsAccountsDB.DataTriesHolder = dth argsAccountsDB.StoragePruningManager = spm adb, _ := state.NewAccountsDB(argsAccountsDB) @@ -2676,47 +2626,6 @@ func TestAccountsDB_GetAccountFromBytes(t *testing.T) { assert.Equal(t, expectedAccount, acc) }) - t.Run("loads data trie for user account", func(t *testing.T) { - t.Parallel() - - rootHash := []byte("root hash") - setDataTrieCalled := false - expectedAccount := &stateMock.UserAccountStub{ - SetDataTrieCalled: func(_ common.Trie) { - setDataTrieCalled = true - }, - GetRootHashCalled: func() []byte { - return rootHash - }, - } - - args := createMockAccountsDBArgs() - args.AccountFactory = &stateMock.AccountsFactoryStub{ - CreateAccountCalled: func(_ []byte) (vmcommon.AccountHandler, error) { - return expectedAccount, nil - }, - } - args.Marshaller = &marshallerMock.MarshalizerStub{ - UnmarshalCalled: func(_ interface{}, _ []byte) error { - return nil - }, - } - args.Trie = &trieMock.TrieStub{ - RecreateCalled: func(root common.RootHashHolder) (common.Trie, error) { - assert.Equal(t, rootHash, root.GetRootHash()) - return &trieMock.TrieStub{}, nil - }, - GetStorageManagerCalled: func() common.StorageManager { - return &storageManager.StorageManagerStub{} - }, - } - adb, _ := state.NewAccountsDB(args) - - acc, err := adb.GetAccountFromBytes([]byte{1}, []byte{}) - assert.Nil(t, err) - assert.Equal(t, expectedAccount, acc) - assert.True(t, setDataTrieCalled) - }) } func TestAccountsDB_GetAccountFromBytesShouldLoadDataTrie(t *testing.T) { @@ -2743,7 +2652,19 @@ func TestAccountsDB_GetAccountFromBytesShouldLoadDataTrie(t *testing.T) { }, } - adb := generateAccountDBFromTrie(trieStub) + dth := &trieMock.TriesHolderStub{ + GetCalled: func(_ []byte) common.Trie { + return dataTrie + }, + } + args := createMockAccountsDBArgs() + args.Trie = trieStub + args.AccountFactory = &stateMock.AccountsFactoryStub{ + CreateAccountCalled: func(_ []byte) (vmcommon.AccountHandler, error) { + return stateMock.NewAccountWrapMockWithDataTrieHolder(dth), nil + }, + } + adb, _ := state.NewAccountsDB(args) retrievedAccount, err := adb.GetAccountFromBytes(acc.AddressBytes(), serializerAcc) assert.Nil(t, err) @@ -2866,27 +2787,37 @@ func TestAccountsDB_NewAccountsDbStartsSnapshotAfterRestart(t *testing.T) { assert.True(t, takeSnapshotCalled.IsSet()) } +func getAccountsCreator(args state.ArgsAccountsDB) (state.AccountFactory, common.TriesHolder) { + dth, _ := triesHolder.NewDataTriesHolder(common.TenMbSize) + argsAccCreator := factory.ArgsAccountCreator{ + Hasher: args.Hasher, + Marshaller: args.Marshaller, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + StateAccessesCollector: &stateMock.StateAccessesCollectorStub{}, + DataTriesHolder: dth, + DataTrieCreator: args.Trie, + } + + accCreator, _ := factory.NewAccountCreator(argsAccCreator) + return accCreator, dth +} + func BenchmarkAccountsDb_GetCodeEntry(b *testing.B) { - maxTrieLevelInMemory := uint(5) marshaller := &marshallerMock.MarshalizerMock{} hasher := &hashingMocks.HasherMock{} - args := storage.GetStorageManagerArgs() + args := testCommon.GetStorageManagerArgs() tsm, _ := trie.NewTrieStorageManager(args) - tr, _ := trie.NewTrie(tsm, marshaller, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, maxTrieLevelInMemory) + tr, _ := trie.NewTrie(tsm, marshaller, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, collapseManager.NewDisabledCollapseManager()) spm := disabled.NewDisabledStoragePruningManager() argsAccountsDB := createMockAccountsDBArgs() argsAccountsDB.Trie = tr argsAccountsDB.Hasher = hasher argsAccountsDB.Marshaller = marshaller - argsAccCreator := factory.ArgsAccountCreator{ - Hasher: hasher, - Marshaller: marshaller, - EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - StateAccessesCollector: &stateMock.StateAccessesCollectorStub{}, - } - argsAccountsDB.AccountFactory, _ = factory.NewAccountCreator(argsAccCreator) + accFactory, dth := getAccountsCreator(argsAccountsDB) + argsAccountsDB.AccountFactory = accFactory + argsAccountsDB.DataTriesHolder = dth argsAccountsDB.StoragePruningManager = spm adb, _ := state.NewAccountsDB(argsAccountsDB) @@ -3196,21 +3127,17 @@ func TestAccountsDB_RevertTxWhichMigratesDataRemovesMigratedData(t *testing.T) { marshaller := &marshallerMock.MarshalizerMock{} hasher := &hashingMocks.HasherMock{} enableEpochsHandler := enableEpochsHandlerMock.NewEnableEpochsHandlerStub() - tsm, _ := trie.NewTrieStorageManager(storage.GetStorageManagerArgs()) - tr, _ := trie.NewTrie(tsm, marshaller, hasher, enableEpochsHandler, uint(5)) + tsm, _ := trie.NewTrieStorageManager(testCommon.GetStorageManagerArgs()) + tr, _ := trie.NewTrie(tsm, marshaller, hasher, enableEpochsHandler, collapseManager.NewDisabledCollapseManager()) spm := &stateMock.StoragePruningManagerStub{} argsAccountsDB := createMockAccountsDBArgs() argsAccountsDB.PruningEnabled = true argsAccountsDB.Trie = tr argsAccountsDB.Hasher = hasher argsAccountsDB.Marshaller = marshaller - argsAccCreator := factory.ArgsAccountCreator{ - Hasher: hasher, - Marshaller: marshaller, - EnableEpochsHandler: enableEpochsHandler, - StateAccessesCollector: &stateMock.StateAccessesCollectorStub{}, - } - argsAccountsDB.AccountFactory, _ = factory.NewAccountCreator(argsAccCreator) + accFactory, dth := getAccountsCreator(argsAccountsDB) + argsAccountsDB.AccountFactory = accFactory + argsAccountsDB.DataTriesHolder = dth argsAccountsDB.StoragePruningManager = spm adb, _ := state.NewAccountsDB(argsAccountsDB) diff --git a/state/dataTriesHolder_test.go b/state/dataTriesHolder_test.go deleted file mode 100644 index 8e65f1bb3b2..00000000000 --- a/state/dataTriesHolder_test.go +++ /dev/null @@ -1,119 +0,0 @@ -package state_test - -import ( - "strconv" - "sync" - "testing" - - "github.com/multiversx/mx-chain-core-go/core/check" - "github.com/multiversx/mx-chain-go/state" - trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" - "github.com/stretchr/testify/assert" -) - -func TestNewDataTriesHolder(t *testing.T) { - t.Parallel() - - dth := state.NewDataTriesHolder() - assert.False(t, check.IfNil(dth)) -} - -func TestDataTriesHolder_PutAndGet(t *testing.T) { - t.Parallel() - - tr1 := &trieMock.TrieStub{} - - dth := state.NewDataTriesHolder() - dth.Put([]byte("trie1"), tr1) - tr := dth.Get([]byte("trie1")) - - assert.True(t, tr == tr1) -} - -func TestDataTriesHolder_Replace(t *testing.T) { - t.Parallel() - - tr1 := &trieMock.TrieStub{} - tr2 := &trieMock.TrieStub{} - - dth := state.NewDataTriesHolder() - dth.Put([]byte("trie1"), tr1) - dth.Replace([]byte("trie1"), tr2) - retrievedTrie := dth.Get([]byte("trie1")) - - assert.True(t, retrievedTrie == tr2) - assert.True(t, retrievedTrie != tr1) -} - -func TestDataTriesHolder_GetAll(t *testing.T) { - t.Parallel() - - tr1 := &trieMock.TrieStub{} - tr2 := &trieMock.TrieStub{} - tr3 := &trieMock.TrieStub{} - - dth := state.NewDataTriesHolder() - dth.Put([]byte("trie1"), tr1) - dth.Put([]byte("trie2"), tr2) - dth.Put([]byte("trie3"), tr3) - tries := dth.GetAll() - - assert.Equal(t, 3, len(tries)) -} - -func TestDataTriesHolder_Reset(t *testing.T) { - t.Parallel() - - tr1 := &trieMock.TrieStub{} - - dth := state.NewDataTriesHolder() - dth.Put([]byte("trie1"), tr1) - dth.Reset() - - tr := dth.Get([]byte("trie1")) - assert.Nil(t, tr) -} - -func TestDataTriesHolder_Concurrency(t *testing.T) { - t.Parallel() - - dth := state.NewDataTriesHolder() - numTries := 50 - - wg := sync.WaitGroup{} - wg.Add(numTries) - - for i := 0; i < numTries; i++ { - go func(key int) { - dth.Put([]byte(strconv.Itoa(key)), &trieMock.TrieStub{}) - wg.Done() - }(i) - } - - wg.Wait() - - tries := dth.GetAll() - assert.Equal(t, numTries, len(tries)) -} - -func TestDataTriesHolder_GetAllTries(t *testing.T) { - t.Parallel() - - dth := state.NewDataTriesHolder() - numTries := 50 - - wg := sync.WaitGroup{} - wg.Add(numTries) - - for i := 0; i < numTries; i++ { - go func(key int) { - dth.Put([]byte(strconv.Itoa(key)), &trieMock.TrieStub{}) - wg.Done() - }(i) - } - - wg.Wait() - - tries := dth.GetAllTries() - assert.Equal(t, numTries, len(tries)) -} diff --git a/state/export_test.go b/state/export_test.go index bbc209312e4..f2bd19b7ad4 100644 --- a/state/export_test.go +++ b/state/export_test.go @@ -13,11 +13,6 @@ func (adb *AccountsDB) LoadCode(accountHandler baseAccountHandler) error { return adb.loadCode(accountHandler) } -// LoadDataTrieConcurrentSafe - -func (adb *AccountsDB) LoadDataTrieConcurrentSafe(accountHandler baseAccountHandler) error { - return adb.loadDataTrieConcurrentSafe(accountHandler, adb.getMainTrie()) -} - // GetAccount - func (adb *AccountsDB) GetAccount(address []byte) (vmcommon.AccountHandler, error) { return adb.getAccount(address, adb.getMainTrie()) diff --git a/state/factory/accountCreator.go b/state/factory/accountCreator.go index fda887fa1e0..995d896c345 100644 --- a/state/factory/accountCreator.go +++ b/state/factory/accountCreator.go @@ -19,6 +19,8 @@ type ArgsAccountCreator struct { Marshaller marshal.Marshalizer EnableEpochsHandler common.EnableEpochsHandler StateAccessesCollector state.StateAccessesCollector + DataTriesHolder common.TriesHolder + DataTrieCreator common.DataTrieCreator } // AccountCreator has method to create a new account @@ -26,7 +28,9 @@ type accountCreator struct { hasher hashing.Hasher marshaller marshal.Marshalizer enableEpochsHandler common.EnableEpochsHandler - StateAccessesCollector state.StateAccessesCollector + stateAccessesCollector state.StateAccessesCollector + dataTriesHolder common.TriesHolder + dataTrieCreator common.DataTrieCreator } // NewAccountCreator creates a new instance of AccountCreator @@ -43,18 +47,35 @@ func NewAccountCreator(args ArgsAccountCreator) (state.AccountFactory, error) { if check.IfNil(args.StateAccessesCollector) { return nil, state.ErrNilStateAccessesCollector } + if check.IfNil(args.DataTriesHolder) { + return nil, errors.ErrNilDataTriesHolder + } + if check.IfNil(args.DataTrieCreator) { + return nil, errors.ErrNilDataTrieCreator + } return &accountCreator{ hasher: args.Hasher, marshaller: args.Marshaller, enableEpochsHandler: args.EnableEpochsHandler, - StateAccessesCollector: args.StateAccessesCollector, + stateAccessesCollector: args.StateAccessesCollector, + dataTriesHolder: args.DataTriesHolder, + dataTrieCreator: args.DataTrieCreator, }, nil } // CreateAccount calls the new Account creator and returns the result func (ac *accountCreator) CreateAccount(address []byte) (vmcommon.AccountHandler, error) { - tdt, err := trackableDataTrie.NewTrackableDataTrie(address, ac.hasher, ac.marshaller, ac.enableEpochsHandler, ac.StateAccessesCollector) + args := trackableDataTrie.TrackableDataTrieArgs{ + Identifier: address, + Hasher: ac.hasher, + Marshaller: ac.marshaller, + EnableEpochsHandler: ac.enableEpochsHandler, + StateAccessesCollector: ac.stateAccessesCollector, + DataTriesHolder: ac.dataTriesHolder, + DataTrieCreator: ac.dataTrieCreator, + } + tdt, err := trackableDataTrie.NewTrackableDataTrie(args) if err != nil { return nil, err } diff --git a/state/factory/accountCreator_test.go b/state/factory/accountCreator_test.go index 1082eadc01f..27528b9a20c 100644 --- a/state/factory/accountCreator_test.go +++ b/state/factory/accountCreator_test.go @@ -12,6 +12,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" + "github.com/multiversx/mx-chain-go/testscommon/trie" "github.com/stretchr/testify/assert" ) @@ -22,6 +23,8 @@ func getDefaultArgs() factory.ArgsAccountCreator { Marshaller: &marshallerMock.MarshalizerMock{}, EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, StateAccessesCollector: &stateMock.StateAccessesCollectorStub{}, + DataTriesHolder: &trie.TriesHolderStub{}, + DataTrieCreator: &trie.TrieStub{}, } } @@ -55,6 +58,22 @@ func TestNewAccountCreator(t *testing.T) { assert.True(t, check.IfNil(accF)) assert.Equal(t, errors.ErrNilEnableEpochsHandler, err) }) + t.Run("nil stateAccessesCollector", func(t *testing.T) { + t.Parallel() + args := getDefaultArgs() + args.StateAccessesCollector = nil + accF, err := factory.NewAccountCreator(args) + assert.True(t, check.IfNil(accF)) + assert.Equal(t, state.ErrNilStateAccessesCollector, err) + }) + t.Run("nil dataTriesHolder", func(t *testing.T) { + t.Parallel() + args := getDefaultArgs() + args.DataTriesHolder = nil + accF, err := factory.NewAccountCreator(args) + assert.True(t, check.IfNil(accF)) + assert.Equal(t, errors.ErrNilDataTriesHolder, err) + }) t.Run("should work", func(t *testing.T) { t.Parallel() diff --git a/state/factory/accountsAdapterAPICreator_test.go b/state/factory/accountsAdapterAPICreator_test.go index a735fc7f9ce..9d4863df0d9 100644 --- a/state/factory/accountsAdapterAPICreator_test.go +++ b/state/factory/accountsAdapterAPICreator_test.go @@ -9,6 +9,7 @@ import ( "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/state" + "github.com/multiversx/mx-chain-go/state/triesHolder" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" mockState "github.com/multiversx/mx-chain-go/testscommon/state" @@ -17,6 +18,7 @@ import ( ) func createMockAccountsArgs() state.ArgsAccountsDB { + dth, _ := triesHolder.NewDataTriesHolder(common.TenMbSize) return state.ArgsAccountsDB{ Trie: &mockTrie.TrieStub{ GetStorageManagerCalled: func() common.StorageManager { @@ -30,6 +32,7 @@ func createMockAccountsArgs() state.ArgsAccountsDB { AddressConverter: &testscommon.PubkeyConverterMock{}, SnapshotsManager: &mockState.SnapshotsManagerStub{}, StateAccessesCollector: &mockState.StateAccessesCollectorStub{}, + DataTriesHolder: dth, } } diff --git a/state/interface.go b/state/interface.go index b9c8b988f5f..9e21fdd382c 100644 --- a/state/interface.go +++ b/state/interface.go @@ -161,8 +161,8 @@ type baseAccountHandler interface { SetCodeHash([]byte) GetCodeHash() []byte SetRootHash([]byte) + SetDataTrieRootHash() GetRootHash() []byte - SetDataTrie(trie common.Trie) DataTrie() common.DataTrieHandler SaveDirtyData(trie common.Trie) ([]*data.DataTrieChange, []core.TrieData, error) IsInterfaceNil() bool @@ -240,7 +240,6 @@ type UserAccountHandler interface { GetCodeHash() []byte SetRootHash([]byte) GetRootHash() []byte - SetDataTrie(trie common.Trie) DataTrie() common.DataTrieHandler RetrieveValue(key []byte) ([]byte, uint32, error) SaveKeyValue(key []byte, value []byte) error @@ -264,7 +263,7 @@ type UserAccountHandler interface { type DataTrieTracker interface { RetrieveValue(key []byte) ([]byte, uint32, error) SaveKeyValue(key []byte, value []byte) error - SetDataTrie(tr common.Trie) + SetRootHash(rootHash []byte) DataTrie() common.DataTrieHandler SaveDirtyData(common.Trie) ([]*data.DataTrieChange, []core.TrieData, error) MigrateDataTrieLeaves(args vmcommon.ArgsMigrateDataTrieLeaves) error diff --git a/state/journalEntries_test.go b/state/journalEntries_test.go index 7530f6cbfae..1bda66c9b5d 100644 --- a/state/journalEntries_test.go +++ b/state/journalEntries_test.go @@ -6,6 +6,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/state/accounts" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" @@ -226,15 +227,18 @@ func TestJournalEntryDataTrieUpdates_RevertFailsWhenUpdateFails(t *testing.T) { Value: []byte("b"), Version: 0, }) - accnt := stateMock.NewAccountWrapMock(nil) - tr := &trieMock.TrieStub{ UpdateWithVersionCalled: func(key, value []byte, version core.TrieNodeVersion) error { return expectedErr }, } + dataTriesHolder := &trieMock.TriesHolderStub{ + GetCalled: func(_ []byte) common.Trie { + return tr + }, + } + accnt := stateMock.NewAccountWrapMockWithDataTrieHolder(dataTriesHolder) - accnt.SetDataTrie(tr) entry, _ := state.NewJournalEntryDataTrieUpdates(trieUpdates, accnt) acc, err := entry.Revert() @@ -253,7 +257,6 @@ func TestJournalEntryDataTrieUpdates_RevertFailsWhenAccountRootFails(t *testing. Value: []byte("b"), Version: 0, }) - accnt := stateMock.NewAccountWrapMock(nil) tr := &trieMock.TrieStub{ UpdateWithVersionCalled: func(key, value []byte, version core.TrieNodeVersion) error { @@ -264,7 +267,12 @@ func TestJournalEntryDataTrieUpdates_RevertFailsWhenAccountRootFails(t *testing. }, } - accnt.SetDataTrie(tr) + dataTriesHolder := &trieMock.TriesHolderStub{ + GetCalled: func(_ []byte) common.Trie { + return tr + }, + } + accnt := stateMock.NewAccountWrapMockWithDataTrieHolder(dataTriesHolder) entry, _ := state.NewJournalEntryDataTrieUpdates(trieUpdates, accnt) acc, err := entry.Revert() @@ -284,7 +292,6 @@ func TestJournalEntryDataTrieUpdates_RevertShouldWork(t *testing.T) { Value: []byte("b"), Version: 0, }) - accnt := stateMock.NewAccountWrapMock(nil) tr := &trieMock.TrieStub{ UpdateWithVersionCalled: func(key, value []byte, version core.TrieNodeVersion) error { @@ -297,7 +304,12 @@ func TestJournalEntryDataTrieUpdates_RevertShouldWork(t *testing.T) { }, } - accnt.SetDataTrie(tr) + dataTriesHolder := &trieMock.TriesHolderStub{ + GetCalled: func(_ []byte) common.Trie { + return tr + }, + } + accnt := stateMock.NewAccountWrapMockWithDataTrieHolder(dataTriesHolder) entry, _ := state.NewJournalEntryDataTrieUpdates(trieUpdates, accnt) acc, err := entry.Revert() diff --git a/state/storagePruningManager/storagePruningManager_test.go b/state/storagePruningManager/storagePruningManager_test.go index 0ca2df57801..c9681f8a58a 100644 --- a/state/storagePruningManager/storagePruningManager_test.go +++ b/state/storagePruningManager/storagePruningManager_test.go @@ -3,6 +3,7 @@ package storagePruningManager import ( "testing" + "github.com/multiversx/mx-chain-go/state/triesHolder" "github.com/stretchr/testify/assert" "github.com/multiversx/mx-chain-go/common" @@ -15,12 +16,13 @@ import ( "github.com/multiversx/mx-chain-go/state/lastSnapshotMarker" "github.com/multiversx/mx-chain-go/state/storagePruningManager/evictionWaitingList" "github.com/multiversx/mx-chain-go/testscommon" + testCommon "github.com/multiversx/mx-chain-go/testscommon/common" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" - "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/trie" + "github.com/multiversx/mx-chain-go/trie/collapseManager" ) func getDefaultTrieAndAccountsDbAndStoragePruningManager() (common.Trie, *state.AccountsDB, *storagePruningManager) { @@ -31,21 +33,23 @@ func getDefaultTrieAndAccountsDbAndStoragePruningManager() (common.Trie, *state. } marshaller := &marshallerMock.MarshalizerMock{} hasher := &hashingMocks.HasherMock{} - args := storage.GetStorageManagerArgs() + args := testCommon.GetStorageManagerArgs() trieStorage, _ := trie.NewTrieStorageManager(args) - tr, _ := trie.NewTrie(trieStorage, marshaller, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 5) + tr, _ := trie.NewTrie(trieStorage, marshaller, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, collapseManager.NewDisabledCollapseManager()) ewlArgs := evictionWaitingList.MemoryEvictionWaitingListArgs{ RootHashesSize: 100, HashesSize: 10000, } ewl, _ := evictionWaitingList.NewMemoryEvictionWaitingList(ewlArgs) spm, _ := NewStoragePruningManager(ewl, generalCfg.PruningBufferLen) - + dth, _ := triesHolder.NewDataTriesHolder(common.TenMbSize) argsAccCreator := factory.ArgsAccountCreator{ Hasher: hasher, Marshaller: marshaller, EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, StateAccessesCollector: &stateMock.StateAccessesCollectorStub{}, + DataTriesHolder: dth, + DataTrieCreator: tr, } accCreator, _ := factory.NewAccountCreator(argsAccCreator) @@ -70,6 +74,7 @@ func getDefaultTrieAndAccountsDbAndStoragePruningManager() (common.Trie, *state. AddressConverter: &testscommon.PubkeyConverterMock{}, SnapshotsManager: snapshotsManager, StateAccessesCollector: &stateMock.StateAccessesCollectorStub{}, + DataTriesHolder: dth, } adb, _ := state.NewAccountsDB(argsAccountsDB) @@ -264,7 +269,7 @@ func TestStoragePruningManager_MarkForEviction_removeDuplicatedKeys(t *testing.T func TestStoragePruningManager_Reset(t *testing.T) { t.Parallel() - args := storage.GetStorageManagerArgs() + args := testCommon.GetStorageManagerArgs() trieStorage, _ := trie.NewTrieStorageManager(args) ewlArgs := evictionWaitingList.MemoryEvictionWaitingListArgs{ RootHashesSize: 100, diff --git a/state/syncer/baseAccountsSyncer.go b/state/syncer/baseAccountsSyncer.go index 3cee93d7325..6e04b170854 100644 --- a/state/syncer/baseAccountsSyncer.go +++ b/state/syncer/baseAccountsSyncer.go @@ -16,6 +16,7 @@ import ( "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/trie" + "github.com/multiversx/mx-chain-go/trie/collapseManager" ) type baseAccountsSyncer struct { @@ -28,7 +29,6 @@ type baseAccountsSyncer struct { timeoutHandler trie.TimeoutHandler shardId uint32 cacher storage.Cacher - maxTrieLevelInMemory uint name string maxHardCapForMissingNodes int checkNodesOnDisk bool @@ -54,7 +54,6 @@ type ArgsNewBaseAccountsSyncer struct { UserAccountsSyncStatisticsHandler common.SizeSyncStatisticsHandler AppStatusHandler core.AppStatusHandler EnableEpochsHandler common.EnableEpochsHandler - MaxTrieLevelInMemory uint MaxHardCapForMissingNodes int TrieSyncerVersion int CheckNodesOnDisk bool @@ -217,7 +216,7 @@ func (b *baseAccountsSyncer) GetSyncedTries() map[string]common.Trie { b.mutex.Lock() defer b.mutex.Unlock() - dataTrie, err := trie.NewTrie(b.trieStorageManager, b.marshalizer, b.hasher, b.enableEpochsHandler, b.maxTrieLevelInMemory) + dataTrie, err := trie.NewTrie(b.trieStorageManager, b.marshalizer, b.hasher, b.enableEpochsHandler, collapseManager.NewDisabledCollapseManager()) if err != nil { log.Warn("error creating a new trie in baseAccountsSyncer.GetSyncedTries", "error", err) return make(map[string]common.Trie) diff --git a/state/syncer/baseAccoutnsSyncer_test.go b/state/syncer/baseAccoutnsSyncer_test.go index e2fcf5336f0..4a7ec3cb1ab 100644 --- a/state/syncer/baseAccoutnsSyncer_test.go +++ b/state/syncer/baseAccoutnsSyncer_test.go @@ -28,7 +28,6 @@ func getDefaultBaseAccSyncerArgs() syncer.ArgsNewBaseAccountsSyncer { UserAccountsSyncStatisticsHandler: &testscommon.SizeSyncStatisticsHandlerStub{}, AppStatusHandler: &statusHandler.AppStatusHandlerStub{}, EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - MaxTrieLevelInMemory: 5, MaxHardCapForMissingNodes: 100, TrieSyncerVersion: 3, CheckNodesOnDisk: false, diff --git a/state/syncer/userAccountSyncer_test.go b/state/syncer/userAccountSyncer_test.go index 3ecdf5cd178..f532f407782 100644 --- a/state/syncer/userAccountSyncer_test.go +++ b/state/syncer/userAccountSyncer_test.go @@ -4,6 +4,7 @@ import ( "testing" "time" + "github.com/multiversx/mx-chain-go/trie/collapseManager" "github.com/stretchr/testify/assert" "github.com/multiversx/mx-chain-go/dataRetriever/mock" @@ -29,7 +30,6 @@ func getDefaultBaseAccSyncerArgs() ArgsNewBaseAccountsSyncer { Cacher: cache.NewCacherMock(), UserAccountsSyncStatisticsHandler: &testscommon.SizeSyncStatisticsHandlerStub{}, AppStatusHandler: &statusHandler.AppStatusHandlerStub{}, - MaxTrieLevelInMemory: 0, MaxHardCapForMissingNodes: 100, TrieSyncerVersion: 2, CheckNodesOnDisk: false, @@ -90,7 +90,7 @@ func TestUserAccountsSyncer_MissingDataTrieNodeFound(t *testing.T) { }, } - tr, _ := trie.NewTrie(tsm, args.Marshalizer, args.Hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 5) + tr, _ := trie.NewTrie(tsm, args.Marshalizer, args.Hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, collapseManager.NewDisabledCollapseManager()) key := []byte("key") value := []byte("value") _ = tr.Update(key, value) diff --git a/state/syncer/userAccountsSyncer.go b/state/syncer/userAccountsSyncer.go index a63745aa387..626b8457b7c 100644 --- a/state/syncer/userAccountsSyncer.go +++ b/state/syncer/userAccountsSyncer.go @@ -85,7 +85,6 @@ func NewUserAccountsSyncer(args ArgsNewUserAccountsSyncer) (*userAccountsSyncer, timeoutHandler: timeoutHandler, shardId: args.ShardId, cacher: args.Cacher, - maxTrieLevelInMemory: args.MaxTrieLevelInMemory, name: fmt.Sprintf("user accounts for shard %s", core.GetShardIDString(args.ShardId)), maxHardCapForMissingNodes: args.MaxHardCapForMissingNodes, trieSyncerVersion: args.TrieSyncerVersion, diff --git a/state/syncer/userAccountsSyncer_test.go b/state/syncer/userAccountsSyncer_test.go index 5d7252d3b2e..cd4e0c7eda6 100644 --- a/state/syncer/userAccountsSyncer_test.go +++ b/state/syncer/userAccountsSyncer_test.go @@ -10,6 +10,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/trie/collapseManager" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -111,7 +112,7 @@ func getSerializedTrieNode( }, } - tr, _ := trie.NewTrie(tsm, marshaller, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 5) + tr, _ := trie.NewTrie(tsm, marshaller, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, collapseManager.NewDisabledCollapseManager()) _ = tr.Update(key, []byte("value")) _ = tr.Commit() @@ -162,7 +163,7 @@ func TestUserAccountsSyncer_SyncAccounts(t *testing.T) { }) } -func getDefaultTrieParameters() (common.StorageManager, marshal.Marshalizer, hashing.Hasher, common.EnableEpochsHandler, uint) { +func getDefaultTrieParameters() (common.StorageManager, marshal.Marshalizer, hashing.Hasher, common.EnableEpochsHandler, common.TrieCollapseManager) { marshalizer := &testscommon.ProtobufMarshalizerMock{} hasher := &testscommon.KeccakMock{} @@ -183,9 +184,8 @@ func getDefaultTrieParameters() (common.StorageManager, marshal.Marshalizer, has } trieStorageManager, _ := trie.NewTrieStorageManager(args) - maxTrieLevelInMemory := uint(1) - return trieStorageManager, args.Marshalizer, args.Hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, maxTrieLevelInMemory + return trieStorageManager, args.Marshalizer, args.Hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, collapseManager.NewDisabledCollapseManager() } func emptyTrie() common.Trie { @@ -237,7 +237,7 @@ func TestUserAccountsSyncer_SyncAccountDataTries(t *testing.T) { s, err := syncer.NewUserAccountsSyncer(args) require.Nil(t, err) - _, _ = trie.NewTrie(args.TrieStorageManager, args.Marshalizer, args.Hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 5) + _, _ = trie.NewTrie(args.TrieStorageManager, args.Marshalizer, args.Hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, collapseManager.NewDisabledCollapseManager()) tr := emptyTrie() account, err := accounts.NewUserAccount(testscommon.TestPubKeyAlice, &trieMock.DataTrieTrackerStub{}, &trieMock.TrieLeafParserStub{}) @@ -294,7 +294,7 @@ func TestUserAccountsSyncer_SyncAccountDataTries(t *testing.T) { s, err := syncer.NewUserAccountsSyncer(args) require.Nil(t, err) - _, _ = trie.NewTrie(args.TrieStorageManager, args.Marshalizer, args.Hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 5) + _, _ = trie.NewTrie(args.TrieStorageManager, args.Marshalizer, args.Hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, collapseManager.NewDisabledCollapseManager()) tr := emptyTrie() account, err := accounts.NewUserAccount(testscommon.TestPubKeyAlice, &trieMock.DataTrieTrackerStub{}, &trieMock.TrieLeafParserStub{}) @@ -361,7 +361,7 @@ func TestUserAccountsSyncer_MissingDataTrieNodeFound(t *testing.T) { }, } - tr, _ := trie.NewTrie(tsm, args.Marshalizer, args.Hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 5) + tr, _ := trie.NewTrie(tsm, args.Marshalizer, args.Hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, collapseManager.NewDisabledCollapseManager()) key := []byte("key") value := []byte("value") _ = tr.Update(key, value) diff --git a/state/syncer/validatorAccountsSyncer.go b/state/syncer/validatorAccountsSyncer.go index e436bde8e8c..8463f65dfca 100644 --- a/state/syncer/validatorAccountsSyncer.go +++ b/state/syncer/validatorAccountsSyncer.go @@ -43,7 +43,6 @@ func NewValidatorAccountsSyncer(args ArgsNewValidatorAccountsSyncer) (*validator timeoutHandler: timeoutHandler, shardId: core.MetachainShardId, cacher: args.Cacher, - maxTrieLevelInMemory: args.MaxTrieLevelInMemory, name: "peer accounts", maxHardCapForMissingNodes: args.MaxHardCapForMissingNodes, trieSyncerVersion: args.TrieSyncerVersion, diff --git a/state/trackableDataTrie/export_test.go b/state/trackableDataTrie/export_test.go index ae44bbbdf7e..db3927ce821 100644 --- a/state/trackableDataTrie/export_test.go +++ b/state/trackableDataTrie/export_test.go @@ -1,6 +1,9 @@ package trackableDataTrie -import "github.com/multiversx/mx-chain-core-go/core" +import ( + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-go/common" +) // DirtyData - type DirtyData struct { @@ -27,3 +30,13 @@ func (tdt *trackableDataTrie) GetValueForVersion(key []byte, val []byte, version valWithMetadata, _ := tdt.getValueForVersion(key, val, version) return valWithMetadata } + +// LoadTrie - +func (tdt *trackableDataTrie) LoadTrie() error { + return tdt.loadTrie() +} + +// GetDataTrie - +func (tdt *trackableDataTrie) GetDataTrie() common.Trie { + return tdt.tr +} diff --git a/state/trackableDataTrie/trackableDataTrie.go b/state/trackableDataTrie/trackableDataTrie.go index 8b6a7333496..5c5ff67cae1 100644 --- a/state/trackableDataTrie/trackableDataTrie.go +++ b/state/trackableDataTrie/trackableDataTrie.go @@ -2,6 +2,7 @@ package trackableDataTrie import ( "bytes" + "errors" "fmt" "sort" @@ -34,35 +35,60 @@ type dirtyData struct { type trackableDataTrie struct { dirtyData map[string]dirtyData tr common.Trie + rootHash []byte hasher hashing.Hasher marshaller marshal.Marshalizer enableEpochsHandler common.EnableEpochsHandler identifier []byte stateAccessesCollector state.StateAccessesCollector + dataTriesHolder common.TriesHolder + dataTrieCreator common.DataTrieCreator } -// NewTrackableDataTrie returns an instance of trackableDataTrie -func NewTrackableDataTrie( - identifier []byte, - hasher hashing.Hasher, - marshaller marshal.Marshalizer, - enableEpochsHandler common.EnableEpochsHandler, - stateAccessesCollector state.StateAccessesCollector, -) (*trackableDataTrie, error) { - if check.IfNil(hasher) { - return nil, state.ErrNilHasher +// TrackableDataTrieArgs represent the args needed to create a new trackableDataTrie +type TrackableDataTrieArgs struct { + Identifier []byte + Hasher hashing.Hasher + Marshaller marshal.Marshalizer + EnableEpochsHandler common.EnableEpochsHandler + StateAccessesCollector state.StateAccessesCollector + DataTriesHolder common.TriesHolder + DataTrieCreator common.DataTrieCreator +} + +func checkTrackableDataTrieArgs(args TrackableDataTrieArgs) error { + if check.IfNil(args.Hasher) { + return state.ErrNilHasher + } + if check.IfNil(args.Marshaller) { + return state.ErrNilMarshalizer } - if check.IfNil(marshaller) { - return nil, state.ErrNilMarshalizer + if check.IfNil(args.EnableEpochsHandler) { + return state.ErrNilEnableEpochsHandler } - if check.IfNil(enableEpochsHandler) { - return nil, state.ErrNilEnableEpochsHandler + if check.IfNil(args.StateAccessesCollector) { + return state.ErrNilStateAccessesCollector } - if check.IfNil(stateAccessesCollector) { - return nil, state.ErrNilStateAccessesCollector + if check.IfNil(args.DataTriesHolder) { + return errorsCommon.ErrNilDataTriesHolder } + if check.IfNil(args.DataTrieCreator) { + return errorsCommon.ErrNilDataTrieCreator + } + + return nil +} - err := core.CheckHandlerCompatibility(enableEpochsHandler, []core.EnableEpochFlag{ +// NewTrackableDataTrie returns an instance of trackableDataTrie +func NewTrackableDataTrie( + args TrackableDataTrieArgs, +) (*trackableDataTrie, error) { + err := checkTrackableDataTrieArgs(args) + if err != nil { + return nil, err + } + + err = core.CheckHandlerCompatibility(args.EnableEpochsHandler, []core.EnableEpochFlag{ common.AutoBalanceDataTriesFlag, }) if err != nil { @@ -70,16 +96,47 @@ func NewTrackableDataTrie( } return &trackableDataTrie{ - tr: nil, - hasher: hasher, - marshaller: marshaller, dirtyData: make(map[string]dirtyData), - identifier: identifier, - enableEpochsHandler: enableEpochsHandler, - stateAccessesCollector: stateAccessesCollector, + tr: nil, + hasher: args.Hasher, + marshaller: args.Marshaller, + enableEpochsHandler: args.EnableEpochsHandler, + identifier: args.Identifier, + stateAccessesCollector: args.StateAccessesCollector, + dataTriesHolder: args.DataTriesHolder, + dataTrieCreator: args.DataTrieCreator, }, nil } +func (tdt *trackableDataTrie) loadTrie() error { + if !check.IfNil(tdt.tr) { + // the trie is already loaded + return nil + } + + // check the cache for the trie, and load it from there if found + tr := tdt.dataTriesHolder.Get(tdt.identifier) + if !check.IfNil(tr) { + tdt.tr = tr + return nil + } + + // try to recreate the trie from db + if common.IsEmptyTrie(tdt.rootHash) { + return state.ErrNilTrie + } + + tr, err := tdt.dataTrieCreator.Recreate(holders.NewDefaultRootHashesHolder(tdt.rootHash)) + if err != nil { + return err + } + + tdt.tr = tr + tdt.dataTriesHolder.Put(tdt.identifier, tr) + + return nil +} + // RetrieveValue fetches the value from a particular key searching the account data store // The search starts with dirty map, continues with original map and ends with the trie // Data must have been retrieved from its trie @@ -92,8 +149,9 @@ func (tdt *trackableDataTrie) RetrieveValue(key []byte) ([]byte, uint32, error) } // ok, not in cache, retrieve from trie - if check.IfNil(tdt.tr) { - return nil, 0, state.ErrNilTrie + err := tdt.loadTrie() + if err != nil { + return nil, 0, err } trieValue, depth, err := tdt.retrieveValueFromTrie(key) if err != nil { @@ -129,8 +187,9 @@ func (tdt *trackableDataTrie) SaveKeyValue(key []byte, value []byte) error { // MigrateDataTrieLeaves migrates the data trie leaves from oldVersion to newVersion func (tdt *trackableDataTrie) MigrateDataTrieLeaves(args vmcommon.ArgsMigrateDataTrieLeaves) error { - if check.IfNil(tdt.tr) { - return state.ErrNilTrie + err := tdt.loadTrie() + if err != nil { + return err } if check.IfNil(args.TrieMigrator) { return errorsCommon.ErrNilTrieMigrator @@ -141,7 +200,7 @@ func (tdt *trackableDataTrie) MigrateDataTrieLeaves(args vmcommon.ArgsMigrateDat return fmt.Errorf("invalid trie, type is %T", tdt.tr) } - err := dtr.CollectLeavesForMigration(args) + err = dtr.CollectLeavesForMigration(args) if err != nil { return err } @@ -224,13 +283,17 @@ func (tdt *trackableDataTrie) getValueForVersion(key []byte, value []byte, versi return valueWithAppendedData, nil } -// SetDataTrie sets the internal data trie -func (tdt *trackableDataTrie) SetDataTrie(tr common.Trie) { - tdt.tr = tr +// SetRootHash sets the internal root hash from which to recreate the data trie +func (tdt *trackableDataTrie) SetRootHash(rootHash []byte) { + tdt.rootHash = rootHash } // DataTrie sets the internal data trie func (tdt *trackableDataTrie) DataTrie() common.DataTrieHandler { + err := tdt.loadTrie() + if err != nil { + log.Error("failed to load data trie", "error", err, "account", tdt.identifier) + } return tdt.tr } @@ -240,6 +303,10 @@ func (tdt *trackableDataTrie) SaveDirtyData(mainTrie common.Trie) ([]*stateChang return make([]*stateChange.DataTrieChange, 0), make([]core.TrieData, 0), nil } + err := tdt.loadTrie() + if err != nil && !errors.Is(err, state.ErrNilTrie) { + return nil, nil, err + } if check.IfNil(tdt.tr) { emptyRootHash := holders.NewDefaultRootHashesHolder(make([]byte, 0)) newDataTrie, err := mainTrie.Recreate(emptyRootHash) @@ -248,6 +315,7 @@ func (tdt *trackableDataTrie) SaveDirtyData(mainTrie common.Trie) ([]*stateChang } tdt.tr = newDataTrie + tdt.dataTriesHolder.Put(tdt.identifier, newDataTrie) } dtr, ok := tdt.tr.(state.DataTrie) diff --git a/state/trackableDataTrie/trackableDataTrie_test.go b/state/trackableDataTrie/trackableDataTrie_test.go index b68a18edc65..a4d8006c228 100644 --- a/state/trackableDataTrie/trackableDataTrie_test.go +++ b/state/trackableDataTrie/trackableDataTrie_test.go @@ -24,19 +24,27 @@ import ( trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" ) +func getDefaultArgs() trackableDataTrie.TrackableDataTrieArgs { + return trackableDataTrie.TrackableDataTrieArgs{ + Identifier: []byte("identifier"), + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + StateAccessesCollector: &stateMock.StateAccessesCollectorStub{}, + DataTriesHolder: &trieMock.TriesHolderStub{}, + DataTrieCreator: &trieMock.TrieStub{}, + } +} + func TestNewTrackableDataTrie(t *testing.T) { t.Parallel() t.Run("create with nil hasher", func(t *testing.T) { t.Parallel() - tdt, err := trackableDataTrie.NewTrackableDataTrie( - []byte("identifier"), - nil, - &marshallerMock.MarshalizerMock{}, - &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - &stateMock.StateAccessesCollectorStub{}, - ) + args := getDefaultArgs() + args.Hasher = nil + tdt, err := trackableDataTrie.NewTrackableDataTrie(args) assert.Equal(t, state.ErrNilHasher, err) assert.True(t, check.IfNil(tdt)) }) @@ -44,13 +52,9 @@ func TestNewTrackableDataTrie(t *testing.T) { t.Run("create with nil marshaller", func(t *testing.T) { t.Parallel() - tdt, err := trackableDataTrie.NewTrackableDataTrie( - []byte("identifier"), - &hashingMocks.HasherMock{}, - nil, - &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - &stateMock.StateAccessesCollectorStub{}, - ) + args := getDefaultArgs() + args.Marshaller = nil + tdt, err := trackableDataTrie.NewTrackableDataTrie(args) assert.Equal(t, state.ErrNilMarshalizer, err) assert.True(t, check.IfNil(tdt)) }) @@ -58,13 +62,9 @@ func TestNewTrackableDataTrie(t *testing.T) { t.Run("create with nil enableEpochsHandler", func(t *testing.T) { t.Parallel() - tdt, err := trackableDataTrie.NewTrackableDataTrie( - []byte("identifier"), - &hashingMocks.HasherMock{}, - &marshallerMock.MarshalizerMock{}, - nil, - &stateMock.StateAccessesCollectorStub{}, - ) + args := getDefaultArgs() + args.EnableEpochsHandler = nil + tdt, err := trackableDataTrie.NewTrackableDataTrie(args) assert.Equal(t, state.ErrNilEnableEpochsHandler, err) assert.True(t, check.IfNil(tdt)) }) @@ -72,27 +72,47 @@ func TestNewTrackableDataTrie(t *testing.T) { t.Run("create with invalid enableEpochsHandler", func(t *testing.T) { t.Parallel() - tdt, err := trackableDataTrie.NewTrackableDataTrie( - []byte("identifier"), - &hashingMocks.HasherMock{}, - &marshallerMock.MarshalizerMock{}, - enableEpochsHandlerMock.NewEnableEpochsHandlerStubWithNoFlagsDefined(), - &stateMock.StateAccessesCollectorStub{}, - ) + args := getDefaultArgs() + args.EnableEpochsHandler = enableEpochsHandlerMock.NewEnableEpochsHandlerStubWithNoFlagsDefined() + tdt, err := trackableDataTrie.NewTrackableDataTrie(args) assert.True(t, errors.Is(err, core.ErrInvalidEnableEpochsHandler)) assert.True(t, check.IfNil(tdt)) }) + t.Run("create with nil stateAccessesCollector", func(t *testing.T) { + t.Parallel() + + args := getDefaultArgs() + args.StateAccessesCollector = nil + tdt, err := trackableDataTrie.NewTrackableDataTrie(args) + assert.Equal(t, state.ErrNilStateAccessesCollector, err) + assert.True(t, check.IfNil(tdt)) + }) + + t.Run("create with nil data tries holder", func(t *testing.T) { + t.Parallel() + + args := getDefaultArgs() + args.DataTriesHolder = nil + tdt, err := trackableDataTrie.NewTrackableDataTrie(args) + assert.Equal(t, errorsCommon.ErrNilDataTriesHolder, err) + assert.True(t, check.IfNil(tdt)) + }) + + t.Run("create with nil data trie creator", func(t *testing.T) { + t.Parallel() + + args := getDefaultArgs() + args.DataTrieCreator = nil + tdt, err := trackableDataTrie.NewTrackableDataTrie(args) + assert.Equal(t, errorsCommon.ErrNilDataTrieCreator, err) + assert.True(t, check.IfNil(tdt)) + }) + t.Run("should work", func(t *testing.T) { t.Parallel() - tdt, err := trackableDataTrie.NewTrackableDataTrie( - []byte("identifier"), - &hashingMocks.HasherMock{}, - &marshallerMock.MarshalizerMock{}, - &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - &stateMock.StateAccessesCollectorStub{}, - ) + tdt, err := trackableDataTrie.NewTrackableDataTrie(getDefaultArgs()) assert.Nil(t, err) assert.False(t, check.IfNil(tdt)) }) @@ -104,13 +124,7 @@ func TestTrackableDataTrie_SaveKeyValue(t *testing.T) { t.Run("data too large", func(t *testing.T) { t.Parallel() - tdt, _ := trackableDataTrie.NewTrackableDataTrie( - []byte("identifier"), - &hashingMocks.HasherMock{}, - &marshallerMock.MarshalizerMock{}, - &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - &stateMock.StateAccessesCollectorStub{}, - ) + tdt, _ := trackableDataTrie.NewTrackableDataTrie(getDefaultArgs()) err := tdt.SaveKeyValue([]byte("key"), make([]byte, core.MaxLeafSize+1)) assert.Equal(t, err, data.ErrLeafSizeTooBig) @@ -131,15 +145,15 @@ func TestTrackableDataTrie_SaveKeyValue(t *testing.T) { return nil, 0, nil }, } - tdt, _ := trackableDataTrie.NewTrackableDataTrie( - []byte("identifier"), - &hashingMocks.HasherMock{}, - &marshallerMock.MarshalizerMock{}, - &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - &stateMock.StateAccessesCollectorStub{}, - ) + + args := getDefaultArgs() + args.DataTriesHolder = &trieMock.TriesHolderStub{ + GetCalled: func(key []byte) common.Trie { + return trie + }, + } + tdt, _ := trackableDataTrie.NewTrackableDataTrie(args) assert.NotNil(t, tdt) - tdt.SetDataTrie(trie) _ = tdt.SaveKeyValue(keyExpected, value) @@ -170,15 +184,17 @@ func TestTrackableDataTrie_RetrieveValue(t *testing.T) { return nil, 0, nil }, } - tdt, _ := trackableDataTrie.NewTrackableDataTrie( - identifier, - &hashingMocks.HasherMock{}, - &marshallerMock.MarshalizerMock{}, - &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - &stateMock.StateAccessesCollectorStub{}, - ) + + args := getDefaultArgs() + args.Identifier = identifier + args.DataTriesHolder = &trieMock.TriesHolderStub{ + GetCalled: func(key []byte) common.Trie { + return trie + }, + } + tdt, err := trackableDataTrie.NewTrackableDataTrie(args) + assert.Nil(t, err) assert.NotNil(t, tdt) - tdt.SetDataTrie(trie) valRecovered, _, err := tdt.RetrieveValue(key) assert.Nil(t, err) @@ -193,13 +209,7 @@ func TestTrackableDataTrie_RetrieveValue(t *testing.T) { t.Run("nil data trie should err", func(t *testing.T) { t.Parallel() - tdt, err := trackableDataTrie.NewTrackableDataTrie( - []byte("identifier"), - &hashingMocks.HasherMock{}, - &marshallerMock.MarshalizerMock{}, - &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - &stateMock.StateAccessesCollectorStub{}, - ) + tdt, err := trackableDataTrie.NewTrackableDataTrie(getDefaultArgs()) assert.Nil(t, err) assert.NotNil(t, tdt) @@ -229,15 +239,15 @@ func TestTrackableDataTrie_RetrieveValue(t *testing.T) { return flag == common.AutoBalanceDataTriesFlag }, } - tdt, _ := trackableDataTrie.NewTrackableDataTrie( - identifier, - &hashingMocks.HasherMock{}, - &marshallerMock.MarshalizerMock{}, - enableEpochsHandler, - &stateMock.StateAccessesCollectorStub{}, - ) + args := getDefaultArgs() + args.EnableEpochsHandler = enableEpochsHandler + args.DataTriesHolder = &trieMock.TriesHolderStub{ + GetCalled: func(key []byte) common.Trie { + return trie + }, + } + tdt, _ := trackableDataTrie.NewTrackableDataTrie(args) assert.NotNil(t, tdt) - tdt.SetDataTrie(trie) valRecovered, _, err := tdt.RetrieveValue(expectedKey) assert.Nil(t, err) @@ -270,15 +280,16 @@ func TestTrackableDataTrie_RetrieveValue(t *testing.T) { return false }, } - tdt, _ := trackableDataTrie.NewTrackableDataTrie( - identifier, - &hashingMocks.HasherMock{}, - &marshallerMock.MarshalizerMock{}, - enableEpochsHandler, - &stateMock.StateAccessesCollectorStub{}, - ) + args := getDefaultArgs() + args.EnableEpochsHandler = enableEpochsHandler + args.DataTriesHolder = &trieMock.TriesHolderStub{ + GetCalled: func(key []byte) common.Trie { + return trie + }, + } + tdt, _ := trackableDataTrie.NewTrackableDataTrie(args) + assert.NotNil(t, tdt) - tdt.SetDataTrie(trie) valRecovered, _, err := tdt.RetrieveValue(expectedKey) assert.Nil(t, err) @@ -315,15 +326,18 @@ func TestTrackableDataTrie_RetrieveValue(t *testing.T) { return flag == common.AutoBalanceDataTriesFlag }, } - tdt, _ := trackableDataTrie.NewTrackableDataTrie( - identifier, - hasher, - marshaller, - enableEpochsHandler, - &stateMock.StateAccessesCollectorStub{}, - ) + + args := getDefaultArgs() + args.Hasher = hasher + args.Marshaller = marshaller + args.EnableEpochsHandler = enableEpochsHandler + args.DataTriesHolder = &trieMock.TriesHolderStub{ + GetCalled: func(key []byte) common.Trie { + return trie + }, + } + tdt, _ := trackableDataTrie.NewTrackableDataTrie(args) assert.NotNil(t, tdt) - tdt.SetDataTrie(trie) valRecovered, _, err := tdt.RetrieveValue(expectedKey) assert.Nil(t, err) @@ -340,15 +354,14 @@ func TestTrackableDataTrie_RetrieveValue(t *testing.T) { return nil, 0, errExpected }, } - tdt, _ := trackableDataTrie.NewTrackableDataTrie( - []byte("identifier"), - &hashingMocks.HasherMock{}, - &marshallerMock.MarshalizerMock{}, - &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - &stateMock.StateAccessesCollectorStub{}, - ) + args := getDefaultArgs() + args.DataTriesHolder = &trieMock.TriesHolderStub{ + GetCalled: func(key []byte) common.Trie { + return trie + }, + } + tdt, _ := trackableDataTrie.NewTrackableDataTrie(args) assert.NotNil(t, tdt) - tdt.SetDataTrie(trie) valRecovered, _, err := tdt.RetrieveValue(keyExpected) assert.Equal(t, errExpected, err) @@ -358,7 +371,6 @@ func TestTrackableDataTrie_RetrieveValue(t *testing.T) { t.Run("val not found in trie - auto balance enabled", func(t *testing.T) { t.Parallel() - identifier := []byte("identifier") expectedKey := []byte("key") hasher := &hashingMocks.HasherMock{} marshaller := &marshallerMock.MarshalizerMock{} @@ -376,15 +388,17 @@ func TestTrackableDataTrie_RetrieveValue(t *testing.T) { return flag == common.AutoBalanceDataTriesFlag }, } - tdt, _ := trackableDataTrie.NewTrackableDataTrie( - identifier, - hasher, - marshaller, - enableEpochsHandler, - &stateMock.StateAccessesCollectorStub{}, - ) + args := getDefaultArgs() + args.Hasher = hasher + args.Marshaller = marshaller + args.EnableEpochsHandler = enableEpochsHandler + args.DataTriesHolder = &trieMock.TriesHolderStub{ + GetCalled: func(key []byte) common.Trie { + return trie + }, + } + tdt, _ := trackableDataTrie.NewTrackableDataTrie(args) assert.NotNil(t, tdt) - tdt.SetDataTrie(trie) valRecovered, _, err := tdt.RetrieveValue(expectedKey) assert.Nil(t, err) @@ -394,7 +408,6 @@ func TestTrackableDataTrie_RetrieveValue(t *testing.T) { t.Run("val not found in trie - auto balance disabled", func(t *testing.T) { t.Parallel() - identifier := []byte("identifier") expectedKey := []byte("key") hasher := &hashingMocks.HasherMock{} marshaller := &marshallerMock.MarshalizerMock{} @@ -412,15 +425,17 @@ func TestTrackableDataTrie_RetrieveValue(t *testing.T) { return flag == common.AutoBalanceDataTriesFlag }, } - tdt, _ := trackableDataTrie.NewTrackableDataTrie( - identifier, - hasher, - marshaller, - enableEpochsHandler, - &stateMock.StateAccessesCollectorStub{}, - ) + args := getDefaultArgs() + args.Hasher = hasher + args.Marshaller = marshaller + args.EnableEpochsHandler = enableEpochsHandler + args.DataTriesHolder = &trieMock.TriesHolderStub{ + GetCalled: func(key []byte) common.Trie { + return trie + }, + } + tdt, _ := trackableDataTrie.NewTrackableDataTrie(args) assert.NotNil(t, tdt) - tdt.SetDataTrie(trie) valRecovered, _, err := tdt.RetrieveValue(expectedKey) assert.Nil(t, err) @@ -434,13 +449,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { t.Run("no dirty data", func(t *testing.T) { t.Parallel() - tdt, _ := trackableDataTrie.NewTrackableDataTrie( - []byte("identifier"), - &hashingMocks.HasherMock{}, - &marshallerMock.MarshalizerMock{}, - &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - &stateMock.StateAccessesCollectorStub{}, - ) + tdt, _ := trackableDataTrie.NewTrackableDataTrie(getDefaultArgs()) stateChanges, oldValues, err := tdt.SaveDirtyData(&trieMock.TrieStub{}) assert.Nil(t, err) @@ -466,13 +475,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { }, } - tdt, _ := trackableDataTrie.NewTrackableDataTrie( - []byte("identifier"), - &hashingMocks.HasherMock{}, - &marshallerMock.MarshalizerMock{}, - &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - &stateMock.StateAccessesCollectorStub{}, - ) + tdt, _ := trackableDataTrie.NewTrackableDataTrie(getDefaultArgs()) key := []byte("key") val := []byte("val") @@ -491,7 +494,6 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { t.Run("present in trie as valWithAppendedData", func(t *testing.T) { t.Parallel() - identifier := []byte("identifier") hasher := &hashingMocks.HasherMock{} marshaller := &marshallerMock.MarshalizerMock{} deleteCalled := false @@ -501,13 +503,12 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { return flag == common.AutoBalanceDataTriesFlag }, } - tdt, _ := trackableDataTrie.NewTrackableDataTrie( - identifier, - hasher, - marshaller, - enableEpochsHandler, - &stateMock.StateAccessesCollectorStub{}, - ) + + args := getDefaultArgs() + args.Hasher = hasher + args.Marshaller = marshaller + args.EnableEpochsHandler = enableEpochsHandler + tdt, _ := trackableDataTrie.NewTrackableDataTrie(args) expectedKey := []byte("key") expectedVal := []byte("value") @@ -534,8 +535,12 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { }, } - tdt, _ = trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler, &stateMock.StateAccessesCollectorStub{}) - tdt.SetDataTrie(trie) + args.DataTriesHolder = &trieMock.TriesHolderStub{ + GetCalled: func(key []byte) common.Trie { + return trie + }, + } + tdt, _ = trackableDataTrie.NewTrackableDataTrie(args) _ = tdt.SaveKeyValue(expectedKey, expectedVal) stateChanges, oldValues, err := tdt.SaveDirtyData(trie) @@ -559,7 +564,6 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { t.Run("present in trie as valWithAppendedData and auto balancing disabled", func(t *testing.T) { t.Parallel() - identifier := []byte("identifier") hasher := &hashingMocks.HasherMock{} marshaller := &marshallerMock.MarshalizerMock{} updateCalled := false @@ -569,13 +573,11 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { }, } - tdt, _ := trackableDataTrie.NewTrackableDataTrie( - identifier, - hasher, - marshaller, - enableEpochsHandler, - &stateMock.StateAccessesCollectorStub{}, - ) + args := getDefaultArgs() + args.Hasher = hasher + args.Marshaller = marshaller + args.EnableEpochsHandler = enableEpochsHandler + tdt, _ := trackableDataTrie.NewTrackableDataTrie(args) expectedKey := []byte("key") val := []byte("value") @@ -600,8 +602,12 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { }, } - tdt, _ = trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler, &stateMock.StateAccessesCollectorStub{}) - tdt.SetDataTrie(trie) + args.DataTriesHolder = &trieMock.TriesHolderStub{ + GetCalled: func(key []byte) common.Trie { + return trie + }, + } + tdt, _ = trackableDataTrie.NewTrackableDataTrie(args) _ = tdt.SaveKeyValue(expectedKey, val) stateChanges, oldValues, err := tdt.SaveDirtyData(trie) @@ -618,7 +624,6 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { t.Run("present in trie as valAsStruct", func(t *testing.T) { t.Parallel() - identifier := []byte("identifier") hasher := &hashingMocks.HasherMock{} marshaller := &marshallerMock.MarshalizerMock{} updateCalled := false @@ -627,13 +632,11 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { return flag == common.AutoBalanceDataTriesFlag }, } - tdt, _ := trackableDataTrie.NewTrackableDataTrie( - identifier, - hasher, - marshaller, - enableEpochsHandler, - &stateMock.StateAccessesCollectorStub{}, - ) + args := getDefaultArgs() + args.Hasher = hasher + args.Marshaller = marshaller + args.EnableEpochsHandler = enableEpochsHandler + tdt, _ := trackableDataTrie.NewTrackableDataTrie(args) expectedKey := []byte("key") newVal := []byte("value") @@ -660,14 +663,12 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { }, } - tdt, _ = trackableDataTrie.NewTrackableDataTrie( - identifier, - hasher, - marshaller, - enableEpochsHandler, - &stateMock.StateAccessesCollectorStub{}, - ) - tdt.SetDataTrie(trie) + args.DataTriesHolder = &trieMock.TriesHolderStub{ + GetCalled: func(key []byte) common.Trie { + return trie + }, + } + tdt, _ = trackableDataTrie.NewTrackableDataTrie(args) _ = tdt.SaveKeyValue(expectedKey, newVal) stateChanges, oldValues, err := tdt.SaveDirtyData(trie) @@ -684,7 +685,6 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { t.Run("not present in trie", func(t *testing.T) { t.Parallel() - identifier := []byte("identifier") hasher := &hashingMocks.HasherMock{} marshaller := &marshallerMock.MarshalizerMock{} updateCalled := false @@ -693,13 +693,11 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { return flag == common.AutoBalanceDataTriesFlag }, } - tdt, _ := trackableDataTrie.NewTrackableDataTrie( - identifier, - hasher, - marshaller, - enableEpochsHandler, - &stateMock.StateAccessesCollectorStub{}, - ) + args := getDefaultArgs() + args.Hasher = hasher + args.Marshaller = marshaller + args.EnableEpochsHandler = enableEpochsHandler + tdt, _ := trackableDataTrie.NewTrackableDataTrie(args) expectedKey := []byte("key") newVal := []byte("value") @@ -721,8 +719,12 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { }, } - tdt, _ = trackableDataTrie.NewTrackableDataTrie(identifier, hasher, marshaller, enableEpochsHandler, &stateMock.StateAccessesCollectorStub{}) - tdt.SetDataTrie(trie) + args.DataTriesHolder = &trieMock.TriesHolderStub{ + GetCalled: func(key []byte) common.Trie { + return trie + }, + } + tdt, _ = trackableDataTrie.NewTrackableDataTrie(args) _ = tdt.SaveKeyValue(expectedKey, newVal) stateChanges, oldValues, err := tdt.SaveDirtyData(trie) @@ -750,9 +752,13 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { return nil }, } - - tdt, _ := trackableDataTrie.NewTrackableDataTrie([]byte("identifier"), &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &stateMock.StateAccessesCollectorStub{}) - tdt.SetDataTrie(trie) + args := getDefaultArgs() + args.DataTriesHolder = &trieMock.TriesHolderStub{ + GetCalled: func(key []byte) common.Trie { + return trie + }, + } + tdt, _ := trackableDataTrie.NewTrackableDataTrie(args) _ = tdt.SaveKeyValue(expectedKey, val) _, _, err := tdt.SaveDirtyData(trie) @@ -776,8 +782,13 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { }, } - tdt, _ := trackableDataTrie.NewTrackableDataTrie([]byte("identifier"), &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &stateMock.StateAccessesCollectorStub{}) - tdt.SetDataTrie(trie) + args := getDefaultArgs() + args.DataTriesHolder = &trieMock.TriesHolderStub{ + GetCalled: func(key []byte) common.Trie { + return trie + }, + } + tdt, _ := trackableDataTrie.NewTrackableDataTrie(args) _ = tdt.SaveKeyValue(expectedKey, nil) stateChanges, _, err := tdt.SaveDirtyData(trie) @@ -805,8 +816,13 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { }, } - tdt, _ := trackableDataTrie.NewTrackableDataTrie([]byte("identifier"), &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &stateMock.StateAccessesCollectorStub{}) - tdt.SetDataTrie(trie) + args := getDefaultArgs() + args.DataTriesHolder = &trieMock.TriesHolderStub{ + GetCalled: func(key []byte) common.Trie { + return trie + }, + } + tdt, _ := trackableDataTrie.NewTrackableDataTrie(args) _ = tdt.SaveKeyValue(expectedKey, nil) stateChanges, _, err := tdt.SaveDirtyData(trie) @@ -843,13 +859,21 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { }, } - enableEpchs := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + enableEpochs := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { return flag == common.AutoBalanceDataTriesFlag }, } - tdt, _ := trackableDataTrie.NewTrackableDataTrie([]byte("identifier"), hasher, marshaller, enableEpchs, &stateMock.StateAccessesCollectorStub{}) - tdt.SetDataTrie(trie) + args := getDefaultArgs() + args.Hasher = hasher + args.Marshaller = marshaller + args.EnableEpochsHandler = enableEpochs + args.DataTriesHolder = &trieMock.TriesHolderStub{ + GetCalled: func(key []byte) common.Trie { + return trie + }, + } + tdt, _ := trackableDataTrie.NewTrackableDataTrie(args) _ = tdt.SaveKeyValue(expectedKey, nil) stateChanges, _, err := tdt.SaveDirtyData(trie) @@ -882,13 +906,19 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { }, } - enableEpchs := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + enableEpochs := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { return flag == common.AutoBalanceDataTriesFlag }, } - tdt, _ := trackableDataTrie.NewTrackableDataTrie([]byte("identifier"), &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, enableEpchs, &stateMock.StateAccessesCollectorStub{}) - tdt.SetDataTrie(trie) + args := getDefaultArgs() + args.EnableEpochsHandler = enableEpochs + args.DataTriesHolder = &trieMock.TriesHolderStub{ + GetCalled: func(key []byte) common.Trie { + return trie + }, + } + tdt, _ := trackableDataTrie.NewTrackableDataTrie(args) _ = tdt.SaveKeyValue(expectedKey, nil) stateChanges, _, err := tdt.SaveDirtyData(trie) @@ -903,7 +933,6 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { t.Run("not present in trie - autobalance disabled", func(t *testing.T) { t.Parallel() - identifier := []byte("identifier") hasher := &hashingMocks.HasherMock{} marshaller := &marshallerMock.MarshalizerMock{} updateCalled := false @@ -912,13 +941,11 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { return false }, } - tdt, _ := trackableDataTrie.NewTrackableDataTrie( - identifier, - hasher, - marshaller, - enableEpochsHandler, - &stateMock.StateAccessesCollectorStub{}, - ) + args := getDefaultArgs() + args.Hasher = hasher + args.Marshaller = marshaller + args.EnableEpochsHandler = enableEpochsHandler + tdt, _ := trackableDataTrie.NewTrackableDataTrie(args) expectedKey := []byte("key") newVal := []byte("value") @@ -939,7 +966,12 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { return nil }, } - tdt.SetDataTrie(trie) + args.DataTriesHolder = &trieMock.TriesHolderStub{ + GetCalled: func(key []byte) common.Trie { + return trie + }, + } + tdt, _ = trackableDataTrie.NewTrackableDataTrie(args) _ = tdt.SaveKeyValue(expectedKey, newVal) stateChanges, oldValues, err := tdt.SaveDirtyData(trie) @@ -956,7 +988,6 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { t.Run("state accesses are ordered deterministically", func(t *testing.T) { t.Parallel() - identifier := []byte("identifier") hasher := &hashingMocks.HasherMock{} marshaller := &marshallerMock.MarshalizerMock{} enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ @@ -964,13 +995,11 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { return true }, } - tdt, _ := trackableDataTrie.NewTrackableDataTrie( - identifier, - hasher, - marshaller, - enableEpochsHandler, - &stateMock.StateAccessesCollectorStub{}, - ) + args := getDefaultArgs() + args.Hasher = hasher + args.Marshaller = marshaller + args.EnableEpochsHandler = enableEpochsHandler + tdt, _ := trackableDataTrie.NewTrackableDataTrie(args) key1 := "key1" key2 := "key2" @@ -1001,7 +1030,12 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { return nil }, } - tdt.SetDataTrie(trie) + args.DataTriesHolder = &trieMock.TriesHolderStub{ + GetCalled: func(key []byte) common.Trie { + return trie + }, + } + tdt, _ = trackableDataTrie.NewTrackableDataTrie(args) val := []byte("value") _ = tdt.SaveKeyValue([]byte(key1), val) _ = tdt.SaveKeyValue([]byte(key2), val) @@ -1041,13 +1075,7 @@ func TestTrackableDataTrie_MigrateDataTrieLeaves(t *testing.T) { t.Run("nil trie", func(t *testing.T) { t.Parallel() - tdt, _ := trackableDataTrie.NewTrackableDataTrie( - []byte("identifier"), - &hashingMocks.HasherMock{}, - &marshallerMock.MarshalizerMock{}, - &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - &stateMock.StateAccessesCollectorStub{}, - ) + tdt, _ := trackableDataTrie.NewTrackableDataTrie(getDefaultArgs()) args := vmcommon.ArgsMigrateDataTrieLeaves{ OldVersion: core.NotSpecified, NewVersion: core.AutoBalanceEnabled, @@ -1060,14 +1088,13 @@ func TestTrackableDataTrie_MigrateDataTrieLeaves(t *testing.T) { t.Run("nil trie migrator", func(t *testing.T) { t.Parallel() - tdt, _ := trackableDataTrie.NewTrackableDataTrie( - []byte("identifier"), - &hashingMocks.HasherMock{}, - &marshallerMock.MarshalizerMock{}, - &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - &stateMock.StateAccessesCollectorStub{}, - ) - tdt.SetDataTrie(&trieMock.TrieStub{}) + argsTdt := getDefaultArgs() + argsTdt.DataTriesHolder = &trieMock.TriesHolderStub{ + GetCalled: func(_ []byte) common.Trie { + return &trieMock.TrieStub{} + }, + } + tdt, _ := trackableDataTrie.NewTrackableDataTrie(argsTdt) args := vmcommon.ArgsMigrateDataTrieLeaves{ OldVersion: core.NotSpecified, @@ -1087,15 +1114,13 @@ func TestTrackableDataTrie_MigrateDataTrieLeaves(t *testing.T) { return expectedErr }, } - - tdt, _ := trackableDataTrie.NewTrackableDataTrie( - []byte("identifier"), - &hashingMocks.HasherMock{}, - &marshallerMock.MarshalizerMock{}, - &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - &stateMock.StateAccessesCollectorStub{}, - ) - tdt.SetDataTrie(tr) + argsTdt := getDefaultArgs() + argsTdt.DataTriesHolder = &trieMock.TriesHolderStub{ + GetCalled: func(key []byte) common.Trie { + return tr + }, + } + tdt, _ := trackableDataTrie.NewTrackableDataTrie(argsTdt) args := vmcommon.ArgsMigrateDataTrieLeaves{ OldVersion: core.NotSpecified, NewVersion: core.AutoBalanceEnabled, @@ -1137,20 +1162,21 @@ func TestTrackableDataTrie_MigrateDataTrieLeaves(t *testing.T) { return leavesToBeMigrated }, } - enableEpchs := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + enableEpochs := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { return flag == common.AutoBalanceDataTriesFlag }, } - tdt, _ := trackableDataTrie.NewTrackableDataTrie( - address, - &hashingMocks.HasherMock{}, - &marshallerMock.MarshalizerMock{}, - enableEpchs, - &stateMock.StateAccessesCollectorStub{}, - ) - tdt.SetDataTrie(tr) + argsTdt := getDefaultArgs() + argsTdt.EnableEpochsHandler = enableEpochs + argsTdt.DataTriesHolder = &trieMock.TriesHolderStub{ + GetCalled: func(key []byte) common.Trie { + return tr + }, + } + tdt, _ := trackableDataTrie.NewTrackableDataTrie(argsTdt) + args := vmcommon.ArgsMigrateDataTrieLeaves{ OldVersion: core.NotSpecified, NewVersion: 100, @@ -1172,15 +1198,180 @@ func TestTrackableDataTrie_MigrateDataTrieLeaves(t *testing.T) { func TestTrackableDataTrie_SetAndGetDataTrie(t *testing.T) { t.Parallel() - tdt, _ := trackableDataTrie.NewTrackableDataTrie( - []byte("identifier"), - &hashingMocks.HasherMock{}, - &marshallerMock.MarshalizerMock{}, - &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - &stateMock.StateAccessesCollectorStub{}, - ) - newTrie := &trieMock.TrieStub{} - tdt.SetDataTrie(newTrie) + argsTdt := getDefaultArgs() + argsTdt.DataTriesHolder = &trieMock.TriesHolderStub{ + GetCalled: func(key []byte) common.Trie { + return newTrie + }, + } + tdt, _ := trackableDataTrie.NewTrackableDataTrie(argsTdt) + assert.Equal(t, newTrie, tdt.DataTrie()) } + +func TestTrackableDataTrie_loadTrie(t *testing.T) { + t.Parallel() + + t.Run("trie already loaded", func(t *testing.T) { + t.Parallel() + + cachedTrie := &trieMock.TrieStub{} + numGetCalls := 0 + numRecreateCalls := 0 + argsTdt := getDefaultArgs() + argsTdt.DataTriesHolder = &trieMock.TriesHolderStub{ + GetCalled: func(key []byte) common.Trie { + numGetCalls++ + assert.Equal(t, argsTdt.Identifier, key) + return cachedTrie + }, + } + argsTdt.DataTrieCreator = &trieMock.TrieStub{ + RecreateCalled: func(_ common.RootHashHolder) (common.Trie, error) { + numRecreateCalls++ + return nil, nil + }, + } + + tdt, err := trackableDataTrie.NewTrackableDataTrie(argsTdt) + assert.Nil(t, err) + + loadedTrie := tdt.DataTrie() + assert.Equal(t, cachedTrie, loadedTrie) + assert.Equal(t, cachedTrie, tdt.GetDataTrie()) + + err = tdt.LoadTrie() + assert.Nil(t, err) + assert.Equal(t, 1, numGetCalls) + assert.Equal(t, 0, numRecreateCalls) + }) + t.Run("trie found in cache", func(t *testing.T) { + t.Parallel() + + cachedTrie := &trieMock.TrieStub{} + numRecreateCalls := 0 + numGetCalls := 0 + argsTdt := getDefaultArgs() + argsTdt.DataTriesHolder = &trieMock.TriesHolderStub{ + GetCalled: func(key []byte) common.Trie { + assert.Equal(t, argsTdt.Identifier, key) + numGetCalls++ + return cachedTrie + }, + } + argsTdt.DataTrieCreator = &trieMock.TrieStub{ + RecreateCalled: func(_ common.RootHashHolder) (common.Trie, error) { + numRecreateCalls++ + return nil, nil + }, + } + + tdt, err := trackableDataTrie.NewTrackableDataTrie(argsTdt) + assert.Nil(t, err) + assert.Nil(t, tdt.GetDataTrie()) + + err = tdt.LoadTrie() + assert.Nil(t, err) + assert.Equal(t, 1, numGetCalls) + assert.Equal(t, 0, numRecreateCalls) + assert.Equal(t, cachedTrie, tdt.GetDataTrie()) + }) + t.Run("empty root hash", func(t *testing.T) { + t.Parallel() + + numRecreateCalls := 0 + numGetCalls := 0 + argsTdt := getDefaultArgs() + argsTdt.DataTriesHolder = &trieMock.TriesHolderStub{ + GetCalled: func(_ []byte) common.Trie { + numGetCalls++ + return nil + }, + } + argsTdt.DataTrieCreator = &trieMock.TrieStub{ + RecreateCalled: func(_ common.RootHashHolder) (common.Trie, error) { + numRecreateCalls++ + return nil, nil + }, + } + tdt, err := trackableDataTrie.NewTrackableDataTrie(argsTdt) + assert.Nil(t, err) + assert.Nil(t, tdt.GetDataTrie()) + + err = tdt.LoadTrie() + assert.Equal(t, state.ErrNilTrie, err) + assert.Equal(t, 1, numGetCalls) + assert.Equal(t, 0, numRecreateCalls) + assert.Nil(t, tdt.GetDataTrie()) + }) + t.Run("recreate trie from db returns error", func(t *testing.T) { + t.Parallel() + + errExpected := errors.New("expected error") + numPutCalls := 0 + numRecreateCalls := 0 + expectedRootHash := []byte("root hash") + argsTdt := getDefaultArgs() + argsTdt.DataTriesHolder = &trieMock.TriesHolderStub{ + GetCalled: func(_ []byte) common.Trie { + return nil + }, + } + argsTdt.DataTrieCreator = &trieMock.TrieStub{ + RecreateCalled: func(rootHolder common.RootHashHolder) (common.Trie, error) { + numRecreateCalls++ + assert.Equal(t, expectedRootHash, rootHolder.GetRootHash()) + return nil, errExpected + }, + } + + tdt, err := trackableDataTrie.NewTrackableDataTrie(argsTdt) + assert.Nil(t, err) + tdt.SetRootHash(expectedRootHash) + assert.Nil(t, tdt.GetDataTrie()) + + _, _, err = tdt.RetrieveValue([]byte("key")) + assert.Equal(t, errExpected, err) + assert.Equal(t, 1, numRecreateCalls) + assert.Equal(t, 0, numPutCalls) + assert.Nil(t, tdt.GetDataTrie()) + }) + t.Run("recreate trie from db", func(t *testing.T) { + t.Parallel() + + recreatedTrie := &trieMock.TrieStub{} + expectedRootHash := []byte("root hash") + numRecreateCalls := 0 + numPutCalls := 0 + argsTdt := getDefaultArgs() + argsTdt.DataTriesHolder = &trieMock.TriesHolderStub{ + GetCalled: func(_ []byte) common.Trie { + return nil + }, + PutCalled: func(key []byte, trie common.Trie) { + numPutCalls++ + assert.Equal(t, argsTdt.Identifier, key) + assert.Equal(t, recreatedTrie, trie) + }, + } + argsTdt.DataTrieCreator = &trieMock.TrieStub{ + RecreateCalled: func(rootHolder common.RootHashHolder) (common.Trie, error) { + numRecreateCalls++ + assert.Equal(t, expectedRootHash, rootHolder.GetRootHash()) + return recreatedTrie, nil + }, + } + + tdt, err := trackableDataTrie.NewTrackableDataTrie(argsTdt) + assert.Nil(t, err) + tdt.SetRootHash(expectedRootHash) + assert.Nil(t, tdt.GetDataTrie()) + + loadedTrie := tdt.DataTrie() + assert.Equal(t, recreatedTrie, loadedTrie) + assert.Equal(t, 1, numRecreateCalls) + assert.Equal(t, 1, numPutCalls) + assert.Equal(t, recreatedTrie, tdt.GetDataTrie()) + }) +} diff --git a/state/triesHolder/dataTriesHolder.go b/state/triesHolder/dataTriesHolder.go new file mode 100644 index 00000000000..bda27fdd1db --- /dev/null +++ b/state/triesHolder/dataTriesHolder.go @@ -0,0 +1,266 @@ +package triesHolder + +import ( + "fmt" + "math" + "sync" + + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/storage" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-storage-go/lrucache/capacity" +) + +const maxTrieSizeMinValue = 1 * 1024 * 1024 // 1 MB + +var log = logger.GetOrCreate("state/dataTriesHolder") + +// dataTriesHolder is a structure that holds a map of tries and manages their memory usage +// It uses a doubly linked list to keep track of the order in which the tries were used +// and evicts the oldest used tries when the total size exceeds a maximum limit. +type dataTriesHolder struct { + cacher storage.AdaptedSizedLRUCache + dirtyTries map[string]struct{} // These are the tries that have been modified and need to be persisted + touchedTries map[string]struct{} // These are needed to compute an accurate totalTriesSize + evictedBuffer map[string]common.Trie // in case eviction happens for a dirty trie, we keep it here until we commit it + + mutex sync.RWMutex +} + +// NewDataTriesHolder creates a new instance of dataTriesHolder +func NewDataTriesHolder(maxTriesSize uint64) (*dataTriesHolder, error) { + if maxTriesSize < maxTrieSizeMinValue { + return nil, fmt.Errorf("%w, provided %d, minimum %d", ErrInvalidMaxTrieSizeValue, maxTriesSize, maxTrieSizeMinValue) + } + log.Trace("creating new data tries holder", "maxTriesSize", maxTriesSize) + + c, err := capacity.NewCapacityLRU(math.MaxInt, int64(maxTriesSize)) + if err != nil { + return nil, err + } + + return &dataTriesHolder{ + cacher: c, + dirtyTries: make(map[string]struct{}), + touchedTries: make(map[string]struct{}), + evictedBuffer: make(map[string]common.Trie), + mutex: sync.RWMutex{}, + }, nil +} + +// Put adds a trie pointer to the tries map +func (dth *dataTriesHolder) Put(key []byte, tr common.Trie) { + dth.mutex.Lock() + defer dth.mutex.Unlock() + + dth.putNoLock(key, tr) +} + +func (dth *dataTriesHolder) putNoLock(key []byte, tr common.Trie) { + if check.IfNil(tr) || len(key) == 0 { + log.Warn("trying to put nil trie or empty key in dataTriesHolder", "key", key, "trie", tr) + return + } + + dth.clearFromEvictedBuffer(key) + + keyString := string(key) + evicted := dth.cacher.AddSizedAndReturnEvicted(keyString, tr, int64(tr.SizeInMemory())) + dth.dirtyTries[keyString] = struct{}{} + dth.touchedTries[keyString] = struct{}{} + + if log.GetLevel() == logger.LogTrace { + log.Trace("put trie in data tries holder", "key", key, "trieSize", tr.SizeInMemory(), "totalCacheSize", dth.cacher.SizeInBytesContained()) + } + + if len(evicted) == 0 { + return + } + + for evictedKey, evictedValue := range evicted { + evictedKeyString, ok := evictedKey.(string) + if !ok { + log.Warn("invalid data in dataTriesHolder cache", "entry type", fmt.Sprintf("%T", evictedKey)) + continue + } + log.Trace("trie evicted from dataTriesHolder cache", "key", []byte(evictedKeyString)) + + _, ok = dth.dirtyTries[evictedKeyString] + if !ok { + continue + } + + evictedTrie, ok := evictedValue.(common.Trie) + if !ok { + log.Warn("invalid data in dataTriesHolder cache", "entry type", fmt.Sprintf("%T", evictedValue)) + continue + } + log.Trace("storing evicted dirty trie in evicted buffer", "key", []byte(evictedKeyString)) + dth.evictedBuffer[evictedKeyString] = evictedTrie + } +} + +func (dth *dataTriesHolder) clearFromEvictedBuffer(key []byte) { + if len(dth.evictedBuffer) == 0 { + return + } + + _, ok := dth.evictedBuffer[string(key)] + if !ok { + return + } + + // this means that this trie was evicted while being dirty + delete(dth.evictedBuffer, string(key)) + log.Trace("removed trie from evicted buffer", "key", key) +} + +// Get returns the trie pointer that is stored in the map at the given key +func (dth *dataTriesHolder) Get(key []byte) common.Trie { + if len(key) == 0 { + return nil + } + + dth.mutex.Lock() + defer dth.mutex.Unlock() + + keyString := string(key) + + val, ok := dth.cacher.Get(keyString) + if !ok { + // maybe it was evicted while being dirty + evictedTr, exists := dth.evictedBuffer[keyString] + if !exists { + log.Trace("trie not found in data tries holder", "key", key) + return nil + } + + log.Trace("trie found in evicted buffer of data tries holder, cancel eviction", "key", key) + delete(dth.evictedBuffer, keyString) + dth.putNoLock(key, evictedTr) + return evictedTr + } + + dth.touchedTries[keyString] = struct{}{} + tr, ok := val.(common.Trie) + if !ok { + log.Warn("invalid data in dataTriesHolder cache", "entry type", fmt.Sprintf("%T", val)) + return nil + } + log.Trace("trie found in data tries holder cache", "key", key) + + return tr +} + +// GetAll returns all the tries that are marked as dirty for this implementation. +// It also resets their dirty flag and recomputes the total size. +func (dth *dataTriesHolder) GetAll() []common.Trie { + dth.mutex.Lock() + defer dth.mutex.Unlock() + + tries := make([]common.Trie, 0) + for keyString := range dth.dirtyTries { + tr := dth.getDirtyTrieNoLock(keyString) + if check.IfNil(tr) { + continue + } + log.Trace("getting dirty trie from data tries holder", "key", []byte(keyString)) + tries = append(tries, tr) + } + dth.dirtyTries = make(map[string]struct{}) + dth.evictedBuffer = make(map[string]common.Trie) + dth.recomputeTotalSize() + log.Trace("data tries holder returned all dirty tries", "numTries", len(tries), "totalCacheSize", dth.cacher.SizeInBytesContained()) + return tries +} + +func (dth *dataTriesHolder) getDirtyTrieNoLock(key string) common.Trie { + entry, exists := dth.cacher.Get(key) + if exists { + tr, ok := entry.(common.Trie) + if !ok { + log.Warn("invalid data in dataTriesHolder cache", "entry type", fmt.Sprintf("%T", entry)) + return nil + } + + return tr + } + + tr, ok := dth.evictedBuffer[key] + if !ok { + return nil + } + return tr +} + +func (dth *dataTriesHolder) recomputeTotalSize() { + for keyString := range dth.touchedTries { + entry, exists := dth.cacher.Get(keyString) + if !exists { + continue + } + tr, ok := entry.(common.Trie) + if !ok { + log.Warn("invalid data in dataTriesHolder cache", "entry type", fmt.Sprintf("%T", entry)) + continue + } + + evicted := dth.cacher.AddSized(keyString, tr, int64(tr.SizeInMemory())) + if evicted { + log.Warn("unexpected eviction while recomputing total size in dataTriesHolder") + } + } + dth.touchedTries = make(map[string]struct{}) +} + +// Remove evicts the trie associated with the given key from the holder. +// This must be called when an account is deleted so that a subsequent recreation +// of the same account does not inherit the stale data trie from the previous incarnation. +func (dth *dataTriesHolder) Remove(key []byte) { + if len(key) == 0 { + return + } + + dth.mutex.Lock() + defer dth.mutex.Unlock() + + keyString := string(key) + dth.cacher.Remove(keyString) + delete(dth.dirtyTries, keyString) + delete(dth.touchedTries, keyString) + delete(dth.evictedBuffer, keyString) + log.Trace("removed trie from data tries holder", "key", key) +} + +// Reset clears the tries map +func (dth *dataTriesHolder) Reset() { + dth.mutex.Lock() + dth.reset() + log.Trace("data tries holder reset") + dth.mutex.Unlock() +} + +// MarkAsDirty marks the trie at the given key as dirty +func (dth *dataTriesHolder) MarkAsDirty(key []byte) { + dth.mutex.Lock() + stringKey := string(key) + dth.dirtyTries[stringKey] = struct{}{} + dth.touchedTries[stringKey] = struct{}{} + log.Trace("marked trie as dirty in data tries holder", "key", key) + dth.mutex.Unlock() +} + +func (dth *dataTriesHolder) reset() { + log.Trace("reset data tries holder") + + dth.cacher.Purge() + dth.dirtyTries = make(map[string]struct{}) + dth.touchedTries = make(map[string]struct{}) + dth.evictedBuffer = make(map[string]common.Trie) +} + +// IsInterfaceNil returns true if underlying object is nil +func (dth *dataTriesHolder) IsInterfaceNil() bool { + return dth == nil +} diff --git a/state/triesHolder/dataTriesHolder_test.go b/state/triesHolder/dataTriesHolder_test.go new file mode 100644 index 00000000000..86b3a8eb332 --- /dev/null +++ b/state/triesHolder/dataTriesHolder_test.go @@ -0,0 +1,498 @@ +package triesHolder + +import ( + "errors" + "strconv" + "sync" + "testing" + + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/common" + trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" + "github.com/stretchr/testify/assert" +) + +const ( + dthSize = 2 * 1024 * 1024 // 2MB + oneKB = 1 * 1024 +) + +type testTries struct { + key []byte + trie common.Trie +} + +func getTestTries(numTries int) []testTries { + tries := make([]testTries, 0) + for i := 0; i < numTries; i++ { + tr := &trieMock.TrieStub{ + SizeInMemoryCalled: func() int { + return oneKB + }, + } + key := []byte("trie" + strconv.Itoa(i)) + tries = append(tries, testTries{key: key, trie: tr}) + } + return tries +} + +func TestNewDataTriesHolder(t *testing.T) { + t.Parallel() + + t.Run(" invalid max size", func(t *testing.T) { + t.Parallel() + + dth, err := NewDataTriesHolder(512 * 1024) // less than 1MB + assert.True(t, errors.Is(err, ErrInvalidMaxTrieSizeValue)) + assert.True(t, check.IfNil(dth)) + }) + + t.Run("should create new instance", func(t *testing.T) { + t.Parallel() + + dth, err := NewDataTriesHolder(dthSize) + assert.Nil(t, err) + assert.False(t, check.IfNil(dth)) + assert.NotNil(t, dth.evictedBuffer) + assert.NotNil(t, dth.dirtyTries) + assert.NotNil(t, dth.touchedTries) + assert.Equal(t, uint64(0), dth.cacher.SizeInBytesContained()) + assert.Equal(t, 0, dth.cacher.Len()) + }) +} + +func TestDataTriesHolder_Put(t *testing.T) { + t.Parallel() + + t.Run("put invalid data", func(t *testing.T) { + t.Parallel() + + dth, _ := NewDataTriesHolder(dthSize) + + dth.Put([]byte("key"), nil) + dth.Put(nil, &trieMock.TrieStub{}) + assert.Equal(t, uint64(0), dth.cacher.SizeInBytesContained()) + assert.Equal(t, 0, dth.cacher.Len()) + assert.Equal(t, 0, len(dth.dirtyTries)) + assert.Equal(t, 0, len(dth.touchedTries)) + assert.Equal(t, 0, len(dth.evictedBuffer)) + }) + t.Run("put in empty tries holder", func(t *testing.T) { + t.Parallel() + + dth, _ := NewDataTriesHolder(dthSize) + entry := getTestTries(1)[0] + + dth.Put(entry.key, entry.trie) + + assert.Equal(t, 1, dth.cacher.Len()) + assert.Equal(t, 1, len(dth.dirtyTries)) + retrievedEntry, ok := dth.cacher.Get(string(entry.key)) + assert.True(t, ok) + tr, ok := retrievedEntry.(common.Trie) + assert.True(t, ok) + assert.Equal(t, tr, entry.trie) + assert.Equal(t, uint64(oneKB), dth.cacher.SizeInBytesContained()) + assert.Equal(t, 1, len(dth.dirtyTries)) + assert.Equal(t, 1, len(dth.touchedTries)) + assert.Equal(t, 0, len(dth.evictedBuffer)) + }) + t.Run("put in populated tries holder", func(t *testing.T) { + t.Parallel() + + dth, _ := NewDataTriesHolder(dthSize) + numEntries := 5 + entries := getTestTries(numEntries) + for i := 0; i < numEntries; i++ { + dth.Put(entries[i].key, entries[i].trie) + } + + assert.Equal(t, numEntries, dth.cacher.Len()) + assert.Equal(t, numEntries, len(dth.dirtyTries)) + assert.Equal(t, numEntries, len(dth.touchedTries)) + assert.Equal(t, 0, len(dth.evictedBuffer)) + assert.Equal(t, uint64(numEntries*oneKB), dth.cacher.SizeInBytesContained()) + cacherKeys := dth.cacher.Keys() + assert.Equal(t, numEntries, len(cacherKeys)) + for i := 0; i < numEntries; i++ { + retrievedEntry, ok := dth.cacher.Get(cacherKeys[i]) + assert.True(t, ok) + tr, ok := retrievedEntry.(common.Trie) + assert.True(t, ok) + assert.Equal(t, tr, entries[i].trie) + assert.Equal(t, cacherKeys[i], string(entries[i].key)) + } + }) + t.Run("put oldest used trie moves to newest used", func(t *testing.T) { + t.Parallel() + + dth, _ := NewDataTriesHolder(dthSize) + numEntries := 5 + entries := getTestTries(numEntries) + for i := 0; i < numEntries; i++ { + dth.Put(entries[i].key, entries[i].trie) + } + + dth.Put(entries[0].key, entries[0].trie) + + assert.Equal(t, numEntries, dth.cacher.Len()) + assert.Equal(t, numEntries, len(dth.dirtyTries)) + assert.Equal(t, numEntries, len(dth.touchedTries)) + assert.Equal(t, 0, len(dth.evictedBuffer)) + assert.Equal(t, uint64(numEntries*oneKB), dth.cacher.SizeInBytesContained()) + keys := dth.cacher.Keys() + assert.Equal(t, string(entries[0].key), keys[numEntries-1]) + }) + t.Run("put existing trie moves to newest used", func(t *testing.T) { + t.Parallel() + + dth, _ := NewDataTriesHolder(dthSize) + numEntries := 5 + entries := getTestTries(numEntries) + for i := 0; i < numEntries; i++ { + dth.Put(entries[i].key, entries[i].trie) + } + + triePos := 2 + dth.Put(entries[triePos].key, entries[triePos].trie) + + assert.Equal(t, numEntries, dth.cacher.Len()) + assert.Equal(t, numEntries, len(dth.dirtyTries)) + assert.Equal(t, numEntries, len(dth.touchedTries)) + assert.Equal(t, 0, len(dth.evictedBuffer)) + assert.Equal(t, uint64(numEntries*oneKB), dth.cacher.SizeInBytesContained()) + keys := dth.cacher.Keys() + assert.Equal(t, string(entries[triePos].key), keys[numEntries-1]) + }) + t.Run("put with eviction - evicted dirty tries should be in eviction buffer", func(t *testing.T) { + t.Parallel() + + dth, _ := NewDataTriesHolder(dthSize) + numEntries := 5 + entries := getTestTries(numEntries) + for i := 0; i < numEntries; i++ { + dth.Put(entries[i].key, entries[i].trie) + } + + tr := &trieMock.TrieStub{ + SizeInMemoryCalled: func() int { + return dthSize + }, + } + key := []byte("trieEvict") + + dth.Put(key, tr) + + assert.Equal(t, 1, dth.cacher.Len()) + assert.Equal(t, numEntries+1, len(dth.dirtyTries)) + assert.Equal(t, numEntries+1, len(dth.touchedTries)) + assert.Equal(t, numEntries, len(dth.evictedBuffer)) + assert.Equal(t, uint64(dthSize), dth.cacher.SizeInBytesContained()) + }) + t.Run("put with eviction - not dirty tries should evict", func(t *testing.T) { + t.Parallel() + + dth, _ := NewDataTriesHolder(dthSize) + numEntries := 5 + entries := getTestTries(numEntries) + for i := 0; i < numEntries; i++ { + dth.Put(entries[i].key, entries[i].trie) + } + dth.dirtyTries = make(map[string]struct{}) // reset dirty tries + + sizeToEvictTwoTries := dthSize - (3 * oneKB) + tr := &trieMock.TrieStub{ + SizeInMemoryCalled: func() int { + return sizeToEvictTwoTries + }, + } + key := []byte("trieEvict") + + dth.Put(key, tr) + numEvictedTries := 2 + + assert.Equal(t, numEntries-numEvictedTries+1, dth.cacher.Len()) + assert.Equal(t, 1, len(dth.dirtyTries)) + assert.Equal(t, numEntries+1, len(dth.touchedTries)) + assert.Equal(t, 0, len(dth.evictedBuffer)) + assert.Equal(t, uint64(dthSize), dth.cacher.SizeInBytesContained()) + }) +} + +func TestDataTriesHolder_Get(t *testing.T) { + t.Parallel() + + t.Run("get not existing trie should return nil", func(t *testing.T) { + t.Parallel() + + dth, _ := NewDataTriesHolder(dthSize) + tr := dth.Get([]byte("notExistingKey")) + assert.Nil(t, tr) + assert.Equal(t, 0, dth.cacher.Len()) + }) + t.Run("get existing trie should move to newest used", func(t *testing.T) { + t.Parallel() + + dth, _ := NewDataTriesHolder(dthSize) + numEntries := 5 + entries := getTestTries(numEntries) + for i := 0; i < numEntries; i++ { + dth.Put(entries[i].key, entries[i].trie) + } + + tr := dth.Get(entries[1].key) + assert.Equal(t, entries[1].trie, tr) + assert.Equal(t, numEntries, dth.cacher.Len()) + keys := dth.cacher.Keys() + assert.Equal(t, string(entries[1].key), keys[numEntries-1]) + assert.Equal(t, uint64(numEntries*oneKB), dth.cacher.SizeInBytesContained()) + }) + t.Run("get from evicted buffer should put back in cache", func(t *testing.T) { + t.Parallel() + + dth, _ := NewDataTriesHolder(dthSize) + numEntries := 5 + entries := getTestTries(numEntries) + for i := 0; i < numEntries; i++ { + dth.Put(entries[i].key, entries[i].trie) + } + + assert.Equal(t, 0, len(dth.evictedBuffer)) + tr := &trieMock.TrieStub{ + SizeInMemoryCalled: func() int { + return dthSize + }, + } + key := []byte("trieEvict") + dth.Put(key, tr) + assert.Equal(t, numEntries, len(dth.evictedBuffer)) + assert.Equal(t, 1, dth.cacher.Len()) + + numGetFromTrie := 3 + for i := 0; i < numGetFromTrie; i++ { + _ = dth.Get(entries[i].key) + } + + assert.Equal(t, 3, len(dth.evictedBuffer)) // 2 original tries + trieEvict + assert.Equal(t, numGetFromTrie, dth.cacher.Len()) + }) +} + +func TestDataTriesHolder_GetAll(t *testing.T) { + t.Parallel() + + t.Run("dirty trie not found in tries map does not panic", func(t *testing.T) { + t.Parallel() + + dth, _ := NewDataTriesHolder(dthSize) + numEntries := 5 + entries := getTestTries(numEntries) + for i := 0; i < numEntries; i++ { + dth.Put(entries[i].key, entries[i].trie) + } + dth.cacher.Remove(string(entries[0].key)) + + dirtyTries := dth.GetAll() + assert.Equal(t, numEntries-1, len(dirtyTries)) + assert.Equal(t, 0, len(dth.dirtyTries)) + }) + t.Run("trie size is correctly computed after GetAll and eviction", func(t *testing.T) { + t.Parallel() + + dth, _ := NewDataTriesHolder(dthSize) + numEntries := 5 + entries := getTestTries(numEntries) + for i := 0; i < numEntries; i++ { + dth.Put(entries[i].key, entries[i].trie) + } + + assert.Equal(t, uint64(numEntries*oneKB), dth.cacher.SizeInBytesContained()) + + dirtyTries := dth.GetAll() + assert.Equal(t, numEntries, len(dirtyTries)) + assert.Equal(t, 0, len(dth.dirtyTries)) + assert.Equal(t, uint64(numEntries*oneKB), dth.cacher.SizeInBytesContained()) + + trieSize := dthSize - ((numEntries + 1) * oneKB) + tr := &trieMock.TrieStub{ + SizeInMemoryCalled: func() int { + return trieSize + }, + } + key := []byte("newTrie") + dth.Put(key, tr) + assert.Equal(t, numEntries+1, dth.cacher.Len()) + assert.Equal(t, uint64(numEntries*oneKB+trieSize), dth.cacher.SizeInBytesContained()) + dth.dirtyTries = make(map[string]struct{}) // reset dirty tries + + // get a trie and "resolve" some nodes, thus increasing its size in memory + _ = dth.Get(key) + originalSize := trieSize + trieSize = trieSize + oneKB + assert.Equal(t, uint64(numEntries*oneKB+originalSize), dth.cacher.SizeInBytesContained()) // size is not updated + assert.Equal(t, numEntries+1, dth.cacher.Len()) // no eviction + assert.Equal(t, 1, len(dth.touchedTries)) + + dirtyTries = dth.GetAll() + assert.Equal(t, 0, len(dirtyTries)) + assert.Equal(t, 0, len(dth.dirtyTries)) + assert.Equal(t, uint64(dthSize), dth.cacher.SizeInBytesContained()) // size is updated + assert.Equal(t, numEntries+1, dth.cacher.Len()) // no eviction + assert.Equal(t, 0, len(dth.touchedTries)) + + // put again the same trie, now with increased size triggering eviction + _ = dth.Get(key) + assert.Equal(t, 1, len(dth.touchedTries)) + trieSize = trieSize + oneKB/2 + assert.Equal(t, 0, len(dth.dirtyTries)) + assert.Equal(t, uint64(dthSize), dth.cacher.SizeInBytesContained()) // size is not updated + assert.Equal(t, numEntries+1, dth.cacher.Len()) // no eviction + + dirtyTries = dth.GetAll() + assert.Equal(t, 0, len(dirtyTries)) + assert.Equal(t, 0, len(dth.dirtyTries)) + assert.Equal(t, uint64(dthSize-oneKB/2), dth.cacher.SizeInBytesContained()) // size is updated + assert.Equal(t, numEntries, dth.cacher.Len()) // eviction + assert.Equal(t, 0, len(dth.touchedTries)) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + dth, _ := NewDataTriesHolder(dthSize) + numEntries := 5 + entries := getTestTries(numEntries) + for i := 0; i < numEntries; i++ { + dth.Put(entries[i].key, entries[i].trie) + } + + dirtyTries := dth.GetAll() + assert.Equal(t, numEntries, len(dirtyTries)) + assert.Equal(t, 0, len(dth.dirtyTries)) + }) +} + +func TestDataTriesHolder_Reset(t *testing.T) { + t.Parallel() + + dth, _ := NewDataTriesHolder(dthSize) + numEntries := 5 + entries := getTestTries(numEntries) + for i := 0; i < numEntries; i++ { + dth.Put(entries[i].key, entries[i].trie) + } + + dth.Reset() + assert.Equal(t, 0, dth.cacher.Len()) + assert.Equal(t, uint64(0), dth.cacher.SizeInBytesContained()) + assert.Equal(t, 0, len(dth.dirtyTries)) + assert.Equal(t, 0, len(dth.touchedTries)) + assert.Equal(t, 0, len(dth.evictedBuffer)) +} + +func TestDataTriesHolder_MarkAsDirty(t *testing.T) { + t.Parallel() + + t.Run("mark existing trie as dirty should track key once and return trie on get all", func(t *testing.T) { + t.Parallel() + + dth, _ := NewDataTriesHolder(dthSize) + entry := getTestTries(1)[0] + dth.Put(entry.key, entry.trie) + dth.dirtyTries = make(map[string]struct{}) + dth.touchedTries = make(map[string]struct{}) + + dth.MarkAsDirty(entry.key) + + assert.Equal(t, 1, len(dth.dirtyTries)) + assert.Equal(t, 1, len(dth.touchedTries)) + + dirtyTries := dth.GetAll() + assert.Equal(t, []common.Trie{entry.trie}, dirtyTries) + assert.Equal(t, 0, len(dth.dirtyTries)) + assert.Equal(t, 0, len(dth.touchedTries)) + }) +} + +func TestDataTriesHolder_Concurrency(t *testing.T) { + t.Parallel() + + dth, _ := NewDataTriesHolder(dthSize) + numEntries := 5 + entries := getTestTries(numEntries) + + wg := sync.WaitGroup{} + wg.Add(numEntries) + + for i := 0; i < numEntries; i++ { + go func(key int) { + dth.Put(entries[key].key, entries[key].trie) + wg.Done() + }(i) + } + + wg.Wait() + + assert.Equal(t, numEntries, dth.cacher.Len()) + assert.Equal(t, numEntries, len(dth.dirtyTries)) + assert.Equal(t, uint64(numEntries*oneKB), dth.cacher.SizeInBytesContained()) +} + +func BenchmarkDataTriesHolder_PutWithEviction(b *testing.B) { + numEntries := 100000 + dth, _ := NewDataTriesHolder(uint64(numEntries * oneKB / 10)) // set max size to 10% of total size to force evictions + entries := getTestTries(numEntries) + b.ResetTimer() + for i := 0; i < b.N; i++ { + entry := entries[i%numEntries] + dth.Put(entry.key, entry.trie) + if i%1000 == 0 { + dth.dirtyTries = make(map[string]struct{}) // reset dirty tries to allow evictions + } + } +} + +func BenchmarkDataTriesHolder_PutNoEviction(b *testing.B) { + numEntries := 100000 + dth, _ := NewDataTriesHolder(uint64(numEntries * oneKB * 2)) // set max size to 200% of total size to avoid evictions + entries := getTestTries(numEntries) + b.ResetTimer() + for i := 0; i < b.N; i++ { + entry := entries[i%numEntries] + dth.Put(entry.key, entry.trie) + } +} + +func BenchmarkDataTriesHolder_Get(b *testing.B) { + numEntries := 100000 + dth, _ := NewDataTriesHolder(uint64(numEntries * oneKB * 2)) + entries := getTestTries(numEntries) + for i := 0; i < numEntries; i++ { + dth.Put(entries[i].key, entries[i].trie) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + entry := entries[i%numEntries] + dth.Get(entry.key) + } +} + +func BenchmarkDataTriesHolder_GetAll(b *testing.B) { + numEntries := 10000 + dth, _ := NewDataTriesHolder(uint64(numEntries * oneKB * 2)) + entries := getTestTries(numEntries) + for i := 0; i < numEntries; i++ { + dth.Put(entries[i].key, entries[i].trie) + } + dth.dirtyTries = make(map[string]struct{}) + dth.touchedTries = make(map[string]struct{}) + numDirty := numEntries / 10 + for i := 0; i < numDirty; i++ { + dth.dirtyTries[string(entries[i].key)] = struct{}{} + dth.touchedTries[string(entries[i].key)] = struct{}{} + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = dth.GetAll() + } +} diff --git a/state/triesHolder/disabledTriesHolder.go b/state/triesHolder/disabledTriesHolder.go new file mode 100644 index 00000000000..70f7becf048 --- /dev/null +++ b/state/triesHolder/disabledTriesHolder.go @@ -0,0 +1,41 @@ +package triesHolder + +import "github.com/multiversx/mx-chain-go/common" + +type disabledDataTriesHolder struct{} + +// NewDisabledDataTriesHolder creates a disabled no-op data tries holder. +func NewDisabledDataTriesHolder() *disabledDataTriesHolder { + return &disabledDataTriesHolder{} +} + +// Put does nothing for the disabled implementation. +func (d *disabledDataTriesHolder) Put(_ []byte, _ common.Trie) { +} + +// Get always returns nil for the disabled implementation. +func (d *disabledDataTriesHolder) Get(_ []byte) common.Trie { + return nil +} + +// GetAll always returns an empty list for the disabled implementation. +func (d *disabledDataTriesHolder) GetAll() []common.Trie { + return make([]common.Trie, 0) +} + +// MarkAsDirty does nothing for the disabled implementation. +func (d *disabledDataTriesHolder) MarkAsDirty(_ []byte) { +} + +// Reset does nothing for the disabled implementation. +func (d *disabledDataTriesHolder) Reset() { +} + +// Remove does nothing for the disabled implementation +func (d *disabledDataTriesHolder) Remove(_ []byte) { +} + +// IsInterfaceNil returns true if the underlying object is nil. +func (d *disabledDataTriesHolder) IsInterfaceNil() bool { + return d == nil +} diff --git a/state/triesHolder/disabledTriesHolder_test.go b/state/triesHolder/disabledTriesHolder_test.go new file mode 100644 index 00000000000..d947896bf5f --- /dev/null +++ b/state/triesHolder/disabledTriesHolder_test.go @@ -0,0 +1,47 @@ +package triesHolder + +import ( + "testing" + + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/stretchr/testify/assert" +) + +func TestNewDisabledDataTriesHolder(t *testing.T) { + t.Parallel() + + dth := NewDisabledDataTriesHolder() + assert.False(t, check.IfNil(dth)) +} + +func TestDisabledDataTriesHolder_PutGetGetAll(t *testing.T) { + t.Parallel() + + dth := NewDisabledDataTriesHolder() + dth.Put([]byte("key"), nil) + + trie := dth.Get([]byte("key")) + assert.Nil(t, trie) + assert.Empty(t, dth.GetAll()) +} + +func TestDisabledDataTriesHolder_MarkAsDirtyAndReset(t *testing.T) { + t.Parallel() + + dth := NewDisabledDataTriesHolder() + dth.MarkAsDirty([]byte("key")) + dth.Reset() + + assert.Empty(t, dth.GetAll()) +} + +func TestDisabledDataTriesHolder_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var dth *disabledDataTriesHolder + assert.True(t, dth.IsInterfaceNil()) + + dth = NewDisabledDataTriesHolder() + assert.False(t, dth.IsInterfaceNil()) +} + diff --git a/state/triesHolder/errors.go b/state/triesHolder/errors.go new file mode 100644 index 00000000000..539ef76c3f5 --- /dev/null +++ b/state/triesHolder/errors.go @@ -0,0 +1,6 @@ +package triesHolder + +import "errors" + +// ErrInvalidMaxTrieSizeValue signals that the provided max trie size value is invalid +var ErrInvalidMaxTrieSizeValue = errors.New("invalid max trie size value") diff --git a/state/dataTriesHolder.go b/state/triesHolder/triesHolder.go similarity index 53% rename from state/dataTriesHolder.go rename to state/triesHolder/triesHolder.go index 8333b875fce..b90292d85f7 100644 --- a/state/dataTriesHolder.go +++ b/state/triesHolder/triesHolder.go @@ -1,4 +1,4 @@ -package state +package triesHolder import ( "sync" @@ -7,34 +7,29 @@ import ( logger "github.com/multiversx/mx-chain-logger-go" ) -type dataTriesHolder struct { +type triesHolder struct { tries map[string]common.Trie mutex sync.RWMutex } -// NewDataTriesHolder creates a new instance of dataTriesHolder -func NewDataTriesHolder() *dataTriesHolder { - return &dataTriesHolder{ +// NewTriesHolder creates a new instance of triesHolder +func NewTriesHolder() *triesHolder { + return &triesHolder{ tries: make(map[string]common.Trie), } } // Put adds a trie pointer to the tries map -func (dth *dataTriesHolder) Put(key []byte, tr common.Trie) { - log.Trace("put trie in data tries holder", "key", key) +func (dth *triesHolder) Put(key []byte, tr common.Trie) { + log.Trace("put trie in tries holder", "key", key) dth.mutex.Lock() dth.tries[string(key)] = tr dth.mutex.Unlock() } -// Replace changes a trie pointer to the tries map -func (dth *dataTriesHolder) Replace(key []byte, tr common.Trie) { - dth.Put(key, tr) -} - // Get returns the trie pointer that is stored in the map at the given key -func (dth *dataTriesHolder) Get(key []byte) common.Trie { +func (dth *triesHolder) Get(key []byte) common.Trie { dth.mutex.Lock() defer dth.mutex.Unlock() @@ -42,7 +37,7 @@ func (dth *dataTriesHolder) Get(key []byte) common.Trie { } // GetAll returns all trie pointers from the map -func (dth *dataTriesHolder) GetAll() []common.Trie { +func (dth *triesHolder) GetAll() []common.Trie { dth.mutex.Lock() defer dth.mutex.Unlock() @@ -54,21 +49,15 @@ func (dth *dataTriesHolder) GetAll() []common.Trie { return tries } -// GetAllTries returns the tries with key value map -func (dth *dataTriesHolder) GetAllTries() map[string]common.Trie { +// Remove deletes the trie associated with the given key from the holder +func (dth *triesHolder) Remove(key []byte) { dth.mutex.Lock() - defer dth.mutex.Unlock() - - copyTries := make(map[string]common.Trie, len(dth.tries)) - for key, trie := range dth.tries { - copyTries[key] = trie - } - - return copyTries + delete(dth.tries, string(key)) + dth.mutex.Unlock() } // Reset clears the tries map -func (dth *dataTriesHolder) Reset() { +func (dth *triesHolder) Reset() { dth.mutex.Lock() if log.GetLevel() == logger.LogTrace { @@ -81,7 +70,10 @@ func (dth *dataTriesHolder) Reset() { dth.mutex.Unlock() } +// MarkAsDirty does nothing in this implementation +func (dth *triesHolder) MarkAsDirty(_ []byte) {} + // IsInterfaceNil returns true if underlying object is nil -func (dth *dataTriesHolder) IsInterfaceNil() bool { +func (dth *triesHolder) IsInterfaceNil() bool { return dth == nil } diff --git a/state/triesHolder/triesHolder_test.go b/state/triesHolder/triesHolder_test.go new file mode 100644 index 00000000000..f34f0ef6dfe --- /dev/null +++ b/state/triesHolder/triesHolder_test.go @@ -0,0 +1,88 @@ +package triesHolder + +import ( + "strconv" + "sync" + "testing" + + "github.com/multiversx/mx-chain-core-go/core/check" + trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" + "github.com/stretchr/testify/assert" +) + +func TestNewTriesHolder(t *testing.T) { + t.Parallel() + + dth := NewTriesHolder() + assert.False(t, check.IfNil(dth)) +} + +func TestTriesHolder_PutAndGet(t *testing.T) { + t.Parallel() + + tr1 := &trieMock.TrieStub{} + + dth := NewTriesHolder() + dth.Put([]byte("trie1"), tr1) + tr := dth.Get([]byte("trie1")) + + assert.True(t, tr == tr1) +} + +func TestTriesHolder_GetAll(t *testing.T) { + t.Parallel() + + tr1 := &trieMock.TrieStub{} + tr2 := &trieMock.TrieStub{} + tr3 := &trieMock.TrieStub{} + + dth := NewTriesHolder() + dth.Put([]byte("trie1"), tr1) + dth.Put([]byte("trie2"), tr2) + dth.Put([]byte("trie3"), tr3) + tries := dth.GetAll() + + assert.Equal(t, 3, len(tries)) +} + +func TestTriesHolder_Reset(t *testing.T) { + t.Parallel() + + tr1 := &trieMock.TrieStub{} + + dth := NewTriesHolder() + dth.Put([]byte("trie1"), tr1) + dth.Reset() + + tr := dth.Get([]byte("trie1")) + assert.Nil(t, tr) +} + +func TestTriesHolder_Concurrency(t *testing.T) { + t.Parallel() + + dth := NewTriesHolder() + numCalls := 5000 + + wg := sync.WaitGroup{} + wg.Add(numCalls) + + for i := 0; i < numCalls; i++ { + go func(key int) { + defer wg.Done() + + switch key % 4 { + case 0: + dth.Put([]byte(strconv.Itoa(key)), &trieMock.TrieStub{}) + case 1: + dth.Get([]byte(strconv.Itoa(key))) + case 2: + dth.GetAll() + case 3: + dth.Reset() + } + }(i) + } + + wg.Wait() +} diff --git a/testscommon/storage/storageManagerArgs.go b/testscommon/common/storageManagerArgs.go similarity index 98% rename from testscommon/storage/storageManagerArgs.go rename to testscommon/common/storageManagerArgs.go index 1f32e18f0d0..6e244bfb162 100644 --- a/testscommon/storage/storageManagerArgs.go +++ b/testscommon/common/storageManagerArgs.go @@ -1,4 +1,4 @@ -package storage +package common import ( "github.com/multiversx/mx-chain-go/common/statistics/disabled" diff --git a/testscommon/components/components.go b/testscommon/components/components.go index a8d7fb25b50..c37846a835e 100644 --- a/testscommon/components/components.go +++ b/testscommon/components/components.go @@ -9,6 +9,8 @@ import ( "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/endProcess" "github.com/multiversx/mx-chain-core-go/data/outport" + "github.com/multiversx/mx-chain-go/state/triesHolder" + "github.com/multiversx/mx-chain-go/trie/collapseManager" logger "github.com/multiversx/mx-chain-logger-go" wasmConfig "github.com/multiversx/mx-chain-vm-go/config" "github.com/stretchr/testify/require" @@ -37,13 +39,11 @@ import ( p2pConfig "github.com/multiversx/mx-chain-go/p2p/config" "github.com/multiversx/mx-chain-go/sharding" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" - "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon" commonMocks "github.com/multiversx/mx-chain-go/testscommon/common" "github.com/multiversx/mx-chain-go/testscommon/dblookupext" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" statusHandlerMock "github.com/multiversx/mx-chain-go/testscommon/statusHandler" - "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/trie" ) @@ -364,20 +364,20 @@ func GetNetworkFactoryArgs() networkComp.NetworkComponentsFactoryArgs { // GetStateFactoryArgs - func GetStateFactoryArgs(coreComponents factory.CoreComponentsHolder, statusCoreComp factory.StatusCoreComponentsHolder) stateComp.StateComponentsFactoryArgs { - tsm, _ := trie.NewTrieStorageManager(storage.GetStorageManagerArgs()) + tsm, _ := trie.NewTrieStorageManager(commonMocks.GetStorageManagerArgs()) storageManagerUser, _ := trie.NewTrieStorageManagerWithoutPruning(tsm) - tsm, _ = trie.NewTrieStorageManager(storage.GetStorageManagerArgs()) + tsm, _ = trie.NewTrieStorageManager(commonMocks.GetStorageManagerArgs()) storageManagerPeer, _ := trie.NewTrieStorageManagerWithoutPruning(tsm) trieStorageManagers := make(map[string]common.StorageManager) trieStorageManagers[dataRetriever.UserAccountsUnit.String()] = storageManagerUser trieStorageManagers[dataRetriever.PeerAccountsUnit.String()] = storageManagerPeer - triesHolder := state.NewDataTriesHolder() - trieUsers, _ := trie.NewTrie(storageManagerUser, coreComponents.InternalMarshalizer(), coreComponents.Hasher(), coreComponents.EnableEpochsHandler(), 5) - triePeers, _ := trie.NewTrie(storageManagerPeer, coreComponents.InternalMarshalizer(), coreComponents.Hasher(), coreComponents.EnableEpochsHandler(), 5) - triesHolder.Put([]byte(dataRetriever.UserAccountsUnit.String()), trieUsers) - triesHolder.Put([]byte(dataRetriever.PeerAccountsUnit.String()), triePeers) + triesContainer := triesHolder.NewTriesHolder() + trieUsers, _ := trie.NewTrie(storageManagerUser, coreComponents.InternalMarshalizer(), coreComponents.Hasher(), coreComponents.EnableEpochsHandler(), collapseManager.NewDisabledCollapseManager()) + triePeers, _ := trie.NewTrie(storageManagerPeer, coreComponents.InternalMarshalizer(), coreComponents.Hasher(), coreComponents.EnableEpochsHandler(), collapseManager.NewDisabledCollapseManager()) + triesContainer.Put([]byte(dataRetriever.UserAccountsUnit.String()), trieUsers) + triesContainer.Put([]byte(dataRetriever.PeerAccountsUnit.String()), triePeers) stateComponentsFactoryArgs := stateComp.StateComponentsFactoryArgs{ Config: GetGeneralConfig(), diff --git a/testscommon/components/configs.go b/testscommon/components/configs.go index b3672ef1d32..0bd1984d0c4 100644 --- a/testscommon/components/configs.go +++ b/testscommon/components/configs.go @@ -1,6 +1,7 @@ package components import ( + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/testscommon" ) @@ -20,10 +21,12 @@ func GetGeneralConfig() config.Config { SignatureLength: 48, }, StateTriesConfig: config.StateTriesConfig{ - AccountsStatePruningEnabled: true, - PeerStatePruningEnabled: true, - MaxStateTrieLevelInMemory: 5, - MaxPeerTrieLevelInMemory: 5, + AccountsStatePruningEnabled: true, + PeerStatePruningEnabled: true, + MaxUserTrieSizeInMemory: common.TenMbSize, + MaxPeerTrieSizeInMemory: common.TenMbSize, + DataTriesSizeInMemory: common.TenMbSize, + NumLeavesToCollapseSingleRun: 100, }, EvictionWaitingList: config.EvictionWaitingListConfig{ HashesSize: 100, diff --git a/testscommon/generalConfig.go b/testscommon/generalConfig.go index df01b6ec29c..b0949da7c56 100644 --- a/testscommon/generalConfig.go +++ b/testscommon/generalConfig.go @@ -1,6 +1,7 @@ package testscommon import ( + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/storage/storageunit" ) @@ -149,11 +150,13 @@ func GetGeneralConfig() config.Config { }, }, StateTriesConfig: config.StateTriesConfig{ - SnapshotsEnabled: true, - AccountsStatePruningEnabled: false, - PeerStatePruningEnabled: false, - MaxStateTrieLevelInMemory: 5, - MaxPeerTrieLevelInMemory: 5, + SnapshotsEnabled: true, + AccountsStatePruningEnabled: false, + PeerStatePruningEnabled: false, + MaxUserTrieSizeInMemory: common.TenMbSize, + MaxPeerTrieSizeInMemory: common.TenMbSize, + DataTriesSizeInMemory: common.TenMbSize, + NumLeavesToCollapseSingleRun: common.NumLeavesToCollapseSingleRun, }, TrieStorageManagerConfig: config.TrieStorageManagerConfig{ PruningBufferLen: 1000, diff --git a/testscommon/integrationtests/factory.go b/testscommon/integrationtests/factory.go index 80d65d1c703..785aa440b33 100644 --- a/testscommon/integrationtests/factory.go +++ b/testscommon/integrationtests/factory.go @@ -13,15 +13,17 @@ import ( "github.com/multiversx/mx-chain-go/state/lastSnapshotMarker" "github.com/multiversx/mx-chain-go/state/storagePruningManager" "github.com/multiversx/mx-chain-go/state/storagePruningManager/evictionWaitingList" + "github.com/multiversx/mx-chain-go/state/triesHolder" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/storage/database" "github.com/multiversx/mx-chain-go/storage/factory" "github.com/multiversx/mx-chain-go/storage/storageunit" "github.com/multiversx/mx-chain-go/testscommon" + testCommon "github.com/multiversx/mx-chain-go/testscommon/common" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" testStorage "github.com/multiversx/mx-chain-go/testscommon/state" - testcommonStorage "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/trie" + "github.com/multiversx/mx-chain-go/trie/collapseManager" ) // TestMarshalizer - @@ -30,9 +32,6 @@ var TestMarshalizer = &marshal.GogoProtoMarshalizer{} // TestHasher - var TestHasher = sha256.NewSha256() -// MaxTrieLevelInMemory - -const MaxTrieLevelInMemory = uint(5) - // CreateMemUnit - func CreateMemUnit() storage.Storer { capacity := uint32(10) @@ -94,21 +93,23 @@ func CreateAccountsDB(db storage.Storer, enableEpochs common.EnableEpochsHandler } ewl, _ := evictionWaitingList.NewMemoryEvictionWaitingList(ewlArgs) - args := testcommonStorage.GetStorageManagerArgs() + args := testCommon.GetStorageManagerArgs() args.MainStorer = db args.Marshalizer = TestMarshalizer args.Hasher = TestHasher trieStorage, _ := trie.NewTrieStorageManager(args) - tr, _ := trie.NewTrie(trieStorage, TestMarshalizer, TestHasher, enableEpochs, MaxTrieLevelInMemory) + tr, _ := trie.NewTrie(trieStorage, TestMarshalizer, TestHasher, enableEpochs, collapseManager.NewDisabledCollapseManager()) spm, _ := storagePruningManager.NewStoragePruningManager(ewl, 10) - + dth, _ := triesHolder.NewDataTriesHolder(common.TenMbSize) argsAccCreator := accountFactory.ArgsAccountCreator{ Hasher: TestHasher, Marshaller: TestMarshalizer, EnableEpochsHandler: enableEpochs, StateAccessesCollector: disabled.NewDisabledStateAccessesCollector(), + DataTriesHolder: dth, + DataTrieCreator: tr, } accCreator, _ := accountFactory.NewAccountCreator(argsAccCreator) @@ -133,6 +134,7 @@ func CreateAccountsDB(db storage.Storer, enableEpochs common.EnableEpochsHandler AddressConverter: &testscommon.PubkeyConverterMock{}, SnapshotsManager: snapshotsManager, StateAccessesCollector: disabled.NewDisabledStateAccessesCollector(), + DataTriesHolder: dth, } adb, _ := state.NewAccountsDB(argsAccountsDB) diff --git a/testscommon/state/accountAdapterStub.go b/testscommon/state/accountAdapterStub.go index fa9305f8222..59a77804b7a 100644 --- a/testscommon/state/accountAdapterStub.go +++ b/testscommon/state/accountAdapterStub.go @@ -19,7 +19,6 @@ type StateUserAccountHandlerStub struct { GetCodeHashCalled func() []byte SetRootHashCalled func([]byte) GetRootHashCalled func() []byte - SetDataTrieCalled func(trie common.Trie) DataTrieCalled func() common.DataTrieHandler RetrieveValueCalled func(key []byte) ([]byte, uint32, error) SaveKeyValueCalled func(key []byte, value []byte) error @@ -114,13 +113,6 @@ func (aas *StateUserAccountHandlerStub) GetRootHash() []byte { return nil } -// SetDataTrie - -func (aas *StateUserAccountHandlerStub) SetDataTrie(trie common.Trie) { - if aas.SetDataTrieCalled != nil { - aas.SetDataTrieCalled(trie) - } -} - // DataTrie - func (aas *StateUserAccountHandlerStub) DataTrie() common.DataTrieHandler { if aas.DataTrieCalled != nil { diff --git a/testscommon/state/accountFactoryStub.go b/testscommon/state/accountFactoryStub.go index ce2a41259fe..a60c7d13442 100644 --- a/testscommon/state/accountFactoryStub.go +++ b/testscommon/state/accountFactoryStub.go @@ -13,7 +13,10 @@ type AccountsFactoryStub struct { // CreateAccount - func (afs *AccountsFactoryStub) CreateAccount(address []byte) (vmcommon.AccountHandler, error) { - return afs.CreateAccountCalled(address) + if afs.CreateAccountCalled != nil { + return afs.CreateAccountCalled(address) + } + return nil, nil } // IsInterfaceNil returns true if there is no value under the interface diff --git a/testscommon/state/accountWrapperMock.go b/testscommon/state/accountWrapperMock.go index 557b36b601c..6e758a61467 100644 --- a/testscommon/state/accountWrapperMock.go +++ b/testscommon/state/accountWrapperMock.go @@ -8,6 +8,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/stateChange" + "github.com/multiversx/mx-chain-go/testscommon/trie" vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/multiversx/mx-chain-go/common" @@ -45,13 +46,16 @@ var errInsufficientBalance = fmt.Errorf("insufficient balance") // NewAccountWrapMock - func NewAccountWrapMock(adr []byte) *AccountWrapMock { - tdt, _ := trackableDataTrie.NewTrackableDataTrie( - []byte("identifier"), - &hashingMocks.HasherMock{}, - &marshallerMock.MarshalizerMock{}, - &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - disabled.NewDisabledStateAccessesCollector(), - ) + args := trackableDataTrie.TrackableDataTrieArgs{ + Identifier: []byte("identifier"), + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + StateAccessesCollector: disabled.NewDisabledStateAccessesCollector(), + DataTriesHolder: &trie.TriesHolderStub{}, + DataTrieCreator: &trie.TrieStub{}, + } + tdt, _ := trackableDataTrie.NewTrackableDataTrie(args) return &AccountWrapMock{ address: adr, @@ -60,6 +64,28 @@ func NewAccountWrapMock(adr []byte) *AccountWrapMock { } } +// NewAccountWrapMockWithDataTrieHolder - +func NewAccountWrapMockWithDataTrieHolder( + dataTriesHolder common.TriesHolder, +) *AccountWrapMock { + args := trackableDataTrie.TrackableDataTrieArgs{ + Identifier: []byte("identifier"), + Hasher: &hashingMocks.HasherMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + StateAccessesCollector: disabled.NewDisabledStateAccessesCollector(), + DataTriesHolder: dataTriesHolder, + DataTrieCreator: &trie.TrieStub{}, + } + tdt, _ := trackableDataTrie.NewTrackableDataTrie(args) + + return &AccountWrapMock{ + address: args.Identifier, + trackableDataTrie: tdt, + Balance: big.NewInt(0), + } +} + // SetTrackableDataTrie - func (awm *AccountWrapMock) SetTrackableDataTrie(tdt state.DataTrieTracker) { awm.trackableDataTrie = tdt @@ -187,6 +213,12 @@ func (awm *AccountWrapMock) GetRootHash() []byte { // SetRootHash - func (awm *AccountWrapMock) SetRootHash(rootHash []byte) { awm.RootHash = rootHash + awm.trackableDataTrie.SetRootHash(rootHash) +} + +// SetDataTrieRootHash - +func (awm *AccountWrapMock) SetDataTrieRootHash() { + awm.trackableDataTrie.SetRootHash(awm.RootHash) } // AddressBytes - @@ -204,11 +236,6 @@ func (awm *AccountWrapMock) SaveDirtyData(trie common.Trie) ([]*stateChange.Data return awm.trackableDataTrie.SaveDirtyData(trie) } -// SetDataTrie - -func (awm *AccountWrapMock) SetDataTrie(trie common.Trie) { - awm.trackableDataTrie.SetDataTrie(trie) -} - // IncreaseNonce adds the given value to the current nonce func (awm *AccountWrapMock) IncreaseNonce(val uint64) { awm.nonce = awm.nonce + val diff --git a/testscommon/state/accountsAdapterStub.go b/testscommon/state/accountsAdapterStub.go index 60e8898b5e4..a3400538341 100644 --- a/testscommon/state/accountsAdapterStub.go +++ b/testscommon/state/accountsAdapterStub.go @@ -45,6 +45,7 @@ type AccountsStub struct { SetSyncerCalled func(syncer state.AccountsDBSyncer) error StartSnapshotIfNeededCalled func() error SetTxHashForLatestStateAccessesCalled func(txHash []byte) + GetAccountsFactoryCalled func() state.AccountFactory } // CleanCache - @@ -277,6 +278,14 @@ func (as *AccountsStub) SetTxHashForLatestStateAccesses(txHash []byte) { } } +// GetAccountsFactory - +func (as *AccountsStub) GetAccountsFactory() state.AccountFactory { + if as.GetAccountsFactoryCalled != nil { + return as.GetAccountsFactoryCalled() + } + return nil +} + // Close - func (as *AccountsStub) Close() error { if as.CloseCalled != nil { diff --git a/testscommon/state/testTrie.go b/testscommon/state/testTrie.go index 8744009aa18..bed48710ed7 100644 --- a/testscommon/state/testTrie.go +++ b/testscommon/state/testTrie.go @@ -14,6 +14,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/trie" + "github.com/multiversx/mx-chain-go/trie/collapseManager" ) // GetDefaultTrieParameters - @@ -40,7 +41,7 @@ func GetDefaultTrieParameters() (common.StorageManager, marshal.Marshalizer, has // GetNewTrie - func GetNewTrie() common.Trie { tsm, marshaller, hasher := GetDefaultTrieParameters() - tr, _ := trie.NewTrie(tsm, marshaller, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 5) + tr, _ := trie.NewTrie(tsm, marshaller, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, collapseManager.NewDisabledCollapseManager()) return tr } diff --git a/testscommon/state/userAccountStub.go b/testscommon/state/userAccountStub.go index 22f28958c6a..010eee7927a 100644 --- a/testscommon/state/userAccountStub.go +++ b/testscommon/state/userAccountStub.go @@ -31,7 +31,7 @@ type UserAccountStub struct { IsGuardedCalled func() bool AccountDataHandlerCalled func() vmcommon.AccountDataHandler RetrieveValueCalled func(_ []byte) ([]byte, uint32, error) - SetDataTrieCalled func(dataTrie common.Trie) + SetRootHashCalled func(rootHash []byte) GetRootHashCalled func() []byte SaveKeyValueCalled func(key []byte, value []byte) error } @@ -145,6 +145,10 @@ func (u *UserAccountStub) SetRootHash([]byte) { } +// SetDataTrieRootHash - +func (u *UserAccountStub) SetDataTrieRootHash() { +} + // GetRootHash - func (u *UserAccountStub) GetRootHash() []byte { if u.GetRootHashCalled != nil { @@ -155,9 +159,9 @@ func (u *UserAccountStub) GetRootHash() []byte { } // SetDataTrie - -func (u *UserAccountStub) SetDataTrie(dataTrie common.Trie) { - if u.SetDataTrieCalled != nil { - u.SetDataTrieCalled(dataTrie) +func (u *UserAccountStub) SetDataTrie(rootHash []byte) { + if u.SetRootHashCalled != nil { + u.SetRootHashCalled(rootHash) } } diff --git a/testscommon/trie/snapshotPruningStorerStub.go b/testscommon/storage/snapshotPruningStorerStub.go similarity index 99% rename from testscommon/trie/snapshotPruningStorerStub.go rename to testscommon/storage/snapshotPruningStorerStub.go index 8de709ab2cd..f6186eb49dc 100644 --- a/testscommon/trie/snapshotPruningStorerStub.go +++ b/testscommon/storage/snapshotPruningStorerStub.go @@ -1,4 +1,4 @@ -package trie +package storage import ( "github.com/multiversx/mx-chain-core-go/core" diff --git a/testscommon/trie/dataTrieTrackerStub.go b/testscommon/trie/dataTrieTrackerStub.go index 73c083b30c1..e76886de29b 100644 --- a/testscommon/trie/dataTrieTrackerStub.go +++ b/testscommon/trie/dataTrieTrackerStub.go @@ -15,7 +15,7 @@ type DataTrieTrackerStub struct { RetrieveValueCalled func(key []byte) ([]byte, uint32, error) SaveKeyValueCalled func(key []byte, value []byte) error - SetDataTrieCalled func(tr common.Trie) + SetRootHashCalled func(rootHash []byte) DataTrieCalled func() common.Trie SaveDirtyDataCalled func(trie common.Trie) ([]*stateChange.DataTrieChange, []core.TrieData, error) SaveTrieDataCalled func(trieData core.TrieData) error @@ -40,13 +40,11 @@ func (dtts *DataTrieTrackerStub) SaveKeyValue(key []byte, value []byte) error { return nil } -// SetDataTrie - -func (dtts *DataTrieTrackerStub) SetDataTrie(tr common.Trie) { - if dtts.SetDataTrieCalled != nil { - dtts.SetDataTrieCalled(tr) +// SetRootHash - +func (dtts *DataTrieTrackerStub) SetRootHash(rootHash []byte) { + if dtts.SetRootHashCalled != nil { + dtts.SetRootHashCalled(rootHash) } - - dtts.dataTrie = tr } // DataTrie - diff --git a/testscommon/trie/trieStub.go b/testscommon/trie/trieStub.go index 30e0ba6066e..cd078b5a428 100644 --- a/testscommon/trie/trieStub.go +++ b/testscommon/trie/trieStub.go @@ -33,6 +33,7 @@ type TrieStub struct { CloseCalled func() error CollectLeavesForMigrationCalled func(args vmcommon.ArgsMigrateDataTrieLeaves) error IsMigratedToLatestVersionCalled func() (bool, error) + SizeInMemoryCalled func() int } // GetStorageManager - @@ -215,6 +216,14 @@ func (ts *TrieStub) IsMigratedToLatestVersion() (bool, error) { return false, nil } +// SizeInMemory - +func (ts *TrieStub) SizeInMemory() int { + if ts.SizeInMemoryCalled != nil { + return ts.SizeInMemoryCalled() + } + return 0 +} + // Close - func (ts *TrieStub) Close() error { if ts.CloseCalled != nil { diff --git a/testscommon/trie/triesHolderStub.go b/testscommon/trie/triesHolderStub.go index 42eab41d7f5..3bd234844ec 100644 --- a/testscommon/trie/triesHolderStub.go +++ b/testscommon/trie/triesHolderStub.go @@ -6,11 +6,12 @@ import ( // TriesHolderStub - type TriesHolderStub struct { - PutCalled func([]byte, common.Trie) - RemoveCalled func([]byte, common.Trie) - GetCalled func([]byte) common.Trie - GetAllCalled func() []common.Trie - ResetCalled func() + PutCalled func([]byte, common.Trie) + RemoveCalled func([]byte) + GetCalled func([]byte) common.Trie + GetAllCalled func() []common.Trie + ResetCalled func() + MarkAsDirtyCalled func([]byte) } // Put - @@ -20,10 +21,10 @@ func (ths *TriesHolderStub) Put(key []byte, trie common.Trie) { } } -// Replace - -func (ths *TriesHolderStub) Replace(key []byte, trie common.Trie) { +// Remove - +func (ths *TriesHolderStub) Remove(key []byte) { if ths.RemoveCalled != nil { - ths.RemoveCalled(key, trie) + ths.RemoveCalled(key) } } @@ -50,6 +51,13 @@ func (ths *TriesHolderStub) Reset() { } } +// MarkAsDirty - +func (ths *TriesHolderStub) MarkAsDirty(key []byte) { + if ths.MarkAsDirtyCalled != nil { + ths.MarkAsDirtyCalled(key) + } +} + // IsInterfaceNil returns true if there is no value under the interface func (ths *TriesHolderStub) IsInterfaceNil() bool { return ths == nil diff --git a/trie/baseIterator.go b/trie/baseIterator.go index 8ff558790d8..5f34106ab92 100644 --- a/trie/baseIterator.go +++ b/trie/baseIterator.go @@ -3,11 +3,13 @@ package trie import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/trie/trieMetricsCollector" ) type baseIterator struct { currentNode node nextNodes []node + tmc MetricsCollector db common.TrieStorageInteractor } @@ -23,7 +25,8 @@ func newBaseIterator(trie common.Trie) (*baseIterator, error) { } trieStorage := trie.GetStorageManager() - nextNodes, err := pmt.root.getChildren(trieStorage) + tmc := trieMetricsCollector.NewDisabledTrieMetricsCollector() + nextNodes, err := pmt.root.getChildren(tmc, trieStorage) if err != nil { return nil, err } @@ -32,6 +35,7 @@ func newBaseIterator(trie common.Trie) (*baseIterator, error) { currentNode: pmt.root, nextNodes: nextNodes, db: trieStorage, + tmc: tmc, }, nil } @@ -50,7 +54,7 @@ func (it *baseIterator) next() ([]node, error) { } it.currentNode = n - return it.currentNode.getChildren(it.db) + return it.currentNode.getChildren(it.tmc, it.db) } // MarshalizedNode marshalizes the current node, and then returns the serialized node diff --git a/trie/branchNode.go b/trie/branchNode.go index 8e3584d0589..96dc5414270 100644 --- a/trie/branchNode.go +++ b/trie/branchNode.go @@ -252,8 +252,7 @@ func (bn *branchNode) hashNode() ([]byte, error) { return encodeNodeAndGetHash(bn) } -func (bn *branchNode) commitDirty(level byte, maxTrieLevelInMemory uint, originDb common.TrieStorageInteractor, targetDb common.BaseStorer) error { - level++ +func (bn *branchNode) commitDirty(originDb common.TrieStorageInteractor, targetDb common.BaseStorer) error { err := bn.isEmptyOrNil() if err != nil { return fmt.Errorf("commit error %w", err) @@ -268,7 +267,7 @@ func (bn *branchNode) commitDirty(level byte, maxTrieLevelInMemory uint, originD continue } - err = bn.children[i].commitDirty(level, maxTrieLevelInMemory, originDb, targetDb) + err = bn.children[i].commitDirty(originDb, targetDb) if err != nil { return err } @@ -278,18 +277,32 @@ func (bn *branchNode) commitDirty(level byte, maxTrieLevelInMemory uint, originD if err != nil { return err } - if uint(level) == maxTrieLevelInMemory { - log.Trace("collapse branch node on commit") - var collapsedBn *branchNode - collapsedBn, err = bn.getCollapsedBn() - if err != nil { - return err - } + return nil +} - *bn = *collapsedBn +func (bn *branchNode) shouldCollapseChild(hexKey []byte, tmc MetricsCollector) bool { + if len(hexKey) == 0 { + return false } - return nil + childPos := hexKey[firstByte] + if childPosOutOfRange(childPos) { + return false + } + hexKey = hexKey[1:] + + if bn.children[childPos] == nil { + return false + } + + shouldCollapseChild := bn.children[childPos].shouldCollapseChild(hexKey, tmc) + if shouldCollapseChild { + tmc.AddSizeLoadedInMem(-bn.children[childPos].sizeInBytes()) + bn.children[childPos] = nil + return false + } + + return false } // TODO refactor long parameter list @@ -302,12 +315,16 @@ func (bn *branchNode) commitSnapshot( stats common.TrieStatisticsHandler, idleProvider IdleNodeProvider, nodeBytes []byte, - depthLevel int, + tmc MetricsCollector, ) error { if shouldStopIfContextDoneBlockingIfBusy(ctx, idleProvider) { return core.ErrContextClosing } + depthLevel := tmc.GetCurrentDepth() + tmc.SetDepth(depthLevel + 1) + defer tmc.SetDepth(depthLevel) // Reset depth when returning from this level + for i := range bn.EncodedChildren { if len(bn.EncodedChildren[i]) == 0 { continue @@ -323,7 +340,7 @@ func (bn *branchNode) commitSnapshot( ctx, stats, idleProvider, - depthLevel, + tmc, bn.EncodedChildren[i], ) if err != nil { @@ -331,7 +348,7 @@ func (bn *branchNode) commitSnapshot( } } - stats.AddBranchNode(depthLevel, uint64(len(nodeBytes))) + stats.AddBranchNode(int(depthLevel), uint64(len(nodeBytes))) return nil } @@ -348,7 +365,7 @@ func (bn *branchNode) getEncodedNode() ([]byte, error) { return marshaledNode, nil } -func (bn *branchNode) resolveCollapsed(pos byte, db common.TrieStorageInteractor) error { +func (bn *branchNode) resolveCollapsed(pos byte, tmc MetricsCollector, db common.TrieStorageInteractor) error { err := bn.isEmptyOrNil() if err != nil { return fmt.Errorf("resolveCollapsed error %w", err) @@ -363,6 +380,7 @@ func (bn *branchNode) resolveCollapsed(pos byte, db common.TrieStorageInteractor return err } child.setGivenHash(bn.EncodedChildren[pos]) + tmc.AddSizeLoadedInMem(child.sizeInBytes()) bn.children[pos] = child } return nil @@ -381,31 +399,32 @@ func (bn *branchNode) isPosCollapsed(pos int) bool { return bn.children[pos] == nil && len(bn.EncodedChildren[pos]) != 0 } -func (bn *branchNode) tryGet(key []byte, currentDepth uint32, db common.TrieStorageInteractor) (value []byte, maxDepth uint32, err error) { +func (bn *branchNode) tryGet(key []byte, tmc MetricsCollector, db common.TrieStorageInteractor) (value []byte, err error) { err = bn.isEmptyOrNil() if err != nil { - return nil, currentDepth, fmt.Errorf("tryGet error %w", err) + return nil, fmt.Errorf("tryGet error %w", err) } if len(key) == 0 { - return nil, currentDepth, nil + return nil, nil } childPos := key[firstByte] if childPosOutOfRange(childPos) { - return nil, currentDepth, ErrChildPosOutOfRange + return nil, ErrChildPosOutOfRange } key = key[1:] - err = resolveIfCollapsed(bn, childPos, db) + err = resolveIfCollapsed(bn, childPos, tmc, db) if err != nil { - return nil, currentDepth, err + return nil, err } if bn.children[childPos] == nil { - return nil, currentDepth, nil + return nil, nil } - return bn.children[childPos].tryGet(key, currentDepth+1, db) + tmc.SetDepth(tmc.GetCurrentDepth() + 1) + return bn.children[childPos].tryGet(key, tmc, db) } -func (bn *branchNode) getNext(key []byte, db common.TrieStorageInteractor) (node, []byte, error) { +func (bn *branchNode) getNext(key []byte, tmc MetricsCollector, db common.TrieStorageInteractor) (node, []byte, error) { err := bn.isEmptyOrNil() if err != nil { return nil, nil, fmt.Errorf("getNext error %w", err) @@ -418,7 +437,7 @@ func (bn *branchNode) getNext(key []byte, db common.TrieStorageInteractor) (node return nil, nil, ErrChildPosOutOfRange } key = key[1:] - err = resolveIfCollapsed(bn, childPos, db) + err = resolveIfCollapsed(bn, childPos, tmc, db) if err != nil { return nil, nil, err } @@ -429,7 +448,7 @@ func (bn *branchNode) getNext(key []byte, db common.TrieStorageInteractor) (node return bn.children[childPos], key, nil } -func (bn *branchNode) insert(newData core.TrieData, db common.TrieStorageInteractor) (node, [][]byte, error) { +func (bn *branchNode) insert(newData core.TrieData, tmc MetricsCollector, db common.TrieStorageInteractor) (node, [][]byte, error) { emptyHashes := make([][]byte, 0) err := bn.isEmptyOrNil() if err != nil { @@ -445,19 +464,19 @@ func (bn *branchNode) insert(newData core.TrieData, db common.TrieStorageInterac } newData.Key = newData.Key[1:] - err = resolveIfCollapsed(bn, childPos, db) + err = resolveIfCollapsed(bn, childPos, tmc, db) if err != nil { return nil, emptyHashes, err } if bn.children[childPos] == nil { - return bn.insertOnNilChild(newData, childPos) + return bn.insertOnNilChild(newData, tmc, childPos) } - return bn.insertOnExistingChild(newData, childPos, db) + return bn.insertOnExistingChild(newData, tmc, childPos, db) } -func (bn *branchNode) insertOnNilChild(newData core.TrieData, childPos byte) (node, [][]byte, error) { +func (bn *branchNode) insertOnNilChild(newData core.TrieData, tmc MetricsCollector, childPos byte) (node, [][]byte, error) { newLn, err := newLeafNode(newData, bn.marsh, bn.hasher) if err != nil { return nil, [][]byte{}, err @@ -468,12 +487,13 @@ func (bn *branchNode) insertOnNilChild(newData core.TrieData, childPos byte) (no if err != nil { return nil, [][]byte{}, err } + tmc.AddSizeLoadedInMem(newLn.sizeInBytes()) return bn, modifiedHashes, nil } -func (bn *branchNode) insertOnExistingChild(newData core.TrieData, childPos byte, db common.TrieStorageInteractor) (node, [][]byte, error) { - newNode, modifiedHashes, err := bn.children[childPos].insert(newData, db) +func (bn *branchNode) insertOnExistingChild(newData core.TrieData, tmc MetricsCollector, childPos byte, db common.TrieStorageInteractor) (node, [][]byte, error) { + newNode, modifiedHashes, err := bn.children[childPos].insert(newData, tmc, db) if check.IfNil(newNode) || err != nil { return nil, [][]byte{}, err } @@ -504,7 +524,7 @@ func (bn *branchNode) modifyNodeAfterInsert(modifiedHashes [][]byte, childPos by return modifiedHashes, nil } -func (bn *branchNode) delete(key []byte, db common.TrieStorageInteractor) (bool, node, [][]byte, error) { +func (bn *branchNode) delete(key []byte, tmc MetricsCollector, db common.TrieStorageInteractor) (bool, node, [][]byte, error) { emptyHashes := make([][]byte, 0) err := bn.isEmptyOrNil() if err != nil { @@ -518,7 +538,7 @@ func (bn *branchNode) delete(key []byte, db common.TrieStorageInteractor) (bool, return false, nil, emptyHashes, ErrChildPosOutOfRange } key = key[1:] - err = resolveIfCollapsed(bn, childPos, db) + err = resolveIfCollapsed(bn, childPos, tmc, db) if err != nil { return false, nil, emptyHashes, err } @@ -527,7 +547,7 @@ func (bn *branchNode) delete(key []byte, db common.TrieStorageInteractor) (bool, return false, bn, emptyHashes, nil } - dirty, newNode, oldHashes, err := bn.children[childPos].delete(key, db) + dirty, newNode, oldHashes, err := bn.children[childPos].delete(key, tmc, db) if !dirty || err != nil { return false, bn, emptyHashes, err } @@ -536,7 +556,7 @@ func (bn *branchNode) delete(key []byte, db common.TrieStorageInteractor) (bool, oldHashes = append(oldHashes, bn.hash) } - err = bn.setNewChild(childPos, newNode) + err = bn.setNewChild(childPos, newNode, tmc) if err != nil { return false, nil, emptyHashes, err } @@ -544,18 +564,18 @@ func (bn *branchNode) delete(key []byte, db common.TrieStorageInteractor) (bool, numChildren, pos := getChildPosition(bn) if numChildren == 1 { - err = resolveIfCollapsed(bn, byte(pos), db) + err = resolveIfCollapsed(bn, byte(pos), tmc, db) if err != nil { return false, nil, emptyHashes, err } - err = resolveIfCollapsed(bn.children[pos], byte(pos), db) + err = resolveIfCollapsed(bn.children[pos], byte(pos), tmc, db) if err != nil { return false, nil, emptyHashes, err } var newChildHash bool - newNode, newChildHash, err = bn.children[pos].reduceNode(pos) + newNode, newChildHash, err = bn.children[pos].reduceNode(pos, tmc) if err != nil { return false, nil, emptyHashes, err } @@ -564,6 +584,7 @@ func (bn *branchNode) delete(key []byte, db common.TrieStorageInteractor) (bool, oldHashes = append(oldHashes, bn.children[pos].getHash()) } + tmc.AddSizeLoadedInMem(-bn.sizeInBytes()) return true, newNode, oldHashes, nil } @@ -572,12 +593,13 @@ func (bn *branchNode) delete(key []byte, db common.TrieStorageInteractor) (bool, return true, bn, oldHashes, nil } -func (bn *branchNode) setNewChild(childPos byte, newNode node) error { +func (bn *branchNode) setNewChild(childPos byte, newNode node, tmc MetricsCollector) error { bn.hash = nil bn.children[childPos] = newNode if check.IfNil(newNode) { bn.setVersionForChild(core.NotSpecified, childPos) bn.EncodedChildren[childPos] = nil + tmc.AddSizeLoadedInMem(-hashSizeInBytes) return nil } @@ -602,11 +624,12 @@ func (bn *branchNode) revertChildrenVersionSliceIfNeeded() { bn.ChildrenVersion = []byte(nil) } -func (bn *branchNode) reduceNode(pos int) (node, bool, error) { +func (bn *branchNode) reduceNode(pos int, tmc MetricsCollector) (node, bool, error) { newEn, err := newExtensionNode([]byte{byte(pos)}, bn, bn.marsh, bn.hasher) if err != nil { return nil, false, err } + tmc.AddSizeLoadedInMem(newEn.sizeInBytes()) return newEn, false, nil } @@ -638,7 +661,7 @@ func (bn *branchNode) isEmptyOrNil() error { return ErrEmptyBranchNode } -func (bn *branchNode) print(writer io.Writer, index int, db common.TrieStorageInteractor) { +func (bn *branchNode) print(writer io.Writer, index int, tmc MetricsCollector, db common.TrieStorageInteractor) { if bn == nil { return } @@ -646,7 +669,7 @@ func (bn *branchNode) print(writer io.Writer, index int, db common.TrieStorageIn str := fmt.Sprintf("B: %v - %v", hex.EncodeToString(bn.hash), bn.dirty) _, _ = fmt.Fprintln(writer, str) for i := 0; i < len(bn.children); i++ { - err := resolveIfCollapsed(bn, byte(i), db) + err := resolveIfCollapsed(bn, byte(i), tmc, db) if err != nil { log.Debug("branch node: print trie err", "error", err, "hash", bn.EncodedChildren[i]) } @@ -662,7 +685,7 @@ func (bn *branchNode) print(writer io.Writer, index int, db common.TrieStorageIn str2 := fmt.Sprintf("+ %d: ", i) _, _ = fmt.Fprint(writer, str2) childIndex := index + len(str) - 1 + len(str2) - child.print(writer, childIndex, db) + child.print(writer, childIndex, tmc, db) } } @@ -691,7 +714,7 @@ func (bn *branchNode) getDirtyHashes(hashes common.ModifiedHashes) error { return nil } -func (bn *branchNode) getChildren(db common.TrieStorageInteractor) ([]node, error) { +func (bn *branchNode) getChildren(tmc MetricsCollector, db common.TrieStorageInteractor) ([]node, error) { err := bn.isEmptyOrNil() if err != nil { return nil, fmt.Errorf("getChildren error %w", err) @@ -700,7 +723,7 @@ func (bn *branchNode) getChildren(db common.TrieStorageInteractor) ([]node, erro nextNodes := make([]node, 0) for i := range bn.children { - err = resolveIfCollapsed(bn, byte(i), db) + err = resolveIfCollapsed(bn, byte(i), tmc, db) if err != nil { return nil, err } @@ -766,6 +789,7 @@ func (bn *branchNode) getAllLeavesOnChannel( marshalizer marshal.Marshalizer, chanClose chan struct{}, ctx context.Context, + tmc MetricsCollector, ) error { err := bn.isEmptyOrNil() if err != nil { @@ -781,7 +805,7 @@ func (bn *branchNode) getAllLeavesOnChannel( log.Trace("branchNode.getAllLeavesOnChannel context done") return nil default: - err = resolveIfCollapsed(bn, byte(i), db) + err = resolveIfCollapsed(bn, byte(i), tmc, db) if err != nil { return err } @@ -792,7 +816,7 @@ func (bn *branchNode) getAllLeavesOnChannel( clonedKeyBuilder := keyBuilder.ShallowClone() clonedKeyBuilder.BuildKey([]byte{byte(i)}) - err = bn.children[i].getAllLeavesOnChannel(leavesChannel, clonedKeyBuilder, trieLeafParser, db, marshalizer, chanClose, ctx) + err = bn.children[i].getAllLeavesOnChannel(leavesChannel, clonedKeyBuilder, trieLeafParser, db, marshalizer, chanClose, ctx, tmc) if err != nil { return err } @@ -804,7 +828,7 @@ func (bn *branchNode) getAllLeavesOnChannel( return nil } -func (bn *branchNode) getAllHashes(db common.TrieStorageInteractor) ([][]byte, error) { +func (bn *branchNode) getAllHashes(tmc MetricsCollector, db common.TrieStorageInteractor) ([][]byte, error) { err := bn.isEmptyOrNil() if err != nil { return nil, fmt.Errorf("getAllHashes error: %w", err) @@ -813,7 +837,7 @@ func (bn *branchNode) getAllHashes(db common.TrieStorageInteractor) ([][]byte, e var childrenHashes [][]byte hashes := make([][]byte, 0) for i := range bn.children { - err = resolveIfCollapsed(bn, byte(i), db) + err = resolveIfCollapsed(bn, byte(i), tmc, db) if err != nil { return nil, err } @@ -822,7 +846,7 @@ func (bn *branchNode) getAllHashes(db common.TrieStorageInteractor) ([][]byte, e continue } - childrenHashes, err = bn.children[i].getAllHashes(db) + childrenHashes, err = bn.children[i].getAllHashes(tmc, db) if err != nil { return nil, err } @@ -851,12 +875,15 @@ func (bn *branchNode) sizeInBytes() int { return 0 } - // hasher + marshalizer + dirty flag = numNodeInnerPointers * pointerSizeInBytes + 1 - nodeSize := len(bn.hash) + numNodeInnerPointers*pointerSizeInBytes + 1 - for _, collapsed := range bn.EncodedChildren { - nodeSize += len(collapsed) + nodeSize := baseNodeSizeInBytes + bnChildrenPointersSize + numChildren := 0 + for i := range bn.EncodedChildren { + if bn.children[i] != nil || len(bn.EncodedChildren[i]) != 0 { + numChildren++ + } } - nodeSize += len(bn.children) * pointerSizeInBytes + nodeSize += numChildren * hashSizeInBytes + nodeSize += len(bn.ChildrenVersion) return nodeSize } @@ -865,14 +892,18 @@ func (bn *branchNode) getValue() []byte { return []byte{} } -func (bn *branchNode) collectStats(ts common.TrieStatisticsHandler, depthLevel int, db common.TrieStorageInteractor) error { +func (bn *branchNode) collectStats(ts common.TrieStatisticsHandler, tmc MetricsCollector, db common.TrieStorageInteractor) error { err := bn.isEmptyOrNil() if err != nil { return fmt.Errorf("collectStats error %w", err) } + depthLevel := tmc.GetCurrentDepth() + tmc.SetDepth(depthLevel + 1) + defer tmc.SetDepth(depthLevel) // Reset depth when returning from this level + for i := range bn.children { - err = resolveIfCollapsed(bn, byte(i), db) + err = resolveIfCollapsed(bn, byte(i), tmc, db) if err != nil { return err } @@ -881,7 +912,7 @@ func (bn *branchNode) collectStats(ts common.TrieStatisticsHandler, depthLevel i continue } - err = bn.children[i].collectStats(ts, depthLevel+1, db) + err = bn.children[i].collectStats(ts, tmc, db) if err != nil { return err } @@ -892,7 +923,7 @@ func (bn *branchNode) collectStats(ts common.TrieStatisticsHandler, depthLevel i return err } - ts.AddBranchNode(depthLevel, uint64(len(val))) + ts.AddBranchNode(int(depthLevel), uint64(len(val))) return nil } @@ -936,6 +967,7 @@ func (bn *branchNode) getVersionForChild(childIndex byte) core.TrieNodeVersion { func (bn *branchNode) collectLeavesForMigration( migrationArgs vmcommon.ArgsMigrateDataTrieLeaves, + tmc MetricsCollector, db common.TrieStorageInteractor, keyBuilder common.KeyBuilder, ) (bool, error) { @@ -961,14 +993,14 @@ func (bn *branchNode) collectLeavesForMigration( continue } - err = resolveIfCollapsed(bn, byte(i), db) + err = resolveIfCollapsed(bn, byte(i), tmc, db) if err != nil { return false, err } clonedKeyBuilder := keyBuilder.ShallowClone() clonedKeyBuilder.BuildKey([]byte{byte(i)}) - shouldContinueMigrating, err := bn.children[i].collectLeavesForMigration(migrationArgs, db, clonedKeyBuilder) + shouldContinueMigrating, err := bn.children[i].collectLeavesForMigration(migrationArgs, tmc, db, clonedKeyBuilder) if err != nil { return false, err } diff --git a/trie/branchNode_test.go b/trie/branchNode_test.go index 4cea8910bad..9edea9ac49b 100644 --- a/trie/branchNode_test.go +++ b/trie/branchNode_test.go @@ -16,12 +16,16 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" + "github.com/multiversx/mx-chain-go/trie/collapseManager" "github.com/multiversx/mx-chain-go/trie/keyBuilder" "github.com/multiversx/mx-chain-go/trie/mock" "github.com/multiversx/mx-chain-go/trie/statistics" + "github.com/multiversx/mx-chain-go/trie/trieMetricsCollector" "github.com/stretchr/testify/assert" ) +var dtmc = trieMetricsCollector.NewDisabledTrieMetricsCollector() + func getTestMarshalizerAndHasher() (marshal.Marshalizer, hashing.Hasher) { marsh := &marshal.GogoProtoMarshalizer{} hash := &testscommon.KeccakMock{} @@ -55,6 +59,15 @@ func getBnAndCollapsedBn(marshalizer marshal.Marshalizer, hasher hashing.Hasher) return bn, collapsedBn } +func markNotDirtyBranchNode(bn *branchNode) { + bn.dirty = false + for i := range bn.children { + if bn.children[i] != nil { + bn.children[i].setDirty(false) + } + } +} + func emptyDirtyBranchNode() *branchNode { var children [nrOfChildren]node encChildren := make([][]byte, nrOfChildren) @@ -76,14 +89,14 @@ func newEmptyTrie() (*patriciaMerkleTrie, *trieStorageManager) { args := GetDefaultTrieStorageManagerParameters() trieStorage, _ := NewTrieStorageManager(args) tr := &patriciaMerkleTrie{ - trieStorage: trieStorage, - marshalizer: args.Marshalizer, - hasher: args.Hasher, - oldHashes: make([][]byte, 0), - oldRoot: make([]byte, 0), - maxTrieLevelInMemory: 5, - chanClose: make(chan struct{}), - enableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + trieStorage: trieStorage, + marshalizer: args.Marshalizer, + hasher: args.Hasher, + oldHashes: make([][]byte, 0), + oldRoot: make([]byte, 0), + chanClose: make(chan struct{}), + enableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + collapseManager: collapseManager.NewDisabledCollapseManager(), } return tr, trieStorage @@ -197,10 +210,9 @@ func TestBranchNode_setRootHash(t *testing.T) { trieStorage1, _ := NewTrieStorageManager(GetDefaultTrieStorageManagerParameters()) trieStorage2, _ := NewTrieStorageManager(GetDefaultTrieStorageManagerParameters()) - maxTrieLevelInMemory := uint(5) - tr1, _ := NewTrie(trieStorage1, marsh, hsh, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, maxTrieLevelInMemory) - tr2, _ := NewTrie(trieStorage2, marsh, hsh, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, maxTrieLevelInMemory) + tr1, _ := NewTrie(trieStorage1, marsh, hsh, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, collapseManager.NewDisabledCollapseManager()) + tr2, _ := NewTrie(trieStorage2, marsh, hsh, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, collapseManager.NewDisabledCollapseManager()) maxIterations := 10000 for i := 0; i < maxIterations; i++ { @@ -359,7 +371,7 @@ func TestBranchNode_commit(t *testing.T) { hash, _ := encodeNodeAndGetHash(collapsedBn) _ = bn.setHash() - err := bn.commitDirty(0, 5, db, db) + err := bn.commitDirty(db, db) assert.Nil(t, err) encNode, _ := db.Get(hash) @@ -374,7 +386,7 @@ func TestBranchNode_commitEmptyNode(t *testing.T) { bn := emptyDirtyBranchNode() - err := bn.commitDirty(0, 5, nil, nil) + err := bn.commitDirty(nil, nil) assert.True(t, errors.Is(err, ErrEmptyBranchNode)) } @@ -383,7 +395,7 @@ func TestBranchNode_commitNilNode(t *testing.T) { var bn *branchNode - err := bn.commitDirty(0, 5, nil, nil) + err := bn.commitDirty(nil, nil) assert.True(t, errors.Is(err, ErrNilBranchNode)) } @@ -428,14 +440,16 @@ func TestBranchNode_resolveCollapsed(t *testing.T) { childPos := byte(2) _ = bn.setHash() - _ = bn.commitDirty(0, 5, db, db) + _ = bn.commitDirty(db, db) resolved, _ := newLeafNode(getTrieDataWithDefaultVersion("dog", "dog"), bn.marsh, bn.hasher) resolved.dirty = false resolved.hash = bn.EncodedChildren[childPos] - err := collapsedBn.resolveCollapsed(childPos, db) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + err := collapsedBn.resolveCollapsed(childPos, tmc, db) assert.Nil(t, err) assert.Equal(t, resolved, collapsedBn.children[childPos]) + assert.Equal(t, collapsedBn.children[childPos].sizeInBytes(), tmc.GetSizeLoadedInMem()) } func TestBranchNode_resolveCollapsedEmptyNode(t *testing.T) { @@ -443,7 +457,7 @@ func TestBranchNode_resolveCollapsedEmptyNode(t *testing.T) { bn := emptyDirtyBranchNode() - err := bn.resolveCollapsed(2, nil) + err := bn.resolveCollapsed(2, nil, nil) assert.True(t, errors.Is(err, ErrEmptyBranchNode)) } @@ -452,7 +466,7 @@ func TestBranchNode_resolveCollapsedENilNode(t *testing.T) { var bn *branchNode - err := bn.resolveCollapsed(2, nil) + err := bn.resolveCollapsed(2, nil, nil) assert.True(t, errors.Is(err, ErrNilBranchNode)) } @@ -461,7 +475,7 @@ func TestBranchNode_resolveCollapsedPosOutOfRange(t *testing.T) { bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) - err := bn.resolveCollapsed(17, nil) + err := bn.resolveCollapsed(17, nil, nil) assert.Equal(t, ErrChildPosOutOfRange, err) } @@ -484,10 +498,11 @@ func TestBranchNode_tryGet(t *testing.T) { childPos := byte(2) key := append([]byte{childPos}, []byte("dog")...) - val, maxDepth, err := bn.tryGet(key, 0, nil) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + val, err := bn.tryGet(key, tmc, nil) assert.Equal(t, []byte("dog"), val) assert.Nil(t, err) - assert.Equal(t, uint32(1), maxDepth) + assert.Equal(t, uint32(1), tmc.GetMaxDepth()) } func TestBranchNode_tryGetEmptyKey(t *testing.T) { @@ -496,10 +511,11 @@ func TestBranchNode_tryGetEmptyKey(t *testing.T) { bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) var key []byte - val, maxDepth, err := bn.tryGet(key, 0, nil) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + val, err := bn.tryGet(key, tmc, nil) assert.Nil(t, err) assert.Nil(t, val) - assert.Equal(t, uint32(0), maxDepth) + assert.Equal(t, uint32(0), tmc.GetMaxDepth()) } func TestBranchNode_tryGetChildPosOutOfRange(t *testing.T) { @@ -508,10 +524,11 @@ func TestBranchNode_tryGetChildPosOutOfRange(t *testing.T) { bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) key := []byte("dog") - val, maxDepth, err := bn.tryGet(key, 0, nil) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + val, err := bn.tryGet(key, tmc, nil) assert.Equal(t, ErrChildPosOutOfRange, err) assert.Nil(t, val) - assert.Equal(t, uint32(0), maxDepth) + assert.Equal(t, uint32(0), tmc.GetMaxDepth()) } func TestBranchNode_tryGetNilChild(t *testing.T) { @@ -520,10 +537,11 @@ func TestBranchNode_tryGetNilChild(t *testing.T) { bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) nilChildKey := []byte{3} - val, maxDepth, err := bn.tryGet(nilChildKey, 0, nil) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + val, err := bn.tryGet(nilChildKey, tmc, nil) assert.Nil(t, err) assert.Nil(t, val) - assert.Equal(t, uint32(0), maxDepth) + assert.Equal(t, uint32(0), tmc.GetMaxDepth()) } func TestBranchNode_tryGetCollapsedNode(t *testing.T) { @@ -533,15 +551,16 @@ func TestBranchNode_tryGetCollapsedNode(t *testing.T) { bn, collapsedBn := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) _ = bn.setHash() - _ = bn.commitDirty(0, 5, db, db) + _ = bn.commitDirty(db, db) childPos := byte(2) key := append([]byte{childPos}, []byte("dog")...) - val, maxDepth, err := collapsedBn.tryGet(key, 0, db) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + val, err := collapsedBn.tryGet(key, tmc, db) assert.Equal(t, []byte("dog"), val) assert.Nil(t, err) - assert.Equal(t, uint32(1), maxDepth) + assert.Equal(t, uint32(1), tmc.GetMaxDepth()) } func TestBranchNode_tryGetEmptyNode(t *testing.T) { @@ -551,10 +570,11 @@ func TestBranchNode_tryGetEmptyNode(t *testing.T) { childPos := byte(2) key := append([]byte{childPos}, []byte("dog")...) - val, maxDepth, err := bn.tryGet(key, 0, nil) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + val, err := bn.tryGet(key, tmc, nil) assert.True(t, errors.Is(err, ErrEmptyBranchNode)) assert.Nil(t, val) - assert.Equal(t, uint32(0), maxDepth) + assert.Equal(t, uint32(0), tmc.GetMaxDepth()) } func TestBranchNode_tryGetNilNode(t *testing.T) { @@ -564,10 +584,11 @@ func TestBranchNode_tryGetNilNode(t *testing.T) { childPos := byte(2) key := append([]byte{childPos}, []byte("dog")...) - val, maxDepth, err := bn.tryGet(key, 0, nil) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + val, err := bn.tryGet(key, tmc, nil) assert.True(t, errors.Is(err, ErrNilBranchNode)) assert.Nil(t, val) - assert.Equal(t, uint32(0), maxDepth) + assert.Equal(t, uint32(0), tmc.GetMaxDepth()) } func TestBranchNode_getNext(t *testing.T) { @@ -578,7 +599,7 @@ func TestBranchNode_getNext(t *testing.T) { childPos := byte(2) key := append([]byte{childPos}, []byte("dog")...) - n, key, err := bn.getNext(key, nil) + n, key, err := bn.getNext(key, dtmc, nil) h1, _ := encodeNodeAndGetHash(nextNode) h2, _ := encodeNodeAndGetHash(n) @@ -593,7 +614,7 @@ func TestBranchNode_getNextWrongKey(t *testing.T) { bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) key := []byte("dog") - n, key, err := bn.getNext(key, nil) + n, key, err := bn.getNext(key, dtmc, nil) assert.Nil(t, n) assert.Nil(t, key) assert.Equal(t, ErrChildPosOutOfRange, err) @@ -606,7 +627,7 @@ func TestBranchNode_getNextNilChild(t *testing.T) { nilChildPos := byte(4) key := append([]byte{nilChildPos}, []byte("dog")...) - n, key, err := bn.getNext(key, nil) + n, key, err := bn.getNext(key, dtmc, nil) assert.Nil(t, n) assert.Nil(t, key) assert.Equal(t, ErrNodeNotFound, err) @@ -618,9 +639,11 @@ func TestBranchNode_insert(t *testing.T) { bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) nodeKey := []byte{0, 2, 3} - newBn, _, err := bn.insert(getTrieDataWithDefaultVersion(string(nodeKey), "dogs"), nil) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + newBn, _, err := bn.insert(getTrieDataWithDefaultVersion(string(nodeKey), "dogs"), tmc, nil) assert.NotNil(t, newBn) assert.Nil(t, err) + assert.Equal(t, bn.children[0].sizeInBytes(), tmc.GetSizeLoadedInMem()) nodeKeyRemainder := nodeKey[1:] bn.children[0], _ = newLeafNode(getTrieDataWithDefaultVersion(string(nodeKeyRemainder), "dogs"), bn.marsh, bn.hasher) @@ -632,7 +655,7 @@ func TestBranchNode_insertEmptyKey(t *testing.T) { bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) - newBn, _, err := bn.insert(getTrieDataWithDefaultVersion("", "dogs"), nil) + newBn, _, err := bn.insert(getTrieDataWithDefaultVersion("", "dogs"), dtmc, nil) assert.Equal(t, ErrValueTooShort, err) assert.Nil(t, newBn) } @@ -642,7 +665,7 @@ func TestBranchNode_insertChildPosOutOfRange(t *testing.T) { bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) - newBn, _, err := bn.insert(getTrieDataWithDefaultVersion("dog", "dogs"), nil) + newBn, _, err := bn.insert(getTrieDataWithDefaultVersion("dog", "dogs"), dtmc, nil) assert.Equal(t, ErrChildPosOutOfRange, err) assert.Nil(t, newBn) } @@ -655,14 +678,20 @@ func TestBranchNode_insertCollapsedNode(t *testing.T) { childPos := byte(2) key := append([]byte{childPos}, []byte("dog")...) + originalSize := bn.children[childPos].sizeInBytes() _ = bn.setHash() - _ = bn.commitDirty(0, 5, db, db) + _ = bn.commitDirty(db, db) - newBn, _, err := collapsedBn.insert(getTrieDataWithDefaultVersion(string(key), "dogs"), db) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + newBn, _, err := collapsedBn.insert(getTrieDataWithDefaultVersion(string(key), "dogs"), tmc, db) assert.NotNil(t, newBn) assert.Nil(t, err) - val, _, _ := newBn.tryGet(key, 0, db) + newSize := collapsedBn.children[childPos].sizeInBytes() + sizeDiff := newSize - originalSize + assert.Equal(t, originalSize+sizeDiff, tmc.GetSizeLoadedInMem()) + + val, _ := newBn.tryGet(key, trieMetricsCollector.NewTrieMetricsCollector(), db) assert.Equal(t, []byte("dogs"), val) } @@ -674,16 +703,21 @@ func TestBranchNode_insertInStoredBnOnExistingPos(t *testing.T) { childPos := byte(2) key := append([]byte{childPos}, []byte("dog")...) - _ = bn.commitDirty(0, 5, db, db) + originalSize := bn.children[childPos].sizeInBytes() + _ = bn.setHash() + markNotDirtyBranchNode(bn) bnHash := bn.getHash() - ln, _, _ := bn.getNext(key, db) - lnHash := ln.getHash() + lnHash := bn.EncodedChildren[childPos] expectedHashes := [][]byte{lnHash, bnHash} - newNode, oldHashes, err := bn.insert(getTrieDataWithDefaultVersion(string(key), "dogs"), db) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + newNode, oldHashes, err := bn.insert(getTrieDataWithDefaultVersion(string(key), "dogs"), tmc, db) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, expectedHashes, oldHashes) + + newSize := bn.children[childPos].sizeInBytes() + assert.Equal(t, newSize-originalSize, tmc.GetSizeLoadedInMem()) } func TestBranchNode_insertInStoredBnOnNilPos(t *testing.T) { @@ -694,14 +728,17 @@ func TestBranchNode_insertInStoredBnOnNilPos(t *testing.T) { nilChildPos := byte(11) key := append([]byte{nilChildPos}, []byte("dog")...) - _ = bn.commitDirty(0, 5, db, db) + _ = bn.setHash() + markNotDirtyBranchNode(bn) bnHash := bn.getHash() expectedHashes := [][]byte{bnHash} - newNode, oldHashes, err := bn.insert(getTrieDataWithDefaultVersion(string(key), "dogs"), db) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + newNode, oldHashes, err := bn.insert(getTrieDataWithDefaultVersion(string(key), "dogs"), tmc, db) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, expectedHashes, oldHashes) + assert.Equal(t, bn.children[nilChildPos].sizeInBytes(), tmc.GetSizeLoadedInMem()) } func TestBranchNode_insertInDirtyBnOnNilPos(t *testing.T) { @@ -711,10 +748,12 @@ func TestBranchNode_insertInDirtyBnOnNilPos(t *testing.T) { nilChildPos := byte(11) key := append([]byte{nilChildPos}, []byte("dog")...) - newNode, oldHashes, err := bn.insert(getTrieDataWithDefaultVersion(string(key), "dogs"), nil) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + newNode, oldHashes, err := bn.insert(getTrieDataWithDefaultVersion(string(key), "dogs"), tmc, nil) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, [][]byte{}, oldHashes) + assert.Equal(t, bn.children[nilChildPos].sizeInBytes(), tmc.GetSizeLoadedInMem()) } func TestBranchNode_insertInDirtyBnOnExistingPos(t *testing.T) { @@ -724,10 +763,14 @@ func TestBranchNode_insertInDirtyBnOnExistingPos(t *testing.T) { childPos := byte(2) key := append([]byte{childPos}, []byte("dog")...) - newNode, oldHashes, err := bn.insert(getTrieDataWithDefaultVersion(string(key), "dogs"), nil) + originalSize := bn.children[childPos].sizeInBytes() + tmc := trieMetricsCollector.NewTrieMetricsCollector() + newNode, oldHashes, err := bn.insert(getTrieDataWithDefaultVersion(string(key), "dogs"), tmc, nil) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, [][]byte{}, oldHashes) + newSize := bn.children[childPos].sizeInBytes() + assert.Equal(t, newSize-originalSize, tmc.GetSizeLoadedInMem()) } func TestBranchNode_insertInNilNode(t *testing.T) { @@ -735,7 +778,7 @@ func TestBranchNode_insertInNilNode(t *testing.T) { var bn *branchNode - newBn, _, err := bn.insert(getTrieDataWithDefaultVersion("key", "dogs"), nil) + newBn, _, err := bn.insert(getTrieDataWithDefaultVersion("key", "dogs"), dtmc, nil) assert.True(t, errors.Is(err, ErrNilBranchNode)) assert.Nil(t, newBn) } @@ -753,9 +796,12 @@ func TestBranchNode_delete(t *testing.T) { childPos := byte(2) key := append([]byte{childPos}, []byte("dog")...) - dirty, newBn, _, err := bn.delete(key, nil) + originalSize := bn.children[childPos].sizeInBytes() + tmc := trieMetricsCollector.NewTrieMetricsCollector() + dirty, newBn, _, err := bn.delete(key, tmc, nil) assert.True(t, dirty) assert.Nil(t, err) + assert.Equal(t, -(originalSize + hashSizeInBytes), tmc.GetSizeLoadedInMem()) _ = expectedBn.setHash() _ = newBn.setHash() @@ -770,16 +816,19 @@ func TestBranchNode_deleteFromStoredBn(t *testing.T) { childPos := byte(2) lnKey := append([]byte{childPos}, []byte("dog")...) - _ = bn.commitDirty(0, 5, db, db) + originalSize := bn.children[childPos].sizeInBytes() + _ = bn.setHash() + markNotDirtyBranchNode(bn) bnHash := bn.getHash() - ln, _, _ := bn.getNext(lnKey, db) - lnHash := ln.getHash() + lnHash := bn.EncodedChildren[childPos] expectedHashes := [][]byte{lnHash, bnHash} - dirty, _, oldHashes, err := bn.delete(lnKey, db) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + dirty, _, oldHashes, err := bn.delete(lnKey, tmc, db) assert.True(t, dirty) assert.Nil(t, err) assert.Equal(t, expectedHashes, oldHashes) + assert.Equal(t, -(originalSize + hashSizeInBytes), tmc.GetSizeLoadedInMem()) } func TestBranchNode_deleteFromDirtyBn(t *testing.T) { @@ -789,10 +838,13 @@ func TestBranchNode_deleteFromDirtyBn(t *testing.T) { childPos := byte(2) lnKey := append([]byte{childPos}, []byte("dog")...) - dirty, _, oldHashes, err := bn.delete(lnKey, nil) + originalSize := bn.children[childPos].sizeInBytes() + tmc := trieMetricsCollector.NewTrieMetricsCollector() + dirty, _, oldHashes, err := bn.delete(lnKey, tmc, nil) assert.True(t, dirty) assert.Nil(t, err) assert.Equal(t, [][]byte{}, oldHashes) + assert.Equal(t, -(originalSize + hashSizeInBytes), tmc.GetSizeLoadedInMem()) } func TestBranchNode_deleteEmptyNode(t *testing.T) { @@ -802,10 +854,12 @@ func TestBranchNode_deleteEmptyNode(t *testing.T) { childPos := byte(2) key := append([]byte{childPos}, []byte("dog")...) - dirty, newBn, _, err := bn.delete(key, nil) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + dirty, newBn, _, err := bn.delete(key, tmc, nil) assert.False(t, dirty) assert.True(t, errors.Is(err, ErrEmptyBranchNode)) assert.Nil(t, newBn) + assert.Equal(t, 0, tmc.GetSizeLoadedInMem()) } func TestBranchNode_deleteNilNode(t *testing.T) { @@ -815,10 +869,12 @@ func TestBranchNode_deleteNilNode(t *testing.T) { childPos := byte(2) key := append([]byte{childPos}, []byte("dog")...) - dirty, newBn, _, err := bn.delete(key, nil) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + dirty, newBn, _, err := bn.delete(key, tmc, nil) assert.False(t, dirty) assert.True(t, errors.Is(err, ErrNilBranchNode)) assert.Nil(t, newBn) + assert.Equal(t, 0, tmc.GetSizeLoadedInMem()) } func TestBranchNode_deleteNonexistentNodeFromChild(t *testing.T) { @@ -829,10 +885,12 @@ func TestBranchNode_deleteNonexistentNodeFromChild(t *testing.T) { childPos := byte(2) key := append([]byte{childPos}, []byte("butterfly")...) - dirty, newBn, _, err := bn.delete(key, nil) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + dirty, newBn, _, err := bn.delete(key, tmc, nil) assert.False(t, dirty) assert.Nil(t, err) assert.Equal(t, bn, newBn) + assert.Equal(t, 0, tmc.GetSizeLoadedInMem()) } func TestBranchNode_deleteEmptykey(t *testing.T) { @@ -840,10 +898,12 @@ func TestBranchNode_deleteEmptykey(t *testing.T) { bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) - dirty, newBn, _, err := bn.delete([]byte{}, nil) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + dirty, newBn, _, err := bn.delete([]byte{}, tmc, nil) assert.False(t, dirty) assert.Equal(t, ErrValueTooShort, err) assert.Nil(t, newBn) + assert.Equal(t, 0, tmc.GetSizeLoadedInMem()) } func TestBranchNode_deleteCollapsedNode(t *testing.T) { @@ -852,16 +912,18 @@ func TestBranchNode_deleteCollapsedNode(t *testing.T) { db := testscommon.NewMemDbMock() bn, collapsedBn := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) _ = bn.setHash() - _ = bn.commitDirty(0, 5, db, db) + _ = bn.commitDirty(db, db) childPos := byte(2) key := append([]byte{childPos}, []byte("dog")...) - dirty, newBn, _, err := collapsedBn.delete(key, db) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + dirty, newBn, _, err := collapsedBn.delete(key, tmc, db) assert.True(t, dirty) assert.Nil(t, err) + assert.Equal(t, -hashSizeInBytes, tmc.GetSizeLoadedInMem()) - val, _, err := newBn.tryGet(key, 0, db) + val, err := newBn.tryGet(key, trieMetricsCollector.NewTrieMetricsCollector(), db) assert.Nil(t, val) assert.Nil(t, err) } @@ -876,15 +938,21 @@ func TestBranchNode_deleteAndReduceBn(t *testing.T) { children[firstChildPos], _ = newLeafNode(getTrieDataWithDefaultVersion("dog", "dog"), bn.marsh, bn.hasher) children[secondChildPos], _ = newLeafNode(getTrieDataWithDefaultVersion("doe", "doe"), bn.marsh, bn.hasher) bn.children = children + bn.EncodedChildren[firstChildPos], _ = encodeNodeAndGetHash(children[firstChildPos]) + bn.EncodedChildren[secondChildPos], _ = encodeNodeAndGetHash(children[secondChildPos]) + extraLeafData := 1 + expectedSizeInMem := -bn.children[secondChildPos].sizeInBytes() - bn.sizeInBytes() + extraLeafData key := append([]byte{firstChildPos}, []byte("dog")...) ln, _ := newLeafNode(getTrieDataWithDefaultVersion(string(key), "dog"), bn.marsh, bn.hasher) + tmc := trieMetricsCollector.NewTrieMetricsCollector() key = append([]byte{secondChildPos}, []byte("doe")...) - dirty, newBn, _, err := bn.delete(key, nil) + dirty, newBn, _, err := bn.delete(key, tmc, nil) assert.True(t, dirty) assert.Nil(t, err) assert.Equal(t, ln, newBn) + assert.Equal(t, expectedSizeInMem, tmc.GetSizeLoadedInMem()) } func TestBranchNode_reduceNode(t *testing.T) { @@ -899,7 +967,7 @@ func TestBranchNode_reduceNode(t *testing.T) { key := append([]byte{childPos}, []byte("dog")...) ln, _ := newLeafNode(getTrieDataWithDefaultVersion(string(key), "dog"), bn.marsh, bn.hasher) - n, newChildHash, err := bn.children[childPos].reduceNode(int(childPos)) + n, newChildHash, err := bn.children[childPos].reduceNode(int(childPos), dtmc) assert.Equal(t, ln, n) assert.Nil(t, err) assert.True(t, newChildHash) @@ -1015,7 +1083,7 @@ func TestBranchNode_getChildren(t *testing.T) { bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) - children, err := bn.getChildren(nil) + children, err := bn.getChildren(dtmc, nil) assert.Nil(t, err) assert.Equal(t, 3, len(children)) } @@ -1025,9 +1093,9 @@ func TestBranchNode_getChildrenCollapsedBn(t *testing.T) { db := testscommon.NewMemDbMock() bn, collapsedBn := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) - _ = bn.commitDirty(0, 5, db, db) + _ = bn.commitDirty(db, db) - children, err := collapsedBn.getChildren(db) + children, err := collapsedBn.getChildren(dtmc, db) assert.Nil(t, err) assert.Equal(t, 3, len(children)) } @@ -1189,20 +1257,6 @@ func TestBranchNode_setRootHashCollapsedChildren(t *testing.T) { assert.Nil(t, err) } -func TestBranchNode_commitCollapsesTrieIfMaxTrieLevelInMemoryIsReached(t *testing.T) { - t.Parallel() - - bn, collapsedBn := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) - _ = collapsedBn.setRootHash() - - err := bn.commitDirty(0, 1, testscommon.NewMemDbMock(), testscommon.NewMemDbMock()) - assert.Nil(t, err) - - assert.Equal(t, collapsedBn.EncodedChildren, bn.EncodedChildren) - assert.Equal(t, collapsedBn.children, bn.children) - assert.Equal(t, collapsedBn.hash, bn.hash) -} - func TestBranchNode_reduceNodeBnChild(t *testing.T) { t.Parallel() @@ -1211,7 +1265,7 @@ func TestBranchNode_reduceNodeBnChild(t *testing.T) { pos := 5 expectedNode, _ := newExtensionNode([]byte{byte(pos)}, en.child, marsh, hasher) - newNode, newChildHash, err := en.child.reduceNode(pos) + newNode, newChildHash, err := en.child.reduceNode(pos, dtmc) assert.Nil(t, err) assert.Equal(t, expectedNode, newNode) assert.False(t, newChildHash) @@ -1225,11 +1279,11 @@ func TestBranchNode_printShouldNotPanicEvenIfNodeIsCollapsed(t *testing.T) { db := testscommon.NewMemDbMock() bn, collapsedBn := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) - _ = bn.commitDirty(0, 5, db, db) - _ = collapsedBn.commitDirty(0, 5, db, db) + _ = bn.commitDirty(db, db) + _ = collapsedBn.commitDirty(db, db) - bn.print(bnWriter, 0, db) - collapsedBn.print(collapsedBnWriter, 0, db) + bn.print(bnWriter, 0, dtmc, db) + collapsedBn.print(collapsedBnWriter, 0, dtmc, db) assert.Equal(t, bnWriter.Bytes(), collapsedBnWriter.Bytes()) } @@ -1239,7 +1293,7 @@ func TestBranchNode_getDirtyHashesFromCleanNode(t *testing.T) { db := testscommon.NewMemDbMock() bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) - _ = bn.commitDirty(0, 5, db, db) + _ = bn.commitDirty(db, db) dirtyHashes := make(common.ModifiedHashes) err := bn.getDirtyHashes(dirtyHashes) @@ -1253,7 +1307,7 @@ func TestBranchNode_getAllHashes(t *testing.T) { trieNodes := 4 bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) - hashes, err := bn.getAllHashes(testscommon.NewMemDbMock()) + hashes, err := bn.getAllHashes(dtmc, testscommon.NewMemDbMock()) assert.Nil(t, err) assert.Equal(t, trieNodes, len(hashes)) } @@ -1263,9 +1317,9 @@ func TestBranchNode_getAllHashesResolvesCollapsed(t *testing.T) { db := testscommon.NewMemDbMock() bn, collapsedBn := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) - _ = bn.commitDirty(0, 5, db, db) + _ = bn.commitDirty(db, db) - hashes, err := collapsedBn.getAllHashes(db) + hashes, err := collapsedBn.getAllHashes(dtmc, db) assert.Nil(t, err) assert.Equal(t, 4, len(hashes)) } @@ -1311,10 +1365,11 @@ func TestBranchNode_SizeInBytes(t *testing.T) { collapsed1 := []byte("collapsed1") collapsed2 := []byte("collapsed2") - hash := []byte("hash") + hash := bytes.Repeat([]byte{1}, 32) bn = &branchNode{ CollapsedBn: CollapsedBn{ EncodedChildren: [][]byte{collapsed1, collapsed2}, + ChildrenVersion: []byte("version"), }, children: [17]node{}, baseNode: &baseNode{ @@ -1324,7 +1379,8 @@ func TestBranchNode_SizeInBytes(t *testing.T) { hasher: nil, }, } - assert.Equal(t, len(collapsed1)+len(collapsed2)+len(hash)+1+19*pointerSizeInBytes, bn.sizeInBytes()) + numChildren := 2 + assert.Equal(t, numChildren*hashSizeInBytes+len(hash)+1+19*pointerSizeInBytes+len(bn.ChildrenVersion), bn.sizeInBytes()) } func TestBranchNode_commitContextDone(t *testing.T) { @@ -1335,7 +1391,7 @@ func TestBranchNode_commitContextDone(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() - err := bn.commitSnapshot(db, 0, nil, nil, ctx, statistics.NewTrieStatistics(), &testscommon.ProcessStatusHandlerStub{}, []byte{}, 0) + err := bn.commitSnapshot(db, 0, nil, nil, ctx, statistics.NewTrieStatistics(), &testscommon.ProcessStatusHandlerStub{}, []byte{}, dtmc) assert.Equal(t, core.ErrContextClosing, err) } @@ -1349,7 +1405,7 @@ func TestBranchNode_commitSnapshotDbIsClosing(t *testing.T) { _, collapsedBn := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) missingNodesChan := make(chan []byte, 10) - err := collapsedBn.commitSnapshot(db, 0, nil, missingNodesChan, context.Background(), statistics.NewTrieStatistics(), &testscommon.ProcessStatusHandlerStub{}, []byte{}, 0) + err := collapsedBn.commitSnapshot(db, 0, nil, missingNodesChan, context.Background(), statistics.NewTrieStatistics(), &testscommon.ProcessStatusHandlerStub{}, []byte{}, dtmc) assert.True(t, core.IsClosingError(err)) assert.Equal(t, 0, len(missingNodesChan)) } @@ -1364,7 +1420,7 @@ func TestBranchNode_commitSnapshotChildIsMissingErr(t *testing.T) { _, collapsedBn := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) missingNodesChan := make(chan []byte, 10) - err := collapsedBn.commitSnapshot(db, 0, nil, missingNodesChan, context.Background(), statistics.NewTrieStatistics(), &testscommon.ProcessStatusHandlerStub{}, []byte{}, 0) + err := collapsedBn.commitSnapshot(db, 0, nil, missingNodesChan, context.Background(), statistics.NewTrieStatistics(), &testscommon.ProcessStatusHandlerStub{}, []byte{}, dtmc) assert.Nil(t, err) assert.Equal(t, 3, len(missingNodesChan)) } @@ -1448,7 +1504,7 @@ func TestBranchNode_VerifyChildrenVersionIsSetCorrectlyAfterInsertAndDelete(t *t Value: []byte("value"), Version: 0, } - newBn, _, err := bn.insert(data, &testscommon.MemDbMock{}) + newBn, _, err := bn.insert(data, dtmc, &testscommon.MemDbMock{}) assert.Nil(t, err) assert.Nil(t, newBn.(*branchNode).ChildrenVersion) }) @@ -1461,7 +1517,7 @@ func TestBranchNode_VerifyChildrenVersionIsSetCorrectlyAfterInsertAndDelete(t *t bn.ChildrenVersion[2] = byte(core.AutoBalanceEnabled) childKey := []byte{2, 'd', 'o', 'g'} - _, newBn, _, err := bn.delete(childKey, &testscommon.MemDbMock{}) + _, newBn, _, err := bn.delete(childKey, dtmc, &testscommon.MemDbMock{}) assert.Nil(t, err) assert.Nil(t, newBn.(*branchNode).ChildrenVersion) }) @@ -1558,3 +1614,50 @@ func TestBranchNode_getNodeData(t *testing.T) { assert.False(t, thirdChildData.IsLeaf()) }) } + +func TestBranchNode_shouldCollapseChild(t *testing.T) { + t.Parallel() + + t.Run("empty hexKey", func(t *testing.T) { + t.Parallel() + + bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) + shouldCollapseChild := bn.shouldCollapseChild([]byte{}, trieMetricsCollector.NewTrieMetricsCollector()) + assert.False(t, shouldCollapseChild) + }) + t.Run("invalid hexKey", func(t *testing.T) { + t.Parallel() + + bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) + shouldCollapseChild := bn.shouldCollapseChild([]byte{17}, trieMetricsCollector.NewTrieMetricsCollector()) + assert.False(t, shouldCollapseChild) + }) + t.Run("child is nil", func(t *testing.T) { + t.Parallel() + + bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) + shouldCollapseChild := bn.shouldCollapseChild([]byte{4}, trieMetricsCollector.NewTrieMetricsCollector()) + assert.False(t, shouldCollapseChild) + }) + t.Run("collapse child that is already collapsed", func(t *testing.T) { + t.Parallel() + + _, collapsedBn := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) + leafChildKey := append([]byte{2}, []byte("dog")...) + shouldCollapseChild := collapsedBn.shouldCollapseChild(leafChildKey, trieMetricsCollector.NewTrieMetricsCollector()) + assert.False(t, shouldCollapseChild) + }) + t.Run("successful collapse", func(t *testing.T) { + t.Parallel() + + bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) + bn.children[2].setDirty(false) + leafSize := bn.children[2].sizeInBytes() + leafChildKey := append([]byte{2}, []byte("dog")...) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + shouldCollapseChild := bn.shouldCollapseChild(leafChildKey, tmc) + assert.False(t, shouldCollapseChild) + assert.Nil(t, bn.children[2]) + assert.Equal(t, -leafSize, tmc.GetSizeLoadedInMem()) + }) +} diff --git a/trie/collapseManager/collapseManager.go b/trie/collapseManager/collapseManager.go new file mode 100644 index 00000000000..0e806031ca1 --- /dev/null +++ b/trie/collapseManager/collapseManager.go @@ -0,0 +1,147 @@ +package collapseManager + +import ( + "container/list" + "fmt" + + "github.com/multiversx/mx-chain-go/common" + logger "github.com/multiversx/mx-chain-logger-go" +) + +var log = logger.GetOrCreate("trie/collapseManager") + +const ( + // TODO calibrate these values + defaultNumLeavesToCollapseSingleRun = 100 + minNumLeavesToCollapseTrie = 1000 + minSizeInMemory = 1048576 // 1 MB +) + +type collapseManager struct { + accessedKeys map[string]*list.Element + orderAccess *list.List + sizeInMemory int + maxSizeInMem uint64 + numLeavesToCollapseSingleRun int +} + +// NewCollapseManager creates a new collapse manager +func NewCollapseManager(maxSize uint64, numLeavesToCollapseSingleRun uint32) (*collapseManager, error) { + if maxSize < minSizeInMemory { + return nil, fmt.Errorf("invalid max size provided: %d, minimum %d", maxSize, minSizeInMemory) + } + + if numLeavesToCollapseSingleRun == 0 { + numLeavesToCollapseSingleRun = defaultNumLeavesToCollapseSingleRun + } + + return &collapseManager{ + accessedKeys: make(map[string]*list.Element), + orderAccess: list.New(), + sizeInMemory: 0, + maxSizeInMem: maxSize, + numLeavesToCollapseSingleRun: int(numLeavesToCollapseSingleRun), + }, nil +} + +func (cm *collapseManager) addSizeInMemory(size int) { + if cm.sizeInMemory+size < 0 { + log.Warn("trie size in memory is negative after adding size, resetting to 0", "size", size, "currentSize", cm.sizeInMemory) + cm.sizeInMemory = 0 + return + } + cm.sizeInMemory += size +} + +// MarkKeyAsAccessed marks a key as accessed, updating its position in the access order +func (cm *collapseManager) MarkKeyAsAccessed(key []byte, sizeLoadedInMemory int) { + defer cm.addSizeInMemory(sizeLoadedInMemory) + + entry, ok := cm.accessedKeys[string(key)] + if !ok { + e := cm.orderAccess.PushFront(key) + cm.accessedKeys[string(key)] = e + + return + } + + cm.orderAccess.MoveToFront(entry) +} + +// RemoveKey removes a key from the accessed keys list and updates the size in memory +func (cm *collapseManager) RemoveKey(key []byte, sizeLoadedInMemory int) { + defer cm.addSizeInMemory(sizeLoadedInMemory) + + entry, ok := cm.accessedKeys[string(key)] + if !ok { + return + } + + cm.orderAccess.Remove(entry) + delete(cm.accessedKeys, string(key)) +} + +// AddSizeInMemory adds size to the current size in memory +func (cm *collapseManager) AddSizeInMemory(size int) { + cm.addSizeInMemory(size) +} + +// GetSizeInMemory returns the current size in memory +func (cm *collapseManager) GetSizeInMemory() int { + return cm.sizeInMemory +} + +// ShouldCollapseTrie determines if the trie should be collapsed based on memory usage and accessed keys +func (cm *collapseManager) ShouldCollapseTrie() bool { + // we collapse only if we are over the memory limit and there are not enough accessed keys to + // free memory by collapsing only leaves + if uint64(cm.sizeInMemory) > cm.maxSizeInMem && len(cm.accessedKeys) < minNumLeavesToCollapseTrie { + return true + } + + return false +} + +// GetCollapsibleLeaves returns a list of keys that can be collapsed to free memory +func (cm *collapseManager) GetCollapsibleLeaves() ([][]byte, error) { + if uint64(cm.sizeInMemory) < cm.maxSizeInMem { + return make([][]byte, 0), nil + } + + evictedKeys := make([][]byte, 0) + for i := 0; i < cm.numLeavesToCollapseSingleRun; i++ { + if cm.orderAccess.Len() == 0 { + break + } + entry := cm.orderAccess.Back() + if entry == nil { + return nil, fmt.Errorf("unexpected nil entry in collapseManager orderAccess list") + } + cm.orderAccess.Remove(entry) + keyBytes, ok := entry.Value.([]byte) + if !ok { + return nil, fmt.Errorf("invalid key type in collapseManager orderAccess list: %T", entry.Value) + } + delete(cm.accessedKeys, string(keyBytes)) + + evictedKeys = append(evictedKeys, keyBytes) + } + + return evictedKeys, nil +} + +// CloneWithoutState creates a new collapse manager with the same configuration but without the current state +func (cm *collapseManager) CloneWithoutState() common.TrieCollapseManager { + return &collapseManager{ + accessedKeys: make(map[string]*list.Element), + orderAccess: list.New(), + sizeInMemory: 0, + maxSizeInMem: cm.maxSizeInMem, + numLeavesToCollapseSingleRun: cm.numLeavesToCollapseSingleRun, + } +} + +// IsInterfaceNil returns true if there is no value under the interface +func (cm *collapseManager) IsInterfaceNil() bool { + return cm == nil +} diff --git a/trie/collapseManager/collapseManager_test.go b/trie/collapseManager/collapseManager_test.go new file mode 100644 index 00000000000..c036af33b3c --- /dev/null +++ b/trie/collapseManager/collapseManager_test.go @@ -0,0 +1,197 @@ +package collapseManager + +import ( + "testing" + + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/common" + "github.com/stretchr/testify/assert" +) + +func TestNewCollapseManager(t *testing.T) { + t.Parallel() + + t.Run("invalid maxSizeInMem should error", func(t *testing.T) { + t.Parallel() + + cm, err := NewCollapseManager(0, common.NumLeavesToCollapseSingleRun) + assert.Nil(t, cm) + assert.NotNil(t, err) + }) + t.Run("invalid numLeavesToCollapseSingleRun should default to defaultNumLeavesToCollapseSingleRun", func(t *testing.T) { + t.Parallel() + + cm, err := NewCollapseManager(2*minSizeInMemory, 0) + assert.False(t, check.IfNil(cm)) + assert.Nil(t, err) + assert.Equal(t, 0, cm.sizeInMemory) + assert.Equal(t, 2*minSizeInMemory, int(cm.maxSizeInMem)) + assert.Equal(t, defaultNumLeavesToCollapseSingleRun, cm.numLeavesToCollapseSingleRun) + assert.Equal(t, 0, len(cm.accessedKeys)) + assert.Equal(t, 0, cm.orderAccess.Len()) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + cm, err := NewCollapseManager(2*minSizeInMemory, common.NumLeavesToCollapseSingleRun) + assert.False(t, check.IfNil(cm)) + assert.Nil(t, err) + assert.Equal(t, 0, cm.sizeInMemory) + assert.Equal(t, 2*minSizeInMemory, int(cm.maxSizeInMem)) + assert.Equal(t, common.NumLeavesToCollapseSingleRun, cm.numLeavesToCollapseSingleRun) + assert.Equal(t, 0, len(cm.accessedKeys)) + assert.Equal(t, 0, cm.orderAccess.Len()) + }) +} + +func TestCollapseManager_MarkKeyAsAccessed(t *testing.T) { + t.Parallel() + + cm, _ := NewCollapseManager(2*minSizeInMemory, common.NumLeavesToCollapseSingleRun) + cm.MarkKeyAsAccessed([]byte("key1"), 500) + assert.Equal(t, 500, cm.sizeInMemory) + assert.Equal(t, 1, len(cm.accessedKeys)) + assert.Equal(t, 1, cm.orderAccess.Len()) + + cm.MarkKeyAsAccessed([]byte("key2"), 600) + assert.Equal(t, 1100, cm.sizeInMemory) + assert.Equal(t, 2, len(cm.accessedKeys)) + assert.Equal(t, 2, cm.orderAccess.Len()) + oldestVal := cm.orderAccess.Back().Value.([]byte) + assert.Equal(t, "key1", string(oldestVal)) + + cm.MarkKeyAsAccessed([]byte("key1"), 0) + assert.Equal(t, 1100, cm.sizeInMemory) + assert.Equal(t, 2, len(cm.accessedKeys)) + assert.Equal(t, 2, cm.orderAccess.Len()) + oldestVal = cm.orderAccess.Back().Value.([]byte) + assert.Equal(t, "key2", string(oldestVal)) +} + +func TestCollapseManager_RemoveKey(t *testing.T) { + t.Parallel() + + cm, _ := NewCollapseManager(2*minSizeInMemory, common.NumLeavesToCollapseSingleRun) + cm.MarkKeyAsAccessed([]byte("key1"), 500) + cm.MarkKeyAsAccessed([]byte("key2"), 600) + assert.Equal(t, 1100, cm.sizeInMemory) + assert.Equal(t, 2, len(cm.accessedKeys)) + assert.Equal(t, 2, cm.orderAccess.Len()) + + cm.RemoveKey([]byte("key1"), -500) + assert.Equal(t, 600, cm.sizeInMemory) + assert.Equal(t, 1, len(cm.accessedKeys)) + assert.Equal(t, 1, cm.orderAccess.Len()) + + // removing non existing key should do nothing + cm.RemoveKey([]byte("key3"), 0) + assert.Equal(t, 600, cm.sizeInMemory) + assert.Equal(t, 1, len(cm.accessedKeys)) + assert.Equal(t, 1, cm.orderAccess.Len()) +} + +func TestCollapseManager_AddSizeInMemory(t *testing.T) { + t.Parallel() + + cm, _ := NewCollapseManager(2*minSizeInMemory, common.NumLeavesToCollapseSingleRun) + cm.AddSizeInMemory(700) + assert.Equal(t, 700, cm.GetSizeInMemory()) + + cm.AddSizeInMemory(-200) + assert.Equal(t, 500, cm.GetSizeInMemory()) + + cm.AddSizeInMemory(0) + assert.Equal(t, 500, cm.GetSizeInMemory()) + + // sizeInMemory should not go below 0 + cm.AddSizeInMemory(-800) + assert.Equal(t, 0, cm.GetSizeInMemory()) +} + +func TestCollapseManager_ShouldCollapseTrie(t *testing.T) { + t.Parallel() + + cm, _ := NewCollapseManager(minSizeInMemory, common.NumLeavesToCollapseSingleRun) + assert.False(t, cm.ShouldCollapseTrie()) + + cm.AddSizeInMemory(minSizeInMemory + 1) + assert.True(t, cm.ShouldCollapseTrie()) +} + +func TestCollapseManager_GetCollapsibleLeaves(t *testing.T) { + t.Parallel() + + t.Run("sizeInMemory below limit should return empty", func(t *testing.T) { + t.Parallel() + + cm, _ := NewCollapseManager(minSizeInMemory, common.NumLeavesToCollapseSingleRun) + cm.MarkKeyAsAccessed([]byte("key1"), 500) + leaves, err := cm.GetCollapsibleLeaves() + assert.Nil(t, err) + assert.Equal(t, 0, len(leaves)) + }) + t.Run("should return evicted keys until sizeInMemory is below limit", func(t *testing.T) { + t.Parallel() + + cm, _ := NewCollapseManager(minSizeInMemory, common.NumLeavesToCollapseSingleRun) + cm.MarkKeyAsAccessed([]byte("key1"), 700) + cm.MarkKeyAsAccessed([]byte("key2"), 600) + cm.MarkKeyAsAccessed([]byte("key3"), 500) + + leaves, err := cm.GetCollapsibleLeaves() + assert.Nil(t, err) + assert.Equal(t, 0, len(leaves)) + + cm.MarkKeyAsAccessed([]byte("key4"), minSizeInMemory-1) + leaves, err = cm.GetCollapsibleLeaves() + assert.Nil(t, err) + assert.Equal(t, 4, len(leaves)) + assert.Equal(t, "key1", string(leaves[0])) + assert.Equal(t, "key2", string(leaves[1])) + assert.Equal(t, "key3", string(leaves[2])) + assert.Equal(t, "key4", string(leaves[3])) + assert.Equal(t, 0, len(cm.accessedKeys)) + assert.Equal(t, 0, cm.orderAccess.Len()) + }) + + t.Run("should return up to numLeavesToCollapseSingleRun evicted keys", func(t *testing.T) { + t.Parallel() + + numLeavesToCollapseSingleRun := uint32(5) + cm, _ := NewCollapseManager(minSizeInMemory, numLeavesToCollapseSingleRun) + for i := 0; i < int(numLeavesToCollapseSingleRun)+2; i++ { + key := []byte("key" + string(rune(i))) + cm.MarkKeyAsAccessed(key, 500) + } + cm.AddSizeInMemory(minSizeInMemory + 1) + + leaves, err := cm.GetCollapsibleLeaves() + assert.Nil(t, err) + assert.Equal(t, int(numLeavesToCollapseSingleRun), len(leaves)) + for i := 0; i < int(numLeavesToCollapseSingleRun); i++ { + expectedKey := "key" + string(rune(i)) + assert.Equal(t, expectedKey, string(leaves[i])) + } + assert.Equal(t, 2, len(cm.accessedKeys)) + assert.Equal(t, 2, cm.orderAccess.Len()) + }) +} + +func TestCollapseManager_CloneWithoutState(t *testing.T) { + t.Parallel() + + cm, _ := NewCollapseManager(2*minSizeInMemory, common.NumLeavesToCollapseSingleRun) + cm.MarkKeyAsAccessed([]byte("key1"), 500) + assert.Equal(t, 500, cm.sizeInMemory) + assert.Equal(t, 1, len(cm.accessedKeys)) + assert.Equal(t, 1, cm.orderAccess.Len()) + clone := cm.CloneWithoutState() + + assert.False(t, check.IfNil(clone)) + clonedCM := clone.(*collapseManager) + assert.Equal(t, 2*minSizeInMemory, int(clonedCM.maxSizeInMem)) + assert.Equal(t, common.NumLeavesToCollapseSingleRun, clonedCM.numLeavesToCollapseSingleRun) + assert.Equal(t, 0, len(clonedCM.accessedKeys)) + assert.Equal(t, 0, clonedCM.orderAccess.Len()) + assert.Equal(t, 0, clone.GetSizeInMemory()) +} diff --git a/trie/collapseManager/disabledCollapseManager.go b/trie/collapseManager/disabledCollapseManager.go new file mode 100644 index 00000000000..6c7c6ccace2 --- /dev/null +++ b/trie/collapseManager/disabledCollapseManager.go @@ -0,0 +1,49 @@ +package collapseManager + +import ( + "github.com/multiversx/mx-chain-go/common" +) + +type disabledCollapseManager struct{} + +// NewDisabledCollapseManager creates a new disabled collapse manager +func NewDisabledCollapseManager() *disabledCollapseManager { + return &disabledCollapseManager{} +} + +// MarkKeyAsAccessed does nothing for this implementation +func (d *disabledCollapseManager) MarkKeyAsAccessed(_ []byte, _ int) { +} + +// RemoveKey does nothing for this implementation +func (d *disabledCollapseManager) RemoveKey(_ []byte, _ int) { +} + +// ShouldCollapseTrie always returns false for this implementation +func (d *disabledCollapseManager) ShouldCollapseTrie() bool { + return false +} + +// GetCollapsibleLeaves always returns nil for this implementation +func (d *disabledCollapseManager) GetCollapsibleLeaves() ([][]byte, error) { + return nil, nil +} + +// AddSizeInMemory does nothing for this implementation +func (d *disabledCollapseManager) AddSizeInMemory(_ int) { +} + +// GetSizeInMemory always returns 0 for this implementation +func (d *disabledCollapseManager) GetSizeInMemory() int { + return 0 +} + +// CloneWithoutState returns a new disabled collapse manager +func (d *disabledCollapseManager) CloneWithoutState() common.TrieCollapseManager { + return NewDisabledCollapseManager() +} + +// IsInterfaceNil returns true if there is no value under the interface +func (d *disabledCollapseManager) IsInterfaceNil() bool { + return d == nil +} diff --git a/trie/collapseManager/disabledCollapseManager_test.go b/trie/collapseManager/disabledCollapseManager_test.go new file mode 100644 index 00000000000..f65f3bbad2e --- /dev/null +++ b/trie/collapseManager/disabledCollapseManager_test.go @@ -0,0 +1,68 @@ +package collapseManager + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewDisabledCollapseManager(t *testing.T) { + t.Parallel() + + dcm := NewDisabledCollapseManager() + assert.False(t, dcm.IsInterfaceNil()) +} + +func TestDisabledCollapseManager_MarkKeyAsAccessed(t *testing.T) { + t.Parallel() + + dcm := NewDisabledCollapseManager() + dcm.MarkKeyAsAccessed([]byte("key1"), 500) + + assert.Equal(t, 0, dcm.GetSizeInMemory()) +} + +func TestDisabledCollapseManager_RemoveKey(t *testing.T) { + t.Parallel() + + dcm := NewDisabledCollapseManager() + dcm.RemoveKey([]byte("key1"), 500) + + assert.Equal(t, 0, dcm.GetSizeInMemory()) +} + +func TestDisabledCollapseManager_ShouldCollapseTrie(t *testing.T) { + t.Parallel() + + dcm := NewDisabledCollapseManager() + shouldCollapse := dcm.ShouldCollapseTrie() + + assert.False(t, shouldCollapse) +} + +func TestDisabledCollapseManager_GetCollapsibleLeaves(t *testing.T) { + t.Parallel() + + dcm := NewDisabledCollapseManager() + leaves, err := dcm.GetCollapsibleLeaves() + + assert.Nil(t, err) + assert.Nil(t, leaves) +} + +func TestDisabledCollapseManager_AddSizeInMemory(t *testing.T) { + t.Parallel() + + dcm := NewDisabledCollapseManager() + dcm.AddSizeInMemory(500) + assert.Equal(t, 0, dcm.GetSizeInMemory()) +} + +func TestDisabledCollapseManager_CloneWithoutState(t *testing.T) { + t.Parallel() + + dcm := NewDisabledCollapseManager() + clone := dcm.CloneWithoutState() + + assert.False(t, clone.IsInterfaceNil()) +} diff --git a/trie/doubleListSync_test.go b/trie/doubleListSync_test.go index 8e631237cc6..cf5684a994f 100644 --- a/trie/doubleListSync_test.go +++ b/trie/doubleListSync_test.go @@ -18,6 +18,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" + "github.com/multiversx/mx-chain-go/trie/collapseManager" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -37,7 +38,7 @@ func createTrieStorageManager(store storage.Storer) (common.StorageManager, stor func createInMemoryTrie() (common.Trie, storage.Storer) { memUnit := testscommon.CreateMemUnit() tsm, _ := createTrieStorageManager(memUnit) - tr, _ := NewTrie(tsm, marshalizer, hasherMock, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 6) + tr, _ := NewTrie(tsm, marshalizer, hasherMock, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, collapseManager.NewDisabledCollapseManager()) return tr, memUnit } @@ -50,7 +51,7 @@ func createInMemoryTrieFromDB(db storage.Persister) (common.Trie, storage.Storer unit, _ := storageunit.NewStorageUnit(cache, db) tsm, _ := createTrieStorageManager(unit) - tr, _ := NewTrie(tsm, marshalizer, hasherMock, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 6) + tr, _ := NewTrie(tsm, marshalizer, hasherMock, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, collapseManager.NewDisabledCollapseManager()) return tr, unit } diff --git a/trie/errors.go b/trie/errors.go index a879fd6c94c..25069bfbb14 100644 --- a/trie/errors.go +++ b/trie/errors.go @@ -129,3 +129,9 @@ var ErrEmptyInitialIteratorState = errors.New("empty initial iterator state") // ErrInvalidIteratorState signals that an invalid iterator state was provided var ErrInvalidIteratorState = errors.New("invalid iterator state") + +// ErrInvalidMaxSizeInMemory signals that the provided max trie size value is invalid +var ErrInvalidMaxSizeInMemory = errors.New("invalid max size in memory") + +// ErrNilCollapseManager signals that a nil collapse manager has been provided +var ErrNilCollapseManager = errors.New("nil collapse manager") diff --git a/trie/export_test.go b/trie/export_test.go index dea9315ffab..254c3e44815 100644 --- a/trie/export_test.go +++ b/trie/export_test.go @@ -111,3 +111,34 @@ func GetDefaultTrieStorageManagerParameters() NewTrieStorageManagerArgs { StatsCollector: statistics.NewStateStatistics(), } } + +// GetNumCollapsedNodes returns the number of collapsed nodes in the trie +func (tr *patriciaMerkleTrie) GetNumCollapsedNodes() (int, error) { + count := 0 + + nextNodes := make([]node, 0) + nextNodes = append(nextNodes, tr.root) + for len(nextNodes) > 0 { + currentNode := nextNodes[0] + switch current := currentNode.(type) { + case *branchNode: + for i := range current.children { + if current.children[i] != nil { + nextNodes = append(nextNodes, current.children[i]) + } else if len(current.EncodedChildren[i]) != 0 { + count++ + } + } + case *extensionNode: + if current.child != nil { + nextNodes = append(nextNodes, current.child) + } + case *leafNode: + // do nothing + default: + return 0, ErrInvalidNode + } + nextNodes = nextNodes[1:] + } + return count, nil +} diff --git a/trie/extensionNode.go b/trie/extensionNode.go index 0c1a657665b..8bdc5dfe090 100644 --- a/trie/extensionNode.go +++ b/trie/extensionNode.go @@ -173,8 +173,7 @@ func (en *extensionNode) hashNode() ([]byte, error) { return encodeNodeAndGetHash(en) } -func (en *extensionNode) commitDirty(level byte, maxTrieLevelInMemory uint, originDb common.TrieStorageInteractor, targetDb common.BaseStorer) error { - level++ +func (en *extensionNode) commitDirty(originDb common.TrieStorageInteractor, targetDb common.BaseStorer) error { err := en.isEmptyOrNil() if err != nil { return fmt.Errorf("commit error %w", err) @@ -185,7 +184,7 @@ func (en *extensionNode) commitDirty(level byte, maxTrieLevelInMemory uint, orig } if en.child != nil { - err = en.child.commitDirty(level, maxTrieLevelInMemory, originDb, targetDb) + err = en.child.commitDirty(originDb, targetDb) if err != nil { return err } @@ -196,18 +195,27 @@ func (en *extensionNode) commitDirty(level byte, maxTrieLevelInMemory uint, orig if err != nil { return err } - if uint(level) == maxTrieLevelInMemory { - log.Trace("collapse extension node on commit") - var collapsedEn *extensionNode - collapsedEn, err = en.getCollapsedEn() - if err != nil { - return err - } + return nil +} - *en = *collapsedEn +func (en *extensionNode) shouldCollapseChild(hexKey []byte, tmc MetricsCollector) bool { + keyTooShort := len(hexKey) < len(en.Key) + if keyTooShort { + return false } - return nil + keysDontMatch := !bytes.Equal(en.Key, hexKey[:len(en.Key)]) + if keysDontMatch { + return false + } + hexKey = hexKey[len(en.Key):] + if en.child == nil { + return false + } + + // an extension node can not have a leaf as child, so no need to check for that + _ = en.child.shouldCollapseChild(hexKey, tmc) + return false } func (en *extensionNode) commitSnapshot( @@ -219,12 +227,16 @@ func (en *extensionNode) commitSnapshot( stats common.TrieStatisticsHandler, idleProvider IdleNodeProvider, nodeBytes []byte, - depthLevel int, + tmc MetricsCollector, ) error { if shouldStopIfContextDoneBlockingIfBusy(ctx, idleProvider) { return core.ErrContextClosing } + depthLevel := tmc.GetCurrentDepth() + tmc.SetDepth(depthLevel + 1) + defer tmc.SetDepth(depthLevel) // Reset depth when returning from this level + err := commitSnapshot( db, maxEpochToSearchFrom, @@ -235,14 +247,14 @@ func (en *extensionNode) commitSnapshot( ctx, stats, idleProvider, - depthLevel, + tmc, en.EncodedChild, ) if err != nil { return err } - stats.AddExtensionNode(depthLevel, uint64(len(nodeBytes))) + stats.AddExtensionNode(int(depthLevel), uint64(len(nodeBytes))) return nil } @@ -259,7 +271,7 @@ func (en *extensionNode) getEncodedNode() ([]byte, error) { return marshaledNode, nil } -func (en *extensionNode) resolveCollapsed(_ byte, db common.TrieStorageInteractor) error { +func (en *extensionNode) resolveCollapsed(_ byte, tmc MetricsCollector, db common.TrieStorageInteractor) error { err := en.isEmptyOrNil() if err != nil { return fmt.Errorf("resolveCollapsed error %w", err) @@ -269,6 +281,7 @@ func (en *extensionNode) resolveCollapsed(_ byte, db common.TrieStorageInteracto return err } child.setGivenHash(en.EncodedChild) + tmc.AddSizeLoadedInMem(child.sizeInBytes()) en.child = child return nil } @@ -281,29 +294,30 @@ func (en *extensionNode) isPosCollapsed(_ int) bool { return en.isCollapsed() } -func (en *extensionNode) tryGet(key []byte, currentDepth uint32, db common.TrieStorageInteractor) (value []byte, maxDepth uint32, err error) { +func (en *extensionNode) tryGet(key []byte, tmc MetricsCollector, db common.TrieStorageInteractor) (value []byte, err error) { err = en.isEmptyOrNil() if err != nil { - return nil, currentDepth, fmt.Errorf("tryGet error %w", err) + return nil, fmt.Errorf("tryGet error %w", err) } keyTooShort := len(key) < len(en.Key) if keyTooShort { - return nil, currentDepth, nil + return nil, nil } keysDontMatch := !bytes.Equal(en.Key, key[:len(en.Key)]) if keysDontMatch { - return nil, currentDepth, nil + return nil, nil } key = key[len(en.Key):] - err = resolveIfCollapsed(en, 0, db) + err = resolveIfCollapsed(en, 0, tmc, db) if err != nil { - return nil, currentDepth, err + return nil, err } - return en.child.tryGet(key, currentDepth+1, db) + tmc.SetDepth(tmc.GetCurrentDepth() + 1) + return en.child.tryGet(key, tmc, db) } -func (en *extensionNode) getNext(key []byte, db common.TrieStorageInteractor) (node, []byte, error) { +func (en *extensionNode) getNext(key []byte, tmc MetricsCollector, db common.TrieStorageInteractor) (node, []byte, error) { err := en.isEmptyOrNil() if err != nil { return nil, nil, fmt.Errorf("getNext error %w", err) @@ -316,7 +330,7 @@ func (en *extensionNode) getNext(key []byte, db common.TrieStorageInteractor) (n if keysDontMatch { return nil, nil, ErrNodeNotFound } - err = resolveIfCollapsed(en, 0, db) + err = resolveIfCollapsed(en, 0, tmc, db) if err != nil { return nil, nil, err } @@ -325,13 +339,13 @@ func (en *extensionNode) getNext(key []byte, db common.TrieStorageInteractor) (n return en.child, key, nil } -func (en *extensionNode) insert(newData core.TrieData, db common.TrieStorageInteractor) (node, [][]byte, error) { +func (en *extensionNode) insert(newData core.TrieData, tmc MetricsCollector, db common.TrieStorageInteractor) (node, [][]byte, error) { emptyHashes := make([][]byte, 0) err := en.isEmptyOrNil() if err != nil { return nil, emptyHashes, fmt.Errorf("insert error %w", err) } - err = resolveIfCollapsed(en, 0, db) + err = resolveIfCollapsed(en, 0, tmc, db) if err != nil { return nil, emptyHashes, err } @@ -341,16 +355,16 @@ func (en *extensionNode) insert(newData core.TrieData, db common.TrieStorageInte // If the whole key matches, keep this extension node as is // and only update the value. if keyMatchLen == len(en.Key) { - return en.insertInSameEn(newData, keyMatchLen, db) + return en.insertInSameEn(newData, keyMatchLen, tmc, db) } // Otherwise branch out at the index where they differ. - return en.insertInNewBn(newData, keyMatchLen) + return en.insertInNewBn(newData, keyMatchLen, tmc) } -func (en *extensionNode) insertInSameEn(newData core.TrieData, keyMatchLen int, db common.TrieStorageInteractor) (node, [][]byte, error) { +func (en *extensionNode) insertInSameEn(newData core.TrieData, keyMatchLen int, tmc MetricsCollector, db common.TrieStorageInteractor) (node, [][]byte, error) { newData.Key = newData.Key[keyMatchLen:] - newNode, oldHashes, err := en.child.insert(newData, db) + newNode, oldHashes, err := en.child.insert(newData, tmc, db) if check.IfNil(newNode) || err != nil { return nil, [][]byte{}, err } @@ -367,7 +381,7 @@ func (en *extensionNode) insertInSameEn(newData core.TrieData, keyMatchLen int, return newEn, oldHashes, nil } -func (en *extensionNode) insertInNewBn(newData core.TrieData, keyMatchLen int) (node, [][]byte, error) { +func (en *extensionNode) insertInNewBn(newData core.TrieData, keyMatchLen int, tmc MetricsCollector) (node, [][]byte, error) { oldHash := make([][]byte, 0) if !en.dirty { oldHash = append(oldHash, en.hash) @@ -384,16 +398,17 @@ func (en *extensionNode) insertInNewBn(newData core.TrieData, keyMatchLen int) ( return nil, [][]byte{}, ErrChildPosOutOfRange } - err = en.insertOldChildInBn(bn, oldChildPos, keyMatchLen) + err = en.insertOldChildInBn(bn, oldChildPos, keyMatchLen, tmc) if err != nil { return nil, [][]byte{}, err } - err = en.insertNewChildInBn(bn, newData, newChildPos, keyMatchLen) + err = en.insertNewChildInBn(bn, newData, newChildPos, keyMatchLen, tmc) if err != nil { return nil, [][]byte{}, err } + tmc.AddSizeLoadedInMem(bn.sizeInBytes()) if keyMatchLen == 0 { return bn, oldHash, nil } @@ -403,10 +418,12 @@ func (en *extensionNode) insertInNewBn(newData core.TrieData, keyMatchLen int) ( return nil, [][]byte{}, err } + tmc.AddSizeLoadedInMem(newEn.sizeInBytes()) + return newEn, oldHash, nil } -func (en *extensionNode) insertOldChildInBn(bn *branchNode, oldChildPos byte, keyMatchLen int) error { +func (en *extensionNode) insertOldChildInBn(bn *branchNode, oldChildPos byte, keyMatchLen int, tmc MetricsCollector) error { keyReminder := en.Key[keyMatchLen+1:] childVersion, err := en.child.getVersion() if err != nil { @@ -415,10 +432,12 @@ func (en *extensionNode) insertOldChildInBn(bn *branchNode, oldChildPos byte, ke bn.setVersionForChild(childVersion, oldChildPos) if len(keyReminder) < 1 { + tmc.AddSizeLoadedInMem(-en.sizeInBytes()) bn.children[oldChildPos] = en.child return nil } + tmc.AddSizeLoadedInMem(-keyMatchLen) followingExtensionNode, err := newExtensionNode(en.Key[keyMatchLen+1:], en.child, en.marsh, en.hasher) if err != nil { return err @@ -428,7 +447,7 @@ func (en *extensionNode) insertOldChildInBn(bn *branchNode, oldChildPos byte, ke return nil } -func (en *extensionNode) insertNewChildInBn(bn *branchNode, newData core.TrieData, newChildPos byte, keyMatchLen int) error { +func (en *extensionNode) insertNewChildInBn(bn *branchNode, newData core.TrieData, newChildPos byte, keyMatchLen int, tmc MetricsCollector) error { newData.Key = newData.Key[keyMatchLen+1:] newLeaf, err := newLeafNode(newData, en.marsh, en.hasher) @@ -436,12 +455,14 @@ func (en *extensionNode) insertNewChildInBn(bn *branchNode, newData core.TrieDat return err } + tmc.AddSizeLoadedInMem(newLeaf.sizeInBytes()) + bn.children[newChildPos] = newLeaf bn.setVersionForChild(newData.Version, newChildPos) return nil } -func (en *extensionNode) delete(key []byte, db common.TrieStorageInteractor) (bool, node, [][]byte, error) { +func (en *extensionNode) delete(key []byte, tmc MetricsCollector, db common.TrieStorageInteractor) (bool, node, [][]byte, error) { emptyHashes := make([][]byte, 0) err := en.isEmptyOrNil() if err != nil { @@ -454,12 +475,12 @@ func (en *extensionNode) delete(key []byte, db common.TrieStorageInteractor) (bo if keyMatchLen < len(en.Key) { return false, en, emptyHashes, nil } - err = resolveIfCollapsed(en, 0, db) + err = resolveIfCollapsed(en, 0, tmc, db) if err != nil { return false, nil, emptyHashes, err } - dirty, newNode, oldHashes, err := en.child.delete(key[len(en.Key):], db) + dirty, newNode, oldHashes, err := en.child.delete(key[len(en.Key):], tmc, db) if !dirty || err != nil { return false, en, emptyHashes, err } @@ -468,6 +489,7 @@ func (en *extensionNode) delete(key []byte, db common.TrieStorageInteractor) (bo oldHashes = append(oldHashes, en.hash) } + tmc.AddSizeLoadedInMem(-en.sizeInBytes()) switch newNode := newNode.(type) { case *leafNode: newLeafData := core.TrieData{ @@ -479,6 +501,7 @@ func (en *extensionNode) delete(key []byte, db common.TrieStorageInteractor) (bo if err != nil { return false, nil, emptyHashes, err } + tmc.AddSizeLoadedInMem(len(en.Key)) return true, n, oldHashes, nil case *extensionNode: @@ -486,6 +509,7 @@ func (en *extensionNode) delete(key []byte, db common.TrieStorageInteractor) (bo if err != nil { return false, nil, emptyHashes, err } + tmc.AddSizeLoadedInMem(len(en.Key)) return true, n, oldHashes, nil case *branchNode: @@ -493,6 +517,7 @@ func (en *extensionNode) delete(key []byte, db common.TrieStorageInteractor) (bo if err != nil { return false, nil, emptyHashes, err } + tmc.AddSizeLoadedInMem(n.sizeInBytes()) return true, n, oldHashes, nil case nil: @@ -503,13 +528,15 @@ func (en *extensionNode) delete(key []byte, db common.TrieStorageInteractor) (bo } } -func (en *extensionNode) reduceNode(pos int) (node, bool, error) { - k := append([]byte{byte(pos)}, en.Key...) +func (en *extensionNode) reduceNode(pos int, tmc MetricsCollector) (node, bool, error) { + extraKey := []byte{byte(pos)} + k := append(extraKey, en.Key...) newEn, err := newExtensionNode(k, en.child, en.marsh, en.hasher) if err != nil { return nil, false, err } + tmc.AddSizeLoadedInMem(len(extraKey)) return newEn, true, nil } @@ -529,12 +556,12 @@ func (en *extensionNode) isEmptyOrNil() error { return nil } -func (en *extensionNode) print(writer io.Writer, index int, db common.TrieStorageInteractor) { +func (en *extensionNode) print(writer io.Writer, index int, tmc MetricsCollector, db common.TrieStorageInteractor) { if en == nil { return } - err := resolveIfCollapsed(en, 0, db) + err := resolveIfCollapsed(en, 0, tmc, db) if err != nil { log.Debug("extension node: print trie err", "error", err, "hash", en.EncodedChild) } @@ -550,7 +577,7 @@ func (en *extensionNode) print(writer io.Writer, index int, db common.TrieStorag if en.child == nil { return } - en.child.print(writer, index+len(str), db) + en.child.print(writer, index+len(str), tmc, db) } func (en *extensionNode) getDirtyHashes(hashes common.ModifiedHashes) error { @@ -576,7 +603,7 @@ func (en *extensionNode) getDirtyHashes(hashes common.ModifiedHashes) error { return nil } -func (en *extensionNode) getChildren(db common.TrieStorageInteractor) ([]node, error) { +func (en *extensionNode) getChildren(tmc MetricsCollector, db common.TrieStorageInteractor) ([]node, error) { err := en.isEmptyOrNil() if err != nil { return nil, fmt.Errorf("getChildren error %w", err) @@ -584,7 +611,7 @@ func (en *extensionNode) getChildren(db common.TrieStorageInteractor) ([]node, e nextNodes := make([]node, 0) - err = resolveIfCollapsed(en, 0, db) + err = resolveIfCollapsed(en, 0, tmc, db) if err != nil { return nil, err } @@ -638,6 +665,7 @@ func (en *extensionNode) getAllLeavesOnChannel( marshalizer marshal.Marshalizer, chanClose chan struct{}, ctx context.Context, + tmc MetricsCollector, ) error { err := en.isEmptyOrNil() if err != nil { @@ -652,13 +680,13 @@ func (en *extensionNode) getAllLeavesOnChannel( log.Trace("extensionNode.getAllLeavesOnChannel: context done") return nil default: - err = resolveIfCollapsed(en, 0, db) + err = resolveIfCollapsed(en, 0, tmc, db) if err != nil { return err } keyBuilder.BuildKey(en.Key) - err = en.child.getAllLeavesOnChannel(leavesChannel, keyBuilder.ShallowClone(), trieLeafParser, db, marshalizer, chanClose, ctx) + err = en.child.getAllLeavesOnChannel(leavesChannel, keyBuilder.ShallowClone(), trieLeafParser, db, marshalizer, chanClose, ctx, tmc) if err != nil { return err } @@ -669,18 +697,18 @@ func (en *extensionNode) getAllLeavesOnChannel( return nil } -func (en *extensionNode) getAllHashes(db common.TrieStorageInteractor) ([][]byte, error) { +func (en *extensionNode) getAllHashes(tmc MetricsCollector, db common.TrieStorageInteractor) ([][]byte, error) { err := en.isEmptyOrNil() if err != nil { return nil, fmt.Errorf("getAllHashes error: %w", err) } - err = resolveIfCollapsed(en, 0, db) + err = resolveIfCollapsed(en, 0, tmc, db) if err != nil { return nil, err } - hashes, err := en.child.getAllHashes(db) + hashes, err := en.child.getAllHashes(tmc, db) if err != nil { return nil, err } @@ -706,9 +734,8 @@ func (en *extensionNode) sizeInBytes() int { return 0 } - // hasher + marshalizer + child + dirty flag = 3 * pointerSizeInBytes + 1 - nodeSize := len(en.hash) + len(en.Key) + (numNodeInnerPointers+1)*pointerSizeInBytes + 1 - nodeSize += len(en.EncodedChild) + nodeSize := baseNodeSizeInBytes + len(en.Key) + nodeVersionSizeInBytes + pointerSizeInBytes + nodeSize += hashSizeInBytes // child hash return nodeSize } @@ -717,18 +744,22 @@ func (en *extensionNode) getValue() []byte { return []byte{} } -func (en *extensionNode) collectStats(ts common.TrieStatisticsHandler, depthLevel int, db common.TrieStorageInteractor) error { +func (en *extensionNode) collectStats(ts common.TrieStatisticsHandler, tmc MetricsCollector, db common.TrieStorageInteractor) error { err := en.isEmptyOrNil() if err != nil { return fmt.Errorf("collectStats error %w", err) } - err = resolveIfCollapsed(en, 0, db) + err = resolveIfCollapsed(en, 0, tmc, db) if err != nil { return err } - err = en.child.collectStats(ts, depthLevel+1, db) + depthLevel := tmc.GetCurrentDepth() + tmc.SetDepth(depthLevel + 1) + defer tmc.SetDepth(depthLevel) // Reset depth when returning from this level + + err = en.child.collectStats(ts, tmc, db) if err != nil { return err } @@ -738,7 +769,7 @@ func (en *extensionNode) collectStats(ts common.TrieStatisticsHandler, depthLeve return err } - ts.AddExtensionNode(depthLevel, uint64(len(val))) + ts.AddExtensionNode(int(depthLevel), uint64(len(val))) return nil } @@ -753,6 +784,7 @@ func (en *extensionNode) getVersion() (core.TrieNodeVersion, error) { func (en *extensionNode) collectLeavesForMigration( migrationArgs vmcommon.ArgsMigrateDataTrieLeaves, + tmc MetricsCollector, db common.TrieStorageInteractor, keyBuilder common.KeyBuilder, ) (bool, error) { @@ -769,13 +801,13 @@ func (en *extensionNode) collectLeavesForMigration( return true, nil } - err = resolveIfCollapsed(en, 0, db) + err = resolveIfCollapsed(en, 0, tmc, db) if err != nil { return false, err } keyBuilder.BuildKey(en.Key) - return en.child.collectLeavesForMigration(migrationArgs, db, keyBuilder.ShallowClone()) + return en.child.collectLeavesForMigration(migrationArgs, tmc, db, keyBuilder.ShallowClone()) } func (en *extensionNode) getNodeData(keyBuilder common.KeyBuilder) ([]common.TrieNodeData, error) { diff --git a/trie/extensionNode_test.go b/trie/extensionNode_test.go index b68d9f14d79..44c22318e30 100644 --- a/trie/extensionNode_test.go +++ b/trie/extensionNode_test.go @@ -17,6 +17,7 @@ import ( "github.com/multiversx/mx-chain-go/trie/keyBuilder" "github.com/multiversx/mx-chain-go/trie/mock" "github.com/multiversx/mx-chain-go/trie/statistics" + "github.com/multiversx/mx-chain-go/trie/trieMetricsCollector" "github.com/stretchr/testify/assert" ) @@ -31,6 +32,11 @@ func getEnAndCollapsedEn() (*extensionNode, *extensionNode) { return en, collapsedEn } +func markNotDirtyEn(en *extensionNode) { + en.dirty = false + markNotDirtyBranchNode(en.child.(*branchNode)) +} + func TestExtensionNode_newExtensionNode(t *testing.T) { t.Parallel() @@ -243,7 +249,7 @@ func TestExtensionNode_commit(t *testing.T) { hash, _ := encodeNodeAndGetHash(collapsedEn) _ = en.setHash() - err := en.commitDirty(0, 5, db, db) + err := en.commitDirty(db, db) assert.Nil(t, err) encNode, _ := db.Get(hash) @@ -259,7 +265,7 @@ func TestExtensionNode_commitEmptyNode(t *testing.T) { en := &extensionNode{} - err := en.commitDirty(0, 5, nil, nil) + err := en.commitDirty(nil, nil) assert.True(t, errors.Is(err, ErrEmptyExtensionNode)) } @@ -268,7 +274,7 @@ func TestExtensionNode_commitNilNode(t *testing.T) { var en *extensionNode - err := en.commitDirty(0, 5, nil, nil) + err := en.commitDirty(nil, nil) assert.True(t, errors.Is(err, ErrNilExtensionNode)) } @@ -281,7 +287,7 @@ func TestExtensionNode_commitCollapsedNode(t *testing.T) { _ = collapsedEn.setHash() collapsedEn.dirty = true - err := collapsedEn.commitDirty(0, 5, db, db) + err := collapsedEn.commitDirty(db, db) assert.Nil(t, err) encNode, _ := db.Get(hash) @@ -331,12 +337,15 @@ func TestExtensionNode_resolveCollapsed(t *testing.T) { db := testscommon.NewMemDbMock() en, collapsedEn := getEnAndCollapsedEn() _ = en.setHash() - _ = en.commitDirty(0, 5, db, db) + _ = en.commitDirty(db, db) _, resolved := getBnAndCollapsedBn(en.marsh, en.hasher) + expectedSize := resolved.sizeInBytes() - err := collapsedEn.resolveCollapsed(0, db) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + err := collapsedEn.resolveCollapsed(0, tmc, db) assert.Nil(t, err) assert.Equal(t, en.child.getHash(), collapsedEn.child.getHash()) + assert.Equal(t, expectedSize, tmc.GetSizeLoadedInMem()) h1, _ := encodeNodeAndGetHash(resolved) h2, _ := encodeNodeAndGetHash(collapsedEn.child) @@ -348,7 +357,7 @@ func TestExtensionNode_resolveCollapsedEmptyNode(t *testing.T) { en := &extensionNode{} - err := en.resolveCollapsed(0, nil) + err := en.resolveCollapsed(0, nil, nil) assert.True(t, errors.Is(err, ErrEmptyExtensionNode)) } @@ -357,7 +366,7 @@ func TestExtensionNode_resolveCollapsedNilNode(t *testing.T) { var en *extensionNode - err := en.resolveCollapsed(2, nil) + err := en.resolveCollapsed(2, nil, nil) assert.True(t, errors.Is(err, ErrNilExtensionNode)) } @@ -384,10 +393,11 @@ func TestExtensionNode_tryGet(t *testing.T) { key := append(enKey, bnKey...) key = append(key, lnKey...) - val, maxDepth, err := en.tryGet(key, 0, nil) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + val, err := en.tryGet(key, tmc, nil) assert.Equal(t, dogBytes, val) assert.Nil(t, err) - assert.Equal(t, uint32(2), maxDepth) + assert.Equal(t, uint32(2), tmc.GetMaxDepth()) } func TestExtensionNode_tryGetEmptyKey(t *testing.T) { @@ -396,10 +406,11 @@ func TestExtensionNode_tryGetEmptyKey(t *testing.T) { en, _ := getEnAndCollapsedEn() var key []byte - val, maxDepth, err := en.tryGet(key, 0, nil) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + val, err := en.tryGet(key, tmc, nil) assert.Nil(t, err) assert.Nil(t, val) - assert.Equal(t, uint32(0), maxDepth) + assert.Equal(t, uint32(0), tmc.GetMaxDepth()) } func TestExtensionNode_tryGetWrongKey(t *testing.T) { @@ -408,10 +419,11 @@ func TestExtensionNode_tryGetWrongKey(t *testing.T) { en, _ := getEnAndCollapsedEn() key := []byte("gdo") - val, maxDepth, err := en.tryGet(key, 0, nil) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + val, err := en.tryGet(key, tmc, nil) assert.Nil(t, err) assert.Nil(t, val) - assert.Equal(t, uint32(0), maxDepth) + assert.Equal(t, uint32(0), tmc.GetMaxDepth()) } func TestExtensionNode_tryGetCollapsedNode(t *testing.T) { @@ -420,7 +432,7 @@ func TestExtensionNode_tryGetCollapsedNode(t *testing.T) { db := testscommon.NewMemDbMock() en, collapsedEn := getEnAndCollapsedEn() _ = en.setHash() - _ = en.commitDirty(0, 5, db, db) + _ = en.commitDirty(db, db) enKey := []byte{100} bnKey := []byte{2} @@ -428,10 +440,14 @@ func TestExtensionNode_tryGetCollapsedNode(t *testing.T) { key := append(enKey, bnKey...) key = append(key, lnKey...) - val, maxDepth, err := collapsedEn.tryGet(key, 0, db) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + val, err := collapsedEn.tryGet(key, tmc, db) assert.Equal(t, []byte("dog"), val) assert.Nil(t, err) - assert.Equal(t, uint32(2), maxDepth) + assert.Equal(t, uint32(2), tmc.GetMaxDepth()) + bn, collapsedBn := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) + expectedSize := collapsedBn.sizeInBytes() + bn.children[2].sizeInBytes() + assert.Equal(t, expectedSize, tmc.GetSizeLoadedInMem()) } func TestExtensionNode_tryGetEmptyNode(t *testing.T) { @@ -440,10 +456,11 @@ func TestExtensionNode_tryGetEmptyNode(t *testing.T) { en := &extensionNode{} key := []byte("dog") - val, maxDepth, err := en.tryGet(key, 0, nil) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + val, err := en.tryGet(key, tmc, nil) assert.True(t, errors.Is(err, ErrEmptyExtensionNode)) assert.Nil(t, val) - assert.Equal(t, uint32(0), maxDepth) + assert.Equal(t, uint32(0), tmc.GetMaxDepth()) } func TestExtensionNode_tryGetNilNode(t *testing.T) { @@ -452,10 +469,11 @@ func TestExtensionNode_tryGetNilNode(t *testing.T) { var en *extensionNode key := []byte("dog") - val, maxDepth, err := en.tryGet(key, 0, nil) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + val, err := en.tryGet(key, tmc, nil) assert.True(t, errors.Is(err, ErrNilExtensionNode)) assert.Nil(t, val) - assert.Equal(t, uint32(0), maxDepth) + assert.Equal(t, uint32(0), tmc.GetMaxDepth()) } func TestExtensionNode_getNext(t *testing.T) { @@ -470,7 +488,7 @@ func TestExtensionNode_getNext(t *testing.T) { key := append(enKey, bnKey...) key = append(key, lnKey...) - n, newKey, err := en.getNext(key, nil) + n, newKey, err := en.getNext(key, dtmc, nil) assert.Equal(t, nextNode, n) assert.Equal(t, key[1:], newKey) assert.Nil(t, err) @@ -484,7 +502,7 @@ func TestExtensionNode_getNextWrongKey(t *testing.T) { lnKey := []byte("dog") key := append(bnKey, lnKey...) - n, key, err := en.getNext(key, nil) + n, key, err := en.getNext(key, nil, nil) assert.Nil(t, n) assert.Nil(t, key) assert.Equal(t, ErrNodeNotFound, err) @@ -495,12 +513,17 @@ func TestExtensionNode_insert(t *testing.T) { en, _ := getEnAndCollapsedEn() key := []byte{100, 15, 5, 6} + newData := getTrieDataWithDefaultVersion(string(key), "dogs") - newNode, _, err := en.insert(getTrieDataWithDefaultVersion(string(key), "dogs"), nil) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + newNode, _, err := en.insert(newData, tmc, nil) assert.NotNil(t, newNode) assert.Nil(t, err) + newData.Key = newData.Key[2:] + newLn, _ := newLeafNode(newData, en.marsh, en.hasher) + assert.Equal(t, newLn.sizeInBytes(), tmc.GetSizeLoadedInMem()) - val, _, _ := newNode.tryGet(key, 0, nil) + val, _ := newNode.tryGet(key, trieMetricsCollector.NewTrieMetricsCollector(), nil) assert.Equal(t, []byte("dogs"), val) } @@ -510,15 +533,21 @@ func TestExtensionNode_insertCollapsedNode(t *testing.T) { db := testscommon.NewMemDbMock() en, collapsedEn := getEnAndCollapsedEn() key := []byte{100, 15, 5, 6} + newData := getTrieDataWithDefaultVersion(string(key), "dogs") _ = en.setHash() - _ = en.commitDirty(0, 5, db, db) + _ = en.commitDirty(db, db) - newNode, _, err := collapsedEn.insert(getTrieDataWithDefaultVersion(string(key), "dogs"), db) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + newNode, _, err := collapsedEn.insert(newData, tmc, db) assert.NotNil(t, newNode) assert.Nil(t, err) + newData.Key = newData.Key[2:] + newLn, _ := newLeafNode(newData, en.marsh, en.hasher) + expectedSize := newLn.sizeInBytes() + en.child.sizeInBytes() + assert.Equal(t, expectedSize, tmc.GetSizeLoadedInMem()) - val, _, _ := newNode.tryGet(key, 0, db) + val, _ := newNode.tryGet(key, trieMetricsCollector.NewTrieMetricsCollector(), db) assert.Equal(t, []byte("dogs"), val) } @@ -530,16 +559,22 @@ func TestExtensionNode_insertInStoredEnSameKey(t *testing.T) { enKey := []byte{100} key := append(enKey, []byte{11, 12}...) - _ = en.commitDirty(0, 5, db, db) + _ = en.setHash() + markNotDirtyEn(en) enHash := en.getHash() - bn, _, _ := en.getNext(enKey, db) + bn, _, _ := en.getNext(enKey, dtmc, db) bnHash := bn.getHash() expectedHashes := [][]byte{bnHash, enHash} + newData := getTrieDataWithDefaultVersion(string(key), "dogs") - newNode, oldHashes, err := en.insert(getTrieDataWithDefaultVersion(string(key), "dogs"), db) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + newNode, oldHashes, err := en.insert(newData, tmc, db) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, expectedHashes, oldHashes) + newData.Key = newData.Key[2:] + newLn, _ := newLeafNode(newData, en.marsh, en.hasher) + assert.Equal(t, newLn.sizeInBytes(), tmc.GetSizeLoadedInMem()) } func TestExtensionNode_insertInStoredEnDifferentKey(t *testing.T) { @@ -551,13 +586,20 @@ func TestExtensionNode_insertInStoredEnDifferentKey(t *testing.T) { en, _ := newExtensionNode(enKey, bn, bn.marsh, bn.hasher) nodeKey := []byte{11, 12} - _ = en.commitDirty(0, 5, db, db) + _ = en.setHash() + markNotDirtyEn(en) expectedHashes := [][]byte{en.getHash()} + originalSize := en.sizeInBytes() - newNode, oldHashes, err := en.insert(getTrieDataWithDefaultVersion(string(nodeKey), "dogs"), db) + newData := getTrieDataWithDefaultVersion(string(nodeKey), "dogs") + tmc := trieMetricsCollector.NewTrieMetricsCollector() + newNode, oldHashes, err := en.insert(newData, tmc, db) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, expectedHashes, oldHashes) + + expectedSize := newNode.sizeInBytes() + newNode.(*branchNode).children[nodeKey[0]].sizeInBytes() - originalSize + assert.Equal(t, expectedSize, tmc.GetSizeLoadedInMem()) } func TestExtensionNode_insertInDirtyEnSameKey(t *testing.T) { @@ -566,10 +608,15 @@ func TestExtensionNode_insertInDirtyEnSameKey(t *testing.T) { en, _ := getEnAndCollapsedEn() nodeKey := []byte{100, 11, 12} - newNode, oldHashes, err := en.insert(getTrieDataWithDefaultVersion(string(nodeKey), "dogs"), nil) + newData := getTrieDataWithDefaultVersion(string(nodeKey), "dogs") + tmc := trieMetricsCollector.NewTrieMetricsCollector() + newNode, oldHashes, err := en.insert(newData, tmc, nil) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, [][]byte{}, oldHashes) + newLn, _ := newLeafNode(newData, en.marsh, en.hasher) + expectedSize := newLn.sizeInBytes() - 2 // 2 because of the parts of the key that is added to the parents + assert.Equal(t, expectedSize, tmc.GetSizeLoadedInMem()) } func TestExtensionNode_insertInDirtyEnDifferentKey(t *testing.T) { @@ -580,10 +627,15 @@ func TestExtensionNode_insertInDirtyEnDifferentKey(t *testing.T) { en, _ := newExtensionNode(enKey, bn, bn.marsh, bn.hasher) nodeKey := []byte{11, 12} - newNode, oldHashes, err := en.insert(getTrieDataWithDefaultVersion(string(nodeKey), "dogs"), nil) + newData := getTrieDataWithDefaultVersion(string(nodeKey), "dogs") + tmc := trieMetricsCollector.NewTrieMetricsCollector() + newNode, oldHashes, err := en.insert(newData, tmc, nil) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, [][]byte{}, oldHashes) + + expectedSize := newNode.sizeInBytes() + newNode.(*branchNode).children[nodeKey[0]].sizeInBytes() - en.sizeInBytes() + assert.Equal(t, expectedSize, tmc.GetSizeLoadedInMem()) } func TestExtensionNode_insertInNilNode(t *testing.T) { @@ -591,7 +643,7 @@ func TestExtensionNode_insertInNilNode(t *testing.T) { var en *extensionNode - newNode, _, err := en.insert(getTrieDataWithDefaultVersion("key", "val"), nil) + newNode, _, err := en.insert(getTrieDataWithDefaultVersion("key", "val"), nil, nil) assert.Nil(t, newNode) assert.True(t, errors.Is(err, ErrNilExtensionNode)) assert.Nil(t, newNode) @@ -609,13 +661,18 @@ func TestExtensionNode_delete(t *testing.T) { key := append(enKey, bnKey...) key = append(key, lnKey...) - val, _, _ := en.tryGet(key, 0, nil) + val, _ := en.tryGet(key, trieMetricsCollector.NewTrieMetricsCollector(), nil) assert.Equal(t, dogBytes, val) - dirty, _, _, err := en.delete(key, nil) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + dirty, _, _, err := en.delete(key, tmc, nil) assert.True(t, dirty) assert.Nil(t, err) - val, _, _ = en.tryGet(key, 0, nil) + + deletedLeaf, _ := newLeafNode(getTrieDataWithDefaultVersion("dog", "dog"), en.marsh, en.hasher) + assert.Equal(t, -(deletedLeaf.sizeInBytes() + hashSizeInBytes), tmc.GetSizeLoadedInMem()) + + val, _ = en.tryGet(key, trieMetricsCollector.NewTrieMetricsCollector(), nil) assert.Nil(t, val) } @@ -631,12 +688,12 @@ func TestExtensionNode_deleteFromStoredEn(t *testing.T) { key = append(key, lnKey...) lnPathKey := key - _ = en.commitDirty(0, 5, db, db) - bn, key, _ := en.getNext(key, db) - ln, _, _ := bn.getNext(key, db) + _ = en.commitDirty(db, db) + bn, key, _ := en.getNext(key, dtmc, db) + ln, _, _ := bn.getNext(key, dtmc, db) expectedHashes := [][]byte{ln.getHash(), bn.getHash(), en.getHash()} - dirty, _, oldHashes, err := en.delete(lnPathKey, db) + dirty, _, oldHashes, err := en.delete(lnPathKey, dtmc, db) assert.True(t, dirty) assert.Nil(t, err) assert.Equal(t, expectedHashes, oldHashes) @@ -648,7 +705,7 @@ func TestExtensionNode_deleteFromDirtyEn(t *testing.T) { en, _ := getEnAndCollapsedEn() lnKey := []byte{100, 2, 100, 111, 103} - dirty, _, oldHashes, err := en.delete(lnKey, nil) + dirty, _, oldHashes, err := en.delete(lnKey, dtmc, nil) assert.True(t, dirty) assert.Nil(t, err) assert.Equal(t, [][]byte{}, oldHashes) @@ -659,7 +716,7 @@ func TestExtendedNode_deleteEmptyNode(t *testing.T) { en := &extensionNode{} - dirty, newNode, _, err := en.delete([]byte("dog"), nil) + dirty, newNode, _, err := en.delete([]byte("dog"), nil, nil) assert.False(t, dirty) assert.True(t, errors.Is(err, ErrEmptyExtensionNode)) assert.Nil(t, newNode) @@ -670,7 +727,7 @@ func TestExtensionNode_deleteNilNode(t *testing.T) { var en *extensionNode - dirty, newNode, _, err := en.delete([]byte("dog"), nil) + dirty, newNode, _, err := en.delete([]byte("dog"), nil, nil) assert.False(t, dirty) assert.True(t, errors.Is(err, ErrNilExtensionNode)) assert.Nil(t, newNode) @@ -681,7 +738,7 @@ func TestExtensionNode_deleteEmptykey(t *testing.T) { en, _ := getEnAndCollapsedEn() - dirty, newNode, _, err := en.delete([]byte{}, nil) + dirty, newNode, _, err := en.delete([]byte{}, nil, nil) assert.False(t, dirty) assert.Equal(t, ErrValueTooShort, err) assert.Nil(t, newNode) @@ -693,7 +750,7 @@ func TestExtensionNode_deleteCollapsedNode(t *testing.T) { db := testscommon.NewMemDbMock() en, collapsedEn := getEnAndCollapsedEn() _ = en.setHash() - _ = en.commitDirty(0, 5, db, db) + _ = en.commitDirty(db, db) enKey := []byte{100} bnKey := []byte{2} @@ -701,13 +758,16 @@ func TestExtensionNode_deleteCollapsedNode(t *testing.T) { key := append(enKey, bnKey...) key = append(key, lnKey...) - val, _, _ := en.tryGet(key, 0, db) + val, _ := en.tryGet(key, trieMetricsCollector.NewTrieMetricsCollector(), db) assert.Equal(t, []byte("dog"), val) - dirty, newNode, _, err := collapsedEn.delete(key, db) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + dirty, newNode, _, err := collapsedEn.delete(key, tmc, db) assert.True(t, dirty) assert.Nil(t, err) - val, _, _ = newNode.tryGet(key, 0, db) + assert.Equal(t, collapsedEn.child.sizeInBytes(), tmc.GetSizeLoadedInMem()) + + val, _ = newNode.tryGet(key, trieMetricsCollector.NewTrieMetricsCollector(), db) assert.Nil(t, val) } @@ -723,7 +783,7 @@ func TestExtensionNode_reduceNode(t *testing.T) { expected.hasher = en.hasher expected.child = en.child - n, newChildPos, err := en.reduceNode(2) + n, newChildPos, err := en.reduceNode(2, dtmc) assert.Equal(t, expected, n) assert.Nil(t, err) assert.True(t, newChildPos) @@ -768,7 +828,7 @@ func TestExtensionNode_getChildren(t *testing.T) { en, _ := getEnAndCollapsedEn() - children, err := en.getChildren(nil) + children, err := en.getChildren(dtmc, nil) assert.Nil(t, err) assert.Equal(t, 1, len(children)) } @@ -778,9 +838,9 @@ func TestExtensionNode_getChildrenCollapsedEn(t *testing.T) { db := testscommon.NewMemDbMock() en, collapsedEn := getEnAndCollapsedEn() - _ = en.commitDirty(0, 5, db, db) + _ = en.commitDirty(db, db) - children, err := collapsedEn.getChildren(db) + children, err := collapsedEn.getChildren(dtmc, db) assert.Nil(t, err) assert.Equal(t, 1, len(children)) } @@ -886,20 +946,6 @@ func TestExtensionNode_getMarshalizer(t *testing.T) { assert.Equal(t, marsh, en.getMarshalizer()) } -func TestExtensionNode_commitCollapsesTrieIfMaxTrieLevelInMemoryIsReached(t *testing.T) { - t.Parallel() - - en, collapsedEn := getEnAndCollapsedEn() - _ = collapsedEn.setRootHash() - - err := en.commitDirty(0, 1, testscommon.NewMemDbMock(), testscommon.NewMemDbMock()) - assert.Nil(t, err) - - assert.Equal(t, collapsedEn.EncodedChild, en.EncodedChild) - assert.Equal(t, collapsedEn.child, en.child) - assert.Equal(t, collapsedEn.hash, en.hash) -} - func TestExtensionNode_printShouldNotPanicEvenIfNodeIsCollapsed(t *testing.T) { t.Parallel() @@ -908,11 +954,11 @@ func TestExtensionNode_printShouldNotPanicEvenIfNodeIsCollapsed(t *testing.T) { db := testscommon.NewMemDbMock() en, collapsedEn := getEnAndCollapsedEn() - _ = en.commitDirty(0, 5, db, db) + _ = en.commitDirty(db, db) _ = collapsedEn.setHash() - en.print(enWriter, 0, db) - collapsedEn.print(collapsedEnWriter, 0, db) + en.print(enWriter, 0, dtmc, db) + collapsedEn.print(collapsedEnWriter, 0, dtmc, db) assert.Equal(t, enWriter.Bytes(), collapsedEnWriter.Bytes()) } @@ -922,7 +968,7 @@ func TestExtensionNode_getDirtyHashesFromCleanNode(t *testing.T) { db := testscommon.NewMemDbMock() en, _ := getEnAndCollapsedEn() - _ = en.commitDirty(0, 5, db, db) + _ = en.commitDirty(db, db) dirtyHashes := make(common.ModifiedHashes) err := en.getDirtyHashes(dirtyHashes) @@ -936,7 +982,7 @@ func TestExtensionNode_getAllHashes(t *testing.T) { trieNodes := 5 en, _ := getEnAndCollapsedEn() - hashes, err := en.getAllHashes(testscommon.NewMemDbMock()) + hashes, err := en.getAllHashes(dtmc, testscommon.NewMemDbMock()) assert.Nil(t, err) assert.Equal(t, trieNodes, len(hashes)) } @@ -947,9 +993,9 @@ func TestExtensionNode_getAllHashesResolvesCollapsed(t *testing.T) { trieNodes := 5 db := testscommon.NewMemDbMock() en, collapsedEn := getEnAndCollapsedEn() - _ = en.commitDirty(0, 5, db, db) + _ = en.commitDirty(db, db) - hashes, err := collapsedEn.getAllHashes(db) + hashes, err := collapsedEn.getAllHashes(dtmc, db) assert.Nil(t, err) assert.Equal(t, trieNodes, len(hashes)) } @@ -995,7 +1041,7 @@ func TestExtensionNode_SizeInBytes(t *testing.T) { collapsed := []byte("collapsed") key := []byte("key") - hash := []byte("hash") + hash := bytes.Repeat([]byte{1}, 32) en = &extensionNode{ CollapsedEn: CollapsedEn{ Key: key, @@ -1009,7 +1055,7 @@ func TestExtensionNode_SizeInBytes(t *testing.T) { hasher: nil, }, } - assert.Equal(t, len(collapsed)+len(key)+len(hash)+1+3*pointerSizeInBytes, en.sizeInBytes()) + assert.Equal(t, hashSizeInBytes+len(key)+nodeVersionSizeInBytes+len(hash)+1+3*pointerSizeInBytes, en.sizeInBytes()) } func TestExtensionNode_commitContextDone(t *testing.T) { @@ -1020,7 +1066,7 @@ func TestExtensionNode_commitContextDone(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() - err := en.commitSnapshot(db, 0, nil, nil, ctx, statistics.NewTrieStatistics(), &testscommon.ProcessStatusHandlerStub{}, []byte{}, 0) + err := en.commitSnapshot(db, 0, nil, nil, ctx, statistics.NewTrieStatistics(), &testscommon.ProcessStatusHandlerStub{}, []byte{}, dtmc) assert.Equal(t, core.ErrContextClosing, err) } @@ -1041,7 +1087,7 @@ func TestExtensionNode_commitSnapshotDbIsClosing(t *testing.T) { _, collapsedEn := getEnAndCollapsedEn() missingNodesChan := make(chan []byte, 10) - err := collapsedEn.commitSnapshot(db, 0, nil, missingNodesChan, context.Background(), statistics.NewTrieStatistics(), &testscommon.ProcessStatusHandlerStub{}, []byte{}, 0) + err := collapsedEn.commitSnapshot(db, 0, nil, missingNodesChan, context.Background(), statistics.NewTrieStatistics(), &testscommon.ProcessStatusHandlerStub{}, []byte{}, dtmc) assert.True(t, core.IsClosingError(err)) assert.Equal(t, 0, len(missingNodesChan)) } @@ -1112,3 +1158,42 @@ func TestExtensionNode_getNodeData(t *testing.T) { assert.False(t, nodeData[0].IsLeaf()) }) } + +func TestExtensionNode_shouldCollapseChild(t *testing.T) { + t.Parallel() + + t.Run("key too short", func(t *testing.T) { + t.Parallel() + + en, _ := getEnAndCollapsedEn() + shouldCollapse := en.shouldCollapseChild([]byte{}, nil) + assert.False(t, shouldCollapse) + }) + t.Run("keys do not match", func(t *testing.T) { + t.Parallel() + + en, _ := getEnAndCollapsedEn() + shouldCollapse := en.shouldCollapseChild([]byte{1, 2, 3}, nil) + assert.False(t, shouldCollapse) + }) + t.Run("nil child", func(t *testing.T) { + t.Parallel() + + _, collapsedEn := getEnAndCollapsedEn() + shouldCollapse := collapsedEn.shouldCollapseChild(collapsedEn.Key, nil) + assert.False(t, shouldCollapse) + }) + t.Run("calls collapse for child", func(t *testing.T) { + t.Parallel() + + en, _ := getEnAndCollapsedEn() + en.child.(*branchNode).children[2].setDirty(false) + bn, _ := getBnAndCollapsedBn(en.marsh, en.hasher) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + leafKey := append(en.Key, byte(2)) + leafKey = append(leafKey, []byte("dog")...) + shouldCollapse := en.shouldCollapseChild(leafKey, tmc) + assert.False(t, shouldCollapse) + assert.Equal(t, -bn.children[2].sizeInBytes(), tmc.GetSizeLoadedInMem()) + }) +} diff --git a/trie/factory/trieCreator.go b/trie/factory/trieCreator.go index 198b33a0455..ff43466afb1 100644 --- a/trie/factory/trieCreator.go +++ b/trie/factory/trieCreator.go @@ -7,21 +7,23 @@ import ( "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" - "github.com/multiversx/mx-chain-go/state" + "github.com/multiversx/mx-chain-go/state/triesHolder" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/trie" + "github.com/multiversx/mx-chain-go/trie/collapseManager" ) // TrieCreateArgs holds arguments for calling the Create method on the TrieFactory type TrieCreateArgs struct { - MainStorer storage.Storer - PruningEnabled bool - SnapshotsEnabled bool - MaxTrieLevelInMem uint - IdleProvider trie.IdleNodeProvider - Identifier string - EnableEpochsHandler common.EnableEpochsHandler - StatsCollector common.StateStatisticsHandler + MainStorer storage.Storer + PruningEnabled bool + SnapshotsEnabled bool + IdleProvider trie.IdleNodeProvider + Identifier string + EnableEpochsHandler common.EnableEpochsHandler + StatsCollector common.StateStatisticsHandler + MaxSizeInMemory uint64 + NumLeavesToCollapseSingleRun uint32 } type trieCreator struct { @@ -78,7 +80,12 @@ func (tc *trieCreator) Create(args TrieCreateArgs) (common.StorageManager, commo return nil, nil, err } - newTrie, err := trie.NewTrie(trieStorage, tc.marshalizer, tc.hasher, args.EnableEpochsHandler, args.MaxTrieLevelInMem) + cm, err := collapseManager.NewCollapseManager(args.MaxSizeInMemory, args.NumLeavesToCollapseSingleRun) + if err != nil { + return nil, nil, err + } + + newTrie, err := trie.NewTrie(trieStorage, tc.marshalizer, tc.hasher, args.EnableEpochsHandler, cm) if err != nil { return nil, nil, err } @@ -115,21 +122,22 @@ func CreateTriesComponentsForShardId( } args := TrieCreateArgs{ - MainStorer: mainStorer, - PruningEnabled: generalConfig.StateTriesConfig.AccountsStatePruningEnabled, - MaxTrieLevelInMem: generalConfig.StateTriesConfig.MaxStateTrieLevelInMemory, - SnapshotsEnabled: generalConfig.StateTriesConfig.SnapshotsEnabled, - IdleProvider: coreComponentsHolder.ProcessStatusHandler(), - Identifier: dataRetriever.UserAccountsUnit.String(), - EnableEpochsHandler: coreComponentsHolder.EnableEpochsHandler(), - StatsCollector: stateStatsHandler, + MainStorer: mainStorer, + PruningEnabled: generalConfig.StateTriesConfig.AccountsStatePruningEnabled, + SnapshotsEnabled: generalConfig.StateTriesConfig.SnapshotsEnabled, + IdleProvider: coreComponentsHolder.ProcessStatusHandler(), + Identifier: dataRetriever.UserAccountsUnit.String(), + EnableEpochsHandler: coreComponentsHolder.EnableEpochsHandler(), + StatsCollector: stateStatsHandler, + MaxSizeInMemory: generalConfig.StateTriesConfig.MaxUserTrieSizeInMemory, + NumLeavesToCollapseSingleRun: generalConfig.StateTriesConfig.NumLeavesToCollapseSingleRun, } userStorageManager, userAccountTrie, err := trFactory.Create(args) if err != nil { return nil, nil, err } - trieContainer := state.NewDataTriesHolder() + trieContainer := triesHolder.NewTriesHolder() trieStorageManagers := make(map[string]common.StorageManager) trieContainer.Put([]byte(dataRetriever.UserAccountsUnit.String()), userAccountTrie) @@ -141,14 +149,15 @@ func CreateTriesComponentsForShardId( } args = TrieCreateArgs{ - MainStorer: mainStorer, - PruningEnabled: generalConfig.StateTriesConfig.PeerStatePruningEnabled, - MaxTrieLevelInMem: generalConfig.StateTriesConfig.MaxPeerTrieLevelInMemory, - SnapshotsEnabled: generalConfig.StateTriesConfig.SnapshotsEnabled, - IdleProvider: coreComponentsHolder.ProcessStatusHandler(), - Identifier: dataRetriever.PeerAccountsUnit.String(), - EnableEpochsHandler: coreComponentsHolder.EnableEpochsHandler(), - StatsCollector: stateStatsHandler, + MainStorer: mainStorer, + PruningEnabled: generalConfig.StateTriesConfig.PeerStatePruningEnabled, + SnapshotsEnabled: generalConfig.StateTriesConfig.SnapshotsEnabled, + IdleProvider: coreComponentsHolder.ProcessStatusHandler(), + Identifier: dataRetriever.PeerAccountsUnit.String(), + EnableEpochsHandler: coreComponentsHolder.EnableEpochsHandler(), + StatsCollector: stateStatsHandler, + MaxSizeInMemory: generalConfig.StateTriesConfig.MaxPeerTrieSizeInMemory, + NumLeavesToCollapseSingleRun: generalConfig.StateTriesConfig.NumLeavesToCollapseSingleRun, } peerStorageManager, peerAccountsTrie, err := trFactory.Create(args) if err != nil { diff --git a/trie/factory/trieCreator_test.go b/trie/factory/trieCreator_test.go index c4a716e2cc4..815576978aa 100644 --- a/trie/factory/trieCreator_test.go +++ b/trie/factory/trieCreator_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/statistics/disabled" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -33,14 +34,15 @@ func getArgs() factory.TrieFactoryArgs { func getCreateArgs() factory.TrieCreateArgs { return factory.TrieCreateArgs{ - MainStorer: testscommon.CreateMemUnit(), - PruningEnabled: false, - SnapshotsEnabled: true, - MaxTrieLevelInMem: 5, - IdleProvider: &testscommon.ProcessStatusHandlerStub{}, - Identifier: dataRetriever.UserAccountsUnit.String(), - EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - StatsCollector: disabled.NewStateStatistics(), + MainStorer: testscommon.CreateMemUnit(), + PruningEnabled: false, + SnapshotsEnabled: true, + IdleProvider: &testscommon.ProcessStatusHandlerStub{}, + Identifier: dataRetriever.UserAccountsUnit.String(), + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + StatsCollector: disabled.NewStateStatistics(), + MaxSizeInMemory: common.TenMbSize, + NumLeavesToCollapseSingleRun: common.NumLeavesToCollapseSingleRun, } } @@ -140,20 +142,6 @@ func TestTrieCreator_CreateWithNilMainStorerShouldErr(t *testing.T) { require.True(t, strings.Contains(err.Error(), trie.ErrNilStorer.Error())) } -func TestTrieCreator_CreateWithInvalidMaxTrieLevelInMemShouldErr(t *testing.T) { - t.Parallel() - - args := getArgs() - tf, _ := factory.NewTrieFactory(args) - - createArgs := getCreateArgs() - createArgs.MaxTrieLevelInMem = 0 - _, tr, err := tf.Create(createArgs) - require.Nil(t, tr) - require.NotNil(t, err) - require.Contains(t, err.Error(), trie.ErrInvalidLevelValue.Error()) -} - func TestTrieCreator_CreateTriesComponentsForShardId(t *testing.T) { t.Parallel() diff --git a/trie/interface.go b/trie/interface.go index af1eb4c1eef..c6d58991a54 100644 --- a/trie/interface.go +++ b/trie/interface.go @@ -24,38 +24,39 @@ type node interface { isPosCollapsed(pos int) bool isDirty() bool getEncodedNode() ([]byte, error) - resolveCollapsed(pos byte, db common.TrieStorageInteractor) error + resolveCollapsed(pos byte, tmc MetricsCollector, db common.TrieStorageInteractor) error hashNode() ([]byte, error) hashChildren() error - tryGet(key []byte, depth uint32, db common.TrieStorageInteractor) ([]byte, uint32, error) - getNext(key []byte, db common.TrieStorageInteractor) (node, []byte, error) - insert(newData core.TrieData, db common.TrieStorageInteractor) (node, [][]byte, error) - delete(key []byte, db common.TrieStorageInteractor) (bool, node, [][]byte, error) - reduceNode(pos int) (node, bool, error) + tryGet(key []byte, tmc MetricsCollector, db common.TrieStorageInteractor) ([]byte, error) + getNext(key []byte, tmc MetricsCollector, db common.TrieStorageInteractor) (node, []byte, error) + insert(newData core.TrieData, tmc MetricsCollector, db common.TrieStorageInteractor) (node, [][]byte, error) + delete(key []byte, tmc MetricsCollector, db common.TrieStorageInteractor) (bool, node, [][]byte, error) + reduceNode(pos int, tmc MetricsCollector) (node, bool, error) isEmptyOrNil() error - print(writer io.Writer, index int, db common.TrieStorageInteractor) + print(writer io.Writer, index int, tmc MetricsCollector, db common.TrieStorageInteractor) getDirtyHashes(common.ModifiedHashes) error - getChildren(db common.TrieStorageInteractor) ([]node, error) + getChildren(tmc MetricsCollector, db common.TrieStorageInteractor) ([]node, error) isValid() bool getNodeData(common.KeyBuilder) ([]common.TrieNodeData, error) setDirty(bool) loadChildren(func([]byte) (node, error)) ([][]byte, []node, error) - getAllLeavesOnChannel(chan core.KeyValueHolder, common.KeyBuilder, common.TrieLeafParser, common.TrieStorageInteractor, marshal.Marshalizer, chan struct{}, context.Context) error - getAllHashes(db common.TrieStorageInteractor) ([][]byte, error) + getAllLeavesOnChannel(chan core.KeyValueHolder, common.KeyBuilder, common.TrieLeafParser, common.TrieStorageInteractor, marshal.Marshalizer, chan struct{}, context.Context, MetricsCollector) error + getAllHashes(tmc MetricsCollector, db common.TrieStorageInteractor) ([][]byte, error) getNextHashAndKey([]byte) (bool, []byte, []byte) getValue() []byte getVersion() (core.TrieNodeVersion, error) - collectLeavesForMigration(migrationArgs vmcommon.ArgsMigrateDataTrieLeaves, db common.TrieStorageInteractor, keyBuilder common.KeyBuilder) (bool, error) + collectLeavesForMigration(migrationArgs vmcommon.ArgsMigrateDataTrieLeaves, tmc MetricsCollector, db common.TrieStorageInteractor, keyBuilder common.KeyBuilder) (bool, error) + shouldCollapseChild([]byte, MetricsCollector) bool - commitDirty(level byte, maxTrieLevelInMemory uint, originDb common.TrieStorageInteractor, targetDb common.BaseStorer) error - commitSnapshot(originDb snapshotDb, maxEpochToSearchFrom uint32, leavesChan chan core.KeyValueHolder, missingNodesChan chan []byte, ctx context.Context, stats common.TrieStatisticsHandler, idleProvider IdleNodeProvider, encodedRoot []byte, depthLevel int) error + commitDirty(originDb common.TrieStorageInteractor, targetDb common.BaseStorer) error + commitSnapshot(originDb snapshotDb, maxEpochToSearchFrom uint32, leavesChan chan core.KeyValueHolder, missingNodesChan chan []byte, ctx context.Context, stats common.TrieStatisticsHandler, idleProvider IdleNodeProvider, encodedRoot []byte, tmc MetricsCollector) error getMarshalizer() marshal.Marshalizer setMarshalizer(marshal.Marshalizer) getHasher() hashing.Hasher setHasher(hashing.Hasher) sizeInBytes() int - collectStats(handler common.TrieStatisticsHandler, depthLevel int, db common.TrieStorageInteractor) error + collectStats(handler common.TrieStatisticsHandler, tmc MetricsCollector, db common.TrieStorageInteractor) error IsInterfaceNil() bool } @@ -65,7 +66,7 @@ type dbWithGetFromEpoch interface { } type snapshotNode interface { - commitSnapshot(originDb snapshotDb, maxEpochToSearchFrom uint32, leavesChan chan core.KeyValueHolder, missingNodesChan chan []byte, ctx context.Context, stats common.TrieStatisticsHandler, idleProvider IdleNodeProvider, encodedRoot []byte, depthLevel int) error + commitSnapshot(originDb snapshotDb, maxEpochToSearchFrom uint32, leavesChan chan core.KeyValueHolder, missingNodesChan chan []byte, ctx context.Context, stats common.TrieStatisticsHandler, idleProvider IdleNodeProvider, encodedRoot []byte, tmc MetricsCollector) error } // RequestHandler defines the methods through which request to data can be made @@ -118,3 +119,12 @@ type snapshotDb interface { PutInEpochWithoutCache(key []byte, data []byte) error GetIdentifier() string } + +// MetricsCollector is used to collect metrics about the trie +type MetricsCollector interface { + SetDepth(depth uint32) + GetCurrentDepth() uint32 + GetMaxDepth() uint32 + AddSizeLoadedInMem(size int) + GetSizeLoadedInMem() int +} diff --git a/trie/leafNode.go b/trie/leafNode.go index 4f82059518d..43c74f12b8d 100644 --- a/trie/leafNode.go +++ b/trie/leafNode.go @@ -119,7 +119,7 @@ func (ln *leafNode) hashNode() ([]byte, error) { return encodeNodeAndGetHash(ln) } -func (ln *leafNode) commitDirty(_ byte, _ uint, _ common.TrieStorageInteractor, targetDb common.BaseStorer) error { +func (ln *leafNode) commitDirty(_ common.TrieStorageInteractor, targetDb common.BaseStorer) error { err := ln.isEmptyOrNil() if err != nil { return fmt.Errorf("commit error %w", err) @@ -135,6 +135,13 @@ func (ln *leafNode) commitDirty(_ byte, _ uint, _ common.TrieStorageInteractor, return err } +func (ln *leafNode) shouldCollapseChild(hexKey []byte, _ MetricsCollector) bool { + if bytes.Equal(hexKey, ln.Key) && !ln.dirty { + return true + } + return false +} + func (ln *leafNode) commitSnapshot( _ snapshotDb, _ uint32, @@ -144,7 +151,7 @@ func (ln *leafNode) commitSnapshot( stats common.TrieStatisticsHandler, idleProvider IdleNodeProvider, nodeBytes []byte, - depthLevel int, + tmc MetricsCollector, ) error { if shouldStopIfContextDoneBlockingIfBusy(ctx, idleProvider) { return core.ErrContextClosing @@ -160,7 +167,7 @@ func (ln *leafNode) commitSnapshot( return err } - stats.AddLeafNode(depthLevel, uint64(len(nodeBytes)), version) + stats.AddLeafNode(int(tmc.GetCurrentDepth()), uint64(len(nodeBytes)), version) return nil } @@ -194,7 +201,7 @@ func (ln *leafNode) getEncodedNode() ([]byte, error) { return marshaledNode, nil } -func (ln *leafNode) resolveCollapsed(_ byte, _ common.TrieStorageInteractor) error { +func (ln *leafNode) resolveCollapsed(_ byte, _ MetricsCollector, _ common.TrieStorageInteractor) error { return nil } @@ -206,19 +213,19 @@ func (ln *leafNode) isPosCollapsed(_ int) bool { return false } -func (ln *leafNode) tryGet(key []byte, currentDepth uint32, _ common.TrieStorageInteractor) (value []byte, maxDepth uint32, err error) { +func (ln *leafNode) tryGet(key []byte, _ MetricsCollector, _ common.TrieStorageInteractor) (value []byte, err error) { err = ln.isEmptyOrNil() if err != nil { - return nil, currentDepth, fmt.Errorf("tryGet error %w", err) + return nil, fmt.Errorf("tryGet error %w", err) } if bytes.Equal(key, ln.Key) { - return ln.Value, currentDepth, nil + return ln.Value, nil } - return nil, currentDepth, nil + return nil, nil } -func (ln *leafNode) getNext(key []byte, _ common.TrieStorageInteractor) (node, []byte, error) { +func (ln *leafNode) getNext(key []byte, _ MetricsCollector, _ common.TrieStorageInteractor) (node, []byte, error) { err := ln.isEmptyOrNil() if err != nil { return nil, nil, fmt.Errorf("getNext error %w", err) @@ -228,7 +235,7 @@ func (ln *leafNode) getNext(key []byte, _ common.TrieStorageInteractor) (node, [ } return nil, nil, ErrNodeNotFound } -func (ln *leafNode) insert(newData core.TrieData, _ common.TrieStorageInteractor) (node, [][]byte, error) { +func (ln *leafNode) insert(newData core.TrieData, tmc MetricsCollector, _ common.TrieStorageInteractor) (node, [][]byte, error) { err := ln.isEmptyOrNil() if err != nil { return nil, [][]byte{}, fmt.Errorf("insert error %w", err) @@ -242,14 +249,15 @@ func (ln *leafNode) insert(newData core.TrieData, _ common.TrieStorageInteractor nodeKey := ln.Key if bytes.Equal(newData.Key, nodeKey) { - return ln.insertInSameLn(newData, oldHash) + return ln.insertInSameLn(newData, oldHash, tmc) } keyMatchLen := prefixLen(newData.Key, nodeKey) - bn, err := ln.insertInNewBn(newData, keyMatchLen) + bn, err := ln.insertInNewBn(newData, keyMatchLen, tmc) if err != nil { return nil, [][]byte{}, err } + tmc.AddSizeLoadedInMem(bn.sizeInBytes()) if keyMatchLen == 0 { return bn, oldHash, nil @@ -259,15 +267,19 @@ func (ln *leafNode) insert(newData core.TrieData, _ common.TrieStorageInteractor if err != nil { return nil, [][]byte{}, err } + tmc.AddSizeLoadedInMem(newEn.sizeInBytes()) return newEn, oldHash, nil } -func (ln *leafNode) insertInSameLn(newData core.TrieData, oldHashes [][]byte) (node, [][]byte, error) { +func (ln *leafNode) insertInSameLn(newData core.TrieData, oldHashes [][]byte, tmc MetricsCollector) (node, [][]byte, error) { if bytes.Equal(ln.Value, newData.Value) { return nil, [][]byte{}, nil } + sizeDiff := len(newData.Value) - len(ln.Value) + tmc.AddSizeLoadedInMem(sizeDiff) + ln.Value = newData.Value ln.Version = uint32(newData.Version) ln.dirty = true @@ -275,12 +287,13 @@ func (ln *leafNode) insertInSameLn(newData core.TrieData, oldHashes [][]byte) (n return ln, oldHashes, nil } -func (ln *leafNode) insertInNewBn(newData core.TrieData, keyMatchLen int) (node, error) { +func (ln *leafNode) insertInNewBn(newData core.TrieData, keyMatchLen int, tmc MetricsCollector) (node, error) { bn, err := newBranchNode(ln.marsh, ln.hasher) if err != nil { return nil, err } + originalSize := ln.sizeInBytes() oldChildPos := ln.Key[keyMatchLen] newChildPos := newData.Key[keyMatchLen] if childPosOutOfRange(oldChildPos) || childPosOutOfRange(newChildPos) { @@ -303,6 +316,7 @@ func (ln *leafNode) insertInNewBn(newData core.TrieData, keyMatchLen int) (node, } bn.children[oldChildPos] = newLnOldChildPos bn.setVersionForChild(oldLnVersion, oldChildPos) + newSize := newLnOldChildPos.sizeInBytes() newData.Key = newData.Key[keyMatchLen+1:] newLnNewChildPos, err := newLeafNode(newData, ln.marsh, ln.hasher) @@ -311,24 +325,28 @@ func (ln *leafNode) insertInNewBn(newData core.TrieData, keyMatchLen int) (node, } bn.children[newChildPos] = newLnNewChildPos bn.setVersionForChild(newData.Version, newChildPos) + newSize += newLnNewChildPos.sizeInBytes() + tmc.AddSizeLoadedInMem(newSize - originalSize) return bn, nil } -func (ln *leafNode) delete(key []byte, _ common.TrieStorageInteractor) (bool, node, [][]byte, error) { +func (ln *leafNode) delete(key []byte, tmc MetricsCollector, _ common.TrieStorageInteractor) (bool, node, [][]byte, error) { if bytes.Equal(key, ln.Key) { oldHash := make([][]byte, 0) if !ln.dirty { oldHash = append(oldHash, ln.hash) } + tmc.AddSizeLoadedInMem(-ln.sizeInBytes()) return true, nil, oldHash, nil } return false, ln, [][]byte{}, nil } -func (ln *leafNode) reduceNode(pos int) (node, bool, error) { - k := append([]byte{byte(pos)}, ln.Key...) +func (ln *leafNode) reduceNode(pos int, tmc MetricsCollector) (node, bool, error) { + extraKey := []byte{byte(pos)} + k := append(extraKey, ln.Key...) oldLnVersion, err := ln.getVersion() if err != nil { @@ -345,6 +363,7 @@ func (ln *leafNode) reduceNode(pos int) (node, bool, error) { if err != nil { return nil, false, err } + tmc.AddSizeLoadedInMem(len(extraKey)) return newLn, true, nil } @@ -359,7 +378,7 @@ func (ln *leafNode) isEmptyOrNil() error { return nil } -func (ln *leafNode) print(writer io.Writer, _ int, _ common.TrieStorageInteractor) { +func (ln *leafNode) print(writer io.Writer, _ int, _ MetricsCollector, _ common.TrieStorageInteractor) { if ln == nil { return } @@ -391,7 +410,7 @@ func (ln *leafNode) getDirtyHashes(hashes common.ModifiedHashes) error { return nil } -func (ln *leafNode) getChildren(_ common.TrieStorageInteractor) ([]node, error) { +func (ln *leafNode) getChildren(_ MetricsCollector, _ common.TrieStorageInteractor) ([]node, error) { return nil, nil } @@ -415,6 +434,7 @@ func (ln *leafNode) getAllLeavesOnChannel( _ marshal.Marshalizer, chanClose chan struct{}, ctx context.Context, + _ MetricsCollector, ) error { err := ln.isEmptyOrNil() if err != nil { @@ -451,7 +471,7 @@ func (ln *leafNode) getAllLeavesOnChannel( } } -func (ln *leafNode) getAllHashes(_ common.TrieStorageInteractor) ([][]byte, error) { +func (ln *leafNode) getAllHashes(_ MetricsCollector, _ common.TrieStorageInteractor) ([][]byte, error) { err := ln.isEmptyOrNil() if err != nil { return nil, fmt.Errorf("getAllHashes error: %w", err) @@ -477,8 +497,7 @@ func (ln *leafNode) sizeInBytes() int { return 0 } - // hasher + marshalizer + dirty flag = numNodeInnerPointers * pointerSizeInBytes + 1 - nodeSize := len(ln.hash) + len(ln.Key) + len(ln.Value) + numNodeInnerPointers*pointerSizeInBytes + 1 + nodeSize := baseNodeSizeInBytes + len(ln.Key) + len(ln.Value) + nodeVersionSizeInBytes return nodeSize } @@ -487,7 +506,7 @@ func (ln *leafNode) getValue() []byte { return ln.Value } -func (ln *leafNode) collectStats(ts common.TrieStatisticsHandler, depthLevel int, _ common.TrieStorageInteractor) error { +func (ln *leafNode) collectStats(ts common.TrieStatisticsHandler, tmc MetricsCollector, _ common.TrieStorageInteractor) error { err := ln.isEmptyOrNil() if err != nil { return fmt.Errorf("collectStats error %w", err) @@ -503,7 +522,7 @@ func (ln *leafNode) collectStats(ts common.TrieStatisticsHandler, depthLevel int return err } - ts.AddLeafNode(depthLevel, uint64(len(val)), version) + ts.AddLeafNode(int(tmc.GetCurrentDepth()), uint64(len(val)), version) return nil } @@ -518,6 +537,7 @@ func (ln *leafNode) getVersion() (core.TrieNodeVersion, error) { func (ln *leafNode) collectLeavesForMigration( migrationArgs vmcommon.ArgsMigrateDataTrieLeaves, + _ MetricsCollector, _ common.TrieStorageInteractor, keyBuilder common.KeyBuilder, ) (bool, error) { diff --git a/trie/leafNode_test.go b/trie/leafNode_test.go index bbcdbbf9fce..2b4a0f9a0b0 100644 --- a/trie/leafNode_test.go +++ b/trie/leafNode_test.go @@ -1,6 +1,7 @@ package trie import ( + "bytes" "context" "errors" "math" @@ -16,6 +17,7 @@ import ( "github.com/multiversx/mx-chain-go/trie/keyBuilder" "github.com/multiversx/mx-chain-go/trie/mock" "github.com/multiversx/mx-chain-go/trie/statistics" + "github.com/multiversx/mx-chain-go/trie/trieMetricsCollector" "github.com/stretchr/testify/assert" ) @@ -158,7 +160,7 @@ func TestLeafNode_commit(t *testing.T) { hash, _ := encodeNodeAndGetHash(ln) _ = ln.setHash() - err := ln.commitDirty(0, 5, db, db) + err := ln.commitDirty(db, db) assert.Nil(t, err) encNode, _ := db.Get(hash) @@ -173,7 +175,7 @@ func TestLeafNode_commitEmptyNode(t *testing.T) { ln := &leafNode{} - err := ln.commitDirty(0, 5, nil, nil) + err := ln.commitDirty(nil, nil) assert.True(t, errors.Is(err, ErrEmptyLeafNode)) } @@ -182,7 +184,7 @@ func TestLeafNode_commitNilNode(t *testing.T) { var ln *leafNode - err := ln.commitDirty(0, 5, nil, nil) + err := ln.commitDirty(nil, nil) assert.True(t, errors.Is(err, ErrNilLeafNode)) } @@ -223,7 +225,7 @@ func TestLeafNode_resolveCollapsed(t *testing.T) { ln := getLn(getTestMarshalizerAndHasher()) - assert.Nil(t, ln.resolveCollapsed(0, nil)) + assert.Nil(t, ln.resolveCollapsed(0, nil, nil)) } func TestLeafNode_isCollapsed(t *testing.T) { @@ -239,10 +241,11 @@ func TestLeafNode_tryGet(t *testing.T) { ln := getLn(getTestMarshalizerAndHasher()) key := []byte("dog") - val, maxDepth, err := ln.tryGet(key, 0, nil) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + val, err := ln.tryGet(key, tmc, nil) assert.Equal(t, []byte("dog"), val) assert.Nil(t, err) - assert.Equal(t, uint32(0), maxDepth) + assert.Equal(t, uint32(0), tmc.GetMaxDepth()) } func TestLeafNode_tryGetWrongKey(t *testing.T) { @@ -251,10 +254,11 @@ func TestLeafNode_tryGetWrongKey(t *testing.T) { ln := getLn(getTestMarshalizerAndHasher()) wrongKey := []byte{1, 2, 3} - val, maxDepth, err := ln.tryGet(wrongKey, 0, nil) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + val, err := ln.tryGet(wrongKey, tmc, nil) assert.Nil(t, val) assert.Nil(t, err) - assert.Equal(t, uint32(0), maxDepth) + assert.Equal(t, uint32(0), tmc.GetMaxDepth()) } func TestLeafNode_tryGetEmptyNode(t *testing.T) { @@ -263,10 +267,11 @@ func TestLeafNode_tryGetEmptyNode(t *testing.T) { ln := &leafNode{} key := []byte("dog") - val, maxDepth, err := ln.tryGet(key, 0, nil) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + val, err := ln.tryGet(key, tmc, nil) assert.True(t, errors.Is(err, ErrEmptyLeafNode)) assert.Nil(t, val) - assert.Equal(t, uint32(0), maxDepth) + assert.Equal(t, uint32(0), tmc.GetMaxDepth()) } func TestLeafNode_tryGetNilNode(t *testing.T) { @@ -275,10 +280,11 @@ func TestLeafNode_tryGetNilNode(t *testing.T) { var ln *leafNode key := []byte("dog") - val, maxDepth, err := ln.tryGet(key, 0, nil) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + val, err := ln.tryGet(key, tmc, nil) assert.True(t, errors.Is(err, ErrNilLeafNode)) assert.Nil(t, val) - assert.Equal(t, uint32(0), maxDepth) + assert.Equal(t, uint32(0), tmc.GetMaxDepth()) } func TestLeafNode_getNext(t *testing.T) { @@ -287,7 +293,7 @@ func TestLeafNode_getNext(t *testing.T) { ln := getLn(getTestMarshalizerAndHasher()) key := []byte("dog") - n, key, err := ln.getNext(key, nil) + n, key, err := ln.getNext(key, nil, nil) assert.Nil(t, n) assert.Nil(t, key) assert.Nil(t, err) @@ -299,7 +305,7 @@ func TestLeafNode_getNextWrongKey(t *testing.T) { ln := getLn(getTestMarshalizerAndHasher()) wrongKey := append([]byte{2}, []byte("dog")...) - n, key, err := ln.getNext(wrongKey, nil) + n, key, err := ln.getNext(wrongKey, nil, nil) assert.Nil(t, n) assert.Nil(t, key) assert.Equal(t, ErrNodeNotFound, err) @@ -311,7 +317,7 @@ func TestLeafNode_getNextNilNode(t *testing.T) { var ln *leafNode key := []byte("dog") - n, key, err := ln.getNext(key, nil) + n, key, err := ln.getNext(key, nil, nil) assert.Nil(t, n) assert.Nil(t, key) assert.True(t, errors.Is(err, ErrNilLeafNode)) @@ -324,11 +330,13 @@ func TestLeafNode_insertAtSameKey(t *testing.T) { key := "dog" expectedVal := "dogs" - newNode, _, err := ln.insert(getTrieDataWithDefaultVersion(key, expectedVal), nil) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + newNode, _, err := ln.insert(getTrieDataWithDefaultVersion(key, expectedVal), tmc, nil) assert.NotNil(t, newNode) assert.Nil(t, err) + assert.Equal(t, 1, tmc.GetSizeLoadedInMem()) - val, _, _ := newNode.tryGet([]byte(key), 0, nil) + val, _ := newNode.tryGet([]byte(key), dtmc, nil) assert.Equal(t, []byte(expectedVal), val) } @@ -343,11 +351,14 @@ func TestLeafNode_insertAtDifferentKey(t *testing.T) { nodeKey := []byte{3, 4, 5} nodeVal := []byte{3, 4, 5} - newNode, _, err := ln.insert(getTrieDataWithDefaultVersion(string(nodeKey), string(nodeVal)), nil) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + newNode, _, err := ln.insert(getTrieDataWithDefaultVersion(string(nodeKey), string(nodeVal)), tmc, nil) assert.NotNil(t, newNode) assert.Nil(t, err) + expectedSize := newNode.sizeInBytes() + newNode.(*branchNode).children[3].sizeInBytes() - 1 + assert.Equal(t, expectedSize, tmc.GetSizeLoadedInMem()) - val, _, _ := newNode.tryGet(nodeKey, 0, nil) + val, _ := newNode.tryGet(nodeKey, trieMetricsCollector.NewTrieMetricsCollector(), nil) assert.Equal(t, nodeVal, val) assert.IsType(t, &branchNode{}, newNode) } @@ -357,13 +368,15 @@ func TestLeafNode_insertInStoredLnAtSameKey(t *testing.T) { db := testscommon.NewMemDbMock() ln := getLn(getTestMarshalizerAndHasher()) - _ = ln.commitDirty(0, 5, db, db) + _ = ln.commitDirty(db, db) lnHash := ln.getHash() - newNode, oldHashes, err := ln.insert(getTrieDataWithDefaultVersion("dog", "dogs"), db) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + newNode, oldHashes, err := ln.insert(getTrieDataWithDefaultVersion("dog", "dogs"), tmc, db) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, [][]byte{lnHash}, oldHashes) + assert.Equal(t, 1, tmc.GetSizeLoadedInMem()) } func TestLeafNode_insertInStoredLnAtDifferentKey(t *testing.T) { @@ -372,13 +385,16 @@ func TestLeafNode_insertInStoredLnAtDifferentKey(t *testing.T) { db := testscommon.NewMemDbMock() marsh, hasher := getTestMarshalizerAndHasher() ln, _ := newLeafNode(getTrieDataWithDefaultVersion(string([]byte{1, 2, 3}), "dog"), marsh, hasher) - _ = ln.commitDirty(0, 5, db, db) + _ = ln.commitDirty(db, db) lnHash := ln.getHash() - newNode, oldHashes, err := ln.insert(getTrieDataWithDefaultVersion(string([]byte{4, 5, 6}), "dogs"), db) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + newNode, oldHashes, err := ln.insert(getTrieDataWithDefaultVersion(string([]byte{4, 5, 6}), "dogs"), tmc, db) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, [][]byte{lnHash}, oldHashes) + expectedSize := newNode.sizeInBytes() + newNode.(*branchNode).children[4].sizeInBytes() - 1 + assert.Equal(t, expectedSize, tmc.GetSizeLoadedInMem()) } func TestLeafNode_insertInDirtyLnAtSameKey(t *testing.T) { @@ -386,7 +402,7 @@ func TestLeafNode_insertInDirtyLnAtSameKey(t *testing.T) { ln := getLn(getTestMarshalizerAndHasher()) - newNode, oldHashes, err := ln.insert(getTrieDataWithDefaultVersion("dog", "dogs"), nil) + newNode, oldHashes, err := ln.insert(getTrieDataWithDefaultVersion("dog", "dogs"), dtmc, nil) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, [][]byte{}, oldHashes) @@ -398,7 +414,7 @@ func TestLeafNode_insertInDirtyLnAtDifferentKey(t *testing.T) { marsh, hasher := getTestMarshalizerAndHasher() ln, _ := newLeafNode(getTrieDataWithDefaultVersion(string([]byte{1, 2, 3}), "dog"), marsh, hasher) - newNode, oldHashes, err := ln.insert(getTrieDataWithDefaultVersion(string([]byte{4, 5, 6}), "dogs"), nil) + newNode, oldHashes, err := ln.insert(getTrieDataWithDefaultVersion(string([]byte{4, 5, 6}), "dogs"), dtmc, nil) assert.NotNil(t, newNode) assert.Nil(t, err) assert.Equal(t, [][]byte{}, oldHashes) @@ -409,7 +425,7 @@ func TestLeafNode_insertInNilNode(t *testing.T) { var ln *leafNode - newNode, _, err := ln.insert(getTrieDataWithDefaultVersion("dog", "dogs"), nil) + newNode, _, err := ln.insert(getTrieDataWithDefaultVersion("dog", "dogs"), nil, nil) assert.Nil(t, newNode) assert.True(t, errors.Is(err, ErrNilLeafNode)) assert.Nil(t, newNode) @@ -420,10 +436,12 @@ func TestLeafNode_deletePresent(t *testing.T) { ln := getLn(getTestMarshalizerAndHasher()) - dirty, newNode, _, err := ln.delete([]byte("dog"), nil) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + dirty, newNode, _, err := ln.delete([]byte("dog"), tmc, nil) assert.True(t, dirty) assert.Nil(t, err) assert.Nil(t, newNode) + assert.Equal(t, -ln.sizeInBytes(), tmc.GetSizeLoadedInMem()) } func TestLeafNode_deleteFromStoredLnAtSameKey(t *testing.T) { @@ -431,10 +449,10 @@ func TestLeafNode_deleteFromStoredLnAtSameKey(t *testing.T) { db := testscommon.NewMemDbMock() ln := getLn(getTestMarshalizerAndHasher()) - _ = ln.commitDirty(0, 5, db, db) + _ = ln.commitDirty(db, db) lnHash := ln.getHash() - dirty, _, oldHashes, err := ln.delete([]byte("dog"), db) + dirty, _, oldHashes, err := ln.delete([]byte("dog"), dtmc, db) assert.True(t, dirty) assert.Nil(t, err) assert.Equal(t, [][]byte{lnHash}, oldHashes) @@ -445,13 +463,15 @@ func TestLeafNode_deleteFromLnAtDifferentKey(t *testing.T) { db := testscommon.NewMemDbMock() ln := getLn(getTestMarshalizerAndHasher()) - _ = ln.commitDirty(0, 5, db, db) + _ = ln.commitDirty(db, db) wrongKey := []byte{1, 2, 3} - dirty, _, oldHashes, err := ln.delete(wrongKey, db) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + dirty, _, oldHashes, err := ln.delete(wrongKey, tmc, db) assert.False(t, dirty) assert.Nil(t, err) assert.Equal(t, [][]byte{}, oldHashes) + assert.Equal(t, 0, tmc.GetSizeLoadedInMem()) } func TestLeafNode_deleteFromDirtyLnAtSameKey(t *testing.T) { @@ -459,7 +479,7 @@ func TestLeafNode_deleteFromDirtyLnAtSameKey(t *testing.T) { ln := getLn(getTestMarshalizerAndHasher()) - dirty, _, oldHashes, err := ln.delete([]byte("dog"), nil) + dirty, _, oldHashes, err := ln.delete([]byte("dog"), dtmc, nil) assert.True(t, dirty) assert.Nil(t, err) assert.Equal(t, [][]byte{}, oldHashes) @@ -471,7 +491,7 @@ func TestLeafNode_deleteNotPresent(t *testing.T) { ln := getLn(getTestMarshalizerAndHasher()) wrongKey := []byte{1, 2, 3} - dirty, newNode, _, err := ln.delete(wrongKey, nil) + dirty, newNode, _, err := ln.delete(wrongKey, dtmc, nil) assert.False(t, dirty) assert.Nil(t, err) assert.Equal(t, ln, newNode) @@ -485,7 +505,7 @@ func TestLeafNode_reduceNode(t *testing.T) { expected, _ := newLeafNode(getTrieDataWithDefaultVersion(string([]byte{2, 100, 111, 103}), ""), marsh, hasher) expected.dirty = true - n, newChildHash, err := ln.reduceNode(2) + n, newChildHash, err := ln.reduceNode(2, dtmc) assert.Equal(t, expected, n) assert.Nil(t, err) assert.True(t, newChildHash) @@ -506,7 +526,7 @@ func TestLeafNode_getChildren(t *testing.T) { ln := getLn(getTestMarshalizerAndHasher()) - children, err := ln.getChildren(nil) + children, err := ln.getChildren(nil, nil) assert.Nil(t, err) assert.Equal(t, 0, len(children)) } @@ -654,7 +674,7 @@ func TestLeafNode_getAllHashes(t *testing.T) { t.Parallel() ln := getLn(getTestMarshalizerAndHasher()) - hashes, err := ln.getAllHashes(testscommon.NewMemDbMock()) + hashes, err := ln.getAllHashes(nil, testscommon.NewMemDbMock()) assert.Nil(t, err) assert.Equal(t, 1, len(hashes)) assert.Equal(t, ln.hash, hashes[0]) @@ -690,7 +710,7 @@ func TestLeafNode_SizeInBytes(t *testing.T) { value := []byte("value") key := []byte("key") - hash := []byte("hash") + hash := bytes.Repeat([]byte{1}, 32) ln = &leafNode{ CollapsedLn: CollapsedLn{ Key: key, @@ -703,7 +723,7 @@ func TestLeafNode_SizeInBytes(t *testing.T) { hasher: nil, }, } - assert.Equal(t, len(key)+len(value)+len(hash)+1+2*pointerSizeInBytes, ln.sizeInBytes()) + assert.Equal(t, len(key)+len(value)+len(hash)+1+2*pointerSizeInBytes+nodeVersionSizeInBytes, ln.sizeInBytes()) } func TestLeafNode_writeNodeOnChannel(t *testing.T) { @@ -729,7 +749,7 @@ func TestLeafNode_commitContextDone(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() - err := ln.commitSnapshot(db, 0, nil, nil, ctx, statistics.NewTrieStatistics(), &testscommon.ProcessStatusHandlerStub{}, []byte{}, 0) + err := ln.commitSnapshot(db, 0, nil, nil, ctx, statistics.NewTrieStatistics(), &testscommon.ProcessStatusHandlerStub{}, []byte{}, dtmc) assert.Equal(t, core.ErrContextClosing, err) } @@ -804,3 +824,14 @@ func TestLeafNode_getNodeData(t *testing.T) { assert.True(t, nodeData[0].IsLeaf()) }) } + +func TestLeafNode_shouldCollapse(t *testing.T) { + t.Parallel() + + ln := getLn(getTestMarshalizerAndHasher()) + ln.setDirty(false) + shouldCollapse := ln.shouldCollapseChild([]byte("doge"), nil) + assert.False(t, shouldCollapse) + shouldCollapse = ln.shouldCollapseChild([]byte("dog"), nil) + assert.True(t, shouldCollapse) +} diff --git a/trie/node.go b/trie/node.go index f15f182c1cb..c896eecfa39 100644 --- a/trie/node.go +++ b/trie/node.go @@ -16,13 +16,17 @@ import ( ) const ( - nrOfChildren = 17 - firstByte = 0 - hexTerminator = 16 - nibbleMask = 0x0f - pointerSizeInBytes = 8 - numNodeInnerPointers = 2 // each trie node contains a marshalizer and a hasher - pollingIdleNode = time.Millisecond + nrOfChildren = 17 + firstByte = 0 + hexTerminator = 16 + nibbleMask = 0x0f + pointerSizeInBytes = 8 + numNodeInnerPointers = 2 // each trie node contains a marshalizer and a hasher + pollingIdleNode = time.Millisecond + hashSizeInBytes = 32 // size of the hash in bytes + baseNodeSizeInBytes = hashSizeInBytes + numNodeInnerPointers*pointerSizeInBytes + 1 // 1 for the dirty flag + bnChildrenPointersSize = nrOfChildren * pointerSizeInBytes + nodeVersionSizeInBytes = 4 ) type baseNode struct { @@ -136,7 +140,7 @@ func treatLogError(logInstance logger.Logger, err error, key []byte) { logInstance.Trace(core.GetNodeFromDBErrorString, "error", err, "key", key, "stack trace", string(debug.Stack())) } -func resolveIfCollapsed(n node, pos byte, db common.TrieStorageInteractor) error { +func resolveIfCollapsed(n node, pos byte, tmc MetricsCollector, db common.TrieStorageInteractor) error { err := n.isEmptyOrNil() if err != nil { return err @@ -147,7 +151,7 @@ func resolveIfCollapsed(n node, pos byte, db common.TrieStorageInteractor) error return nil } - return n.resolveCollapsed(pos, db) + return n.resolveCollapsed(pos, tmc, db) } func handleStorageInteractorStats(db common.TrieStorageInteractor) { @@ -319,7 +323,7 @@ func commitSnapshot( ctx context.Context, stats common.TrieStatisticsHandler, idleProvider IdleNodeProvider, - depthLevel int, + tmc MetricsCollector, hash []byte, ) error { encChild, foundInEpoch, err := db.GetWithoutAddingToCache(hash, maxEpochToSearchFrom) @@ -340,7 +344,7 @@ func commitSnapshot( return err } - err = child.commitSnapshot(db, foundInEpoch, leavesChan, missingNodesChan, ctx, stats, idleProvider, encChild, depthLevel+1) + err = child.commitSnapshot(db, foundInEpoch, leavesChan, missingNodesChan, ctx, stats, idleProvider, encChild, tmc) if err != nil { return err } diff --git a/trie/node_test.go b/trie/node_test.go index d5e8774a289..459c0586b7e 100644 --- a/trie/node_test.go +++ b/trie/node_test.go @@ -16,6 +16,7 @@ import ( "github.com/multiversx/mx-chain-go/state/parsers" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/trie/keyBuilder" + "github.com/multiversx/mx-chain-go/trie/trieMetricsCollector" logger "github.com/multiversx/mx-chain-logger-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -165,7 +166,7 @@ func TestNode_getNodeFromDBAndDecodeBranchNode(t *testing.T) { db := testscommon.NewMemDbMock() bn, collapsedBn := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) - _ = bn.commitDirty(0, 5, db, db) + _ = bn.commitDirty(db, db) encNode, _ := bn.marsh.Marshal(collapsedBn) encNode = append(encNode, branch) @@ -184,7 +185,7 @@ func TestNode_getNodeFromDBAndDecodeExtensionNode(t *testing.T) { db := testscommon.NewMemDbMock() en, collapsedEn := getEnAndCollapsedEn() - _ = en.commitDirty(0, 5, db, db) + _ = en.commitDirty(db, db) encNode, _ := en.marsh.Marshal(collapsedEn) encNode = append(encNode, extension) @@ -203,7 +204,7 @@ func TestNode_getNodeFromDBAndDecodeLeafNode(t *testing.T) { db := testscommon.NewMemDbMock() ln := getLn(getTestMarshalizerAndHasher()) - _ = ln.commitDirty(0, 5, db, db) + _ = ln.commitDirty(db, db) encNode, _ := ln.marsh.Marshal(ln) encNode = append(encNode, leaf) @@ -223,11 +224,13 @@ func TestNode_resolveIfCollapsedBranchNode(t *testing.T) { db := testscommon.NewMemDbMock() bn, collapsedBn := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) childPos := byte(2) - _ = bn.commitDirty(0, 5, db, db) + _ = bn.commitDirty(db, db) - err := resolveIfCollapsed(collapsedBn, childPos, db) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + err := resolveIfCollapsed(collapsedBn, childPos, tmc, db) assert.Nil(t, err) assert.False(t, collapsedBn.isCollapsed()) + assert.Equal(t, collapsedBn.children[childPos].sizeInBytes(), tmc.GetSizeLoadedInMem()) } func TestNode_resolveIfCollapsedExtensionNode(t *testing.T) { @@ -235,11 +238,13 @@ func TestNode_resolveIfCollapsedExtensionNode(t *testing.T) { db := testscommon.NewMemDbMock() en, collapsedEn := getEnAndCollapsedEn() - _ = en.commitDirty(0, 5, db, db) + _ = en.commitDirty(db, db) - err := resolveIfCollapsed(collapsedEn, 0, db) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + err := resolveIfCollapsed(collapsedEn, 0, tmc, db) assert.Nil(t, err) assert.False(t, collapsedEn.isCollapsed()) + assert.Equal(t, collapsedEn.child.sizeInBytes(), tmc.GetSizeLoadedInMem()) } func TestNode_resolveIfCollapsedLeafNode(t *testing.T) { @@ -247,11 +252,13 @@ func TestNode_resolveIfCollapsedLeafNode(t *testing.T) { db := testscommon.NewMemDbMock() ln := getLn(getTestMarshalizerAndHasher()) - _ = ln.commitDirty(0, 5, db, db) + _ = ln.commitDirty(db, db) - err := resolveIfCollapsed(ln, 0, db) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + err := resolveIfCollapsed(ln, 0, tmc, db) assert.Nil(t, err) assert.False(t, ln.isCollapsed()) + assert.Equal(t, 0, tmc.GetSizeLoadedInMem()) } func TestNode_resolveIfCollapsedNilNode(t *testing.T) { @@ -259,7 +266,7 @@ func TestNode_resolveIfCollapsedNilNode(t *testing.T) { var nodeInstance *extensionNode - err := resolveIfCollapsed(nodeInstance, 0, nil) + err := resolveIfCollapsed(nodeInstance, 0, nil, nil) assert.Equal(t, ErrNilExtensionNode, err) } diff --git a/trie/patriciaMerkleTrie.go b/trie/patriciaMerkleTrie.go index ed92942eabe..688d09deeb3 100644 --- a/trie/patriciaMerkleTrie.go +++ b/trie/patriciaMerkleTrie.go @@ -19,6 +19,7 @@ import ( "github.com/multiversx/mx-chain-go/errors" "github.com/multiversx/mx-chain-go/trie/keyBuilder" "github.com/multiversx/mx-chain-go/trie/statistics" + "github.com/multiversx/mx-chain-go/trie/trieMetricsCollector" ) var log = logger.GetOrCreate("trie") @@ -31,8 +32,6 @@ const ( branch ) -const rootDepthLevel = 0 - type patriciaMerkleTrie struct { root node @@ -42,11 +41,11 @@ type patriciaMerkleTrie struct { enableEpochsHandler common.EnableEpochsHandler trieNodeVersionVerifier core.TrieNodeVersionVerifier mutOperation sync.RWMutex + collapseManager common.TrieCollapseManager - oldHashes [][]byte - oldRoot []byte - maxTrieLevelInMemory uint - chanClose chan struct{} + oldHashes [][]byte + oldRoot []byte + chanClose chan struct{} } // NewTrie creates a new Patricia Merkle Trie @@ -55,7 +54,7 @@ func NewTrie( msh marshal.Marshalizer, hsh hashing.Hasher, enableEpochsHandler common.EnableEpochsHandler, - maxTrieLevelInMemory uint, + collapseManager common.TrieCollapseManager, ) (*patriciaMerkleTrie, error) { if check.IfNil(trieStorage) { return nil, ErrNilTrieStorage @@ -69,10 +68,9 @@ func NewTrie( if check.IfNil(enableEpochsHandler) { return nil, errors.ErrNilEnableEpochsHandler } - if maxTrieLevelInMemory == 0 { - return nil, ErrInvalidLevelValue + if check.IfNil(collapseManager) { + return nil, ErrNilCollapseManager } - log.Trace("created new trie", "max trie level in memory", maxTrieLevelInMemory) tnvv, err := core.NewTrieNodeVersionVerifier(enableEpochsHandler) if err != nil { @@ -85,10 +83,10 @@ func NewTrie( hasher: hsh, oldHashes: make([][]byte, 0), oldRoot: make([]byte, 0), - maxTrieLevelInMemory: maxTrieLevelInMemory, chanClose: make(chan struct{}), enableEpochsHandler: enableEpochsHandler, trieNodeVersionVerifier: tnvv, + collapseManager: collapseManager, }, nil } @@ -103,13 +101,15 @@ func (tr *patriciaMerkleTrie) Get(key []byte) ([]byte, uint32, error) { } hexKey := keyBytesToHex(key) - val, depth, err := tr.root.tryGet(hexKey, rootDepthLevel, tr.trieStorage) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + val, err := tr.root.tryGet(hexKey, tmc, tr.trieStorage) + tr.collapseManager.MarkKeyAsAccessed(hexKey, tmc.GetSizeLoadedInMem()) if err != nil { err = fmt.Errorf("trie get error: %w, for key %v", err, hex.EncodeToString(key)) - return nil, depth, err + return nil, tmc.GetMaxDepth(), err } - return val, depth, nil + return val, tmc.GetMaxDepth(), nil } // Update updates the value at the given key. @@ -148,6 +148,7 @@ func (tr *patriciaMerkleTrie) update(key []byte, value []byte, version core.Trie if err != nil { return err } + tr.collapseManager.MarkKeyAsAccessed(hexKey, newRoot.sizeInBytes()) tr.root = newRoot return nil @@ -157,7 +158,9 @@ func (tr *patriciaMerkleTrie) update(key []byte, value []byte, version core.Trie tr.oldRoot = tr.root.getHash() } - newRoot, oldHashes, err := tr.root.insert(newData, tr.trieStorage) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + newRoot, oldHashes, err := tr.root.insert(newData, tmc, tr.trieStorage) + tr.collapseManager.MarkKeyAsAccessed(hexKey, tmc.GetSizeLoadedInMem()) if err != nil { return err } @@ -195,7 +198,9 @@ func (tr *patriciaMerkleTrie) delete(hexKey []byte) error { tr.oldRoot = tr.root.getHash() } - _, newRoot, oldHashes, err := tr.root.delete(hexKey, tr.trieStorage) + tmc := trieMetricsCollector.NewTrieMetricsCollector() + _, newRoot, oldHashes, err := tr.root.delete(hexKey, tmc, tr.trieStorage) + tr.collapseManager.RemoveKey(hexKey, tmc.GetSizeLoadedInMem()) if err != nil { return err } @@ -255,14 +260,52 @@ func (tr *patriciaMerkleTrie) Commit() error { log.Trace("started committing trie", "trie", tr.root.getHash()) } - err = tr.root.commitDirty(0, tr.maxTrieLevelInMemory, tr.trieStorage, tr.trieStorage) + err = tr.root.commitDirty(tr.trieStorage, tr.trieStorage) + if err != nil { + return err + } + + if tr.collapseManager.ShouldCollapseTrie() { + return tr.collapseTrie() + } + + collapsibleLeaves, err := tr.collapseManager.GetCollapsibleLeaves() if err != nil { return err } + tr.collapseLeaves(collapsibleLeaves) return nil } +func (tr *patriciaMerkleTrie) collapseTrie() error { + collapsedRoot, err := tr.root.getCollapsed() + if err != nil { + return err + } + + tr.root = collapsedRoot + tr.collapseManager = tr.collapseManager.CloneWithoutState() + tr.collapseManager.AddSizeInMemory(tr.root.sizeInBytes()) + log.Info("trie collapsed", "root", tr.root.getHash()) + return nil +} + +func (tr *patriciaMerkleTrie) collapseLeaves(collapsibleLeaves [][]byte) { + if len(collapsibleLeaves) == 0 { + return + } + + for _, hexKey := range collapsibleLeaves { + tmc := trieMetricsCollector.NewTrieMetricsCollector() + if check.IfNil(tr.root) { + return + } + _ = tr.root.shouldCollapseChild(hexKey, tmc) + tr.collapseManager.AddSizeInMemory(tmc.GetSizeLoadedInMem()) + } +} + // Recreate returns a new trie, given the options func (tr *patriciaMerkleTrie) Recreate(options common.RootHashHolder) (common.Trie, error) { if check.IfNil(options) { @@ -288,7 +331,7 @@ func (tr *patriciaMerkleTrie) recreate(root []byte, tsm common.StorageManager) ( tr.marshalizer, tr.hasher, tr.enableEpochsHandler, - tr.maxTrieLevelInMemory, + tr.collapseManager.CloneWithoutState(), ) } @@ -316,7 +359,7 @@ func (tr *patriciaMerkleTrie) String() string { if tr.root == nil { _, _ = fmt.Fprintln(writer, "*** EMPTY TRIE ***") } else { - tr.root.print(writer, 0, tr.trieStorage) + tr.root.print(writer, 0, trieMetricsCollector.NewDisabledTrieMetricsCollector(), tr.trieStorage) } return writer.String() @@ -369,7 +412,7 @@ func (tr *patriciaMerkleTrie) recreateFromDb(rootHash []byte, tsm common.Storage tr.marshalizer, tr.hasher, tr.enableEpochsHandler, - tr.maxTrieLevelInMemory, + tr.collapseManager.CloneWithoutState(), ) if err != nil { return nil, nil, err @@ -382,6 +425,7 @@ func (tr *patriciaMerkleTrie) recreateFromDb(rootHash []byte, tsm common.Storage newRoot.setGivenHash(rootHash) newTr.root = newRoot + newTr.collapseManager.AddSizeInMemory(newRoot.sizeInBytes()) return newTr, newRoot, nil } @@ -500,6 +544,7 @@ func (tr *patriciaMerkleTrie) GetAllLeavesOnChannel( tr.marshalizer, tr.chanClose, ctx, + trieMetricsCollector.NewDisabledTrieMetricsCollector(), ) if err != nil { leavesChannels.ErrChan.WriteInChanNonBlocking(err) @@ -530,7 +575,7 @@ func (tr *patriciaMerkleTrie) GetAllHashes() ([][]byte, error) { return nil, err } - hashes, err = tr.root.getAllHashes(tr.trieStorage) + hashes, err = tr.root.getAllHashes(trieMetricsCollector.NewDisabledTrieMetricsCollector(), tr.trieStorage) if err != nil { return nil, err } @@ -580,7 +625,7 @@ func (tr *patriciaMerkleTrie) GetProof(key []byte) ([][]byte, []byte, error) { proof = append(proof, encodedNode) value := currentNode.getValue() - currentNode, hexKey, errGet = currentNode.getNext(hexKey, tr.trieStorage) + currentNode, hexKey, errGet = currentNode.getNext(hexKey, trieMetricsCollector.NewDisabledTrieMetricsCollector(), tr.trieStorage) if errGet != nil { return nil, nil, errGet } @@ -656,7 +701,7 @@ func (tr *patriciaMerkleTrie) GetTrieStats(address string, rootHash []byte) (com } ts := statistics.NewTrieStatistics() - err = newTrie.root.collectStats(ts, rootDepthLevel, newTrie.trieStorage) + err = newTrie.root.collectStats(ts, trieMetricsCollector.NewTrieMetricsCollector(), newTrie.trieStorage) if err != nil { return nil, err } @@ -683,12 +728,10 @@ func (tr *patriciaMerkleTrie) CollectLeavesForMigration(args vmcommon.ArgsMigrat return err } - _, err = tr.root.collectLeavesForMigration(args, tr.trieStorage, keyBuilder.NewKeyBuilder()) - if err != nil { - return err - } - - return nil + tmc := trieMetricsCollector.NewTrieMetricsCollector() + _, err = tr.root.collectLeavesForMigration(args, tmc, tr.trieStorage, keyBuilder.NewKeyBuilder()) + tr.collapseManager.AddSizeInMemory(tmc.GetSizeLoadedInMem()) + return err } func (tr *patriciaMerkleTrie) checkIfMigrationPossible(args vmcommon.ArgsMigrateDataTrieLeaves) error { @@ -735,6 +778,14 @@ func GetNodeDataFromHash(hash []byte, keyBuilder common.KeyBuilder, db common.Tr return n.getNodeData(keyBuilder) } +// SizeInMemory returns the size in memory of the trie +func (tr *patriciaMerkleTrie) SizeInMemory() int { + tr.mutOperation.RLock() + defer tr.mutOperation.RUnlock() + + return tr.collapseManager.GetSizeInMemory() +} + // Close stops all the active goroutines started by the trie func (tr *patriciaMerkleTrie) Close() error { tr.mutOperation.Lock() diff --git a/trie/patriciaMerkleTrie_test.go b/trie/patriciaMerkleTrie_test.go index 76f4e34e230..3246e2694be 100644 --- a/trie/patriciaMerkleTrie_test.go +++ b/trie/patriciaMerkleTrie_test.go @@ -17,6 +17,7 @@ import ( "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/hashing/keccak" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/trie/collapseManager" vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -44,18 +45,17 @@ func emptyTrie() common.Trie { } func emptyTrieWithCustomEnableEpochsHandler(handler common.EnableEpochsHandler) common.Trie { - storage, marshaller, hasher, _, maxTrieLevelInMem := getDefaultTrieParameters() + storage, marshaller, hasher, _, maxSizeInMem := getDefaultTrieParameters() - tr, _ := trie.NewTrie(storage, marshaller, hasher, handler, maxTrieLevelInMem) + tr, _ := trie.NewTrie(storage, marshaller, hasher, handler, maxSizeInMem) return tr } -func getDefaultTrieParameters() (common.StorageManager, marshal.Marshalizer, hashing.Hasher, common.EnableEpochsHandler, uint) { +func getDefaultTrieParameters() (common.StorageManager, marshal.Marshalizer, hashing.Hasher, common.EnableEpochsHandler, common.TrieCollapseManager) { args := trie.GetDefaultTrieStorageManagerParameters() trieStorageManager, _ := trie.NewTrieStorageManager(args) - maxTrieLevelInMemory := uint(1) - return trieStorageManager, args.Marshalizer, args.Hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, maxTrieLevelInMemory + return trieStorageManager, args.Marshalizer, args.Hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, collapseManager.NewDisabledCollapseManager() } func initTrieMultipleValues(nr int) (common.Trie, [][]byte) { @@ -88,8 +88,8 @@ func addDefaultDataToTrie(tr common.Trie) { func TestNewTrieWithNilTrieStorage(t *testing.T) { t.Parallel() - _, marshalizer, hasher, enableEpochsHandler, maxTrieLevelInMemory := getDefaultTrieParameters() - tr, err := trie.NewTrie(nil, marshalizer, hasher, enableEpochsHandler, maxTrieLevelInMemory) + _, marshalizer, hasher, enableEpochsHandler, cm := getDefaultTrieParameters() + tr, err := trie.NewTrie(nil, marshalizer, hasher, enableEpochsHandler, cm) assert.Nil(t, tr) assert.Equal(t, trie.ErrNilTrieStorage, err) @@ -98,8 +98,8 @@ func TestNewTrieWithNilTrieStorage(t *testing.T) { func TestNewTrieWithNilMarshalizer(t *testing.T) { t.Parallel() - trieStorage, _, hasher, enableEpochsHandler, maxTrieLevelInMemory := getDefaultTrieParameters() - tr, err := trie.NewTrie(trieStorage, nil, hasher, enableEpochsHandler, maxTrieLevelInMemory) + trieStorage, _, hasher, enableEpochsHandler, cm := getDefaultTrieParameters() + tr, err := trie.NewTrie(trieStorage, nil, hasher, enableEpochsHandler, cm) assert.Nil(t, tr) assert.Equal(t, trie.ErrNilMarshalizer, err) @@ -108,8 +108,8 @@ func TestNewTrieWithNilMarshalizer(t *testing.T) { func TestNewTrieWithNilHasher(t *testing.T) { t.Parallel() - trieStorage, marshalizer, _, enableEpochsHandler, maxTrieLevelInMemory := getDefaultTrieParameters() - tr, err := trie.NewTrie(trieStorage, marshalizer, nil, enableEpochsHandler, maxTrieLevelInMemory) + trieStorage, marshalizer, _, enableEpochsHandler, cm := getDefaultTrieParameters() + tr, err := trie.NewTrie(trieStorage, marshalizer, nil, enableEpochsHandler, cm) assert.Nil(t, tr) assert.Equal(t, trie.ErrNilHasher, err) @@ -118,21 +118,21 @@ func TestNewTrieWithNilHasher(t *testing.T) { func TestNewTrieWithNilEnableEpochsHandler(t *testing.T) { t.Parallel() - trieStorage, marshalizer, hasher, _, maxTrieLevelInMemory := getDefaultTrieParameters() - tr, err := trie.NewTrie(trieStorage, marshalizer, hasher, nil, maxTrieLevelInMemory) + trieStorage, marshalizer, hasher, _, cm := getDefaultTrieParameters() + tr, err := trie.NewTrie(trieStorage, marshalizer, hasher, nil, cm) assert.Nil(t, tr) assert.Equal(t, errorsCommon.ErrNilEnableEpochsHandler, err) } -func TestNewTrieWithInvalidMaxTrieLevelInMemory(t *testing.T) { +func TestNewTrieWithNilCollapseManager(t *testing.T) { t.Parallel() trieStorage, marshalizer, hasher, enableEpochsHandler, _ := getDefaultTrieParameters() - tr, err := trie.NewTrie(trieStorage, marshalizer, hasher, enableEpochsHandler, 0) + tr, err := trie.NewTrie(trieStorage, marshalizer, hasher, enableEpochsHandler, nil) assert.Nil(t, tr) - assert.Equal(t, trie.ErrInvalidLevelValue, err) + assert.True(t, errors.Is(err, trie.ErrNilCollapseManager)) } func TestPatriciaMerkleTree_Get(t *testing.T) { @@ -900,6 +900,7 @@ func TestPatriciaMerkleTrie_GetTrieStats(t *testing.T) { _ = tr.Update([]byte("dog"), []byte("reindeer")) _ = tr.Update([]byte("fog"), []byte("puppy")) _ = tr.Update([]byte("dogglesworth"), []byte("cat")) + _ = tr.Update([]byte("abch"), []byte("car")) _ = tr.Commit() rootHash, _ := tr.RootHash() @@ -911,10 +912,10 @@ func TestPatriciaMerkleTrie_GetTrieStats(t *testing.T) { stats, err := ts.GetTrieStats(address, rootHash) assert.Nil(t, err) - assert.Equal(t, uint64(2), stats.GetNumBranchNodes()) - assert.Equal(t, uint64(1), stats.GetNumExtensionNodes()) - assert.Equal(t, uint64(3), stats.GetNumLeafNodes()) - assert.Equal(t, uint64(6), stats.GetTotalNumNodes()) + assert.Equal(t, uint64(3), stats.GetNumBranchNodes()) + assert.Equal(t, uint64(2), stats.GetNumExtensionNodes()) + assert.Equal(t, uint64(4), stats.GetNumLeafNodes()) + assert.Equal(t, uint64(9), stats.GetTotalNumNodes()) assert.Equal(t, uint32(3), stats.GetMaxTrieDepth()) } @@ -1062,7 +1063,7 @@ func TestPatriciaMerkleTrie_GetSerializedNodesShouldSerializeTheCalls(t *testing }, } - tr, _ := trie.NewTrie(testTrieStorageManager, args.Marshalizer, args.Hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 5) + tr, _ := trie.NewTrie(testTrieStorageManager, args.Marshalizer, args.Hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, collapseManager.NewDisabledCollapseManager()) numGoRoutines := 100 wg := sync.WaitGroup{} wg.Add(numGoRoutines) @@ -1484,13 +1485,13 @@ func TestPatriciaMerkleTrie_IsMigrated(t *testing.T) { t.Run("not migrated", func(t *testing.T) { t.Parallel() - tsm, marshaller, hasher, _, maxTrieInMem := getDefaultTrieParameters() + tsm, marshaller, hasher, _, cm := getDefaultTrieParameters() enableEpochs := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { return flag == common.AutoBalanceDataTriesFlag }, } - tr, _ := trie.NewTrie(tsm, marshaller, hasher, enableEpochs, maxTrieInMem) + tr, _ := trie.NewTrie(tsm, marshaller, hasher, enableEpochs, cm) _ = tr.Update([]byte("dog"), []byte("reindeer")) isMigrated, err := tr.IsMigratedToLatestVersion() @@ -1501,13 +1502,13 @@ func TestPatriciaMerkleTrie_IsMigrated(t *testing.T) { t.Run("migrated", func(t *testing.T) { t.Parallel() - tsm, marshaller, hasher, _, maxTrieInMem := getDefaultTrieParameters() + tsm, marshaller, hasher, _, cm := getDefaultTrieParameters() enableEpochs := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { return flag == common.AutoBalanceDataTriesFlag }, } - tr, _ := trie.NewTrie(tsm, marshaller, hasher, enableEpochs, maxTrieInMem) + tr, _ := trie.NewTrie(tsm, marshaller, hasher, enableEpochs, cm) _ = tr.UpdateWithVersion([]byte("dog"), []byte("reindeer"), core.AutoBalanceEnabled) isMigrated, err := tr.IsMigratedToLatestVersion() @@ -1549,6 +1550,83 @@ func TestGetNodeDataFromHash(t *testing.T) { } +func TestPatriciaMerkleTree_SizeInMemory(t *testing.T) { + t.Parallel() + + db, msh, hsh, epochs, _ := getDefaultTrieParameters() + cm, err := collapseManager.NewCollapseManager(common.TenMbSize, common.NumLeavesToCollapseSingleRun) + assert.Nil(t, err) + tr, _ := trie.NewTrie(db, msh, hsh, epochs, cm) + + assert.Equal(t, 0, tr.SizeInMemory()) + addDefaultDataToTrie(tr) + + assert.Equal(t, 779, tr.SizeInMemory()) // 3 leaves + 2 branch nodes + 1 extension node + err = tr.Commit() + assert.Nil(t, err) + + err = tr.Delete([]byte("dog")) + assert.Nil(t, err) + assert.Equal(t, 380, tr.SizeInMemory()) // 1 branch node + 2 leaves + + err = tr.Commit() + assert.Nil(t, err) + + err = tr.Update([]byte("dog"), []byte("puppy")) + assert.Nil(t, err) + assert.Equal(t, 779, tr.SizeInMemory()) + + rootHash, err := tr.RootHash() + assert.Nil(t, err) + newTrie, err := tr.Recreate(holders.NewDefaultRootHashesHolder(rootHash)) + assert.Nil(t, err) + assert.Equal(t, 249, newTrie.SizeInMemory()) // only root node is in memory + + val, depth, err := newTrie.Get([]byte("dog")) + assert.Nil(t, err) + assert.Equal(t, []byte("puppy"), val) + assert.Equal(t, uint32(3), depth) + assert.Equal(t, 654, newTrie.SizeInMemory()) + + val, depth, err = newTrie.Get([]byte("dog")) + assert.Nil(t, err) + assert.Equal(t, []byte("puppy"), val) + assert.Equal(t, uint32(3), depth) + assert.Equal(t, 654, newTrie.SizeInMemory()) + + err = tr.Delete([]byte("doe")) // delete collapsed node + assert.Nil(t, err) + assert.Equal(t, 464, tr.SizeInMemory()) +} + +func TestPatriciaMerkleTree_CollapseTrie(t *testing.T) { + t.Parallel() + + db, msh, hsh, epochs, _ := getDefaultTrieParameters() + oneMbSize := uint64(1048576) + cm, _ := collapseManager.NewCollapseManager(oneMbSize, common.NumLeavesToCollapseSingleRun) + tr, _ := trie.NewTrie(db, msh, hsh, epochs, cm) + + for uint64(tr.SizeInMemory()) < oneMbSize { + randomKey := make([]byte, 32) + _, _ = cryptoRand.Read(randomKey) + _ = tr.Update(randomKey, randomKey) + } + + sizeBeforeCollapse := tr.SizeInMemory() + numCollapsed, err := tr.GetNumCollapsedNodes() + assert.Nil(t, err) + assert.Equal(t, 0, numCollapsed) + err = tr.Commit() + assert.Nil(t, err) + sizeAfterCollapse := tr.SizeInMemory() + assert.Less(t, sizeAfterCollapse, sizeBeforeCollapse) + assert.Less(t, sizeAfterCollapse, int(oneMbSize)) + numCollapsed, err = tr.GetNumCollapsedNodes() + assert.Nil(t, err) + assert.Equal(t, 100, numCollapsed) +} + func BenchmarkPatriciaMerkleTree_Insert(b *testing.B) { tr := emptyTrie() hsh := keccak.NewKeccak() diff --git a/trie/snapshotTrieStorageManager_test.go b/trie/snapshotTrieStorageManager_test.go index c8cc2df3ce2..9415edd6566 100644 --- a/trie/snapshotTrieStorageManager_test.go +++ b/trie/snapshotTrieStorageManager_test.go @@ -10,7 +10,7 @@ import ( "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" - "github.com/multiversx/mx-chain-go/testscommon/trie" + storageMock "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/stretchr/testify/assert" ) @@ -30,7 +30,7 @@ func TestNewSnapshotTrieStorageManager(t *testing.T) { t.Parallel() _, trieStorage := newEmptyTrie() - trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{} + trieStorage.mainStorer = &storageMock.SnapshotPruningStorerStub{} stsm, err := newSnapshotTrieStorageManager(trieStorage, 0) assert.Nil(t, err) assert.False(t, check.IfNil(stsm)) @@ -43,7 +43,7 @@ func TestSnapshotTrieStorageManager_GetWithoutAddingToCache(t *testing.T) { t.Parallel() _, trieStorage := newEmptyTrie() - trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{ + trieStorage.mainStorer = &storageMock.SnapshotPruningStorerStub{ GetWithoutAddingToCacheCalled: func(key []byte, maxEpochToSearchFrom uint32) ([]byte, core.OptionalUint32, error) { return nil, core.OptionalUint32{}, core.ErrContextClosing }, @@ -59,7 +59,7 @@ func TestSnapshotTrieStorageManager_GetWithoutAddingToCache(t *testing.T) { t.Parallel() _, trieStorage := newEmptyTrie() - trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{ + trieStorage.mainStorer = &storageMock.SnapshotPruningStorerStub{ GetWithoutAddingToCacheCalled: func(_ []byte, _ uint32) ([]byte, core.OptionalUint32, error) { return nil, core.OptionalUint32{}, storage.ErrDBIsClosed }, @@ -75,7 +75,7 @@ func TestSnapshotTrieStorageManager_GetWithoutAddingToCache(t *testing.T) { _, trieStorage := newEmptyTrie() getFromOldEpochsWithoutCacheCalled := false - trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{ + trieStorage.mainStorer = &storageMock.SnapshotPruningStorerStub{ GetWithoutAddingToCacheCalled: func(_ []byte, _ uint32) ([]byte, core.OptionalUint32, error) { getFromOldEpochsWithoutCacheCalled = true return nil, core.OptionalUint32{}, nil @@ -95,7 +95,7 @@ func TestSnapshotTrieStorageManager_PutInEpochWithoutCache(t *testing.T) { t.Parallel() _, trieStorage := newEmptyTrie() - trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{ + trieStorage.mainStorer = &storageMock.SnapshotPruningStorerStub{ PutInEpochWithoutCacheCalled: func(_ []byte, _ []byte, _ uint32) error { return core.ErrContextClosing }, @@ -111,7 +111,7 @@ func TestSnapshotTrieStorageManager_PutInEpochWithoutCache(t *testing.T) { _, trieStorage := newEmptyTrie() putWithoutCacheCalled := false - trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{ + trieStorage.mainStorer = &storageMock.SnapshotPruningStorerStub{ PutInEpochWithoutCacheCalled: func(_ []byte, _ []byte, _ uint32) error { putWithoutCacheCalled = true return nil @@ -131,7 +131,7 @@ func TestSnapshotTrieStorageManager_GetFromLastEpoch(t *testing.T) { t.Parallel() _, trieStorage := newEmptyTrie() - trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{ + trieStorage.mainStorer = &storageMock.SnapshotPruningStorerStub{ GetFromLastEpochCalled: func(_ []byte) ([]byte, error) { return nil, core.ErrContextClosing }, @@ -148,7 +148,7 @@ func TestSnapshotTrieStorageManager_GetFromLastEpoch(t *testing.T) { _, trieStorage := newEmptyTrie() getFromLastEpochCalled := false - trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{ + trieStorage.mainStorer = &storageMock.SnapshotPruningStorerStub{ GetFromLastEpochCalled: func(_ []byte) ([]byte, error) { getFromLastEpochCalled = true return nil, nil @@ -167,7 +167,7 @@ func TestSnapshotTrieStorageManager_AlsoAddInPreviousEpoch(t *testing.T) { t.Run("HasValue is false", func(t *testing.T) { val := []byte("val") _, trieStorage := newEmptyTrie() - trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{ + trieStorage.mainStorer = &storageMock.SnapshotPruningStorerStub{ GetWithoutAddingToCacheCalled: func(_ []byte, _ uint32) ([]byte, core.OptionalUint32, error) { return val, core.OptionalUint32{}, nil }, @@ -184,7 +184,7 @@ func TestSnapshotTrieStorageManager_AlsoAddInPreviousEpoch(t *testing.T) { t.Run("epoch is previous epoch", func(t *testing.T) { val := []byte("val") _, trieStorage := newEmptyTrie() - trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{ + trieStorage.mainStorer = &storageMock.SnapshotPruningStorerStub{ GetWithoutAddingToCacheCalled: func(_ []byte, _ uint32) ([]byte, core.OptionalUint32, error) { epoch := core.OptionalUint32{ Value: 4, @@ -205,7 +205,7 @@ func TestSnapshotTrieStorageManager_AlsoAddInPreviousEpoch(t *testing.T) { t.Run("epoch is 0", func(t *testing.T) { val := []byte("val") _, trieStorage := newEmptyTrie() - trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{ + trieStorage.mainStorer = &storageMock.SnapshotPruningStorerStub{ GetWithoutAddingToCacheCalled: func(_ []byte, _ uint32) ([]byte, core.OptionalUint32, error) { epoch := core.OptionalUint32{ Value: 4, @@ -226,7 +226,7 @@ func TestSnapshotTrieStorageManager_AlsoAddInPreviousEpoch(t *testing.T) { t.Run("key is ActiveDBKey", func(t *testing.T) { val := []byte("val") _, trieStorage := newEmptyTrie() - trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{ + trieStorage.mainStorer = &storageMock.SnapshotPruningStorerStub{ GetWithoutAddingToCacheCalled: func(_ []byte, _ uint32) ([]byte, core.OptionalUint32, error) { epoch := core.OptionalUint32{ Value: 3, @@ -247,7 +247,7 @@ func TestSnapshotTrieStorageManager_AlsoAddInPreviousEpoch(t *testing.T) { t.Run("key is TrieSyncedKey", func(t *testing.T) { val := []byte("val") _, trieStorage := newEmptyTrie() - trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{ + trieStorage.mainStorer = &storageMock.SnapshotPruningStorerStub{ GetWithoutAddingToCacheCalled: func(_ []byte, _ uint32) ([]byte, core.OptionalUint32, error) { epoch := core.OptionalUint32{ Value: 3, @@ -269,7 +269,7 @@ func TestSnapshotTrieStorageManager_AlsoAddInPreviousEpoch(t *testing.T) { val := []byte("val") putInEpochCalled := false _, trieStorage := newEmptyTrie() - trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{ + trieStorage.mainStorer = &storageMock.SnapshotPruningStorerStub{ GetWithoutAddingToCacheCalled: func(_ []byte, _ uint32) ([]byte, core.OptionalUint32, error) { epoch := core.OptionalUint32{ Value: 3, diff --git a/trie/sync.go b/trie/sync.go index ce48f8c8e6b..9e356405e8c 100644 --- a/trie/sync.go +++ b/trie/sync.go @@ -15,6 +15,7 @@ import ( "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/storage" + "github.com/multiversx/mx-chain-go/trie/trieMetricsCollector" ) type trieNodeInfo struct { @@ -175,6 +176,7 @@ func (ts *trieSyncer) checkIfSynced() (bool, error) { missingNodes := make(map[string]struct{}) currentMissingNodes := make([][]byte, 0) checkedNodes := make(map[string]struct{}) + tmc := trieMetricsCollector.NewDisabledTrieMetricsCollector() newElement := true shouldRetryAfterRequest := false @@ -233,7 +235,7 @@ func (ts *trieSyncer) checkIfSynced() (bool, error) { continue } - nextNodes, err = currentNode.getChildren(ts.db) + nextNodes, err = currentNode.getChildren(tmc, ts.db) if err != nil { return false, err } diff --git a/trie/syncTrieStorageManager_test.go b/trie/syncTrieStorageManager_test.go index 0e7c7532433..5e7fe9ce3e5 100644 --- a/trie/syncTrieStorageManager_test.go +++ b/trie/syncTrieStorageManager_test.go @@ -6,7 +6,7 @@ import ( "testing" "github.com/multiversx/mx-chain-go/testscommon" - "github.com/multiversx/mx-chain-go/testscommon/trie" + "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/stretchr/testify/assert" ) @@ -34,7 +34,7 @@ func TestNewSyncTrieStorageManager(t *testing.T) { t.Parallel() _, trieStorage := newEmptyTrie() - trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{} + trieStorage.mainStorer = &storage.SnapshotPruningStorerStub{} stsm, err := NewSyncTrieStorageManager(trieStorage) assert.Nil(t, err) assert.NotNil(t, stsm) @@ -45,7 +45,7 @@ func TestNewSyncTrieStorageManager_PutInFirstEpoch(t *testing.T) { _, trieStorage := newEmptyTrie() putInEpochCalled := 0 - trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{ + trieStorage.mainStorer = &storage.SnapshotPruningStorerStub{ PutInEpochCalled: func(_ []byte, _ []byte, _ uint32) error { putInEpochCalled++ return nil @@ -66,7 +66,7 @@ func TestNewSyncTrieStorageManager_PutInEpochError(t *testing.T) { expectedErr := errors.New("expected error") _, trieStorage := newEmptyTrie() - trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{ + trieStorage.mainStorer = &storage.SnapshotPruningStorerStub{ PutInEpochCalled: func(_ []byte, _ []byte, _ uint32) error { return expectedErr }, @@ -82,7 +82,7 @@ func TestNewSyncTrieStorageManager_PutInEpoch(t *testing.T) { _, trieStorage := newEmptyTrie() putInEpochCalled := 0 - trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{ + trieStorage.mainStorer = &storage.SnapshotPruningStorerStub{ PutInEpochCalled: func(_ []byte, _ []byte, _ uint32) error { putInEpochCalled++ return nil diff --git a/trie/sync_test.go b/trie/sync_test.go index 77fe8a6c75b..adc11a625d6 100644 --- a/trie/sync_test.go +++ b/trie/sync_test.go @@ -3,6 +3,7 @@ package trie import ( "context" "errors" + "sync" "testing" "time" @@ -10,6 +11,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + trieMock "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -18,8 +20,8 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" - trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" "github.com/multiversx/mx-chain-go/trie/statistics" + "github.com/multiversx/mx-chain-go/trie/trieMetricsCollector" ) func createMockArgument(timeout time.Duration) ArgTrieSyncer { @@ -225,10 +227,10 @@ func TestTrieSync_FoundInStorageShouldNotRequest(t *testing.T) { }, } - err = bn.commitDirty(0, 5, db, db) + err = bn.commitDirty(db, db) require.Nil(t, err) - leaves, err := bn.getChildren(db) + leaves, err := bn.getChildren(trieMetricsCollector.NewDisabledTrieMetricsCollector(), db) require.Nil(t, err) numLeaves := len(leaves) diff --git a/trie/trieMetricsCollector/disabledTrieMetricsCollector.go b/trie/trieMetricsCollector/disabledTrieMetricsCollector.go new file mode 100644 index 00000000000..edb051b948e --- /dev/null +++ b/trie/trieMetricsCollector/disabledTrieMetricsCollector.go @@ -0,0 +1,31 @@ +package trieMetricsCollector + +type disabledTrieMetricsCollector struct{} + +// NewDisabledTrieMetricsCollector returns a new instance of disabledTrieMetricsCollector +func NewDisabledTrieMetricsCollector() *disabledTrieMetricsCollector { + return &disabledTrieMetricsCollector{} +} + +// SetDepth is a no-op for the disabled metrics collector +func (d *disabledTrieMetricsCollector) SetDepth(_ uint32) { +} + +// GetCurrentDepth returns 0 for the disabled metrics collector +func (d *disabledTrieMetricsCollector) GetCurrentDepth() uint32 { + return 0 +} + +// GetMaxDepth returns 0 for the disabled metrics collector +func (d *disabledTrieMetricsCollector) GetMaxDepth() uint32 { + return 0 +} + +// AddSizeLoadedInMem is a no-op for the disabled metrics collector +func (d *disabledTrieMetricsCollector) AddSizeLoadedInMem(_ int) { +} + +// GetSizeLoadedInMem returns 0 for the disabled metrics collector +func (d *disabledTrieMetricsCollector) GetSizeLoadedInMem() int { + return 0 +} diff --git a/trie/trieMetricsCollector/disabledTrieMetricsCollector_test.go b/trie/trieMetricsCollector/disabledTrieMetricsCollector_test.go new file mode 100644 index 00000000000..3b3aa14d21b --- /dev/null +++ b/trie/trieMetricsCollector/disabledTrieMetricsCollector_test.go @@ -0,0 +1,49 @@ +package trieMetricsCollector + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewDisabledTrieMetricsCollector(t *testing.T) { + t.Parallel() + + assert.NotNil(t, NewDisabledTrieMetricsCollector()) +} + +func TestDisabledTrieMetricsCollector_SetDepthDoesNotPanic(t *testing.T) { + t.Parallel() + + collector := NewDisabledTrieMetricsCollector() + collector.SetDepth(5) + + // No assertion needed, just checking that it doesn't panic +} + +func TestDisabledTrieMetricsCollector_GetMaxDepthReturnsZero(t *testing.T) { + t.Parallel() + + collector := NewDisabledTrieMetricsCollector() + maxDepth := collector.GetMaxDepth() + + assert.Equal(t, uint32(0), maxDepth) +} + +func TestDisabledTrieMetricsCollector_AddSizeLoadedInMemDoesNotPanic(t *testing.T) { + t.Parallel() + + collector := NewDisabledTrieMetricsCollector() + collector.AddSizeLoadedInMem(100) + + // No assertion needed, just checking that it doesn't panic +} + +func TestDisabledTrieMetricsCollector_GetSizeLoadedInMemReturnsZero(t *testing.T) { + t.Parallel() + + collector := NewDisabledTrieMetricsCollector() + size := collector.GetSizeLoadedInMem() + + assert.Equal(t, 0, size) +} diff --git a/trie/trieMetricsCollector/trieMetricsCollector.go b/trie/trieMetricsCollector/trieMetricsCollector.go new file mode 100644 index 00000000000..77af2411a21 --- /dev/null +++ b/trie/trieMetricsCollector/trieMetricsCollector.go @@ -0,0 +1,45 @@ +package trieMetricsCollector + +type trieMetricsCollector struct { + currentDepth int + maxDepth int + sizeLoadedInMem int +} + +// NewTrieMetricsCollector creates a new instance of trieMetricsCollector +func NewTrieMetricsCollector() *trieMetricsCollector { + return &trieMetricsCollector{ + maxDepth: 0, + sizeLoadedInMem: 0, + } +} + +// SetDepth sets the maxDepth to the provided value if it is greater than the current maxDepth +func (tmc *trieMetricsCollector) SetDepth(depth uint32) { + tmc.currentDepth = int(depth) + if depth <= uint32(tmc.maxDepth) { + return + } + + tmc.maxDepth = int(depth) +} + +// GetCurrentDepth returns the current depth stored in the collector +func (tmc *trieMetricsCollector) GetCurrentDepth() uint32 { + return uint32(tmc.currentDepth) +} + +// GetMaxDepth returns the collected maxDepth +func (tmc *trieMetricsCollector) GetMaxDepth() uint32 { + return uint32(tmc.maxDepth) +} + +// AddSizeLoadedInMem adds the size of the loaded data in memory to the collector +func (tmc *trieMetricsCollector) AddSizeLoadedInMem(size int) { + tmc.sizeLoadedInMem += size +} + +// GetSizeLoadedInMem returns the total size of data loaded in memory +func (tmc *trieMetricsCollector) GetSizeLoadedInMem() int { + return tmc.sizeLoadedInMem +} diff --git a/trie/trieMetricsCollector/trieMetricsCollector_test.go b/trie/trieMetricsCollector/trieMetricsCollector_test.go new file mode 100644 index 00000000000..87d21c91691 --- /dev/null +++ b/trie/trieMetricsCollector/trieMetricsCollector_test.go @@ -0,0 +1,80 @@ +package trieMetricsCollector + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewTrieMetricsCollector(t *testing.T) { + t.Parallel() + + collector := NewTrieMetricsCollector() + assert.NotNil(t, collector) + assert.Equal(t, 0, collector.maxDepth) + assert.Equal(t, 0, collector.sizeLoadedInMem) +} + +func TestTrieMetricsCollector_SetDepth(t *testing.T) { + t.Parallel() + + collector := NewTrieMetricsCollector() + collector.SetDepth(5) + + assert.Equal(t, 5, collector.maxDepth) + collector.SetDepth(3) + assert.Equal(t, 5, collector.maxDepth) // Should not change since 3 < 5 + collector.SetDepth(7) + assert.Equal(t, 7, collector.maxDepth) // Should update to 7 since it's greater than 5 +} + +func TestTrieMetricsCollector_GetCurrentDepth(t *testing.T) { + t.Parallel() + + collector := NewTrieMetricsCollector() + assert.Equal(t, uint32(0), collector.GetCurrentDepth()) + + collector.SetDepth(8) + assert.Equal(t, uint32(8), collector.GetCurrentDepth()) + + collector.SetDepth(4) + assert.Equal(t, uint32(4), collector.GetCurrentDepth()) // Current depth should update to 4 + assert.Equal(t, uint32(8), collector.GetMaxDepth()) // Max depth should remain 8 +} + +func TestTrieMetricsCollector_GetMaxDepth(t *testing.T) { + t.Parallel() + + collector := NewTrieMetricsCollector() + assert.Equal(t, uint32(0), collector.GetMaxDepth()) + + collector.SetDepth(10) + assert.Equal(t, uint32(10), collector.GetMaxDepth()) + + collector.SetDepth(5) + assert.Equal(t, uint32(10), collector.GetMaxDepth()) // Should not change since 5 < 10 +} + +func TestTrieMetricsCollector_AddSizeLoadedInMem(t *testing.T) { + t.Parallel() + + collector := NewTrieMetricsCollector() + collector.AddSizeLoadedInMem(100) + assert.Equal(t, 100, collector.sizeLoadedInMem) + + collector.AddSizeLoadedInMem(50) + assert.Equal(t, 150, collector.sizeLoadedInMem) // Should accumulate size +} + +func TestTrieMetricsCollector_GetSizeLoadedInMem(t *testing.T) { + t.Parallel() + + collector := NewTrieMetricsCollector() + assert.Equal(t, 0, collector.GetSizeLoadedInMem()) + + collector.AddSizeLoadedInMem(200) + assert.Equal(t, 200, collector.GetSizeLoadedInMem()) + + collector.AddSizeLoadedInMem(300) + assert.Equal(t, 500, collector.GetSizeLoadedInMem()) // Should accumulate size +} diff --git a/trie/trieStorageManager.go b/trie/trieStorageManager.go index 796eace218a..0c3cd4b97c2 100644 --- a/trie/trieStorageManager.go +++ b/trie/trieStorageManager.go @@ -18,6 +18,7 @@ import ( "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/trie/statistics" + "github.com/multiversx/mx-chain-go/trie/trieMetricsCollector" ) // trieStorageManager manages all the storage operations of the trie (commit, snapshot, checkpoint, pruning) @@ -336,7 +337,7 @@ func (tsm *trieStorageManager) takeSnapshot(snapshotEntry *snapshotsQueueEntry, } stats := statistics.NewTrieStatistics() - err = newRoot.commitSnapshot(stsm, foundInEpoch, snapshotEntry.iteratorChannels.LeavesChan, snapshotEntry.missingNodesChan, ctx, stats, tsm.idleProvider, encodedRoot, rootDepthLevel) + err = newRoot.commitSnapshot(stsm, foundInEpoch, snapshotEntry.iteratorChannels.LeavesChan, snapshotEntry.missingNodesChan, ctx, stats, tsm.idleProvider, encodedRoot, trieMetricsCollector.NewTrieMetricsCollector()) if err != nil { snapshotEntry.iteratorChannels.ErrChan.WriteInChanNonBlocking(err) treatSnapshotError(err, diff --git a/trie/trieStorageManagerInEpoch_test.go b/trie/trieStorageManagerInEpoch_test.go index bedf3734529..ca0717242d2 100644 --- a/trie/trieStorageManagerInEpoch_test.go +++ b/trie/trieStorageManagerInEpoch_test.go @@ -7,8 +7,8 @@ import ( "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + testCommon "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/testscommon/storageManager" - "github.com/multiversx/mx-chain-go/testscommon/trie" "github.com/stretchr/testify/assert" ) @@ -72,7 +72,7 @@ func TestTrieStorageManagerInEpoch_GetFromEpoch(t *testing.T) { _, trieStorage := newEmptyTrie() getFromEpochCalled := false - trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{ + trieStorage.mainStorer = &testCommon.SnapshotPruningStorerStub{ GetFromEpochCalled: func(_ []byte, _ uint32) ([]byte, error) { getFromEpochCalled = true return nil, nil @@ -89,7 +89,7 @@ func TestTrieStorageManagerInEpoch_GetFromEpoch(t *testing.T) { _, trieStorage := newEmptyTrie() getFromEpochCalled := false - trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{ + trieStorage.mainStorer = &testCommon.SnapshotPruningStorerStub{ GetFromEpochCalled: func(_ []byte, _ uint32) ([]byte, error) { getFromEpochCalled = true return nil, storage.ErrDBIsClosed @@ -107,7 +107,7 @@ func TestTrieStorageManagerInEpoch_GetFromEpoch(t *testing.T) { _, trieStorage := newEmptyTrie() getFromEpochCalled := false - trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{ + trieStorage.mainStorer = &testCommon.SnapshotPruningStorerStub{ GetFromEpochCalled: func(_ []byte, _ uint32) ([]byte, error) { getFromEpochCalled = true return nil, errors.New("not closing error") @@ -128,7 +128,7 @@ func TestTrieStorageManagerInEpoch_GetFromEpoch(t *testing.T) { getFromPreviousEpochCalled := false currentEpoch := uint32(5) expectedKey := []byte("key") - trieStorage.mainStorer = &trie.SnapshotPruningStorerStub{ + trieStorage.mainStorer = &testCommon.SnapshotPruningStorerStub{ GetFromEpochCalled: func(key []byte, epoch uint32) ([]byte, error) { assert.Equal(t, expectedKey, key) if epoch == currentEpoch { diff --git a/trie/trieStorageManager_test.go b/trie/trieStorageManager_test.go index 104c6e2578c..88c9bb7888a 100644 --- a/trie/trieStorageManager_test.go +++ b/trie/trieStorageManager_test.go @@ -207,7 +207,7 @@ func TestTrieStorageManager_RemoveFromAllActiveEpochs(t *testing.T) { RemoveFromAllActiveEpochsCalled := false args := trie.GetDefaultTrieStorageManagerParameters() - args.MainStorer = &trieMock.SnapshotPruningStorerStub{ + args.MainStorer = &storage.SnapshotPruningStorerStub{ MemDbMock: testscommon.NewMemDbMock(), RemoveFromAllActiveEpochsCalled: func(key []byte) error { RemoveFromAllActiveEpochsCalled = true @@ -237,7 +237,7 @@ func TestTrieStorageManager_PutInEpoch(t *testing.T) { putInEpochCalled := false args := trie.GetDefaultTrieStorageManagerParameters() - args.MainStorer = &trieMock.SnapshotPruningStorerStub{ + args.MainStorer = &storage.SnapshotPruningStorerStub{ MemDbMock: testscommon.NewMemDbMock(), PutInEpochCalled: func(key []byte, data []byte, epoch uint32) error { putInEpochCalled = true @@ -268,7 +268,7 @@ func TestTrieStorageManager_GetLatestStorageEpoch(t *testing.T) { getLatestSorageCalled := false args := trie.GetDefaultTrieStorageManagerParameters() - args.MainStorer = &trieMock.SnapshotPruningStorerStub{ + args.MainStorer = &storage.SnapshotPruningStorerStub{ MemDbMock: testscommon.NewMemDbMock(), GetLatestStorageEpochCalled: func() (uint32, error) { getLatestSorageCalled = true @@ -384,7 +384,7 @@ func TestTrieStorageManager_ShouldTakeSnapshot(t *testing.T) { t.Parallel() args := trie.GetDefaultTrieStorageManagerParameters() - args.MainStorer = &trieMock.SnapshotPruningStorerStub{ + args.MainStorer = &storage.SnapshotPruningStorerStub{ GetFromCurrentEpochCalled: func(key []byte) ([]byte, error) { return []byte(common.TrieSyncedVal), nil }, @@ -398,7 +398,7 @@ func TestTrieStorageManager_ShouldTakeSnapshot(t *testing.T) { t.Parallel() args := trie.GetDefaultTrieStorageManagerParameters() - args.MainStorer = &trieMock.SnapshotPruningStorerStub{ + args.MainStorer = &storage.SnapshotPruningStorerStub{ GetFromCurrentEpochCalled: func(key []byte) ([]byte, error) { return []byte("invalid marker"), nil }, @@ -412,7 +412,7 @@ func TestTrieStorageManager_ShouldTakeSnapshot(t *testing.T) { t.Parallel() args := trie.GetDefaultTrieStorageManagerParameters() - args.MainStorer = &trieMock.SnapshotPruningStorerStub{ + args.MainStorer = &storage.SnapshotPruningStorerStub{ GetFromCurrentEpochCalled: func(key []byte) ([]byte, error) { return nil, expectedErr // isTrieSynced returns false }, @@ -477,7 +477,7 @@ func TestNewSnapshotTrieStorageManager_GetFromCurrentEpoch(t *testing.T) { getFromCurrentEpochCalled := false args := trie.GetDefaultTrieStorageManagerParameters() - args.MainStorer = &trieMock.SnapshotPruningStorerStub{ + args.MainStorer = &storage.SnapshotPruningStorerStub{ MemDbMock: testscommon.NewMemDbMock(), GetFromCurrentEpochCalled: func(_ []byte) ([]byte, error) { getFromCurrentEpochCalled = true diff --git a/update/factory/accountDBSyncerContainerFactory.go b/update/factory/accountDBSyncerContainerFactory.go index 58684996dca..aaa6150f673 100644 --- a/update/factory/accountDBSyncerContainerFactory.go +++ b/update/factory/accountDBSyncerContainerFactory.go @@ -29,7 +29,6 @@ type ArgsNewAccountsDBSyncersContainerFactory struct { Marshalizer marshal.Marshalizer TrieStorageManager common.StorageManager TimoutGettingTrieNode time.Duration - MaxTrieLevelInMemory uint NumConcurrentTrieSyncers int MaxHardCapForMissingNodes int TrieSyncerVersion int @@ -47,7 +46,6 @@ type accountDBSyncersContainerFactory struct { marshalizer marshal.Marshalizer timeoutGettingTrieNode time.Duration trieStorageManager common.StorageManager - maxTrieLevelinMemory uint numConcurrentTrieSyncers int maxHardCapForMissingNodes int trieSyncerVersion int @@ -101,7 +99,6 @@ func NewAccountsDBSContainerFactory(args ArgsNewAccountsDBSyncersContainerFactor marshalizer: args.Marshalizer, trieStorageManager: args.TrieStorageManager, timeoutGettingTrieNode: args.TimoutGettingTrieNode, - maxTrieLevelinMemory: args.MaxTrieLevelInMemory, numConcurrentTrieSyncers: args.NumConcurrentTrieSyncers, maxHardCapForMissingNodes: args.MaxHardCapForMissingNodes, trieSyncerVersion: args.TrieSyncerVersion, @@ -151,7 +148,6 @@ func (a *accountDBSyncersContainerFactory) createUserAccountsSyncer(shardId uint RequestHandler: a.requestHandler, Timeout: a.timeoutGettingTrieNode, Cacher: a.trieCacher, - MaxTrieLevelInMemory: a.maxTrieLevelinMemory, MaxHardCapForMissingNodes: a.maxHardCapForMissingNodes, TrieSyncerVersion: a.trieSyncerVersion, CheckNodesOnDisk: a.checkNodesOnDisk, @@ -181,7 +177,6 @@ func (a *accountDBSyncersContainerFactory) createValidatorAccountsSyncer(shardId RequestHandler: a.requestHandler, Timeout: a.timeoutGettingTrieNode, Cacher: a.trieCacher, - MaxTrieLevelInMemory: a.maxTrieLevelinMemory, MaxHardCapForMissingNodes: a.maxHardCapForMissingNodes, TrieSyncerVersion: a.trieSyncerVersion, CheckNodesOnDisk: a.checkNodesOnDisk, diff --git a/update/factory/dataTrieFactory.go b/update/factory/dataTrieFactory.go index 10483099780..968ffef10ab 100644 --- a/update/factory/dataTrieFactory.go +++ b/update/factory/dataTrieFactory.go @@ -13,33 +13,32 @@ import ( "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/sharding" - "github.com/multiversx/mx-chain-go/state" + "github.com/multiversx/mx-chain-go/state/triesHolder" storageFactory "github.com/multiversx/mx-chain-go/storage/factory" "github.com/multiversx/mx-chain-go/storage/storageunit" "github.com/multiversx/mx-chain-go/trie" + "github.com/multiversx/mx-chain-go/trie/collapseManager" "github.com/multiversx/mx-chain-go/update" "github.com/multiversx/mx-chain-go/update/genesis" ) // ArgsNewDataTrieFactory is the argument structure for the new data trie factory type ArgsNewDataTrieFactory struct { - StorageConfig config.StorageConfig - SyncFolder string - Marshalizer marshal.Marshalizer - Hasher hashing.Hasher - ShardCoordinator sharding.Coordinator - EnableEpochsHandler common.EnableEpochsHandler - StateStatsCollector common.StateStatisticsHandler - MaxTrieLevelInMemory uint + StorageConfig config.StorageConfig + SyncFolder string + Marshalizer marshal.Marshalizer + Hasher hashing.Hasher + ShardCoordinator sharding.Coordinator + EnableEpochsHandler common.EnableEpochsHandler + StateStatsCollector common.StateStatisticsHandler } type dataTrieFactory struct { - shardCoordinator sharding.Coordinator - trieStorage common.StorageManager - marshalizer marshal.Marshalizer - hasher hashing.Hasher - enableEpochsHandler common.EnableEpochsHandler - maxTrieLevelInMemory uint + shardCoordinator sharding.Coordinator + trieStorage common.StorageManager + marshalizer marshal.Marshalizer + hasher hashing.Hasher + enableEpochsHandler common.EnableEpochsHandler } // NewDataTrieFactory creates a data trie factory @@ -100,12 +99,11 @@ func NewDataTrieFactory(args ArgsNewDataTrieFactory) (*dataTrieFactory, error) { } d := &dataTrieFactory{ - shardCoordinator: args.ShardCoordinator, - trieStorage: trieStorage, - marshalizer: args.Marshalizer, - hasher: args.Hasher, - maxTrieLevelInMemory: args.MaxTrieLevelInMemory, - enableEpochsHandler: args.EnableEpochsHandler, + shardCoordinator: args.ShardCoordinator, + trieStorage: trieStorage, + marshalizer: args.Marshalizer, + hasher: args.Hasher, + enableEpochsHandler: args.EnableEpochsHandler, } return d, nil @@ -118,7 +116,7 @@ func (d *dataTrieFactory) TrieStorageManager() common.StorageManager { // Create creates a TriesHolder container to hold all the states func (d *dataTrieFactory) Create() (common.TriesHolder, error) { - container := state.NewDataTriesHolder() + container := triesHolder.NewTriesHolder() for i := uint32(0); i < d.shardCoordinator.NumberOfShards(); i++ { err := d.createAndAddOneTrie(i, genesis.UserAccount, container) @@ -141,7 +139,7 @@ func (d *dataTrieFactory) Create() (common.TriesHolder, error) { } func (d *dataTrieFactory) createAndAddOneTrie(shId uint32, accType genesis.Type, container common.TriesHolder) error { - dataTrie, err := trie.NewTrie(d.trieStorage, d.marshalizer, d.hasher, d.enableEpochsHandler, d.maxTrieLevelInMemory) + dataTrie, err := trie.NewTrie(d.trieStorage, d.marshalizer, d.hasher, d.enableEpochsHandler, collapseManager.NewDisabledCollapseManager()) if err != nil { return err } diff --git a/update/factory/exportHandlerFactory.go b/update/factory/exportHandlerFactory.go index 44bb0904c86..1bce5c5a6b6 100644 --- a/update/factory/exportHandlerFactory.go +++ b/update/factory/exportHandlerFactory.go @@ -54,7 +54,6 @@ type ArgsExporter struct { ExportTriesStorageConfig config.StorageConfig ExportStateStorageConfig config.StorageConfig ExportStateKeysConfig config.StorageConfig - MaxTrieLevelInMemory uint WhiteListHandler process.WhiteListHandler WhiteListerVerifiedTxs process.WhiteListHandler MainInterceptorsContainer process.InterceptorsContainer @@ -88,7 +87,6 @@ type exportHandlerFactory struct { exportTriesStorageConfig config.StorageConfig exportStateStorageConfig config.StorageConfig exportStateKeysConfig config.StorageConfig - maxTrieLevelInMemory uint whiteListHandler process.WhiteListHandler whiteListerVerifiedTxs process.WhiteListHandler mainInterceptorsContainer process.InterceptorsContainer @@ -260,7 +258,6 @@ func NewExportHandlerFactory(args ArgsExporter) (*exportHandlerFactory, error) { headerSigVerifier: args.HeaderSigVerifier, headerIntegrityVerifier: args.HeaderIntegrityVerifier, validityAttester: args.ValidityAttester, - maxTrieLevelInMemory: args.MaxTrieLevelInMemory, roundHandler: args.RoundHandler, maxHardCapForMissingNodes: args.MaxHardCapForMissingNodes, numConcurrentTrieSyncers: args.NumConcurrentTrieSyncers, @@ -321,14 +318,13 @@ func (e *exportHandlerFactory) Create() (update.ExportHandler, error) { } argsDataTrieFactory := ArgsNewDataTrieFactory{ - StorageConfig: e.exportTriesStorageConfig, - SyncFolder: e.exportFolder, - Marshalizer: e.coreComponents.InternalMarshalizer(), - Hasher: e.coreComponents.Hasher(), - ShardCoordinator: e.shardCoordinator, - MaxTrieLevelInMemory: e.maxTrieLevelInMemory, - EnableEpochsHandler: e.coreComponents.EnableEpochsHandler(), - StateStatsCollector: e.statusCoreComponents.StateStatsHandler(), + StorageConfig: e.exportTriesStorageConfig, + SyncFolder: e.exportFolder, + Marshalizer: e.coreComponents.InternalMarshalizer(), + Hasher: e.coreComponents.Hasher(), + ShardCoordinator: e.shardCoordinator, + EnableEpochsHandler: e.coreComponents.EnableEpochsHandler(), + StateStatsCollector: e.statusCoreComponents.StateStatsHandler(), } dataTriesContainerFactory, err := NewDataTrieFactory(argsDataTrieFactory) if err != nil { @@ -413,7 +409,6 @@ func (e *exportHandlerFactory) Create() (update.ExportHandler, error) { Marshalizer: e.coreComponents.InternalMarshalizer(), TrieStorageManager: trieStorageManager, TimoutGettingTrieNode: common.TimeoutGettingTrieNodesInHardfork, - MaxTrieLevelInMemory: e.maxTrieLevelInMemory, MaxHardCapForMissingNodes: e.maxHardCapForMissingNodes, NumConcurrentTrieSyncers: e.numConcurrentTrieSyncers, TrieSyncerVersion: e.trieSyncerVersion, diff --git a/update/genesis/import.go b/update/genesis/import.go index 8e59e45b7f4..f4c2da29518 100644 --- a/update/genesis/import.go +++ b/update/genesis/import.go @@ -21,14 +21,14 @@ import ( disabledState "github.com/multiversx/mx-chain-go/state/disabled" "github.com/multiversx/mx-chain-go/state/factory" "github.com/multiversx/mx-chain-go/state/storagePruningManager/disabled" + "github.com/multiversx/mx-chain-go/state/triesHolder" "github.com/multiversx/mx-chain-go/trie" + "github.com/multiversx/mx-chain-go/trie/collapseManager" "github.com/multiversx/mx-chain-go/update" ) var _ update.ImportHandler = (*stateImport)(nil) -const maxTrieLevelInMemory = uint(5) - // ArgsNewStateImport is the arguments structure to create a new state importer type ArgsNewStateImport struct { Hasher hashing.Hasher @@ -280,23 +280,35 @@ func (si *stateImport) importMiniBlocks(identifier string, keys [][]byte) error func newAccountCreator( accType Type, + tr common.Trie, hasher hashing.Hasher, marshaller marshal.Marshalizer, handler common.EnableEpochsHandler, -) (state.AccountFactory, error) { +) (state.AccountFactory, common.TriesHolder, error) { switch accType { case UserAccount: + dth, err := triesHolder.NewDataTriesHolder(common.TenMbSize) + if err != nil { + return nil, nil, err + } args := factory.ArgsAccountCreator{ Hasher: hasher, Marshaller: marshaller, EnableEpochsHandler: handler, StateAccessesCollector: disabledState.NewDisabledStateAccessesCollector(), + DataTriesHolder: dth, + DataTrieCreator: tr, } - return factory.NewAccountCreator(args) + accCreator, err := factory.NewAccountCreator(args) + if err != nil { + return nil, nil, err + } + + return accCreator, dth, nil case ValidatorAccount: - return factory.NewPeerAccountCreator(), nil + return factory.NewPeerAccountCreator(), triesHolder.NewDisabledDataTriesHolder(), nil } - return nil, update.ErrUnknownType + return nil, nil, update.ErrUnknownType } func (si *stateImport) getTrie(shardID uint32, accType Type) (common.Trie, error) { @@ -315,7 +327,7 @@ func (si *stateImport) getTrie(shardID uint32, accType Type) (common.Trie, error trieStorageManager = si.trieStorageManagers[dataRetriever.PeerAccountsUnit.String()] } - trieForShard, err := trie.NewTrie(trieStorageManager, si.marshalizer, si.hasher, si.enableEpochsHandler, maxTrieLevelInMemory) + trieForShard, err := trie.NewTrie(trieStorageManager, si.marshalizer, si.hasher, si.enableEpochsHandler, collapseManager.NewDisabledCollapseManager()) if err != nil { return nil, err } @@ -347,7 +359,7 @@ func (si *stateImport) importDataTrie(identifier string, shID uint32, keys [][]b return fmt.Errorf("%w wanted a roothash", update.ErrWrongTypeAssertion) } - dataTrie, err := trie.NewTrie(si.trieStorageManagers[dataRetriever.UserAccountsUnit.String()], si.marshalizer, si.hasher, si.enableEpochsHandler, maxTrieLevelInMemory) + dataTrie, err := trie.NewTrie(si.trieStorageManagers[dataRetriever.UserAccountsUnit.String()], si.marshalizer, si.hasher, si.enableEpochsHandler, collapseManager.NewDisabledCollapseManager()) if err != nil { return err } @@ -405,12 +417,7 @@ func (si *stateImport) importDataTrie(identifier string, shID uint32, keys [][]b return nil } -func (si *stateImport) getAccountsDB(accType Type, shardID uint32, accountFactory state.AccountFactory) (state.AccountsDBImporter, common.Trie, error) { - currentTrie, err := si.getTrie(shardID, accType) - if err != nil { - return nil, nil, err - } - +func (si *stateImport) getAccountsDB(accType Type, currentTrie common.Trie, dth common.TriesHolder, shardID uint32, accountFactory state.AccountFactory) (state.AccountsDBImporter, common.Trie, error) { if accType == ValidatorAccount { if check.IfNil(si.validatorDB) { argsAccountDB := state.ArgsAccountsDB{ @@ -422,6 +429,7 @@ func (si *stateImport) getAccountsDB(accType Type, shardID uint32, accountFactor AddressConverter: si.addressConverter, SnapshotsManager: disabledState.NewDisabledSnapshotsManager(), StateAccessesCollector: disabledState.NewDisabledStateAccessesCollector(), + DataTriesHolder: dth, } accountsDB, errCreate := state.NewAccountsDB(argsAccountDB) if errCreate != nil { @@ -429,7 +437,7 @@ func (si *stateImport) getAccountsDB(accType Type, shardID uint32, accountFactor } si.validatorDB = accountsDB } - return si.validatorDB, currentTrie, err + return si.validatorDB, currentTrie, nil } accountsDB, ok := si.accountDBsMap[shardID] @@ -446,8 +454,9 @@ func (si *stateImport) getAccountsDB(accType Type, shardID uint32, accountFactor AddressConverter: si.addressConverter, SnapshotsManager: disabledState.NewDisabledSnapshotsManager(), StateAccessesCollector: disabledState.NewDisabledStateAccessesCollector(), + DataTriesHolder: dth, } - accountsDB, err = state.NewAccountsDB(argsAccountDB) + accountsDB, err := state.NewAccountsDB(argsAccountDB) si.accountDBsMap[shardID] = accountsDB return accountsDB, currentTrie, err } @@ -467,12 +476,17 @@ func (si *stateImport) importState(identifier string, keys [][]byte) error { return si.importDataTrie(identifier, shId, keys) } - accountFactory, err := newAccountCreator(accType, si.hasher, si.marshalizer, si.enableEpochsHandler) + currentTrie, err := si.getTrie(shId, accType) + if err != nil { + return err + } + + accountFactory, dth, err := newAccountCreator(accType, currentTrie, si.hasher, si.marshalizer, si.enableEpochsHandler) if err != nil { return err } - accountsDB, mainTrie, err := si.getAccountsDB(accType, shId, accountFactory) + accountsDB, mainTrie, err := si.getAccountsDB(accType, currentTrie, dth, shId, accountFactory) if err != nil { return err }