Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions src/export_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ cache_t* pypluginCache_init(
py::function cache_init_hook, py::function cache_hit_hook,
py::function cache_miss_hook, py::function cache_eviction_hook,
py::function cache_remove_hook, py::function cache_free_hook) {
py::gil_scoped_acquire acquire;
// Initialize base cache structure with exception safety
cache_t* cache = nullptr;
std::unique_ptr<pypluginCache_params_t, PypluginCacheParamsDeleter> params;
Expand Down Expand Up @@ -163,20 +164,28 @@ cache_t* pypluginCache_init(
}

static void pypluginCache_free(cache_t* cache) {
if (!cache || !cache->eviction_params) {
if (!cache) {
return;
}
py::gil_scoped_acquire acquire;
if (!cache->eviction_params) {
Comment thread
haochengxia marked this conversation as resolved.
// No params, just free the cache structure
cache_struct_free(cache);
return;
}

// Use smart pointer for automatic cleanup
std::unique_ptr<pypluginCache_params_t, PypluginCacheParamsDeleter> params(
static_cast<pypluginCache_params_t*>(cache->eviction_params));

auto* raw_params = static_cast<pypluginCache_params_t*>(cache->eviction_params);
cache->eviction_params = nullptr;
// The smart pointer destructor will handle cleanup automatically
std::unique_ptr<pypluginCache_params_t, PypluginCacheParamsDeleter> params(raw_params);
params.reset();

cache_struct_free(cache);
}

static bool pypluginCache_get(cache_t* cache, const request_t* req) {
bool hit = cache_get_base(cache, req);
py::gil_scoped_acquire acquire;
pypluginCache_params_t* params =
(pypluginCache_params_t*)cache->eviction_params;

Expand Down Expand Up @@ -204,6 +213,7 @@ static cache_obj_t* pypluginCache_to_evict(cache_t* cache,
}

static void pypluginCache_evict(cache_t* cache, const request_t* req) {
py::gil_scoped_acquire acquire;
pypluginCache_params_t* params =
(pypluginCache_params_t*)cache->eviction_params;

Expand All @@ -223,6 +233,7 @@ static void pypluginCache_evict(cache_t* cache, const request_t* req) {
}

static bool pypluginCache_remove(cache_t* cache, const obj_id_t obj_id) {
py::gil_scoped_acquire acquire;
pypluginCache_params_t* params =
(pypluginCache_params_t*)cache->eviction_params;

Expand Down Expand Up @@ -568,7 +579,8 @@ void export_cache(py::module& m) {
bytes_req > 0 ? 1.0 - (double)bytes_hit / bytes_req : 0.0;
return std::make_tuple(obj_miss_ratio, byte_miss_ratio);
},
"cache"_a, "reader"_a, "start_req"_a = 0, "max_req"_a = -1);
"cache"_a, "reader"_a, "start_req"_a = 0, "max_req"_a = -1,
py::call_guard<py::gil_scoped_release>());
}

} // namespace libcachesim
50 changes: 50 additions & 0 deletions tests/test_gil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pytest
import libcachesim as lcs
import threading
import time

S3_URI = "s3://cache-datasets/cache_dataset_oracleGeneral/2007_msr/msr_hm_0.oracleGeneral.zst"

def run_heavy_simulation(name):
# Create a large synthetic trace
reader = lcs.TraceReader(trace=S3_URI)
cache = lcs.LRU(cache_size=1024*1024)

print(f"Thread {name} starting simulation...")
start = time.time()
# Call C++ core logic
lcs.Util.process_trace(cache, reader)
end = time.time()
print(f"Thread {name} completed in {end - start:.2f}s")

def test_gil_release():
"""
Test to verify that the GIL is released during heavy C++ processing.
We run two threads that perform heavy simulations and measure total time.
If the total time is close to the single-thread time, it indicates GIL release.
If the total time is close to double the single-thread time, it indicates GIL is still held.
"""
# --- Experiment start ---

# test single-thread time for reference
start_single = time.time()
run_heavy_simulation("Single")
end_single = time.time()
single_thread_time = end_single - start_single
print(f"\nSingle-thread time: {single_thread_time:.2f}s")

start_total = time.time()

t1 = threading.Thread(target=run_heavy_simulation, args=("A",))
t2 = threading.Thread(target=run_heavy_simulation, args=("B",))

t1.start()
t2.start()

t1.join()
t2.join()

end_total = time.time()
print(f"\nTotal elapsed time: {end_total - start_total:.2f}s")

assert single_thread_time * 1.5 > (end_total - start_total), "GIL release test failed: Total time should be close to single-thread time if GIL is released."
Loading