diff --git a/co/coroutine.cc b/co/coroutine.cc index 8ba3d9e..40fec40 100644 --- a/co/coroutine.cc +++ b/co/coroutine.cc @@ -13,9 +13,12 @@ #include #include #include +#include #include "bitset.h" +#include "absl/container/flat_hash_map.h" + #if defined(__APPLE__) #include #include @@ -24,6 +27,7 @@ #elif defined(__linux__) #include #include +#include #endif @@ -1058,6 +1062,172 @@ void Coroutine::GetAllFds(std::vector &fds) const { } } +class SignalCatcher { +public: + SignalCatcher(CoroutineScheduler& scheduler) : +#if defined(__linux__) + coroutine_(scheduler, + [this](Coroutine *c) { + CatchSignals(c); + }, + "CoroutineScheduler Signal Catcher") +#endif // defined(__linux__) + { ::sigemptyset(&signal_set_); } + + SignalCatcher(const SignalCatcher&) = delete; + SignalCatcher(SignalCatcher&&) = delete; + SignalCatcher& operator=(const SignalCatcher&) = delete; + SignalCatcher& operator=(SignalCatcher&&) = delete; + + ~SignalCatcher() { +#if defined(__linux__) + if (signal_fd_ != -1) { + TearDown(); + } +#endif // defined(__linux__) + } + + std::size_t Size() const { return handlers_.size(); } + bool Empty() const { return handlers_.empty(); } + + ::sigset_t SignalSet() const { return signal_set_; } + + void HandleSig(int signum, CoroutineScheduler& scheduler) { + if (!handlers_.contains(signum)) { + std::cerr << "No handler added for signal " << signum << " (" + << ::strsignal(signum) << ") but caught anyway!" << std::endl; + return; + } + handlers_[signum](signum, scheduler); + } + +#if defined(__linux__) + void CatchSignals(Coroutine *c) { + assert(c != nullptr); + // Scheduler must have setup signalfd before hitting this coroutine. + assert(signal_fd_ != -1); + while (true) { + c->Wait(signal_fd_, POLLIN); + ::signalfd_siginfo siginfo{}; + auto n_bytes = ::read(signal_fd_, &siginfo, sizeof(siginfo)); + assert(n_bytes == sizeof(siginfo)); + if (n_bytes != sizeof(siginfo)) { + continue; + } + HandleSig(static_cast(siginfo.ssi_signo), c->Scheduler()); + } + } +#else // defined(__linux__) + void CheckPendingSignals(CoroutineScheduler& scheduler) { + while (true) { + ::sigset_t pending_set; + if (::sigpending(&pending_set) != 0) { + // Impossible in reality. + assert(false); + } +#if defined(_GNU_SOURCE) + if (::sigisemptyset(&pending_set)) { + // Done! + return; + } +#else // defined(_GNU_SOURCE) + ::sigset_t empty_set; + ::sigemptyset(&empty_set); + if (pending_set == empty_set) { + // Done! + return; + } +#endif // defined(_GNU_SOURCE) + int signum = 0; + if (::sigwait(&pending_set, &signum) != 0) { + // Impossible in reality, we checked pending already. + assert(false); + } + HandleSig(signum, scheduler); + } +#endif // defined(__linux__) + + void AddHandler(int signum, CoroutineScheduler::SignalHandler&& handler) { +#if defined(__linux__) + if (signal_fd_ != -1) { + std::cerr << "Coroutine scheduler is running!" << std::endl; + abort(); + } +#endif // defined(__linux__) + auto [iter, inserted] = handlers_.emplace(signum, std::move(handler)); + if (!inserted) { + std::cerr << "Signal handler " << signum << " already exists!" << std::endl; + abort(); + } + ::sigaddset(&signal_set_, signum); + } + + void RemoveHandler(int signum) { +#if defined(__linux__) + if (signal_fd_ != -1) { + std::cerr << "Coroutine scheduler is running!" << std::endl; + abort(); + } +#endif // defined(__linux__) + ::sigdelset(&signal_set_, signum); + handlers_.erase(signum); + } + + void SetUp() { + if (::sigprocmask(SIG_SETMASK, &signal_set_, &original_signal_set_) == -1) { + std::cerr << "Failed to update signal mask: " + << std::strerror(errno) << std::endl; + abort(); + } +#if defined(__linux__) + assert(signal_fd_ == -1); + signal_fd_ = ::signalfd(-1, &signal_set_, SFD_NONBLOCK | SFD_CLOEXEC); + if (signal_fd_ == -1) { + std::cerr << "Failed to update signal file descriptor: " + << std::strerror(errno) << std::endl; + abort(); + } +#endif // defined(__linux__) + } + + void TearDown() { + if (::sigprocmask(SIG_SETMASK, &original_signal_set_, nullptr) == -1) { + std::cerr << "Failed to update signal mask: " + << std::strerror(errno) << std::endl; + abort(); + } +#if defined(__linux__) + assert(signal_fd_ != -1); + close(signal_fd_); +#endif // defined(__linux__) + } + +private: +#if defined(__linux__) + Coroutine coroutine_; + int signal_fd_ = -1; +#endif // defined(__linux__) + ::sigset_t signal_set_{}; + ::sigset_t original_signal_set_{}; + + absl::flat_hash_map handlers_; +}; + +void CoroutineScheduler::AddSignalHandler(int signum, SignalHandler handler) { + if (signal_catcher_ == nullptr) { + signal_catcher_ = std::make_unique(*this); + } + signal_catcher_->AddHandler(signum, std::move(handler)); +} + +void CoroutineScheduler::RemoveSignalHandler(int signum) { + assert(signal_catcher_ != nullptr); + signal_catcher_->RemoveHandler(signum); + if (signal_catcher_->Empty()) { + signal_catcher_.reset(); + } +} + CoroutineScheduler::CoroutineScheduler() { #if CO_POLL_MODE == CO_POLL_EPOLL interrupt_fd_ = NewEventFd(); @@ -1208,6 +1378,10 @@ void CoroutineScheduler::Run() { // not a big overhead. std::vector events; + if (signal_catcher_ != nullptr) { + signal_catcher_->SetUp(); + } + while (running_) { if (coroutines_.empty()) { // No coroutines, nothing to do. @@ -1272,6 +1446,11 @@ void CoroutineScheduler::Run() { BuildPollFds(&poll_state_); int num_ready = ::poll(poll_state_.pollfds.data(), poll_state_.pollfds.size(), -1); +#if !defined(__linux__) + if (signal_catcher_ != nullptr && num_ready < 0 && errno == EINTR) { + signal_catcher_->CheckPendingSignals(); + } +#endif // !defined(__linux__) if (num_ready <= 0) { continue; } @@ -1409,6 +1588,10 @@ void CoroutineScheduler::Run() { } CommitDeletions(); } + + if (signal_catcher_ != nullptr) { + signal_catcher_->TearDown(); + } } co::Coroutine *CoroutineScheduler::Spawn(std::function f, diff --git a/co/coroutine.h b/co/coroutine.h index b681d89..22fa928 100644 --- a/co/coroutine.h +++ b/co/coroutine.h @@ -621,6 +621,8 @@ struct PollState { std::vector coroutines; }; +class SignalCatcher; // Forward declaration. + class CoroutineScheduler { public: CoroutineScheduler(); @@ -690,6 +692,10 @@ class CoroutineScheduler { std::vector GetAllFds() const; + using SignalHandler = std::function; + void AddSignalHandler(int signum, SignalHandler handler_callback); + void RemoveSignalHandler(int signum); + co::Coroutine *Spawn(std::function f, CoroutineOptions opts = {}); co::Coroutine *Spawn(std::function f, CoroutineOptions opts = {}); @@ -735,6 +741,8 @@ class CoroutineScheduler { struct pollfd co_interrupt_fd_ = {-1, POLLIN, 0}; #endif + std::unique_ptr signal_catcher_; + uint64_t tick_count_ = 0; CompletionCallback completion_callback_; absl::flat_hash_set deletions_; diff --git a/co/coroutines_test.cc b/co/coroutines_test.cc index a7f7cd0..70d21f0 100644 --- a/co/coroutines_test.cc +++ b/co/coroutines_test.cc @@ -55,6 +55,25 @@ TEST(CoroutineTest, Sleep) { std::cerr << "done" << std::endl; } +TEST(CoroutineTest, CatchSignal) { + co::CoroutineScheduler scheduler; + bool sleeper_woke_up = false; + co::Coroutine c_sleeper(scheduler, [&sleeper_woke_up](co::Coroutine *c) { + c->Millisleep(1000); + sleeper_woke_up = true; + }); + co::Coroutine c_killer(scheduler, [](co::Coroutine *c) { + c->Millisleep(10); + ::kill(getpid(), SIGUSR1); + }); + scheduler.AddSignalHandler(SIGUSR1, [](int signum, co::CoroutineScheduler &sched) { + EXPECT_EQ(signum, SIGUSR1); + sched.Stop(); + }); + scheduler.Run(); + EXPECT_FALSE(sleeper_woke_up); +} + TEST(CoroutinesTest, Wait) { co::CoroutineScheduler scheduler; int pipes[2];