Skip to content

Commit

Permalink
net/proxy/ares support custom server
Browse files Browse the repository at this point in the history
  • Loading branch information
iceboy233 committed Oct 22, 2023
1 parent 91658f1 commit af90030
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 3 deletions.
2 changes: 2 additions & 0 deletions net/proxy/ares/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ cc_library(
"//util:int-allocator",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/types:span",
"@org_iceboy_trunk//net:asio",
"@org_iceboy_trunk//net:endpoint",
"@org_iceboy_trunk//net:timer-list",
],
)
Expand Down
27 changes: 27 additions & 0 deletions net/proxy/ares/resolver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ Resolver::Resolver(
abort();
}
ares_set_socket_functions(channel_, &funcs_, this);
if (!options.servers.empty()) {
set_servers(options.servers);
}
}

Resolver::~Resolver() {
Expand Down Expand Up @@ -92,6 +95,30 @@ void Resolver::wait() {
});
}

void Resolver::set_servers(absl::Span<const Endpoint> servers) {
ares_addr_port_node nodes[servers.size()];
for (size_t i = 0; i < servers.size(); ++i) {
nodes[i].next = i + 1 < servers.size() ? &nodes[i + 1] : nullptr;
const address &address = servers[i].address();
if (address.is_v4()) {
nodes[i].family = AF_INET;
auto address_bytes = address.to_v4().to_bytes();
static_assert(address_bytes.size() == 4);
memcpy(&nodes[i].addr.addr4, address_bytes.data(), 4);
} else {
nodes[i].family = AF_INET6;
auto address_bytes = address.to_v6().to_bytes();
static_assert(address_bytes.size() == 16);
memcpy(&nodes[i].addr.addr6, address_bytes.data(), 16);
}
nodes[i].udp_port = servers[i].port();
nodes[i].tcp_port = servers[i].port();
}
if (ares_set_servers_ports(channel_, nodes) != ARES_SUCCESS) {
abort();
}
}

ares_socket_t Resolver::asocket(
int domain, int type, int protocol, void *user_data) {
auto *resolver = reinterpret_cast<Resolver *>(user_data);
Expand Down
4 changes: 4 additions & 0 deletions net/proxy/ares/resolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

#include "absl/container/flat_hash_map.h"
#include "absl/functional/any_invocable.h"
#include "absl/types/span.h"
#include "net/asio.h"
#include "net/endpoint.h"
#include "net/proxy/ares/socket.h"
#include "net/proxy/connector.h"
#include "net/timer-list.h"
Expand All @@ -22,6 +24,7 @@ namespace ares {
class Resolver {
public:
struct Options {
std::vector<Endpoint> servers;
std::chrono::milliseconds query_timeout = std::chrono::seconds(1);
std::chrono::nanoseconds cache_timeout = std::chrono::minutes(1);
};
Expand All @@ -41,6 +44,7 @@ class Resolver {
class Operation;

void wait();
void set_servers(absl::Span<const Endpoint> servers);

static ares_socket_t asocket(
int domain, int type, int protocol, void *user_data);
Expand Down
1 change: 1 addition & 0 deletions net/proxy/system/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ cc_library(
":connector",
"//net/proxy",
"@org_boost_boost//:property_tree",
"@org_iceboy_trunk//base:logging",
],
alwayslink = 1,
)
Expand Down
13 changes: 13 additions & 0 deletions net/proxy/system/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <memory>
#include <boost/property_tree/ptree.hpp>

#include "base/logging.h"
#include "net/proxy/proxy.h"
#include "net/proxy/registry.h"
#include "net/proxy/system/connector.h"
Expand All @@ -17,6 +18,18 @@ REGISTER_CONNECTOR(system, [](
options.timeout = std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::duration<double>(config.get<double>("timeout", 300)));
options.tcp_no_delay = config.get<bool>("tcp_no_delay", true);
const auto &resolver_config = config.get_child("resolver", {});
for (auto iters = resolver_config.equal_range("server");
iters.first != iters.second;
++iters.first) {
std::string server_str = iters.first->second.get_value<std::string>();
auto server_endpoint = Endpoint::from_string(server_str);
if (!server_endpoint) {
LOG(error) << "invalid server endpoint: " << server_str;
continue;
}
options.resolver_options.servers.push_back(*server_endpoint);
}
return std::make_unique<Connector>(proxy.executor(), options);
});

Expand Down
2 changes: 1 addition & 1 deletion net/proxy/system/connector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace system {

Connector::Connector(const any_io_executor &executor, const Options &options)
: executor_(executor),
resolver_(executor_, *this, {}),
resolver_(executor_, *this, options.resolver_options),
timer_list_(executor_, options.timeout),
tcp_no_delay_(options.tcp_no_delay) {}

Expand Down
3 changes: 3 additions & 0 deletions net/proxy/system/connector.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Connector : public proxy::Connector {
struct Options {
std::chrono::nanoseconds timeout = std::chrono::minutes(5);
bool tcp_no_delay = true;
ares::Resolver::Options resolver_options;
};

Connector(const any_io_executor &executor, const Options &options);
Expand Down Expand Up @@ -49,6 +50,8 @@ class Connector : public proxy::Connector {
std::error_code bind_udp_v4(std::unique_ptr<Datagram> &datagram) override;
std::error_code bind_udp_v6(std::unique_ptr<Datagram> &datagram) override;

ares::Resolver &resolver() { return resolver_; }

private:
template <typename EndpointsT>
void connect_tcp(
Expand Down
3 changes: 1 addition & 2 deletions net/tools/ares-resolve.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@ int main(int argc, char *argv[]) {
io_context io_context;
auto executor = io_context.get_executor();
proxy::system::Connector connector(executor, {});
proxy::ares::Resolver resolver(executor, connector, {});

using Result = BlockingResult<std::error_code, std::vector<address>>;
auto results = std::make_unique<Result[]>(argc - 1);
for (int i = 1; i < argc; ++i) {
resolver.resolve(argv[i], results[i - 1].callback());
connector.resolver().resolve(argv[i], results[i - 1].callback());
}

io::OStream os(io::posix::stdout);
Expand Down

0 comments on commit af90030

Please sign in to comment.