@@ -182,13 +182,22 @@ namespace SaveState
182182 return CChunkFileReader::ERROR_BAD_FILE;
183183
184184 static std::vector<u8 > buffer;
185- LockedDecompress (buffer, states_[n], bases_[baseMapping_[n]]);
186- CChunkFileReader::Error error = LoadFromRam (buffer, errorString);
187- rewindLastTime_ = time_now_d ();
188- return error;
185+ if (LockedDecompress (buffer, states_[n], bases_[baseMapping_[n]])) {
186+ CChunkFileReader::Error error = LoadFromRam (buffer, errorString);
187+ if (error == CChunkFileReader::ERROR_NONE) {
188+ INFO_LOG (SAVESTATE, " Rewinding to recent savestate snapshot (%d bytes compressed)" , states_[n].zstd_compressed .size ());
189+ rewindLastTime_ = time_now_d ();
190+ }
191+ return error;
192+ } else {
193+ WARN_LOG (SAVESTATE, " Failed to load rewind savestate" );
194+ // Unclear what CChunkFileReader error code we should pass in this case, which I'm not sure will
195+ // happen in practice barring memory corruption.
196+ }
197+ return CChunkFileReader::ERROR_NONE;
189198 }
190199
191- void ScheduleCompress (StateBuffer *result, std::vector<u8 > *state, const std::vector<u8 > *base)
200+ void ScheduleCompress (StateBuffer *result, const std::vector<u8 > *state, const std::vector<u8 > *base)
192201 {
193202 if (compressThread_.joinable ())
194203 compressThread_.join ();
@@ -202,7 +211,7 @@ namespace SaveState
202211
203212 const bool USE_XOR = false ;
204213
205- void Compress (StateBuffer *result, std::vector<u8 > &state, const std::vector<u8 > &base)
214+ void Compress (StateBuffer *result, const std::vector<u8 > &state, const std::vector<u8 > &base)
206215 {
207216 std::lock_guard<std::mutex> guard (lock_);
208217 // Bail if we were cleared before locking.
@@ -240,28 +249,30 @@ namespace SaveState
240249
241250 // Temporarily allocate a buffer to do compression in.
242251 size_t compressCapacity = ZSTD_compressBound (compressed.size ());
243- u8 *compress_buf = ( u8 *) malloc (compressCapacity);
244- result->compressed_size = ZSTD_compress (compress_buf , compressCapacity, compressed.data (), compressed.size (), 0 );
252+ result-> zstd_compressed . resize (compressCapacity);
253+ result->compressed_size = ZSTD_compress (&result-> zstd_compressed [ 0 ] , compressCapacity, compressed.data (), compressed.size (), 1 );
245254 if (result->compressed_size ) {
246- result->zstd_compressed = std::vector<u8 >(result->compressed_size , 0 );
247- memcpy (&result->zstd_compressed [0 ], compress_buf, result->compressed_size );
255+ result->zstd_compressed .resize (result->compressed_size );
248256 result->decompressed_size = compressed.size ();
249257 }
250- free (compress_buf);
251258
252259 double zstd_s = time_now_d () - start_time - taken_s;
253260 DEBUG_LOG (SAVESTATE, " Rewind: ZSTD compressed to %d in %0.2f ms." , (int )result->compressed_size , zstd_s * 1000.0 );
254261 }
255262
256- void LockedDecompress (std::vector<u8 > &result, const StateBuffer &buffer, const std::vector<u8 > &base)
263+ bool LockedDecompress (std::vector<u8 > &result, const StateBuffer &buffer, const std::vector<u8 > &base)
257264 {
258265 result.clear ();
259266 result.reserve (base.size ());
260267 auto basePos = base.begin ();
261268
262269 // OK, zstd decompress first.
263270 std::vector<u8 > compressed = std::vector<u8 >(buffer.decompressed_size , 0 );
264- ZSTD_decompress (&compressed[0 ], compressed.size (), buffer.zstd_compressed .data (), buffer.zstd_compressed .size ());
271+ size_t retval = ZSTD_decompress (&compressed[0 ], compressed.size (), buffer.zstd_compressed .data (), buffer.zstd_compressed .size ());
272+ if (ZSTD_isError (retval)) {
273+ WARN_LOG (SAVESTATE, " Failed to decompress zstd-compressed rewind savestate" );
274+ return false ;
275+ }
265276
266277 if (USE_XOR) {
267278 result.resize (compressed.size ());
@@ -295,6 +306,7 @@ namespace SaveState
295306 }
296307 }
297308 }
309+ return true ;
298310 }
299311
300312 void Clear ()
@@ -1052,7 +1064,6 @@ namespace SaveState
10521064 break ;
10531065
10541066 case SAVESTATE_REWIND:
1055- INFO_LOG (SAVESTATE, " Rewinding to recent savestate snapshot" );
10561067 result = rewindStates.Restore (&errorString);
10571068 if (result == CChunkFileReader::ERROR_NONE) {
10581069 callbackMessage = sc->T (" Loaded State" );
0 commit comments