From d45201f932ea2b57e2e38fee14b734059435315b Mon Sep 17 00:00:00 2001 From: BAder82t <41265463+BAder82t@users.noreply.github.com> Date: Mon, 20 Apr 2026 12:24:11 +0200 Subject: [PATCH] fix: null-deref when dbgKey unset under HELIB_DEBUG (#501) Several debug-only paths dereference *dbgKey / *dbgEa unconditionally, segfaulting when setupDebugGlobals() was never called (e.g. the bundled BGV_binary_arithmetic example under Ctxt::multiplyBy - > checkNoise) Guard each debug block with a null check, matching the existing pattern in extractDigits.cpp and debugging.cpp Closes #501 --- src/Ctxt.cpp | 5 +++- src/binaryArith.cpp | 51 +++++++++++++++++++-------------- src/binaryCompare.cpp | 66 +++++++++++++++++++++++-------------------- 3 files changed, 69 insertions(+), 53 deletions(-) diff --git a/src/Ctxt.cpp b/src/Ctxt.cpp index b346fcd52..80ecf9c2f 100644 --- a/src/Ctxt.cpp +++ b/src/Ctxt.cpp @@ -1769,7 +1769,10 @@ void Ctxt::multiplyBy(const Ctxt& other) this->multLowLvl(other); // perform the multiplication reLinearize(); // re-linearize #ifdef HELIB_DEBUG - checkNoise(*this, *dbgKey, "reLinearize " + std::to_string(size_t(this))); + // Skip when setupDebugGlobals() was not called; dereferencing a null + // dbgKey here previously caused a segfault (issue #501). + if (dbgKey) + checkNoise(*this, *dbgKey, "reLinearize " + std::to_string(size_t(this))); #endif } diff --git a/src/binaryArith.cpp b/src/binaryArith.cpp index 6b90c34a6..991f25415 100644 --- a/src/binaryArith.cpp +++ b/src/binaryArith.cpp @@ -1008,15 +1008,18 @@ static void multByNegative(CtPtrs& product, CtPtrMat_VecCt nums(numbers); // Wrapper around numbers #ifdef HELIB_DEBUG - long pa, pb; - std::vector slots; - decryptBinaryNums(slots, a, *dbgKey, *dbgEa, false); - pa = slots[0]; - decryptBinaryNums(slots, b, *dbgKey, *dbgEa, true); - pb = slots[0]; - decryptAndSum((std::cout << " multByNegative: " << pa << '*' << pb << " = "), - nums, - true); + if (dbgKey && dbgEa) { + long pa, pb; + std::vector slots; + decryptBinaryNums(slots, a, *dbgKey, *dbgEa, false); + pa = slots[0]; + decryptBinaryNums(slots, b, *dbgKey, *dbgEa, true); + pb = slots[0]; + decryptAndSum( + (std::cout << " multByNegative: " << pa << '*' << pb << " = "), + nums, + true); + } #endif addManyNumbers(product, nums, resSize, unpackSlotEncoding); } @@ -1107,16 +1110,18 @@ void multTwoNumbers(CtPtrs& product, CtPtrMat_VecCt nums(numbers); // A wrapper around numbers #ifdef HELIB_DEBUG - long plaintext_lhs, plaintext_rhs; - std::vector slots; - decryptBinaryNums(slots, lhs, *dbgKey, *dbgEa, false); - plaintext_lhs = slots[0]; - decryptBinaryNums(slots, rhs, *dbgKey, *dbgEa, false); - plaintext_rhs = slots[0]; - decryptAndSum((std::cout << " multTwoNumbers: " << plaintext_lhs << '*' - << plaintext_rhs << " = "), - nums, - false); + if (dbgKey && dbgEa) { + long plaintext_lhs, plaintext_rhs; + std::vector slots; + decryptBinaryNums(slots, lhs, *dbgKey, *dbgEa, false); + plaintext_lhs = slots[0]; + decryptBinaryNums(slots, rhs, *dbgKey, *dbgEa, false); + plaintext_rhs = slots[0]; + decryptAndSum((std::cout << " multTwoNumbers: " << plaintext_lhs << '*' + << plaintext_rhs << " = "), + nums, + false); + } #endif addManyNumbers(product, nums, resSize, unpackSlotEncoding); } @@ -1402,7 +1407,7 @@ void AddDAG::printAddDAG(bool printCT) if (node->parent1) std::cout << ", prnt1=" << node->parent1->nodeName(); std::cout << " }\n"; - if (printCT && node->ct != nullptr) + if (printCT && node->ct != nullptr && dbgKey && dbgEa) decryptAndPrint(std::cout, *(node->ct), *dbgKey, @@ -1430,7 +1435,7 @@ void AddDAG::printAddDAG(bool printCT) if (node->parent1) std::cout << ", prnt1=" << node->parent1->nodeName(); std::cout << " }\n"; - if (printCT && node->ct != nullptr) + if (printCT && node->ct != nullptr && dbgKey && dbgEa) decryptAndPrint(std::cout, *(node->ct), *dbgKey, @@ -1445,6 +1450,10 @@ void decryptAndSum(std::ostream& s, const CtPtrMat& numbers, bool twosComplement) { + if (!dbgKey || !dbgEa) { + s << "(skipped: debug globals not set)" << std::endl; + return; + } s << "sum("; long sum = 0; for (long i = 0; i < numbers.size(); i++) { diff --git a/src/binaryCompare.cpp b/src/binaryCompare.cpp index 3376a7a45..25b59171a 100644 --- a/src/binaryCompare.cpp +++ b/src/binaryCompare.cpp @@ -72,25 +72,27 @@ static void compProducts(const CtPtrs_slice& e, const CtPtrs_slice& g) } NTL_EXEC_RANGE_END #ifdef HELIB_DEBUG - std::cout << " g[" << g.start << ".." << (g.start + g.sz - 1) << "], " - << " e[" << e.start << ".." << (e.start + e.sz - 1) - << "]:" << std::endl; - for (long i = 0; i < g.size(); i++) - decryptAndPrint((std::cout << " g[" << (i + g.start) << "] (" - << ((void*)g[i]) << "): "), - *g[i], - *dbgKey, - *dbgEa, - FLAG_PRINT_POLY); - for (long i = 0; i < e.size(); i++) - decryptAndPrint((std::cout << " e[" << (i + e.start) << "] (" - << ((void*)e[i]) << "): "), - *e[i], - *dbgKey, - *dbgEa, - FLAG_PRINT_POLY); + if (dbgKey && dbgEa) { + std::cout << " g[" << g.start << ".." << (g.start + g.sz - 1) << "], " + << " e[" << e.start << ".." << (e.start + e.sz - 1) + << "]:" << std::endl; + for (long i = 0; i < g.size(); i++) + decryptAndPrint((std::cout << " g[" << (i + g.start) << "] (" + << ((void*)g[i]) << "): "), + *g[i], + *dbgKey, + *dbgEa, + FLAG_PRINT_POLY); + for (long i = 0; i < e.size(); i++) + decryptAndPrint((std::cout << " e[" << (i + e.start) << "] (" + << ((void*)e[i]) << "): "), + *e[i], + *dbgKey, + *dbgEa, + FLAG_PRINT_POLY); - std::cout << std::endl; + std::cout << std::endl; + } #endif } @@ -141,19 +143,21 @@ static void compEqGt(CtPtrs& aeqb, HELIB_NTIMER_STOP(compEqGt2); #ifdef HELIB_DEBUG - for (long i = 0; i < lsize(b); i++) - decryptAndPrint((std::cout << " e[" << i << "]: "), - *aeqb[i], - *dbgKey, - *dbgEa, - FLAG_PRINT_POLY); - for (long i = 0; i < lsize(a); i++) - decryptAndPrint((std::cout << " ag[" << i << "]: "), - *agtb[i], - *dbgKey, - *dbgEa, - FLAG_PRINT_POLY); - std::cout << std::endl; + if (dbgKey && dbgEa) { + for (long i = 0; i < lsize(b); i++) + decryptAndPrint((std::cout << " e[" << i << "]: "), + *aeqb[i], + *dbgKey, + *dbgEa, + FLAG_PRINT_POLY); + for (long i = 0; i < lsize(a); i++) + decryptAndPrint((std::cout << " ag[" << i << "]: "), + *agtb[i], + *dbgKey, + *dbgEa, + FLAG_PRINT_POLY); + std::cout << std::endl; + } #endif // Call a recursive function to compute: