Skip to content
95 changes: 49 additions & 46 deletions state/accountsDB.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,54 +346,68 @@ func (adb *AccountsDB) saveCode(newAcc, oldAcc baseAccountHandler) error {
return nil
}

unmodifiedOldCodeEntry, err := adb.updateOldCodeEntry(oldCodeHash)
oldCodeEntry, err := adb.getCodeEntry(oldCodeHash)
if err != nil {
return err
}

err = adb.updateNewCodeEntry(newCodeHash, newCode)
newCodeEntry, err := adb.getCodeEntry(newCodeHash)
if err != nil {
return err
}

entry, err := NewJournalEntryCode(unmodifiedOldCodeEntry, oldCodeHash, newCodeHash, adb.mainTrie, adb.marshaller)
entry, err := NewJournalEntryCode(oldCodeEntry, oldCodeHash, newCodeEntry, newCodeHash, adb.mainTrie, adb.marshaller)
if err != nil {
return err
}
adb.journalize(entry)

err = adb.updateOldCodeEntry(oldCodeHash, oldCodeEntry)
if err != nil {
return err
}

err = adb.updateNewCodeEntry(newCodeHash, newCodeEntry, newCode)
if err != nil {
return err
}

newAcc.SetCodeHash(newCodeHash)
return nil
}

func (adb *AccountsDB) updateOldCodeEntry(oldCodeHash []byte) (*CodeEntry, error) {
oldCodeEntry, err := getCodeEntry(oldCodeHash, adb.mainTrie, adb.marshaller)
func (adb *AccountsDB) getCodeEntry(hash []byte) (*CodeEntry, error) {
codeEntry, err := getCodeEntry(hash, adb.mainTrie, adb.marshaller)
if err != nil {
return nil, err
}

if oldCodeEntry == nil {
return nil, nil
}

sc := &stateChange.StateAccess{
Type: stateChange.Read,
MainTrieKey: oldCodeHash,
MainTrieKey: hash,
MainTrieVal: nil,
Operation: stateChange.GetCode,
DataTrieChanges: nil,
}
adb.stateAccessesCollector.AddStateAccess(sc)

unmodifiedOldCodeEntry := &CodeEntry{
return codeEntry, nil
}

func (adb *AccountsDB) updateOldCodeEntry(oldCodeHash []byte, oldCodeEntry *CodeEntry) error {
if oldCodeEntry == nil {
return nil
}

codeEntryClone := &CodeEntry{
Code: oldCodeEntry.Code,
NumReferences: oldCodeEntry.NumReferences,
}

if oldCodeEntry.NumReferences <= 1 {
err = adb.mainTrie.Delete(oldCodeHash)
if codeEntryClone.NumReferences <= 1 {
err := adb.mainTrie.Delete(oldCodeHash)
if err != nil {
return nil, err
return err
}

sc1 := &stateChange.StateAccess{
Expand All @@ -405,16 +419,16 @@ func (adb *AccountsDB) updateOldCodeEntry(oldCodeHash []byte) (*CodeEntry, error
}
adb.stateAccessesCollector.AddStateAccess(sc1)

return unmodifiedOldCodeEntry, nil
return nil
}

oldCodeEntry.NumReferences--
codeEntryBytes, err := saveCodeEntry(oldCodeHash, oldCodeEntry, adb.mainTrie, adb.marshaller)
codeEntryClone.NumReferences--
codeEntryBytes, err := saveCodeEntry(oldCodeHash, codeEntryClone, adb.mainTrie, adb.marshaller)
if err != nil {
return nil, err
return err
}

sc = &stateChange.StateAccess{
sc := &stateChange.StateAccess{
Type: stateChange.Write,
MainTrieKey: oldCodeHash,
MainTrieVal: codeEntryBytes,
Expand All @@ -423,41 +437,27 @@ func (adb *AccountsDB) updateOldCodeEntry(oldCodeHash []byte) (*CodeEntry, error
}
adb.stateAccessesCollector.AddStateAccess(sc)

return unmodifiedOldCodeEntry, nil
return nil
}

func (adb *AccountsDB) updateNewCodeEntry(newCodeHash []byte, newCode []byte) error {
func (adb *AccountsDB) updateNewCodeEntry(newCodeHash []byte, newCodeEntry *CodeEntry, newCode []byte) error {
if len(newCode) == 0 {
return nil
}

newCodeEntry, err := getCodeEntry(newCodeHash, adb.mainTrie, adb.marshaller)
if err != nil {
return err
}

sc := &stateChange.StateAccess{
Type: stateChange.Read,
MainTrieKey: newCodeHash,
MainTrieVal: nil,
Operation: stateChange.GetCode,
DataTrieChanges: nil,
}
adb.stateAccessesCollector.AddStateAccess(sc)

if newCodeEntry == nil {
newCodeEntry = &CodeEntry{
Code: newCode,
}
codeEntry := &CodeEntry{}
codeEntry.Code = newCode
if newCodeEntry != nil {
codeEntry.NumReferences = newCodeEntry.NumReferences
}
newCodeEntry.NumReferences++
codeEntry.NumReferences++

codeEntryBytes, err := saveCodeEntry(newCodeHash, newCodeEntry, adb.mainTrie, adb.marshaller)
codeEntryBytes, err := saveCodeEntry(newCodeHash, codeEntry, adb.mainTrie, adb.marshaller)
if err != nil {
return err
}

sc = &stateChange.StateAccess{
sc := &stateChange.StateAccess{
Type: stateChange.Write,
MainTrieKey: newCodeHash,
MainTrieVal: codeEntryBytes,
Expand Down Expand Up @@ -676,18 +676,19 @@ func (adb *AccountsDB) removeDataTrie(baseAcc baseAccountHandler) error {

func (adb *AccountsDB) removeCode(baseAcc baseAccountHandler) error {
oldCodeHash := baseAcc.GetCodeHash()
unmodifiedOldCodeEntry, err := adb.updateOldCodeEntry(oldCodeHash)

oldCodeEntry, err := adb.getCodeEntry(oldCodeHash)
if err != nil {
return err
}

codeChangeEntry, err := NewJournalEntryCode(unmodifiedOldCodeEntry, oldCodeHash, nil, adb.mainTrie, adb.marshaller)
entry, err := NewJournalEntryCode(oldCodeEntry, oldCodeHash, nil, nil, adb.mainTrie, adb.marshaller)
if err != nil {
return err
}
adb.journalize(codeChangeEntry)
adb.journalize(entry)

return nil
return adb.updateOldCodeEntry(oldCodeHash, oldCodeEntry)
}

// LoadAccount fetches the account based on the address. Creates an empty account if the account is missing.
Expand Down Expand Up @@ -853,12 +854,14 @@ func (adb *AccountsDB) RevertToSnapshot(snapshot int) error {
for i := len(adb.entries) - 1; i >= snapshot; i-- {
account, err := adb.entries[i].Revert()
if err != nil {
adb.entries = adb.entries[:i+1]
return err
}

if !check.IfNil(account) {
_, err = adb.saveAccountToTrie(account, adb.mainTrie)
if err != nil {
adb.entries = adb.entries[:i+1]
return err
}
}
Expand Down
Loading
Loading