Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 183 additions & 0 deletions co/coroutine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
#include <cassert>
#include <fcntl.h>
#include <iostream>
#include <signal.h>

#include "bitset.h"

#include "absl/container/flat_hash_map.h"

#if defined(__APPLE__)
#include <sys/event.h>
#include <sys/time.h>
Expand All @@ -24,6 +27,7 @@
#elif defined(__linux__)
#include <sys/eventfd.h>
#include <sys/timerfd.h>
#include <sys/signalfd.h>

#endif

Expand Down Expand Up @@ -1058,6 +1062,172 @@ void Coroutine::GetAllFds(std::vector<int> &fds) const {
}
}

class SignalCatcher {
public:
SignalCatcher(CoroutineScheduler& scheduler) :
#if defined(__linux__)
coroutine_(scheduler,
[this](Coroutine *c) {
CatchSignals(c);
},
"CoroutineScheduler Signal Catcher")
#endif // defined(__linux__)
{ ::sigemptyset(&signal_set_); }

SignalCatcher(const SignalCatcher&) = delete;
SignalCatcher(SignalCatcher&&) = delete;
SignalCatcher& operator=(const SignalCatcher&) = delete;
SignalCatcher& operator=(SignalCatcher&&) = delete;

~SignalCatcher() {
#if defined(__linux__)
if (signal_fd_ != -1) {
TearDown();
}
#endif // defined(__linux__)
}

std::size_t Size() const { return handlers_.size(); }
bool Empty() const { return handlers_.empty(); }

::sigset_t SignalSet() const { return signal_set_; }

void HandleSig(int signum, CoroutineScheduler& scheduler) {
if (!handlers_.contains(signum)) {
std::cerr << "No handler added for signal " << signum << " ("
<< ::strsignal(signum) << ") but caught anyway!" << std::endl;
return;
}
handlers_[signum](signum, scheduler);
}

#if defined(__linux__)
void CatchSignals(Coroutine *c) {
assert(c != nullptr);
// Scheduler must have setup signalfd before hitting this coroutine.
assert(signal_fd_ != -1);
while (true) {
c->Wait(signal_fd_, POLLIN);
::signalfd_siginfo siginfo{};
auto n_bytes = ::read(signal_fd_, &siginfo, sizeof(siginfo));
assert(n_bytes == sizeof(siginfo));
if (n_bytes != sizeof(siginfo)) {
continue;
}
HandleSig(static_cast<int>(siginfo.ssi_signo), c->Scheduler());
}
}
#else // defined(__linux__)
void CheckPendingSignals(CoroutineScheduler& scheduler) {
while (true) {
::sigset_t pending_set;
if (::sigpending(&pending_set) != 0) {
// Impossible in reality.
assert(false);
}
#if defined(_GNU_SOURCE)
if (::sigisemptyset(&pending_set)) {
// Done!
return;
}
#else // defined(_GNU_SOURCE)
::sigset_t empty_set;
::sigemptyset(&empty_set);
if (pending_set == empty_set) {
// Done!
return;
}
#endif // defined(_GNU_SOURCE)
int signum = 0;
if (::sigwait(&pending_set, &signum) != 0) {
// Impossible in reality, we checked pending already.
assert(false);
}
HandleSig(signum, scheduler);
}
#endif // defined(__linux__)

void AddHandler(int signum, CoroutineScheduler::SignalHandler&& handler) {
#if defined(__linux__)
if (signal_fd_ != -1) {
std::cerr << "Coroutine scheduler is running!" << std::endl;
abort();
}
#endif // defined(__linux__)
auto [iter, inserted] = handlers_.emplace(signum, std::move(handler));
if (!inserted) {
std::cerr << "Signal handler " << signum << " already exists!" << std::endl;
abort();
}
::sigaddset(&signal_set_, signum);
}

void RemoveHandler(int signum) {
#if defined(__linux__)
if (signal_fd_ != -1) {
std::cerr << "Coroutine scheduler is running!" << std::endl;
abort();
}
#endif // defined(__linux__)
::sigdelset(&signal_set_, signum);
handlers_.erase(signum);
}

void SetUp() {
if (::sigprocmask(SIG_SETMASK, &signal_set_, &original_signal_set_) == -1) {
std::cerr << "Failed to update signal mask: "
<< std::strerror(errno) << std::endl;
abort();
}
#if defined(__linux__)
assert(signal_fd_ == -1);
signal_fd_ = ::signalfd(-1, &signal_set_, SFD_NONBLOCK | SFD_CLOEXEC);
if (signal_fd_ == -1) {
std::cerr << "Failed to update signal file descriptor: "
<< std::strerror(errno) << std::endl;
abort();
}
#endif // defined(__linux__)
}

void TearDown() {
if (::sigprocmask(SIG_SETMASK, &original_signal_set_, nullptr) == -1) {
std::cerr << "Failed to update signal mask: "
<< std::strerror(errno) << std::endl;
abort();
}
#if defined(__linux__)
assert(signal_fd_ != -1);
close(signal_fd_);
#endif // defined(__linux__)
}

private:
#if defined(__linux__)
Coroutine coroutine_;
int signal_fd_ = -1;
#endif // defined(__linux__)
::sigset_t signal_set_{};
::sigset_t original_signal_set_{};

absl::flat_hash_map<int, CoroutineScheduler::SignalHandler> handlers_;
};

void CoroutineScheduler::AddSignalHandler(int signum, SignalHandler handler) {
if (signal_catcher_ == nullptr) {
signal_catcher_ = std::make_unique<SignalCatcher>(*this);
}
signal_catcher_->AddHandler(signum, std::move(handler));
}

void CoroutineScheduler::RemoveSignalHandler(int signum) {
assert(signal_catcher_ != nullptr);
signal_catcher_->RemoveHandler(signum);
if (signal_catcher_->Empty()) {
signal_catcher_.reset();
}
}

CoroutineScheduler::CoroutineScheduler() {
#if CO_POLL_MODE == CO_POLL_EPOLL
interrupt_fd_ = NewEventFd();
Expand Down Expand Up @@ -1208,6 +1378,10 @@ void CoroutineScheduler::Run() {
// not a big overhead.
std::vector<YieldedCoroutine> events;

if (signal_catcher_ != nullptr) {
signal_catcher_->SetUp();
}

while (running_) {
if (coroutines_.empty()) {
// No coroutines, nothing to do.
Expand Down Expand Up @@ -1272,6 +1446,11 @@ void CoroutineScheduler::Run() {
BuildPollFds(&poll_state_);
int num_ready =
::poll(poll_state_.pollfds.data(), poll_state_.pollfds.size(), -1);
#if !defined(__linux__)
if (signal_catcher_ != nullptr && num_ready < 0 && errno == EINTR) {
signal_catcher_->CheckPendingSignals();
}
#endif // !defined(__linux__)
if (num_ready <= 0) {
continue;
}
Expand Down Expand Up @@ -1409,6 +1588,10 @@ void CoroutineScheduler::Run() {
}
CommitDeletions();
}

if (signal_catcher_ != nullptr) {
signal_catcher_->TearDown();
}
}

co::Coroutine *CoroutineScheduler::Spawn(std::function<void(co::Coroutine *)> f,
Expand Down
8 changes: 8 additions & 0 deletions co/coroutine.h
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,8 @@ struct PollState {
std::vector<Coroutine *> coroutines;
};

class SignalCatcher; // Forward declaration.

class CoroutineScheduler {
public:
CoroutineScheduler();
Expand Down Expand Up @@ -690,6 +692,10 @@ class CoroutineScheduler {

std::vector<int> GetAllFds() const;

using SignalHandler = std::function<void(int, CoroutineScheduler&)>;
void AddSignalHandler(int signum, SignalHandler handler_callback);
void RemoveSignalHandler(int signum);

co::Coroutine *Spawn(std::function<void(co::Coroutine *)> f,
CoroutineOptions opts = {});
co::Coroutine *Spawn(std::function<void()> f, CoroutineOptions opts = {});
Expand Down Expand Up @@ -735,6 +741,8 @@ class CoroutineScheduler {
struct pollfd co_interrupt_fd_ = {-1, POLLIN, 0};
#endif

std::unique_ptr<SignalCatcher> signal_catcher_;

uint64_t tick_count_ = 0;
CompletionCallback completion_callback_;
absl::flat_hash_set<const Coroutine *> deletions_;
Expand Down
19 changes: 19 additions & 0 deletions co/coroutines_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,25 @@ TEST(CoroutineTest, Sleep) {
std::cerr << "done" << std::endl;
}

TEST(CoroutineTest, CatchSignal) {
co::CoroutineScheduler scheduler;
bool sleeper_woke_up = false;
co::Coroutine c_sleeper(scheduler, [&sleeper_woke_up](co::Coroutine *c) {
c->Millisleep(1000);
sleeper_woke_up = true;
});
co::Coroutine c_killer(scheduler, [](co::Coroutine *c) {
c->Millisleep(10);
::kill(getpid(), SIGUSR1);
});
scheduler.AddSignalHandler(SIGUSR1, [](int signum, co::CoroutineScheduler &sched) {
EXPECT_EQ(signum, SIGUSR1);
sched.Stop();
});
scheduler.Run();
EXPECT_FALSE(sleeper_woke_up);
}

TEST(CoroutinesTest, Wait) {
co::CoroutineScheduler scheduler;
int pipes[2];
Expand Down