diff --git a/locator/ec2_multi_region_snitch.cc b/locator/ec2_multi_region_snitch.cc index 591bc608ad..a824822fc6 100644 --- a/locator/ec2_multi_region_snitch.cc +++ b/locator/ec2_multi_region_snitch.cc @@ -30,10 +30,12 @@ future<> ec2_multi_region_snitch::start() { if (this_shard_id() == io_cpu_id()) { inet_address local_public_address; + auto token = aws_api_call(AWS_QUERY_SERVER_ADDR, AWS_QUERY_SERVER_PORT, TOKEN_REQ_ENDPOINT, std::nullopt).get0(); + try { auto broadcast = utils::fb_utilities::get_broadcast_address(); if (broadcast.addr().is_ipv6()) { - auto macs = aws_api_call(AWS_QUERY_SERVER_ADDR, AWS_QUERY_SERVER_PORT, PRIVATE_MAC_QUERY).get0(); + auto macs = aws_api_call(AWS_QUERY_SERVER_ADDR, AWS_QUERY_SERVER_PORT, PRIVATE_MAC_QUERY, token).get0(); // we should just get a single line, ending in '/'. If there are more than one mac, we should // maybe try to loop the addresses and exclude local/link-local etc, but these addresses typically // are already filtered by aws, so probably does not help. For now, just warn and pick first address. @@ -42,11 +44,11 @@ future<> ec2_multi_region_snitch::start() { if (i != std::string::npos && ++i != macs.size()) { logger().warn("Ec2MultiRegionSnitch (ipv6): more than one MAC address listed ({}). Will use first.", macs); } - auto ipv6 = aws_api_call(AWS_QUERY_SERVER_ADDR, AWS_QUERY_SERVER_PORT, format(PUBLIC_IPV6_QUERY_REQ, mac)).get0(); + auto ipv6 = aws_api_call(AWS_QUERY_SERVER_ADDR, AWS_QUERY_SERVER_PORT, format(PUBLIC_IPV6_QUERY_REQ, mac), token).get0(); local_public_address = inet_address(ipv6); _local_private_address = ipv6; } else { - auto pub_addr = aws_api_call(AWS_QUERY_SERVER_ADDR, AWS_QUERY_SERVER_PORT, PUBLIC_IP_QUERY_REQ).get0(); + auto pub_addr = aws_api_call(AWS_QUERY_SERVER_ADDR, AWS_QUERY_SERVER_PORT, PUBLIC_IP_QUERY_REQ, token).get0(); local_public_address = inet_address(pub_addr); } } catch (...) { @@ -66,7 +68,7 @@ future<> ec2_multi_region_snitch::start() { } if (!local_public_address.addr().is_ipv6()) { - sstring priv_addr = aws_api_call(AWS_QUERY_SERVER_ADDR, AWS_QUERY_SERVER_PORT, PRIVATE_IP_QUERY_REQ).get0(); + sstring priv_addr = aws_api_call(AWS_QUERY_SERVER_ADDR, AWS_QUERY_SERVER_PORT, PRIVATE_IP_QUERY_REQ, token).get0(); _local_private_address = priv_addr; } diff --git a/locator/ec2_snitch.cc b/locator/ec2_snitch.cc index 50805ee0fa..64fb3619f3 100644 --- a/locator/ec2_snitch.cc +++ b/locator/ec2_snitch.cc @@ -21,7 +21,8 @@ future<> ec2_snitch::load_config(bool prefer_local) { using namespace boost::algorithm; if (this_shard_id() == io_cpu_id()) { - return aws_api_call(AWS_QUERY_SERVER_ADDR, AWS_QUERY_SERVER_PORT, ZONE_NAME_QUERY_REQ).then([this, prefer_local](sstring az) { + auto token = aws_api_call(AWS_QUERY_SERVER_ADDR, AWS_QUERY_SERVER_PORT, TOKEN_REQ_ENDPOINT, std::nullopt).get0(); + return aws_api_call(AWS_QUERY_SERVER_ADDR, AWS_QUERY_SERVER_PORT, ZONE_NAME_QUERY_REQ, token).then([this, prefer_local](sstring az) { assert(az.size()); std::vector splits; @@ -63,17 +64,26 @@ future<> ec2_snitch::start() { }); } -future ec2_snitch::aws_api_call(sstring addr, uint16_t port, sstring cmd) { +future ec2_snitch::aws_api_call(sstring addr, uint16_t port, sstring cmd, std::optional token) { return connect(socket_address(inet_address{addr}, port)) - .then([this, addr, cmd] (connected_socket fd) { + .then([this, addr, cmd, token] (connected_socket fd) { _sd = std::move(fd); _in = _sd.input(); _out = _sd.output(); - _zone_req = sstring("GET ") + cmd + - sstring(" HTTP/1.1\r\nHost: ") +addr + - sstring("\r\n\r\n"); - return _out.write(_zone_req.c_str()).then([this] { + if (token) { + _req = sstring("GET ") + cmd + + sstring(" HTTP/1.1\r\nHost: ") +addr + + sstring("\r\nX-aws-ec2-metadata-token: ") + *token + + sstring("\r\n\r\n"); + } else { + _req = sstring("PUT ") + cmd + + sstring(" HTTP/1.1\r\nHost: ") + addr + + sstring("\r\nX-aws-ec2-metadata-token-ttl-seconds: 60") + + sstring("\r\n\r\n"); + } + + return _out.write(_req.c_str()).then([this] { return _out.flush(); }); }).then([this] { @@ -85,6 +95,12 @@ future ec2_snitch::aws_api_call(sstring addr, uint16_t port, sstring cm // Read HTTP response header first auto _rsp = _parser.get_parsed_response(); + auto rc = _rsp->_status_code; + // Verify EC2 instance metadata access + if (rc == 403) { + return make_exception_future(std::runtime_error("Error: Unauthorized response received when trying to communicate with instance metadata service.")); + } + auto it = _rsp->_headers.find("Content-Length"); if (it == _rsp->_headers.end()) { return make_exception_future("Error: HTTP response does not contain: Content-Length\n"); diff --git a/locator/ec2_snitch.hh b/locator/ec2_snitch.hh index f19286b4b7..88ee1fc769 100644 --- a/locator/ec2_snitch.hh +++ b/locator/ec2_snitch.hh @@ -13,6 +13,7 @@ namespace locator { class ec2_snitch : public production_snitch_base { public: + static constexpr const char* TOKEN_REQ_ENDPOINT = "/latest/api/token"; static constexpr const char* ZONE_NAME_QUERY_REQ = "/latest/meta-data/placement/availability-zone"; static constexpr const char* AWS_QUERY_SERVER_ADDR = "169.254.169.254"; static constexpr uint16_t AWS_QUERY_SERVER_PORT = 80; @@ -24,13 +25,13 @@ public: } protected: future<> load_config(bool prefer_local); - future aws_api_call(sstring addr, uint16_t port, const sstring cmd); + future aws_api_call(sstring addr, uint16_t port, const sstring cmd, std::optional token); future read_property_file(); private: connected_socket _sd; input_stream _in; output_stream _out; http_response_parser _parser; - sstring _zone_req; + sstring _req; }; } // namespace locator