~singpolyma/biboumi

b86547dc1ef407ca3838444533bc7145e32a0d90 — Florent Le Coz 9 years ago a171357
Implement async DNS resolution using c-ares

fix #2533
M CMakeLists.txt => CMakeLists.txt +13 -0
@@ 39,6 39,12 @@ elseif(NOT WITHOUT_BOTAN)
  find_package(BOTAN)
endif()

if(WITH_CARES)
  find_package(CARES REQUIRED)
elseif(NOT WITHOUT_CARES)
  find_package(CARES)
endif()

#
## Get the software version
#


@@ 84,6 90,10 @@ if(BOTAN_FOUND)
  include_directories(SYSTEM ${BOTAN_INCLUDE_DIRS})
endif()

if(CARES_FOUND)
  include_directories(${CARES_INCLUDE_DIRS})
endif()

set(POLLER_DOCSTRING "Choose the poller between POLL and EPOLL (Linux-only)")
if(${CMAKE_SYSTEM_NAME} MATCHES "Linux")
 set(POLLER "EPOLL" CACHE STRING ${POLLER_DOCSTRING})


@@ 145,6 155,9 @@ target_link_libraries(network logger)
if(BOTAN_FOUND)
  target_link_libraries(network ${BOTAN_LIBRARIES})
endif()
if(CARES_FOUND)
  target_link_libraries(network ${CARES_LIBRARIES})
endif()

#
## irclib

M src/config.h.cmake => src/config.h.cmake +1 -0
@@ 4,4 4,5 @@
#cmakedefine SYSTEMD_FOUND
#cmakedefine POLLER ${POLLER}
#cmakedefine BOTAN_FOUND
#cmakedefine CARES_FOUND
#cmakedefine BIBOUMI_VERSION "${BIBOUMI_VERSION}"

M src/main.cpp => src/main.cpp +15 -1
@@ 1,4 1,3 @@
#include <network/tcp_socket_handler.hpp>
#include <xmpp/xmpp_component.hpp>
#include <utils/timed_events.hpp>
#include <network/poller.hpp>


@@ 11,6 10,10 @@

#include <signal.h>

#ifdef CARES_FOUND
# include <network/dns_handler.hpp>
#endif

// A flag set by the SIGINT signal handler.
static volatile std::atomic<bool> stop(false);
// Flag set by the SIGUSR1/2 signal handler.


@@ 95,6 98,10 @@ int main(int ac, char** av)

  xmpp_component->start();


#ifdef CARES_FOUND
  DNSHandler::instance.watch_dns_sockets(p);
#endif
  auto timeout = TimedEventsManager::instance().get_timeout();
  while (p->poll(timeout) != -1)
  {


@@ 108,6 115,9 @@ int main(int ac, char** av)
      exiting = true;
      stop.store(false);
      xmpp_component->shutdown();
#ifdef CARES_FOUND
      DNSHandler::instance.destroy();
#endif
      // Cancel the timer for an potential reconnection
      TimedEventsManager::instance().cancel("XMPP reconnection");
    }


@@ 153,6 163,10 @@ int main(int ac, char** av)
      xmpp_component->close();
    if (exiting && p->size() == 1 && xmpp_component->is_document_open())
      xmpp_component->close_document();
#ifdef CARES_FOUND
    if (!exiting)
      DNSHandler::instance.watch_dns_sockets(p);
#endif
    if (exiting) // If we are exiting, do not wait for any timed event
      timeout = utils::no_timeout;
    else

A src/network/dns_handler.cpp => src/network/dns_handler.cpp +112 -0
@@ 0,0 1,112 @@
#include <config.h>
#ifdef CARES_FOUND

#include <network/dns_socket_handler.hpp>
#include <network/tcp_socket_handler.hpp>
#include <network/dns_handler.hpp>
#include <network/poller.hpp>

#include <algorithm>
#include <stdexcept>

DNSHandler DNSHandler::instance;

using namespace std::string_literals;

void on_hostname4_resolved(void* arg, int status, int, struct hostent* hostent)
{
  TCPSocketHandler* socket_handler = static_cast<TCPSocketHandler*>(arg);
  socket_handler->on_hostname4_resolved(status, hostent);
}

void on_hostname6_resolved(void* arg, int status, int, struct hostent* hostent)
{
  TCPSocketHandler* socket_handler = static_cast<TCPSocketHandler*>(arg);
  socket_handler->on_hostname6_resolved(status, hostent);
}

DNSHandler::DNSHandler()
{
  int ares_error;
  if ((ares_error = ::ares_library_init(ARES_LIB_INIT_ALL)) != 0)
    throw std::runtime_error("Failed to initialize c-ares lib: "s + ares_strerror(ares_error));
  if ((ares_error = ::ares_init(&this->channel)) != ARES_SUCCESS)
    throw std::runtime_error("Failed to initialize c-ares channel: "s + ares_strerror(ares_error));
}

ares_channel& DNSHandler::get_channel()
{
  return this->channel;
}

void DNSHandler::destroy()
{
  this->socket_handlers.clear();
  ::ares_destroy(this->channel);
  ::ares_library_cleanup();
}

void DNSHandler::gethostbyname(const std::string& name,
                               TCPSocketHandler* socket_handler, int family)
{
  socket_handler->free_cares_addrinfo();
  if (family == AF_INET)
    ::ares_gethostbyname(this->channel, name.data(), family,
                         &::on_hostname4_resolved, socket_handler);
  else
    ::ares_gethostbyname(this->channel, name.data(), family,
                         &::on_hostname6_resolved, socket_handler);
}

void DNSHandler::watch_dns_sockets(std::shared_ptr<Poller>& poller)
{
  fd_set readers;
  fd_set writers;

  FD_ZERO(&readers);
  FD_ZERO(&writers);

  int ndfs = ::ares_fds(this->channel, &readers, &writers);
  // For each existing DNS socket, see if we are still supposed to watch it,
  // if not then erase it
  this->socket_handlers.erase(
      std::remove_if(this->socket_handlers.begin(), this->socket_handlers.end(),
                     [&readers](const auto& dns_socket)
                     {
                       return !FD_ISSET(dns_socket->get_socket(), &readers);
                     }),
      this->socket_handlers.end());

  for (auto i = 0; i < ndfs; ++i)
    {
      bool read = FD_ISSET(i, &readers);
      bool write = FD_ISSET(i, &writers);
      // Look for the DNSSocketHandler with this fd
      auto it = std::find_if(this->socket_handlers.begin(),
                             this->socket_handlers.end(),
                             [i](const auto& socket_handler)
                             {
        return i == socket_handler->get_socket();
      });
      if (!read && !write)      // No need to read or write to it
        { // If found, erase it and stop watching it because it is not
          // needed anymore
          if (it != this->socket_handlers.end())
            // The socket destructor removes it from the poller
            this->socket_handlers.erase(it);
        }
      else            // We need to write and/or read to it
        { // If not found, create it because we need to watch it
          if (it == this->socket_handlers.end())
            {
              this->socket_handlers.emplace_front(std::make_unique<DNSSocketHandler>(poller, i));
              it = this->socket_handlers.begin();
            }
          poller->add_socket_handler(it->get());
          if (write)
            poller->watch_send_events(it->get());
        }
    }
}

#endif /* CARES_FOUND */

A src/network/dns_handler.hpp => src/network/dns_handler.hpp +62 -0
@@ 0,0 1,62 @@
#ifndef DNS_HANDLER_HPP_INCLUDED
#define DNS_HANDLER_HPP_INCLUDED

#include <config.h>
#ifdef CARES_FOUND

class TCPSocketHandler;
class Poller;
class DNSSocketHandler;

# include <ares.h>
# include <memory>
# include <string>
# include <list>

void on_hostname4_resolved(void* arg, int status, int, struct hostent* hostent);
void on_hostname6_resolved(void* arg, int status, int, struct hostent* hostent);

/**
 * Class managing DNS resolution.  It should only be statically instanciated
 * once in SocketHandler.  It manages ares channel and calls various
 * functions of that library.
 */

class DNSHandler
{
public:
  DNSHandler();
  ~DNSHandler() = default;
  void gethostbyname(const std::string& name, TCPSocketHandler* socket_handler,
                     int family);
  /**
   * Call ares_fds to know what fd needs to be watched by the poller, create
   * or destroy DNSSocketHandlers depending on the result.
   */
  void watch_dns_sockets(std::shared_ptr<Poller>& poller);
  /**
   * Destroy and stop watching all the DNS sockets. Then de-init the channel
   * and library.
   */
  void destroy();
  ares_channel& get_channel();

  static DNSHandler instance;

private:
  /**
   * The list of sockets that needs to be watched, according to the last
   * call to ares_fds.  DNSSocketHandlers are added to it or removed from it
   * in the watch_dns_sockets() method
   */
  std::list<std::unique_ptr<DNSSocketHandler>> socket_handlers;
  ares_channel channel;

  DNSHandler(const DNSHandler&) = delete;
  DNSHandler(DNSHandler&&) = delete;
  DNSHandler& operator=(const DNSHandler&) = delete;
  DNSHandler& operator=(DNSHandler&&) = delete;
};

#endif /* CARES_FOUND */
#endif /* DNS_HANDLER_HPP_INCLUDED */

A src/network/dns_socket_handler.cpp => src/network/dns_socket_handler.cpp +45 -0
@@ 0,0 1,45 @@
#include <config.h>
#ifdef CARES_FOUND

#include <network/dns_socket_handler.hpp>
#include <network/dns_handler.hpp>
#include <network/poller.hpp>

#include <ares.h>

DNSSocketHandler::DNSSocketHandler(std::shared_ptr<Poller> poller,
                                   const socket_t socket):
  SocketHandler(poller, socket)
{
}

DNSSocketHandler::~DNSSocketHandler()
{
}

void DNSSocketHandler::connect()
{
}

void DNSSocketHandler::on_recv()
{
  // always stop watching send and read events. We will re-watch them if the
  // next call to ares_fds tell us to
  this->poller->remove_socket_handler(this->socket);
  ::ares_process_fd(DNSHandler::instance.get_channel(), this->socket, ARES_SOCKET_BAD);
}

void DNSSocketHandler::on_send()
{
  // always stop watching send and read events. We will re-watch them if the
  // next call to ares_fds tell us to
  this->poller->remove_socket_handler(this->socket);
  ::ares_process_fd(DNSHandler::instance.get_channel(), ARES_SOCKET_BAD, this->socket);
}

bool DNSSocketHandler::is_connected() const
{
  return true;
}

#endif /* CARES_FOUND */

A src/network/dns_socket_handler.hpp => src/network/dns_socket_handler.hpp +46 -0
@@ 0,0 1,46 @@
#ifndef DNS_SOCKET_HANDLER_HPP
# define DNS_SOCKET_HANDLER_HPP

#include <config.h>
#ifdef CARES_FOUND

#include <network/socket_handler.hpp>
#include <ares.h>

/**
 * Manage a socket returned by ares_fds. We do not create, open or close the
 * socket ourself: this is done by c-ares.  We just call ares_process_fd()
 * with the correct parameters, depending on what can be done on that socket
 * (Poller reported it to be writable or readeable)
 */

class DNSSocketHandler: public SocketHandler
{
public:
  explicit DNSSocketHandler(std::shared_ptr<Poller> poller, const socket_t socket);
  ~DNSSocketHandler();
  /**
   * Just call dns_process_fd, c-ares will do its work of send()ing or
   * recv()ing the data it wants on that socket.
   */
  void on_recv() override final;
  void on_send() override final;
  /**
   * Do nothing, because we are always considered to be connected, since the
   * connection is done by c-ares and not by us.
   */
  void connect() override final;
  /**
   * Always true, see the comment for connect()
   */
  bool is_connected() const override final;

private:
  DNSSocketHandler(const DNSSocketHandler&) = delete;
  DNSSocketHandler(DNSSocketHandler&&) = delete;
  DNSSocketHandler& operator=(const DNSSocketHandler&) = delete;
  DNSSocketHandler& operator=(DNSSocketHandler&&) = delete;
};

#endif // CARES_FOUND
#endif // DNS_SOCKET_HANDLER_HPP

M src/network/socket_handler.hpp => src/network/socket_handler.hpp +2 -0
@@ 1,6 1,7 @@
#ifndef SOCKET_HANDLER_HPP
# define SOCKET_HANDLER_HPP

#include <config.h>
#include <memory>

class Poller;


@@ 19,6 20,7 @@ public:
  virtual void on_send() = 0;
  virtual void connect() = 0;
  virtual bool is_connected() const = 0;

  socket_t get_socket() const
  { return this->socket; }


M src/network/tcp_socket_handler.cpp => src/network/tcp_socket_handler.cpp +161 -5
@@ 1,4 1,5 @@
#include <network/tcp_socket_handler.hpp>
#include <network/dns_handler.hpp>

#include <utils/timed_events.hpp>
#include <utils/scopeguard.hpp>


@@ 42,8 43,22 @@ TCPSocketHandler::TCPSocketHandler(std::shared_ptr<Poller> poller):
  use_tls(false),
  connected(false),
  connecting(false)
#ifdef CARES_FOUND
  ,resolved(false),
  resolved4(false),
  resolved6(false),
  cares_addrinfo(nullptr),
  cares_error()
#endif
{}

TCPSocketHandler::~TCPSocketHandler()
{
#ifdef CARES_FOUND
  this->free_cares_addrinfo();
#endif
}

void TCPSocketHandler::init_socket(const struct addrinfo* rp)
{
  if ((this->socket = ::socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol)) == -1)


@@ 72,9 87,35 @@ void TCPSocketHandler::connect(const std::string& address, const std::string& po

  if (!this->connecting)
    {
      // Get the addrinfo from getaddrinfo (or ares_gethostbyname), only if
      // this is the first call of this function.
#ifdef CARES_FOUND
      if (!this->resolved)
        {
          log_info("Trying to connect to " << address << ":" << port);
          // Start the asynchronous process of resolving the hostname. Once
          // the addresses have been found and `resolved` has been set to true
          // (but connecting will still be false), TCPSocketHandler::connect()
          // needs to be called, again.
          DNSHandler::instance.gethostbyname(address, this, AF_INET6);
          DNSHandler::instance.gethostbyname(address, this, AF_INET);
          return;
        }
      else
        {
          // The c-ares resolved the hostname and the available addresses
          // where saved in the cares_addrinfo linked list. Now, just use
          // this list to try to connect.
          addr_res = this->cares_addrinfo;
          if (!addr_res)
            {
              this->close();
              this->on_connection_failed(this->cares_error);
              return ;
            }
        }
#else
      log_info("Trying to connect to " << address << ":" << port);
      // Get the addrinfo from getaddrinfo, only if this is the first call
      // of this function.
      struct addrinfo hints;
      memset(&hints, 0, sizeof(struct addrinfo));
      hints.ai_flags = 0;


@@ 94,6 135,7 @@ void TCPSocketHandler::connect(const std::string& address, const std::string& po
      // Make sure the alloced structure is always freed at the end of the
      // function
      sg.add_callback([&addr_res](){ freeaddrinfo(addr_res); });
#endif
    }
  else
    { // This function is called again, use the saved addrinfo structure,


@@ 144,9 186,9 @@ void TCPSocketHandler::connect(const std::string& address, const std::string& po
          // If the connection has not succeeded or failed in 5s, we consider
          // it to have failed
          TimedEventsManager::instance().add_event(
                TimedEvent(std::chrono::steady_clock::now() + 5s,
                           std::bind(&TCPSocketHandler::on_connection_timeout, this),
                           "connection_timeout"s + std::to_string(this->socket)));
                                                   TimedEvent(std::chrono::steady_clock::now() + 5s,
                                                              std::bind(&TCPSocketHandler::on_connection_timeout, this),
                                                              "connection_timeout"s + std::to_string(this->socket)));
          return ;
        }
      log_info("Connection failed:" << strerror(errno));


@@ 321,7 363,11 @@ bool TCPSocketHandler::is_connected() const

bool TCPSocketHandler::is_connecting() const
{
#ifdef CARES_FOUND
  return this->connecting || !this->resolved;
#else
  return this->connecting;
#endif
}

void* TCPSocketHandler::get_receive_buffer(const size_t) const


@@ 413,4 459,114 @@ void TCPSocketHandler::on_tls_activated()
{
  this->send_data("");
}

#endif // BOTAN_FOUND

#ifdef CARES_FOUND

void TCPSocketHandler::on_hostname4_resolved(int status, struct hostent* hostent)
{
  this->resolved4 = true;
  if (status == ARES_SUCCESS)
    this->fill_ares_addrinfo4(hostent);
  else
    this->cares_error = ::ares_strerror(status);

  if (this->resolved4 && this->resolved6)
    {
      this->resolved = true;
      this->connect();
    }
}

void TCPSocketHandler::on_hostname6_resolved(int status, struct hostent* hostent)
{
  this->resolved6 = true;
  if (status == ARES_SUCCESS)
    this->fill_ares_addrinfo6(hostent);
  else
    this->cares_error = ::ares_strerror(status);

  if (this->resolved4 && this->resolved6)
    {
      this->resolved = true;
      this->connect();
    }
}

void TCPSocketHandler::fill_ares_addrinfo4(const struct hostent* hostent)
{
  struct addrinfo* prev = this->cares_addrinfo;
  struct in_addr** address = reinterpret_cast<struct in_addr**>(hostent->h_addr_list);

  while (*address)
    {
       // Create a new addrinfo list element, and fill it
      struct addrinfo* current = new struct addrinfo;
      current->ai_flags = 0;
      current->ai_family = hostent->h_addrtype;
      current->ai_socktype = SOCK_STREAM;
      current->ai_protocol = 0;
      current->ai_addrlen = sizeof(struct sockaddr_in);

      struct sockaddr_in* addr = new struct sockaddr_in;
      addr->sin_family = hostent->h_addrtype;
      addr->sin_port = htons(strtoul(this->port.data(), nullptr, 10));
      addr->sin_addr.s_addr = (*address)->s_addr;

      current->ai_addr = reinterpret_cast<struct sockaddr*>(addr);
      current->ai_next = nullptr;
      current->ai_canonname = nullptr;

      current->ai_next = prev;
      this->cares_addrinfo = current;
      prev = current;
      ++address;
    }
}

void TCPSocketHandler::fill_ares_addrinfo6(const struct hostent* hostent)
{
  struct addrinfo* prev = this->cares_addrinfo;
  struct in6_addr** address = reinterpret_cast<struct in6_addr**>(hostent->h_addr_list);

  while (*address)
    {
       // Create a new addrinfo list element, and fill it
      struct addrinfo* current = new struct addrinfo;
      current->ai_flags = 0;
      current->ai_family = hostent->h_addrtype;
      current->ai_socktype = SOCK_STREAM;
      current->ai_protocol = 0;
      current->ai_addrlen = sizeof(struct sockaddr_in6);

      struct sockaddr_in6* addr = new struct sockaddr_in6;
      addr->sin6_family = hostent->h_addrtype;
      addr->sin6_port = htons(strtoul(this->port.data(), nullptr, 10));
      ::memcpy(addr->sin6_addr.s6_addr, (*address)->s6_addr, 16);
      addr->sin6_flowinfo = 0;
      addr->sin6_scope_id = 0;

      current->ai_addr = reinterpret_cast<struct sockaddr*>(addr);
      current->ai_next = nullptr;
      current->ai_canonname = nullptr;

      current->ai_next = prev;
      this->cares_addrinfo = current;
      prev = current;
      ++address;
    }
}

void TCPSocketHandler::free_cares_addrinfo()
{
  while (this->cares_addrinfo)
    {
      delete this->cares_addrinfo->ai_addr;
      auto next = this->cares_addrinfo->ai_next;
      delete this->cares_addrinfo;
      this->cares_addrinfo = next;
    }
}

#endif  // CARES_FOUND

M src/network/tcp_socket_handler.hpp => src/network/tcp_socket_handler.hpp +41 -6
@@ 17,6 17,10 @@

#include "config.h"

#ifdef CARES_FOUND
# include <ares.h>
#endif

#ifdef BOTAN_FOUND
# include <botan/botan.h>
# include <botan/tls_client.h>


@@ 44,7 48,7 @@ public:
class TCPSocketHandler: public SocketHandler
{
protected:
  ~TCPSocketHandler() {}
  ~TCPSocketHandler();

public:
  explicit TCPSocketHandler(std::shared_ptr<Poller> poller);


@@ 54,16 58,16 @@ public:
   * start_tls() when the connection succeeds.
   */
  void connect(const std::string& address, const std::string& port, const bool tls);
  void connect();
  void connect() override final;
  /**
   * Reads raw data from the socket. And pass it to parse_in_buffer()
   * If we are using TLS on this connection, we call tls_recv()
   */
  void on_recv();
  void on_recv() override final;
  /**
   * Write as much data from out_buf as possible, in the socket.
   */
  void on_send();
  void on_send() override final;
  /**
   * Add the given data to out_buf and tell our poller that we want to be
   * notified when a send event is ready.


@@ 107,9 111,19 @@ public:
   * The size argument is the size of the last chunk of data that was added to the buffer.
   */
  virtual void parse_in_buffer(const size_t size) = 0;
  bool is_connected() const;
  bool is_connected() const override final;
  bool is_connecting() const;

#ifdef CARES_FOUND
  void on_hostname4_resolved(int status, struct hostent* hostent);
  void on_hostname6_resolved(int status, struct hostent* hostent);

  void free_cares_addrinfo();

  void fill_ares_addrinfo4(const struct hostent* hostent);
  void fill_ares_addrinfo6(const struct hostent* hostent);
#endif

private:
  /**
   * Initialize the socket with the parameters contained in the given


@@ 185,7 199,7 @@ private:
   */
  std::list<std::string> out_buf;
  /**
   * Keep the details of the addrinfo the triggered a EINPROGRESS error when
   * Keep the details of the addrinfo that triggered a EINPROGRESS error when
   * connect()ing to it, to reuse it directly when connect() is called
   * again.
   */


@@ 225,6 239,27 @@ protected:
  bool connected;
  bool connecting;

#ifdef CARES_FOUND
  /**
   * Whether or not the DNS resolution was successfully done
   */
  bool resolved;
  bool resolved4;
  bool resolved6;
  /**
   * When using c-ares to resolve the host asynchronously, we need the
   * c-ares callback to fill a structure (a struct addrinfo, for
   * compatibility with getaddrinfo and the rest of the code that works when
   * c-ares is not used) with all returned values (for example an IPv6 and
   * an IPv4). The next call of connect() will then try all these values
   * (exactly like we do with the result of getaddrinfo) and save the one
   * that worked (or returned EINPROGRESS) in the other struct addrinfo (see
   * the members addrinfo, ai_addrlen, and ai_addr).
   */
  struct addrinfo* cares_addrinfo;
  std::string cares_error;
#endif  // CARES_FOUND

private:
  TCPSocketHandler(const TCPSocketHandler&) = delete;
  TCPSocketHandler(TCPSocketHandler&&) = delete;

M src/xmpp/xmpp_component.cpp => src/xmpp/xmpp_component.cpp +1 -1
@@ 70,7 70,7 @@ XmppComponent::~XmppComponent()

void XmppComponent::start()
{
  this->connect("127.0.0.1", Config::get("port", "5347"), false);
  this->connect("localhost", Config::get("port", "5347"), false);
}

bool XmppComponent::is_document_open() const