diff --git a/Makefile b/Makefile index 16ff470ba6..bd4578fed4 100644 --- a/Makefile +++ b/Makefile @@ -1,11 +1,15 @@ -CXXFLAGS = -std=gnu++1y -g -Wall -O2 -MD -MT $@ -MP -flto +sanitize = -fsanitize=address -fsanitize=leak -fsanitize=undefined +CXXFLAGS = -std=gnu++1y -g -Wall -O0 -MD -MT $@ -MP -flto $(sanitize) tests = test-reactor all: seastar $(tests) +clean: + rm seastar $(tests) *.o + seastar: main.o reactor.o $(CXX) $(CXXFLAGS) -o $@ $^ diff --git a/reactor.cc b/reactor.cc index 295df005c0..e4b4f73a6d 100644 --- a/reactor.cc +++ b/reactor.cc @@ -21,7 +21,7 @@ reactor::~reactor() { void reactor::epoll_add_in(pollable_fd& pfd, std::unique_ptr t) { auto ctl = pfd.events ? EPOLL_CTL_MOD : EPOLL_CTL_ADD; - pfd.events |= EPOLLIN; + pfd.events |= EPOLLIN | EPOLLONESHOT; assert(!pfd.pollin); pfd.pollin = std::move(t); ::epoll_event eevt; @@ -33,7 +33,7 @@ void reactor::epoll_add_in(pollable_fd& pfd, std::unique_ptr t) { void reactor::epoll_add_out(pollable_fd& pfd, std::unique_ptr t) { auto ctl = pfd.events ? EPOLL_CTL_MOD : EPOLL_CTL_ADD; - pfd.events |= EPOLLOUT; + pfd.events |= EPOLLOUT | EPOLLONESHOT; assert(!pfd.pollout); pfd.pollout = std::move(t); ::epoll_event eevt; @@ -44,10 +44,14 @@ void reactor::epoll_add_out(pollable_fd& pfd, std::unique_ptr t) { } std::unique_ptr -reactor::listen(socket_address sa) +reactor::listen(socket_address sa, listen_options opts) { - int fd = ::socket(sa.u.sa.sa_family, SOCK_STREAM, SOCK_NONBLOCK | SOCK_CLOEXEC); + int fd = ::socket(sa.u.sa.sa_family, SOCK_STREAM | SOCK_NONBLOCK | SOCK_CLOEXEC, 0); assert(fd != -1); + if (opts.reuse_address) { + int opt = 1; + ::setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)); + } int r = ::bind(fd, &sa.u.sa, sizeof(sa.u.sas)); assert(r != -1); ::listen(fd, 100); diff --git a/reactor.hh b/reactor.hh index 47ee018b64..c0f01ba557 100644 --- a/reactor.hh +++ b/reactor.hh @@ -17,6 +17,8 @@ #include #include #include +#include +#include class socket_address; class reactor; @@ -40,88 +42,212 @@ private: friend class reactor; }; -template +template class promise; +template +class future; + +class task { +public: + virtual ~task() {} + virtual void run() = 0; +}; + +template +class lambda_task : public task { + Func _func; +public: + lambda_task(const Func& func) : _func(func) {} + lambda_task(Func&& func) : _func(std::move(func)) {} + virtual void run() { _func(); } +}; + +template +std::unique_ptr +make_task(const Func& func) { + return std::unique_ptr(new lambda_task(func)); +} + +template +std::unique_ptr +make_task(Func&& func) { + return std::unique_ptr(new lambda_task(std::move(func))); +} + + template struct future_state { - virtual ~future_state(); - promise* promise; - bool value_valid = false; - bool ex_valid = false; - union { + promise* _promise = nullptr; + future* _future = nullptr; + std::unique_ptr _task; + enum class state { + invalid, + future, + result, + exception, + } _state = state::future; + union any { + any() {} + ~any() {} T value; std::exception_ptr ex; - } u; - void set(const T& value); - void set(T&& value); - void set_exception(std::exception_ptr ex); - T get() { - while (promise) { - promise->wait(); + } _u; + ~future_state() noexcept { + switch (_state) { + case state::future: + break; + case state::result: + _u.value.~T(); + break; + case state::exception: + _u.ex.~exception_ptr(); + break; + default: + abort(); } - if (ex) { - std::rethrow_exception(ex); - } - return std::move(u.value); } + bool has_promise() const { return _promise; } + bool has_future() const { return _future; } + void wait(); + void set(const T& value) { + assert(_state == state::future); + _state = state::result; + new (&_u.value) T(value); + if (_task) { + _task->run(); + } + } + void set(T&& value) { + assert(_state == state::future); + _state = state::result; + new (&_u.value) T(std::move(value)); + if (_task) { + _task->run(); + } + } + template + void set(A... a) { + assert(_state == state::future); + _state = state::result; + new (&_u.value) T(std::forward(a)...); + std::cout << "checking task at " << &_task << "\n"; + if (_task) { + _task->run(); + } + } + void set_exception(std::exception_ptr ex) { + assert(_state == state::future); + _state = state::exception; + new (&_u.ex) std::exception(ex); + if (_task) { + _task->run(); + } + } + T get() { + while (_state == state::future) { + abort(); + } + if (_state == state::exception) { + std::rethrow_exception(_u.ex); + } + return std::move(_u.value); + } + template + void schedule(Func&& func) { + std::cout << "scheduling task at " << &_task << "\n"; + _task = make_task(std::forward(func)); + } +}; + +template +class promise { + future_state* _state; +public: + promise() : _state(new future_state()) { _state->_promise = this; } + promise(promise&& x) : _state(std::move(x._state)) { x._state = nullptr; } + promise(const promise&) = delete; + ~promise() { + if (_state) { + _state->_promise = nullptr; + if (!_state->has_future()) { + delete _state; + } + } + } + promise& operator=(promise&&); + void operator=(const T&) = delete; + future get_future(); + void set_value(const T& result) { _state->set(result); } + void set_value(T&& result) { _state->set(std::move(result)); } }; template class future { - std::unique_ptr> _state; + future_state* _state; +private: + future(future_state* state) : _state(state) { _state->_future = this; } public: + future(future&& x) : _state(x._state) { x._state = nullptr; } + future(const future&) = delete; + future& operator=(future&& x); + void operator=(const future&) = delete; + ~future() { + if (_state) { + _state->_future = nullptr; + if (!_state->has_promise()) { + delete _state; + } + } + } T get() { - return _state.get(); + return _state->get(); } template - void then(Func func) { - + void then(Func&& func) { + auto state = _state; + state->schedule([fut = std::move(*this), func = std::forward(func)] () mutable { + std::cout << "running task\n"; + func(std::move(fut)); + }); } + friend class promise; +}; + +template +inline +future +promise::get_future() +{ + assert(!_state->_future); + return future(_state); +} + +using accept_result = std::tuple, socket_address>; + +struct listen_options { + bool reuse_address = false; }; class reactor { - class task; public: int _epollfd; io_context_t _io_context; private: - class task { - public: - virtual ~task() {} - virtual void run() = 0; - }; - template - class lambda_task : public task { - Func _func; - public: - lambda_task(Func func) : _func(func) {} - virtual void run() { _func(); } - }; - - template - std::unique_ptr - make_task(Func func) { - return std::make_unique>(func); - } - void epoll_add_in(pollable_fd& fd, std::unique_ptr t); void epoll_add_out(pollable_fd& fd, std::unique_ptr t); void abort_on_error(int ret); public: reactor(); + reactor(const reactor&) = delete; + void operator=(const reactor&) = delete; ~reactor(); - std::unique_ptr listen(socket_address sa); + std::unique_ptr listen(socket_address sa, listen_options opts = {}); - template - void accept(pollable_fd& listenfd, Func with_pfd_sockaddr); - - future> accept(pollable_fd& listen_fd) + future accept(pollable_fd& listen_fd); future read_some(pollable_fd& fd, void* buffer, size_t size); - template - void read_some(pollable_fd& fd, void* buffer, size_t len, Func with_len); void run(); @@ -135,33 +261,37 @@ protected: void operator=(const pollable_fd&) = delete; int fd; int events = 0; - std::unique_ptr pollin; - std::unique_ptr pollout; + std::unique_ptr pollin; + std::unique_ptr pollout; friend class reactor; }; -template inline -void reactor::accept(pollable_fd& listenfd, Func with_pfd_sockaddr) { - auto lfd = listenfd.fd; - epoll_add_in(listenfd, make_task([=] { +future +reactor::accept(pollable_fd& listenfd) { + promise pr; + future fut = pr.get_future(); + epoll_add_in(listenfd, make_task([pr = std::move(pr), lfd = listenfd.fd] () mutable { socket_address sa; socklen_t sl = sizeof(&sa.u.sas); int fd = ::accept4(lfd, &sa.u.sa, &sl, SOCK_NONBLOCK | SOCK_CLOEXEC); assert(fd != -1); - auto pfd = std::unique_ptr(new pollable_fd(fd)); - with_pfd_sockaddr(std::move(pfd), sa); + pr.set_value(accept_result{std::unique_ptr(new pollable_fd(fd)), sa}); })); + return fut; } -template -void reactor::read_some(pollable_fd& fd, void* buffer, size_t len, Func with_len) { - auto rfd = fd.fd; - epoll_add_in(fd, make_task([=] { +inline +future +reactor::read_some(pollable_fd& fd, void* buffer, size_t len) { + promise pr; + auto fut = pr.get_future(); + epoll_add_in(fd, make_task([pr = std::move(pr), rfd = fd.fd, buffer, len] () mutable { ssize_t r = ::recv(rfd, buffer, len, 0); assert(r != -1); - with_len(len); + pr.set_value(r); })); + return fut; } diff --git a/test-reactor.cc b/test-reactor.cc index 85aad2ba71..db7b9133ba 100644 --- a/test-reactor.cc +++ b/test-reactor.cc @@ -11,19 +11,39 @@ struct test { reactor r; std::unique_ptr listener; - void on_accept(std::unique_ptr pfd, socket_address sa) { + struct connection { + reactor& r; + std::unique_ptr fd; + }; + void new_connection(accept_result&& accepted) { std::cout << "got connection\n"; - r.accept(*listener, [=] (std::unique_ptr pfd, socket_address sa) { - on_accept(std::move(pfd), sa); + copy_data(std::move(std::get<0>(accepted))); + } + void copy_data(std::unique_ptr fd) { + char buffer[8192]; + r.read_some(*fd, buffer, sizeof(buffer)).then([this] (future fut) { + auto n = fut.get(); + std::cout << "got data: " << n << "\n"; }); } + void start_accept() { + r.accept(*listener).then([this] (future fut) { + std::cout << "accept future returned\n"; + new_connection(fut.get()); + start_accept(); + }); + } +}; int main(int ac, char** av) { test t; ipv4_addr addr{{}, 10000}; - t.listener = r.listen(make_ipv4_address(addr)); - r.accept(*listener, [&] - r.run(); + listen_options lo; + lo.reuse_address = true; + t.listener = t.r.listen(make_ipv4_address(addr), lo); + t.start_accept(); + t.r.run(); + return 0; }