#include <network/tcp_client_socket_handler.hpp>
#include <utils/timed_events.hpp>
#include <utils/scopeguard.hpp>
#include <network/poller.hpp>
#include <logger/logger.hpp>
#include <cstring>
#include <unistd.h>
#include <fcntl.h>
using namespace std::string_literals;
TCPClientSocketHandler::TCPClientSocketHandler(std::shared_ptr<Poller>& poller):
TCPSocketHandler(poller),
hostname_resolution_failed(false),
connected(false),
connecting(false)
{}
TCPClientSocketHandler::~TCPClientSocketHandler()
{
this->close();
}
void TCPClientSocketHandler::init_socket(const struct addrinfo* rp)
{
if (this->socket != -1)
::close(this->socket);
if ((this->socket = ::socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol)) == -1)
throw std::runtime_error("Could not create socket: "s + std::strerror(errno));
// Bind the socket to a specific address, if specified
if (!this->bind_addr.empty())
{
// Convert the address from string format to a sockaddr that can be
// used in bind()
struct addrinfo* result;
struct addrinfo hints{};
memset(&hints, 0, sizeof(hints));
hints.ai_flags = AI_NUMERICHOST;
hints.ai_family = AF_UNSPEC;
int err = ::getaddrinfo(this->bind_addr.data(), nullptr, &hints, &result);
if (err != 0 || !result)
log_error("Failed to bind socket to ", this->bind_addr, ": ",
gai_strerror(err));
else
{
utils::ScopeGuard sg([result](){ freeaddrinfo(result); });
struct addrinfo* rp;
for (rp = result; rp; rp = rp->ai_next)
{
if ((::bind(this->socket,
reinterpret_cast<const struct sockaddr*>(rp->ai_addr),
rp->ai_addrlen)) == 0)
break;
}
if (!rp)
log_error("Failed to bind socket to ", this->bind_addr, ": ",
strerror(errno));
else
log_info("Socket successfully bound to ", this->bind_addr);
}
}
int optval = 1;
if (::setsockopt(this->socket, SOL_SOCKET, SO_KEEPALIVE, &optval, sizeof(optval)) == -1)
log_warning("Failed to enable TCP keepalive on socket: ", strerror(errno));
// Set the socket on non-blocking mode. This is useful to receive a EAGAIN
// error when connect() would block, to not block the whole process if a
// remote is not responsive.
const int existing_flags = ::fcntl(this->socket, F_GETFL, 0);
if ((existing_flags == -1) ||
(::fcntl(this->socket, F_SETFL, existing_flags | O_NONBLOCK) == -1))
throw std::runtime_error("Could not initialize socket: "s + std::strerror(errno));
}
void TCPClientSocketHandler::connect(const std::string& address, const std::string& port, const bool tls)
{
this->address = address;
this->port = port;
this->use_tls = tls;
struct addrinfo* addr_res;
if (!this->connecting)
{
// Get the addrinfo from getaddrinfo (or using udns), only if
// this is the first call of this function.
if (!this->resolver.is_resolved())
{
log_info("Trying to connect to ", address, ":", port);
// Start the asynchronous process of resolving the hostname. Once
// the addresses have been found and `resolved` has been set to true
// (but connecting will still be false), TCPClientSocketHandler::connect()
// needs to be called, again.
this->resolver.resolve(address, port,
[this](const struct addrinfo*)
{
log_debug("Resolution success, calling connect() again");
this->connect();
},
[this](const char*)
{
log_debug("Resolution failed, calling connect() again");
this->connect();
});
return;
}
else
{
// The DNS resolver resolved the hostname and the available addresses
// where saved in the addrinfo linked list. Now, just use
// this list to try to connect.
addr_res = this->resolver.get_result().get();
if (!addr_res)
{
this->hostname_resolution_failed = true;
const auto msg = this->resolver.get_error_message();
this->close();
this->on_connection_failed(msg);
return ;
}
}
}
else
{ // This function is called again, use the saved addrinfo structure,
// instead of re-doing the whole getaddrinfo process.
addr_res = &this->addrinfo;
}
for (struct addrinfo* rp = addr_res; rp; rp = rp->ai_next)
{
if (!this->connecting)
{
try {
this->init_socket(rp);
}
catch (const std::runtime_error& error) {
log_error("Failed to init socket: ", error.what());
break;
}
}
this->display_resolved_ip(rp);
if (::connect(this->socket, rp->ai_addr, rp->ai_addrlen) == 0
|| errno == EISCONN)
{
log_info("Connection success.");
TimedEventsManager::instance().cancel("connection_timeout" +
std::to_string(this->socket));
this->poller->add_socket_handler(this);
this->connected = true;
this->connecting = false;
#ifdef BOTAN_FOUND
if (this->use_tls)
this->start_tls(this->address, this->port);
#endif
this->connection_date = std::chrono::system_clock::now();
// Get our local TCP port and store it
this->local_port = static_cast<uint16_t>(-1);
if (rp->ai_family == AF_INET6)
{
struct sockaddr_in6 a{};
socklen_t l = sizeof(a);
if (::getsockname(this->socket, (struct sockaddr*)&a, &l) != -1)
this->local_port = ntohs(a.sin6_port);
}
else if (rp->ai_family == AF_INET)
{
struct sockaddr_in a{};
socklen_t l = sizeof(a);
if (::getsockname(this->socket, (struct sockaddr*)&a, &l) != -1)
this->local_port = ntohs(a.sin_port);
}
log_debug("Local port: ", this->local_port, ", and remote port: ", this->port);
this->on_connected();
return ;
}
else if (errno == EINPROGRESS || errno == EALREADY)
{ // retry this process later, when the socket
// is ready to be written on.
this->connecting = true;
this->poller->add_socket_handler(this);
this->poller->watch_send_events(this);
// Save the addrinfo structure, to use it on the next call
this->ai_addrlen = rp->ai_addrlen;
memcpy(&this->ai_addr, rp->ai_addr, this->ai_addrlen);
memcpy(&this->addrinfo, rp, sizeof(struct addrinfo));
this->addrinfo.ai_addr = reinterpret_cast<struct sockaddr*>(&this->ai_addr);
this->addrinfo.ai_next = nullptr;
// If the connection has not succeeded or failed in 5s, we consider
// it to have failed
TimedEventsManager::instance().add_event(
TimedEvent(std::chrono::steady_clock::now() + 5s,
std::bind(&TCPClientSocketHandler::on_connection_timeout, this),
"connection_timeout" + std::to_string(this->socket)));
return ;
}
log_info("Connection failed:", std::strerror(errno));
}
log_error("All connection attempts failed.");
this->close();
this->on_connection_failed(std::strerror(errno));
return ;
}
void TCPClientSocketHandler::on_connection_timeout()
{
this->close();
this->on_connection_failed("connection timed out");
}
void TCPClientSocketHandler::connect()
{
this->connect(this->address, this->port, this->use_tls);
}
void TCPClientSocketHandler::close()
{
TimedEventsManager::instance().cancel("connection_timeout" +
std::to_string(this->socket));
TCPSocketHandler::close();
this->connected = false;
this->connecting = false;
this->port.clear();
this->resolver.clear();
}
void TCPClientSocketHandler::display_resolved_ip(struct addrinfo* rp) const
{
if (rp->ai_family == AF_INET)
log_debug("Trying IPv4 address ", addr_to_string(rp));
else if (rp->ai_family == AF_INET6)
log_debug("Trying IPv6 address ", addr_to_string(rp));
}
bool TCPClientSocketHandler::is_connected() const
{
return this->connected;
}
bool TCPClientSocketHandler::is_connecting() const
{
return this->connecting || this->resolver.is_resolving();
}
std::string TCPClientSocketHandler::get_port() const
{
return this->port;
}
bool TCPClientSocketHandler::match_port_pairt(const uint16_t local, const uint16_t remote) const
{
const auto remote_port = static_cast<uint16_t>(std::stoi(this->port));
return this->is_connected() && local == this->local_port && remote == remote_port;
}