From 81b16de918a23801e6fc6db3461efa92342d2c53 Mon Sep 17 00:00:00 2001 From: Alexandre Dutra Date: Fri, 17 Jan 2025 17:24:50 +0100 Subject: [PATCH 1/2] Auth Manager API part 4: RESTClient, HTTPClient --- .../apache/iceberg/rest/BaseHTTPClient.java | 220 +++++ .../org/apache/iceberg/rest/HTTPClient.java | 287 ++---- .../org/apache/iceberg/rest/RESTClient.java | 6 + .../iceberg/rest/RESTCatalogAdapter.java | 120 +-- .../iceberg/rest/RESTCatalogServlet.java | 14 +- .../apache/iceberg/rest/TestRESTCatalog.java | 886 ++++++++---------- .../iceberg/rest/TestRESTViewCatalog.java | 43 +- 7 files changed, 775 insertions(+), 801 deletions(-) create mode 100644 core/src/main/java/org/apache/iceberg/rest/BaseHTTPClient.java diff --git a/core/src/main/java/org/apache/iceberg/rest/BaseHTTPClient.java b/core/src/main/java/org/apache/iceberg/rest/BaseHTTPClient.java new file mode 100644 index 000000000000..0f441bc346f8 --- /dev/null +++ b/core/src/main/java/org/apache/iceberg/rest/BaseHTTPClient.java @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.rest; + +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Supplier; +import org.apache.iceberg.rest.HTTPRequest.HTTPMethod; +import org.apache.iceberg.rest.auth.AuthSession; +import org.apache.iceberg.rest.responses.ErrorResponse; + +/** + * A base class for {@link RESTClient} implementations. + * + *

All methods in {@link RESTClient} are implemented in the same way: first, an {@link + * HTTPRequest} is {@linkplain #buildRequest(HTTPMethod, String, Map, Map, Object) built from the + * method arguments}, then {@linkplain #execute(HTTPRequest, Class, Consumer, Consumer) executed}. + * + *

This allows subclasses to provide a consistent way to execute all requests, regardless of the + * method or arguments. + */ +public abstract class BaseHTTPClient implements RESTClient { + + @Override + public abstract RESTClient withAuthSession(AuthSession session); + + @Override + public void head( + String path, Supplier> headers, Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.HEAD, path, null, headers.get(), null); + execute(request, null, errorHandler, h -> {}); + } + + @Override + public void head(String path, Map headers, Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.HEAD, path, null, headers, null); + execute(request, null, errorHandler, h -> {}); + } + + @Override + public T delete( + String path, + Map queryParams, + Class responseType, + Supplier> headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.DELETE, path, queryParams, headers.get(), null); + return execute(request, responseType, errorHandler, h -> {}); + } + + @Override + public T delete( + String path, + Class responseType, + Supplier> headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.DELETE, path, null, headers.get(), null); + return execute(request, responseType, errorHandler, h -> {}); + } + + @Override + public T delete( + String path, + Class responseType, + Map headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.DELETE, path, null, headers, null); + return execute(request, responseType, errorHandler, h -> {}); + } + + @Override + public T delete( + String path, + Map queryParams, + Class responseType, + Map headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.DELETE, path, queryParams, headers, null); + return execute(request, responseType, errorHandler, h -> {}); + } + + @Override + public T get( + String path, + Class responseType, + Supplier> headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.GET, path, null, headers.get(), null); + return execute(request, responseType, errorHandler, h -> {}); + } + + @Override + public T get( + String path, + Class responseType, + Map headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.GET, path, null, headers, null); + return execute(request, responseType, errorHandler, h -> {}); + } + + @Override + public T get( + String path, + Map queryParams, + Class responseType, + Supplier> headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.GET, path, queryParams, headers.get(), null); + return execute(request, responseType, errorHandler, h -> {}); + } + + @Override + public T get( + String path, + Map queryParams, + Class responseType, + Map headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.GET, path, queryParams, headers, null); + return execute(request, responseType, errorHandler, h -> {}); + } + + @Override + public T post( + String path, + RESTRequest body, + Class responseType, + Supplier> headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.POST, path, null, headers.get(), body); + return execute(request, responseType, errorHandler, h -> {}); + } + + @Override + public T post( + String path, + RESTRequest body, + Class responseType, + Supplier> headers, + Consumer errorHandler, + Consumer> responseHeaders) { + HTTPRequest request = buildRequest(HTTPMethod.POST, path, null, headers.get(), body); + return execute(request, responseType, errorHandler, responseHeaders); + } + + @Override + public T post( + String path, + RESTRequest body, + Class responseType, + Map headers, + Consumer errorHandler, + Consumer> responseHeaders) { + HTTPRequest request = buildRequest(HTTPMethod.POST, path, null, headers, body); + return execute(request, responseType, errorHandler, responseHeaders); + } + + @Override + public T post( + String path, + RESTRequest body, + Class responseType, + Map headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.POST, path, null, headers, body); + return execute(request, responseType, errorHandler, h -> {}); + } + + @Override + public T postForm( + String path, + Map formData, + Class responseType, + Supplier> headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.POST, path, null, headers.get(), formData); + return execute(request, responseType, errorHandler, h -> {}); + } + + @Override + public T postForm( + String path, + Map formData, + Class responseType, + Map headers, + Consumer errorHandler) { + HTTPRequest request = buildRequest(HTTPMethod.POST, path, null, headers, formData); + return execute(request, responseType, errorHandler, h -> {}); + } + + protected abstract HTTPRequest buildRequest( + HTTPMethod method, + String path, + Map queryParams, + Map headers, + Object body); + + protected abstract T execute( + HTTPRequest request, + Class responseType, + Consumer errorHandler, + Consumer> responseHeaders); +} diff --git a/core/src/main/java/org/apache/iceberg/rest/HTTPClient.java b/core/src/main/java/org/apache/iceberg/rest/HTTPClient.java index 4391d75ec710..f982ba6feb97 100644 --- a/core/src/main/java/org/apache/iceberg/rest/HTTPClient.java +++ b/core/src/main/java/org/apache/iceberg/rest/HTTPClient.java @@ -23,14 +23,11 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.net.URI; -import java.net.URISyntaxException; import java.nio.charset.StandardCharsets; import java.util.Map; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; -import java.util.stream.Collectors; import org.apache.hc.client5.http.auth.CredentialsProvider; -import org.apache.hc.client5.http.classic.methods.HttpUriRequest; import org.apache.hc.client5.http.classic.methods.HttpUriRequestBase; import org.apache.hc.client5.http.config.ConnectionConfig; import org.apache.hc.client5.http.impl.classic.CloseableHttpClient; @@ -45,14 +42,11 @@ import org.apache.hc.core5.http.HttpHost; import org.apache.hc.core5.http.HttpRequestInterceptor; import org.apache.hc.core5.http.HttpStatus; -import org.apache.hc.core5.http.Method; import org.apache.hc.core5.http.ParseException; import org.apache.hc.core5.http.impl.EnglishReasonPhraseCatalog; import org.apache.hc.core5.http.io.entity.EntityUtils; import org.apache.hc.core5.http.io.entity.StringEntity; -import org.apache.hc.core5.http.message.BasicHeader; import org.apache.hc.core5.io.CloseMode; -import org.apache.hc.core5.net.URIBuilder; import org.apache.iceberg.IcebergBuild; import org.apache.iceberg.common.DynConstructors; import org.apache.iceberg.common.DynMethods; @@ -60,13 +54,15 @@ import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.rest.HTTPRequest.HTTPMethod; +import org.apache.iceberg.rest.auth.AuthSession; import org.apache.iceberg.rest.responses.ErrorResponse; import org.apache.iceberg.util.PropertyUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** An HttpClient for usage with the REST catalog. */ -public class HTTPClient implements RESTClient { +public class HTTPClient extends BaseHTTPClient { private static final Logger LOG = LoggerFactory.getLogger(HTTPClient.class); private static final String SIGV4_ENABLED = "rest.sigv4-enabled"; @@ -88,12 +84,14 @@ public class HTTPClient implements RESTClient { @VisibleForTesting static final String REST_SOCKET_TIMEOUT_MS = "rest.client.socket-timeout-ms"; - private final String uri; + private final URI baseUri; private final CloseableHttpClient httpClient; + private final Map baseHeaders; private final ObjectMapper mapper; + private final AuthSession authSession; private HTTPClient( - String uri, + URI baseUri, HttpHost proxy, CredentialsProvider proxyCredsProvider, Map baseHeaders, @@ -101,20 +99,15 @@ private HTTPClient( HttpRequestInterceptor requestInterceptor, Map properties, HttpClientConnectionManager connectionManager) { - this.uri = uri; + this.baseUri = baseUri; + this.baseHeaders = baseHeaders; this.mapper = objectMapper; + this.authSession = AuthSession.EMPTY; HttpClientBuilder clientBuilder = HttpClients.custom(); clientBuilder.setConnectionManager(connectionManager); - if (baseHeaders != null) { - clientBuilder.setDefaultHeaders( - baseHeaders.entrySet().stream() - .map(e -> new BasicHeader(e.getKey(), e.getValue())) - .collect(Collectors.toList())); - } - if (requestInterceptor != null) { clientBuilder.addRequestInterceptorLast(requestInterceptor); } @@ -133,6 +126,24 @@ private HTTPClient( this.httpClient = clientBuilder.build(); } + /** + * Constructor for creating a child HTTPClient associated with an AuthSession. The returned child + * shares the same base uri, mapper, and HTTP client as the parent, thus not requiring any + * additional resource allocation. + */ + private HTTPClient(HTTPClient parent, AuthSession authSession) { + this.baseUri = parent.baseUri; + this.httpClient = parent.httpClient; + this.mapper = parent.mapper; + this.baseHeaders = parent.baseHeaders; + this.authSession = authSession; + } + + @Override + public HTTPClient withAuthSession(AuthSession session) { + return new HTTPClient(this, session); + } + private static String extractResponseBodyAsString(CloseableHttpResponse response) { try { if (response.getEntity() == null) { @@ -214,92 +225,63 @@ private static void throwFailure( throw new RESTException("Unhandled error: %s", errorResponse); } - private URI buildUri(String path, Map params) { - // if full path is provided, use the input path as path - if (path.startsWith("/")) { - throw new RESTException( - "Received a malformed path for a REST request: %s. Paths should not start with /", path); - } - String fullPath = - (path.startsWith("https://") || path.startsWith("http://")) - ? path - : String.format("%s/%s", uri, path); - try { - URIBuilder builder = new URIBuilder(fullPath); - if (params != null) { - params.forEach(builder::addParameter); - } - return builder.build(); - } catch (URISyntaxException e) { - throw new RESTException( - "Failed to create request URI from base %s, params %s", fullPath, params); - } - } - - /** - * Method to execute an HTTP request and process the corresponding response. - * - * @param method - HTTP method, such as GET, POST, HEAD, etc. - * @param queryParams - A map of query parameters - * @param path - URI to send the request to - * @param requestBody - Content to place in the request body - * @param responseType - Class of the Response type. Needs to have serializer registered with - * ObjectMapper - * @param errorHandler - Error handler delegated for HTTP responses which handles server error - * responses - * @param - Class type of the response for deserialization. Must be registered with the - * ObjectMapper. - * @return The response entity, parsed and converted to its type T - */ - private T execute( - Method method, + @Override + protected HTTPRequest buildRequest( + HTTPMethod method, String path, Map queryParams, - Object requestBody, - Class responseType, Map headers, - Consumer errorHandler) { - return execute( - method, path, queryParams, requestBody, responseType, headers, errorHandler, h -> {}); + Object body) { + + ImmutableHTTPRequest.Builder builder = + ImmutableHTTPRequest.builder() + .baseUri(baseUri) + .mapper(mapper) + .method(method) + .path(path) + .body(body) + .queryParameters(queryParams == null ? Map.of() : queryParams); + + Map allHeaders = Maps.newLinkedHashMap(); + if (headers != null) { + allHeaders.putAll(headers); + } + + allHeaders.putIfAbsent(HttpHeaders.ACCEPT, ContentType.APPLICATION_JSON.getMimeType()); + + // Many systems require that content type is set regardless and will fail, + // even on an empty bodied request. + // Encode maps as form data (application/x-www-form-urlencoded), + // and other requests are assumed to contain JSON bodies (application/json). + ContentType mimeType = + body instanceof Map + ? ContentType.APPLICATION_FORM_URLENCODED + : ContentType.APPLICATION_JSON; + allHeaders.putIfAbsent(HttpHeaders.CONTENT_TYPE, mimeType.getMimeType()); + + // Apply base headers now to mimic the behavior of + // org.apache.hc.client5.http.protocol.RequestDefaultHeaders + // We want these headers applied *before* the AuthSession authenticates the request. + if (baseHeaders != null) { + baseHeaders.forEach(allHeaders::putIfAbsent); + } + + return authSession.authenticate(builder.headers(HTTPHeaders.of(allHeaders)).build()); } - /** - * Method to execute an HTTP request and process the corresponding response. - * - * @param method - HTTP method, such as GET, POST, HEAD, etc. - * @param queryParams - A map of query parameters - * @param path - URL to send the request to - * @param requestBody - Content to place in the request body - * @param responseType - Class of the Response type. Needs to have serializer registered with - * ObjectMapper - * @param errorHandler - Error handler delegated for HTTP responses which handles server error - * responses - * @param responseHeaders The consumer of the response headers - * @param - Class type of the response for deserialization. Must be registered with the - * ObjectMapper. - * @return The response entity, parsed and converted to its type T - */ - private T execute( - Method method, - String path, - Map queryParams, - Object requestBody, + @Override + protected T execute( + HTTPRequest req, Class responseType, - Map headers, Consumer errorHandler, Consumer> responseHeaders) { - HttpUriRequestBase request = new HttpUriRequestBase(method.name(), buildUri(path, queryParams)); - - if (requestBody instanceof Map) { - // encode maps as form data, application/x-www-form-urlencoded - addRequestHeaders(request, headers, ContentType.APPLICATION_FORM_URLENCODED.getMimeType()); - request.setEntity(toFormEncoding((Map) requestBody)); - } else if (requestBody != null) { - // other request bodies are serialized as JSON, application/json - addRequestHeaders(request, headers, ContentType.APPLICATION_JSON.getMimeType()); - request.setEntity(toJson(requestBody)); - } else { - addRequestHeaders(request, headers, ContentType.APPLICATION_JSON.getMimeType()); + HttpUriRequestBase request = new HttpUriRequestBase(req.method().name(), req.requestUri()); + + req.headers().entries().forEach(e -> request.addHeader(e.name(), e.value())); + + String encodedBody = req.encodedBody(); + if (encodedBody != null) { + request.setEntity(new StringEntity(encodedBody)); } try (CloseableHttpResponse response = httpClient.execute(request)) { @@ -326,7 +308,7 @@ private T execute( if (responseBody == null) { throw new RESTException( "Invalid (null) response body for request (expected %s): method=%s, path=%s, status=%d", - responseType.getSimpleName(), method.name(), path, response.getCode()); + responseType.getSimpleName(), req.method(), req.path(), response.getCode()); } try { @@ -339,88 +321,17 @@ private T execute( responseType.getSimpleName()); } } catch (IOException e) { - throw new RESTException(e, "Error occurred while processing %s request", method); + throw new RESTException(e, "Error occurred while processing %s request", req.method()); } } - @Override - public void head(String path, Map headers, Consumer errorHandler) { - execute(Method.HEAD, path, null, null, null, headers, errorHandler); - } - - @Override - public T get( - String path, - Map queryParams, - Class responseType, - Map headers, - Consumer errorHandler) { - return execute(Method.GET, path, queryParams, null, responseType, headers, errorHandler); - } - - @Override - public T post( - String path, - RESTRequest body, - Class responseType, - Map headers, - Consumer errorHandler) { - return execute(Method.POST, path, null, body, responseType, headers, errorHandler); - } - - @Override - public T post( - String path, - RESTRequest body, - Class responseType, - Map headers, - Consumer errorHandler, - Consumer> responseHeaders) { - return execute( - Method.POST, path, null, body, responseType, headers, errorHandler, responseHeaders); - } - - @Override - public T delete( - String path, - Class responseType, - Map headers, - Consumer errorHandler) { - return execute(Method.DELETE, path, null, null, responseType, headers, errorHandler); - } - - @Override - public T delete( - String path, - Map queryParams, - Class responseType, - Map headers, - Consumer errorHandler) { - return execute(Method.DELETE, path, queryParams, null, responseType, headers, errorHandler); - } - - @Override - public T postForm( - String path, - Map formData, - Class responseType, - Map headers, - Consumer errorHandler) { - return execute(Method.POST, path, null, formData, responseType, headers, errorHandler); - } - - private void addRequestHeaders( - HttpUriRequest request, Map requestHeaders, String bodyMimeType) { - request.setHeader(HttpHeaders.ACCEPT, ContentType.APPLICATION_JSON.getMimeType()); - // Many systems require that content type is set regardless and will fail, even on an empty - // bodied request. - request.setHeader(HttpHeaders.CONTENT_TYPE, bodyMimeType); - requestHeaders.forEach(request::setHeader); - } - @Override public void close() throws IOException { - httpClient.close(CloseMode.GRACEFUL); + try { + authSession.close(); + } finally { + httpClient.close(CloseMode.GRACEFUL); + } } @VisibleForTesting @@ -510,7 +421,7 @@ public static Builder builder(Map properties) { public static class Builder { private final Map properties; private final Map baseHeaders = Maps.newHashMap(); - private String uri; + private URI uri; private ObjectMapper mapper = RESTObjectMapper.mapper(); private HttpHost proxy; private CredentialsProvider proxyCredentialsProvider; @@ -519,9 +430,19 @@ private Builder(Map properties) { this.properties = properties; } - public Builder uri(String path) { - Preconditions.checkNotNull(path, "Invalid uri for http client: null"); - this.uri = RESTUtil.stripTrailingSlash(path); + public Builder uri(String baseUri) { + Preconditions.checkNotNull(baseUri, "Invalid uri for http client: null"); + try { + this.uri = URI.create(RESTUtil.stripTrailingSlash(baseUri)); + } catch (IllegalArgumentException e) { + throw new RESTException(e, "Failed to create request URI from base %s", baseUri); + } + return this; + } + + public Builder uri(URI baseUri) { + Preconditions.checkNotNull(baseUri, "Invalid uri for http client: null"); + this.uri = baseUri; return this; } @@ -579,16 +500,4 @@ public HTTPClient build() { configureConnectionManager(properties)); } } - - private StringEntity toJson(Object requestBody) { - try { - return new StringEntity(mapper.writeValueAsString(requestBody), StandardCharsets.UTF_8); - } catch (JsonProcessingException e) { - throw new RESTException(e, "Failed to write request body: %s", requestBody); - } - } - - private StringEntity toFormEncoding(Map formData) { - return new StringEntity(RESTUtil.encodeFormData(formData), StandardCharsets.UTF_8); - } } diff --git a/core/src/main/java/org/apache/iceberg/rest/RESTClient.java b/core/src/main/java/org/apache/iceberg/rest/RESTClient.java index 0f17d9a127e2..2843972fee45 100644 --- a/core/src/main/java/org/apache/iceberg/rest/RESTClient.java +++ b/core/src/main/java/org/apache/iceberg/rest/RESTClient.java @@ -23,6 +23,7 @@ import java.util.function.Consumer; import java.util.function.Supplier; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.rest.auth.AuthSession; import org.apache.iceberg.rest.responses.ErrorResponse; /** Interface for a basic HTTP Client for interfacing with the REST catalog. */ @@ -158,4 +159,9 @@ T postForm( Class responseType, Map headers, Consumer errorHandler); + + /** Returns a REST client that authenticates requests using the given session. */ + default RESTClient withAuthSession(AuthSession session) { + return this; + } } diff --git a/core/src/test/java/org/apache/iceberg/rest/RESTCatalogAdapter.java b/core/src/test/java/org/apache/iceberg/rest/RESTCatalogAdapter.java index 2fb4defd1224..3d0984a09092 100644 --- a/core/src/test/java/org/apache/iceberg/rest/RESTCatalogAdapter.java +++ b/core/src/test/java/org/apache/iceberg/rest/RESTCatalogAdapter.java @@ -18,7 +18,9 @@ */ package org.apache.iceberg.rest; +import com.fasterxml.jackson.databind.ObjectMapper; import java.io.IOException; +import java.net.URI; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -50,6 +52,8 @@ import org.apache.iceberg.relocated.com.google.common.base.Splitter; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.rest.HTTPRequest.HTTPMethod; +import org.apache.iceberg.rest.auth.AuthSession; import org.apache.iceberg.rest.requests.CommitTransactionRequest; import org.apache.iceberg.rest.requests.CreateNamespaceRequest; import org.apache.iceberg.rest.requests.CreateTableRequest; @@ -73,7 +77,7 @@ import org.apache.iceberg.util.PropertyUtil; /** Adaptor class to translate REST requests into {@link Catalog} API calls. */ -public class RESTCatalogAdapter implements RESTClient { +public class RESTCatalogAdapter extends BaseHTTPClient { private static final Splitter SLASH = Splitter.on('/'); private static final Map, Integer> EXCEPTION_ERROR_CODES = @@ -98,6 +102,8 @@ public class RESTCatalogAdapter implements RESTClient { private final SupportsNamespaces asNamespaceCatalog; private final ViewCatalog asViewCatalog; + private AuthSession authSession = AuthSession.EMPTY; + public RESTCatalogAdapter(Catalog catalog) { this.catalog = catalog; this.asNamespaceCatalog = @@ -105,13 +111,6 @@ public RESTCatalogAdapter(Catalog catalog) { this.asViewCatalog = catalog instanceof ViewCatalog ? (ViewCatalog) catalog : null; } - enum HTTPMethod { - GET, - HEAD, - POST, - DELETE - } - enum Route { TOKENS(HTTPMethod.POST, "v1/oauth/tokens", null, OAuthTokenResponse.class), SEPARATE_AUTH_TOKENS_URI( @@ -280,6 +279,12 @@ private static OAuthTokenResponse handleOAuthRequest(Object body) { } } + @Override + public RESTClient withAuthSession(AuthSession session) { + this.authSession = session; + return this; + } + @SuppressWarnings({"MethodLength", "checkstyle:CyclomaticComplexity"}) public T handleRequest( Route route, Map vars, Object body, Class responseType) { @@ -567,25 +572,49 @@ private static void commitTransaction(Catalog catalog, CommitTransactionRequest transactions.forEach(Transaction::commitTransaction); } - public T execute( + @Override + protected HTTPRequest buildRequest( HTTPMethod method, String path, Map queryParams, - Object body, - Class responseType, Map headers, - Consumer errorHandler) { + Object body) { + URI baseUri = URI.create("https://localhost:8080"); + ObjectMapper mapper = RESTObjectMapper.mapper(); + ImmutableHTTPRequest.Builder builder = + ImmutableHTTPRequest.builder() + .baseUri(baseUri) + .mapper(mapper) + .method(method) + .path(path) + .body(body); + + if (queryParams != null) { + builder.queryParameters(queryParams); + } + + if (headers != null) { + builder.headers(HTTPHeaders.of(headers)); + } + + return authSession.authenticate(builder.build()); + } + + @Override + protected T execute( + HTTPRequest request, + Class responseType, + Consumer errorHandler, + Consumer> responseHeaders) { ErrorResponse.Builder errorBuilder = ErrorResponse.builder(); - Pair> routeAndVars = Route.from(method, path); + Pair> routeAndVars = Route.from(request.method(), request.path()); if (routeAndVars != null) { try { ImmutableMap.Builder vars = ImmutableMap.builder(); - if (queryParams != null) { - vars.putAll(queryParams); - } + vars.putAll(request.queryParameters()); vars.putAll(routeAndVars.second()); - return handleRequest(routeAndVars.first(), vars.build(), body, responseType); + return handleRequest(routeAndVars.first(), vars.build(), request.body(), responseType); } catch (RuntimeException e) { configureResponseFromException(e, errorBuilder); @@ -595,7 +624,8 @@ public T execute( errorBuilder .responseCode(400) .withType("BadRequestException") - .withMessage(String.format("No route for request: %s %s", method, path)); + .withMessage( + String.format("No route for request: %s %s", request.method(), request.path())); } ErrorResponse error = errorBuilder.build(); @@ -605,60 +635,6 @@ public T execute( throw new RESTException("Unhandled error: %s", error); } - @Override - public T delete( - String path, - Class responseType, - Map headers, - Consumer errorHandler) { - return execute(HTTPMethod.DELETE, path, null, null, responseType, headers, errorHandler); - } - - @Override - public T delete( - String path, - Map queryParams, - Class responseType, - Map headers, - Consumer errorHandler) { - return execute(HTTPMethod.DELETE, path, queryParams, null, responseType, headers, errorHandler); - } - - @Override - public T post( - String path, - RESTRequest body, - Class responseType, - Map headers, - Consumer errorHandler) { - return execute(HTTPMethod.POST, path, null, body, responseType, headers, errorHandler); - } - - @Override - public T get( - String path, - Map queryParams, - Class responseType, - Map headers, - Consumer errorHandler) { - return execute(HTTPMethod.GET, path, queryParams, null, responseType, headers, errorHandler); - } - - @Override - public void head(String path, Map headers, Consumer errorHandler) { - execute(HTTPMethod.HEAD, path, null, null, null, headers, errorHandler); - } - - @Override - public T postForm( - String path, - Map formData, - Class responseType, - Map headers, - Consumer errorHandler) { - return execute(HTTPMethod.POST, path, null, formData, responseType, headers, errorHandler); - } - @Override public void close() throws IOException { // The calling test is responsible for closing the underlying catalog backing this REST catalog diff --git a/core/src/test/java/org/apache/iceberg/rest/RESTCatalogServlet.java b/core/src/test/java/org/apache/iceberg/rest/RESTCatalogServlet.java index f456bb4d354d..97c8b6134482 100644 --- a/core/src/test/java/org/apache/iceberg/rest/RESTCatalogServlet.java +++ b/core/src/test/java/org/apache/iceberg/rest/RESTCatalogServlet.java @@ -38,7 +38,7 @@ import org.apache.iceberg.exceptions.RESTException; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.io.CharStreams; -import org.apache.iceberg.rest.RESTCatalogAdapter.HTTPMethod; +import org.apache.iceberg.rest.HTTPRequest.HTTPMethod; import org.apache.iceberg.rest.RESTCatalogAdapter.Route; import org.apache.iceberg.rest.responses.ErrorResponse; import org.apache.iceberg.util.Pair; @@ -96,15 +96,17 @@ protected void execute(ServletRequestContext context, HttpServletResponse respon } try { - Object responseBody = - restCatalogAdapter.execute( + + HTTPRequest request = + restCatalogAdapter.buildRequest( context.method(), context.path(), context.queryParams(), - context.body(), - context.route().responseClass(), context.headers(), - handle(response)); + context.body()); + Object responseBody = + restCatalogAdapter.execute( + request, context.route().responseClass(), handle(response), h -> {}); if (responseBody != null) { RESTObjectMapper.mapper().writeValue(response.getWriter(), responseBody); diff --git a/core/src/test/java/org/apache/iceberg/rest/TestRESTCatalog.java b/core/src/test/java/org/apache/iceberg/rest/TestRESTCatalog.java index 768d6c3777ee..43a94c401eb4 100644 --- a/core/src/test/java/org/apache/iceberg/rest/TestRESTCatalog.java +++ b/core/src/test/java/org/apache/iceberg/rest/TestRESTCatalog.java @@ -22,6 +22,8 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.times; @@ -32,7 +34,9 @@ import java.io.File; import java.io.IOException; import java.nio.file.Path; +import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.UUID; import java.util.concurrent.TimeUnit; @@ -65,7 +69,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.collect.Maps; -import org.apache.iceberg.rest.RESTCatalogAdapter.HTTPMethod; +import org.apache.iceberg.rest.HTTPRequest.HTTPMethod; import org.apache.iceberg.rest.RESTSessionCatalog.SnapshotMode; import org.apache.iceberg.rest.auth.AuthSessionUtil; import org.apache.iceberg.rest.auth.OAuth2Properties; @@ -115,35 +119,31 @@ public void createCatalog() throws Exception { "in-memory", ImmutableMap.of(CatalogProperties.WAREHOUSE_LOCATION, warehouse.getAbsolutePath())); - Map catalogHeaders = - ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=catalog"); - Map contextHeaders = - ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=user"); + HTTPHeaders catalogHeaders = + HTTPHeaders.of(Map.of("Authorization", "Bearer client-credentials-token:sub=catalog")); + HTTPHeaders contextHeaders = + HTTPHeaders.of(Map.of("Authorization", "Bearer client-credentials-token:sub=user")); RESTCatalogAdapter adaptor = new RESTCatalogAdapter(backendCatalog) { @Override public T execute( - RESTCatalogAdapter.HTTPMethod method, - String path, - Map queryParams, - Object body, + HTTPRequest request, Class responseType, - Map headers, - Consumer errorHandler) { + Consumer errorHandler, + Consumer> responseHeaders) { // this doesn't use a Mockito spy because this is used for catalog tests, which have // different method calls - if (!"v1/oauth/tokens".equals(path)) { - if ("v1/config".equals(path)) { - assertThat(headers).containsAllEntriesOf(catalogHeaders); + if (!"v1/oauth/tokens".equals(request.path())) { + if ("v1/config".equals(request.path())) { + assertThat(request.headers().entries()).containsAll(catalogHeaders.entries()); } else { - assertThat(headers).containsAllEntriesOf(contextHeaders); + assertThat(request.headers().entries()).containsAll(contextHeaders.entries()); } } - Object request = roundTripSerialize(body, "request"); - T response = - super.execute( - method, path, queryParams, request, responseType, headers, errorHandler); + Object body = roundTripSerialize(request.body(), "request"); + HTTPRequest req = ImmutableHTTPRequest.builder().from(request).body(body).build(); + T response = super.execute(req, responseType, errorHandler, responseHeaders); T responseAfterSerialization = roundTripSerialize(response, "response"); return responseAfterSerialization; } @@ -259,13 +259,12 @@ public void testConfigRoute() throws IOException { RESTClient testClient = new RESTCatalogAdapter(backendCatalog) { @Override - public T get( - String path, - Map queryParams, + public T execute( + HTTPRequest request, Class responseType, - Map headers, - Consumer errorHandler) { - if ("v1/config".equals(path)) { + Consumer errorHandler, + Consumer> responseHeaders) { + if ("v1/config".equals(request.path())) { return castResponse( responseType, ConfigResponse.builder() @@ -275,10 +274,11 @@ public T get( CatalogProperties.CACHE_ENABLED, "false", CatalogProperties.WAREHOUSE_LOCATION, - queryParams.get(CatalogProperties.WAREHOUSE_LOCATION) + "warehouse")) + request.queryParameters().get(CatalogProperties.WAREHOUSE_LOCATION) + + "warehouse")) .build()); } - return super.get(path, queryParams, responseType, headers, errorHandler); + return super.execute(request, responseType, errorHandler, responseHeaders); } }; @@ -337,27 +337,20 @@ public void testCatalogBasicBearerToken() { // the bearer token should be used for all interactions Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), + reqMatcher(HTTPMethod.HEAD, "v1/namespaces/ns/tables/table", catalogHeaders), any(), any(), - any(), - eq(catalogHeaders), any()); } @Test public void testCatalogCredentialNoOauth2ServerUri() { - Map emptyHeaders = ImmutableMap.of(); Map catalogHeaders = ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=catalog"); @@ -373,39 +366,29 @@ public void testCatalogCredentialNoOauth2ServerUri() { // no token or credential for catalog token exchange Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq("v1/oauth/tokens"), - any(), - any(), + reqMatcher(HTTPMethod.POST, "v1/oauth/tokens", Map.of()), eq(OAuthTokenResponse.class), - eq(emptyHeaders), + any(), any()); // no token or credential for config Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // use the catalog token for all interactions Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), + reqMatcher(HTTPMethod.HEAD, "v1/namespaces/ns/tables/table", catalogHeaders), any(), any(), - any(), - eq(catalogHeaders), any()); } @ParameterizedTest @ValueSource(strings = {"v1/oauth/tokens", "https://auth-server.com/token"}) public void testCatalogCredential(String oauth2ServerUri) { - Map emptyHeaders = ImmutableMap.of(); Map catalogHeaders = ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=catalog"); @@ -428,32 +411,23 @@ public void testCatalogCredential(String oauth2ServerUri) { // no token or credential for catalog token exchange Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, Map.of()), eq(OAuthTokenResponse.class), - eq(emptyHeaders), + any(), any()); // no token or credential for config Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // use the catalog token for all interactions Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), - any(), + reqMatcher(HTTPMethod.HEAD, "v1/namespaces/ns/tables/table", catalogHeaders), any(), any(), - eq(catalogHeaders), any()); } @@ -489,39 +463,29 @@ public void testCatalogBearerTokenWithClientCredential(String oauth2ServerUri) { // use the bearer token for config Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // use the bearer token to fetch the context token Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, catalogHeaders), eq(OAuthTokenResponse.class), - eq(catalogHeaders), + any(), any()); // use the context token for table existence check Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), + reqMatcher(HTTPMethod.HEAD, "v1/namespaces/ns/tables/table", contextHeaders), any(), any(), - any(), - eq(contextHeaders), any()); } @ParameterizedTest @ValueSource(strings = {"v1/oauth/tokens", "https://auth-server.com/token"}) public void testCatalogCredentialWithClientCredential(String oauth2ServerUri) { - Map emptyHeaders = ImmutableMap.of(); Map contextHeaders = ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=user"); Map catalogHeaders = @@ -552,42 +516,30 @@ public void testCatalogCredentialWithClientCredential(String oauth2ServerUri) { // call client credentials with no initial auth Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, Map.of()), eq(OAuthTokenResponse.class), - eq(emptyHeaders), + any(), any()); // use the client credential token for config Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // use the client credential to fetch the context token Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, catalogHeaders), eq(OAuthTokenResponse.class), - eq(catalogHeaders), + any(), any()); // use the context token for table existence check Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), - any(), + reqMatcher(HTTPMethod.HEAD, "v1/namespaces/ns/tables/table", contextHeaders), any(), any(), - eq(contextHeaders), any()); } @@ -627,42 +579,30 @@ public void testCatalogBearerTokenAndCredentialWithClientCredential(String oauth // use the bearer token for client credentials Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, initHeaders), eq(OAuthTokenResponse.class), - eq(initHeaders), + any(), any()); // use the client credential token for config Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // use the client credential to fetch the context token Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, catalogHeaders), eq(OAuthTokenResponse.class), - eq(catalogHeaders), + any(), any()); // use the context token for table existence check Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), + reqMatcher(HTTPMethod.HEAD, "v1/namespaces/ns/tables/table", contextHeaders), any(), any(), - any(), - eq(contextHeaders), any()); } @@ -822,12 +762,9 @@ private void testClientAuth( Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // token passes a static token. otherwise, validate a client credentials or token exchange @@ -835,34 +772,22 @@ private void testClientAuth( if (!credentials.containsKey("token")) { Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, catalogHeaders), eq(OAuthTokenResponse.class), - eq(catalogHeaders), + any(), any()); } Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), - any(), + reqMatcher(HTTPMethod.HEAD, "v1/namespaces/ns/tables/table", expectedHeaders), any(), any(), - eq(expectedHeaders), any()); if (!optionalOAuthParams.isEmpty()) { Mockito.verify(adapter) - .execute( - eq(HTTPMethod.POST), + .postForm( eq(oauth2ServerUri), - any(), - Mockito.argThat( - body -> - ((Map) body) - .keySet() - .containsAll(optionalOAuthParams.keySet())), + Mockito.argThat(body -> body.keySet().containsAll(optionalOAuthParams.keySet())), eq(OAuthTokenResponse.class), eq(catalogHeaders), any()); @@ -984,10 +909,7 @@ public void testTableSnapshotLoading() { Mockito.doAnswer(refsAnswer) .when(adapter) .execute( - eq(HTTPMethod.GET), - eq(paths.table(TABLE)), - eq(ImmutableMap.of("snapshots", "refs")), - any(), + reqMatcher(HTTPMethod.GET, paths.table(TABLE), Map.of(), Map.of("snapshots", "refs")), eq(LoadTableResponse.class), any(), any()); @@ -998,10 +920,7 @@ public void testTableSnapshotLoading() { // verify that the table was loaded with the refs argument verify(adapter, times(1)) .execute( - eq(HTTPMethod.GET), - eq(paths.table(TABLE)), - eq(ImmutableMap.of("snapshots", "refs")), - any(), + reqMatcher(HTTPMethod.GET, paths.table(TABLE), Map.of(), Map.of("snapshots", "refs")), eq(LoadTableResponse.class), any(), any()); @@ -1010,10 +929,7 @@ public void testTableSnapshotLoading() { assertThat(refsTables.snapshots()).containsExactlyInAnyOrderElementsOf(table.snapshots()); verify(adapter, times(1)) .execute( - eq(HTTPMethod.GET), - eq(paths.table(TABLE)), - eq(ImmutableMap.of("snapshots", "all")), - any(), + reqMatcher(HTTPMethod.GET, paths.table(TABLE), Map.of(), Map.of("snapshots", "all")), eq(LoadTableResponse.class), any(), any()); @@ -1110,10 +1026,7 @@ public void testTableSnapshotLoadingWithDivergedBranches(String formatVersion) { Mockito.doAnswer(refsAnswer) .when(adapter) .execute( - eq(HTTPMethod.GET), - eq(paths.table(TABLE)), - eq(ImmutableMap.of("snapshots", "refs")), - any(), + reqMatcher(HTTPMethod.GET, paths.table(TABLE), Map.of(), Map.of("snapshots", "refs")), eq(LoadTableResponse.class), any(), any()); @@ -1124,10 +1037,7 @@ public void testTableSnapshotLoadingWithDivergedBranches(String formatVersion) { // verify that the table was loaded with the refs argument verify(adapter, times(1)) .execute( - eq(HTTPMethod.GET), - eq(paths.table(TABLE)), - eq(ImmutableMap.of("snapshots", "refs")), - any(), + reqMatcher(HTTPMethod.GET, paths.table(TABLE), Map.of(), Map.of("snapshots", "refs")), eq(LoadTableResponse.class), any(), any()); @@ -1137,10 +1047,7 @@ public void testTableSnapshotLoadingWithDivergedBranches(String formatVersion) { .containsExactlyInAnyOrderElementsOf(table.snapshots()); verify(adapter, times(1)) .execute( - eq(HTTPMethod.GET), - eq(paths.table(TABLE)), - eq(ImmutableMap.of("snapshots", "all")), - any(), + reqMatcher(HTTPMethod.GET, paths.table(TABLE), Map.of(), Map.of("snapshots", "all")), eq(LoadTableResponse.class), any(), any()); @@ -1226,10 +1133,7 @@ public void lazySnapshotLoadingWithDivergedHistory() { Mockito.doAnswer(refsAnswer) .when(adapter) .execute( - eq(HTTPMethod.GET), - eq(paths.table(TABLE)), - eq(ImmutableMap.of("snapshots", "refs")), - any(), + reqMatcher(HTTPMethod.GET, paths.table(TABLE), Map.of(), Map.of("snapshots", "refs")), eq(LoadTableResponse.class), any(), any()); @@ -1267,23 +1171,17 @@ public void testTableAuth( Mockito.doAnswer(addTableConfig) .when(adapter) .execute( - eq(HTTPMethod.POST), - eq("v1/namespaces/ns/tables"), - any(), - any(), + reqMatcher(HTTPMethod.POST, "v1/namespaces/ns/tables", expectedContextHeaders), eq(LoadTableResponse.class), - eq(expectedContextHeaders), + any(), any()); Mockito.doAnswer(addTableConfig) .when(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/namespaces/ns/tables/table"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/namespaces/ns/tables/table", expectedContextHeaders), eq(LoadTableResponse.class), - eq(expectedContextHeaders), + any(), any()); SessionCatalog.SessionContext context = @@ -1324,33 +1222,24 @@ public void testTableAuth( Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // session client credentials flow Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, catalogHeaders), eq(OAuthTokenResponse.class), - eq(catalogHeaders), + any(), any()); // create table request Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq("v1/namespaces/ns/tables"), - any(), - any(), + reqMatcher(HTTPMethod.POST, "v1/namespaces/ns/tables", expectedContextHeaders), eq(LoadTableResponse.class), - eq(expectedContextHeaders), + any(), any()); // if the table returned a bearer token or a credential, there will be no token request @@ -1358,12 +1247,9 @@ public void testTableAuth( // token exchange to get a table token Mockito.verify(adapter, times(1)) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, expectedContextHeaders), eq(OAuthTokenResponse.class), - eq(expectedContextHeaders), + any(), any()); } @@ -1371,34 +1257,25 @@ public void testTableAuth( // load table from catalog + refresh loaded table Mockito.verify(adapter, times(2)) .execute( - eq(HTTPMethod.GET), - eq("v1/namespaces/ns/tables/table"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/namespaces/ns/tables/table", expectedTableHeaders), eq(LoadTableResponse.class), - eq(expectedTableHeaders), + any(), any()); } else { // load table from catalog Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/namespaces/ns/tables/table"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/namespaces/ns/tables/table", expectedContextHeaders), eq(LoadTableResponse.class), - eq(expectedContextHeaders), + any(), any()); // refresh loaded table Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/namespaces/ns/tables/table"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/namespaces/ns/tables/table", expectedTableHeaders), eq(LoadTableResponse.class), - eq(expectedTableHeaders), + any(), any()); } } @@ -1406,7 +1283,6 @@ public void testTableAuth( @ParameterizedTest @ValueSource(strings = {"v1/oauth/tokens", "https://auth-server.com/token"}) public void testCatalogTokenRefresh(String oauth2ServerUri) { - Map emptyHeaders = ImmutableMap.of(); Map catalogHeaders = ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=catalog"); @@ -1426,14 +1302,7 @@ public void testCatalogTokenRefresh(String oauth2ServerUri) { Mockito.doAnswer(addOneSecondExpiration) .when(adapter) - .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), - eq(OAuthTokenResponse.class), - any(), - any()); + .postForm(eq(oauth2ServerUri), anyMap(), eq(OAuthTokenResponse.class), anyMap(), any()); Map contextCredentials = ImmutableMap.of(); SessionCatalog.SessionContext context = @@ -1458,23 +1327,17 @@ public void testCatalogTokenRefresh(String oauth2ServerUri) { // call client credentials with no initial auth Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, Map.of()), eq(OAuthTokenResponse.class), - eq(emptyHeaders), + any(), any()); // use the client credential token for config Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // verify the first token exchange @@ -1486,12 +1349,14 @@ public void testCatalogTokenRefresh(String oauth2ServerUri) { "scope", "catalog"); Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - Mockito.argThat(firstRefreshRequest::equals), + reqMatcher( + HTTPMethod.POST, + oauth2ServerUri, + catalogHeaders, + Map.of(), + firstRefreshRequest), eq(OAuthTokenResponse.class), - eq(catalogHeaders), + any(), any()); // verify that a second exchange occurs @@ -1508,12 +1373,14 @@ public void testCatalogTokenRefresh(String oauth2ServerUri) { "Bearer token-exchange-token:sub=client-credentials-token:sub=catalog"); Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - Mockito.argThat(secondRefreshRequest::equals), + reqMatcher( + HTTPMethod.POST, + oauth2ServerUri, + secondRefreshHeaders, + Map.of(), + secondRefreshRequest), eq(OAuthTokenResponse.class), - eq(secondRefreshHeaders), + any(), any()); }); } @@ -1521,7 +1388,6 @@ public void testCatalogTokenRefresh(String oauth2ServerUri) { @ParameterizedTest @ValueSource(strings = {"v1/oauth/tokens", "https://auth-server.com/token"}) public void testCatalogRefreshedTokenIsUsed(String oauth2ServerUri) { - Map emptyHeaders = ImmutableMap.of(); Map catalogHeaders = ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=catalog"); @@ -1541,14 +1407,7 @@ public void testCatalogRefreshedTokenIsUsed(String oauth2ServerUri) { Mockito.doAnswer(addOneSecondExpiration) .when(adapter) - .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), - eq(OAuthTokenResponse.class), - any(), - any()); + .postForm(eq(oauth2ServerUri), anyMap(), eq(OAuthTokenResponse.class), anyMap(), any()); Map contextCredentials = ImmutableMap.of(); SessionCatalog.SessionContext context = @@ -1574,25 +1433,30 @@ public void testCatalogRefreshedTokenIsUsed(String oauth2ServerUri) { assertThat(catalog.tableExists(TableIdentifier.of("ns", "table"))).isFalse(); // call client credentials with no initial auth + Map clientCredentialsRequest = + ImmutableMap.of( + "grant_type", "client_credentials", + "client_id", "catalog", + "client_secret", "secret", + "scope", "catalog"); Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher( + HTTPMethod.POST, + oauth2ServerUri, + Map.of(), + Map.of(), + clientCredentialsRequest), eq(OAuthTokenResponse.class), - eq(emptyHeaders), + any(), any()); // use the client credential token for config Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // verify the first token exchange @@ -1604,12 +1468,14 @@ public void testCatalogRefreshedTokenIsUsed(String oauth2ServerUri) { "scope", "catalog"); Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - Mockito.argThat(firstRefreshRequest::equals), + reqMatcher( + HTTPMethod.POST, + oauth2ServerUri, + catalogHeaders, + Map.of(), + firstRefreshRequest), eq(OAuthTokenResponse.class), - eq(catalogHeaders), + any(), any()); // use the refreshed context token for table existence check @@ -1619,12 +1485,10 @@ public void testCatalogRefreshedTokenIsUsed(String oauth2ServerUri) { "Bearer token-exchange-token:sub=client-credentials-token:sub=catalog"); Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), - any(), + reqMatcher( + HTTPMethod.HEAD, "v1/namespaces/ns/tables/table", refreshedCatalogHeader), any(), any(), - eq(refreshedCatalogHeader), any()); }); } @@ -1653,7 +1517,6 @@ public void testCatalogExpiredBearerTokenIsRefreshedWithCredential(String oauth2 // expires at epoch second = 1 String token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyLCJleHAiOjF9.gQADTbdEv-rpDWKSkGLbmafyB5UUjTdm9B_1izpuZ6E"; - Map emptyHeaders = ImmutableMap.of(); Map catalogHeaders = ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=catalog"); @@ -1680,24 +1543,25 @@ public void testCatalogExpiredBearerTokenIsRefreshedWithCredential(String oauth2 assertThat(catalog.tableExists(TableIdentifier.of("ns", "table"))).isFalse(); // call client credentials with no initial auth + Map clientCredentialsRequest = + ImmutableMap.of( + "grant_type", "client_credentials", + "client_id", "catalog", + "client_secret", "12345", + "scope", "catalog"); Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher( + HTTPMethod.POST, oauth2ServerUri, Map.of(), Map.of(), clientCredentialsRequest), eq(OAuthTokenResponse.class), - eq(emptyHeaders), + any(), any()); Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); Map firstRefreshRequest = @@ -1708,12 +1572,14 @@ public void testCatalogExpiredBearerTokenIsRefreshedWithCredential(String oauth2 "scope", "catalog"); Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - Mockito.argThat(firstRefreshRequest::equals), + reqMatcher( + HTTPMethod.POST, + oauth2ServerUri, + OAuth2Util.basicAuthHeaders(credential), + Map.of(), + firstRefreshRequest), eq(OAuthTokenResponse.class), - eq(OAuth2Util.basicAuthHeaders(credential)), + any(), any()); // verify that a second exchange occurs @@ -1725,22 +1591,24 @@ public void testCatalogExpiredBearerTokenIsRefreshedWithCredential(String oauth2 "scope", "catalog"); Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - Mockito.argThat(secondRefreshRequest::equals), + reqMatcher( + HTTPMethod.POST, + oauth2ServerUri, + OAuth2Util.basicAuthHeaders(credential), + Map.of(), + secondRefreshRequest), eq(OAuthTokenResponse.class), - eq(OAuth2Util.basicAuthHeaders(credential)), + any(), any()); Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), - any(), + reqMatcher( + HTTPMethod.HEAD, + "v1/namespaces/ns/tables/table", + Map.of("Authorization", "Bearer token-exchange-token:sub=" + token)), any(), any(), - eq(ImmutableMap.of("Authorization", "Bearer token-exchange-token:sub=" + token)), any()); } @@ -1767,29 +1635,23 @@ public void testCatalogValidBearerTokenIsNotRefreshed() { Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), - any(), + reqMatcher( + HTTPMethod.HEAD, "v1/namespaces/ns/tables/table", OAuth2Util.authHeaders(token)), any(), any(), - eq(OAuth2Util.authHeaders(token)), any()); } @ParameterizedTest @ValueSource(strings = {"v1/oauth/tokens", "https://auth-server.com/token"}) public void testCatalogTokenRefreshFailsAndUsesCredentialForRefresh(String oauth2ServerUri) { - Map emptyHeaders = ImmutableMap.of(); Map catalogHeaders = ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=catalog"); @@ -1812,14 +1674,7 @@ public void testCatalogTokenRefreshFailsAndUsesCredentialForRefresh(String oauth Mockito.doAnswer(addOneSecondExpiration) .when(adapter) - .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), - eq(OAuthTokenResponse.class), - any(), - any()); + .postForm(eq(oauth2ServerUri), anyMap(), eq(OAuthTokenResponse.class), anyMap(), any()); Map firstRefreshRequest = ImmutableMap.of( @@ -1831,11 +1686,9 @@ public void testCatalogTokenRefreshFailsAndUsesCredentialForRefresh(String oauth // simulate that the token expired when it was about to be refreshed Mockito.doThrow(new RuntimeException("token expired")) .when(adapter) - .execute( - eq(HTTPMethod.POST), + .postForm( eq(oauth2ServerUri), - any(), - Mockito.argThat(firstRefreshRequest::equals), + argThat(firstRefreshRequest::equals), eq(OAuthTokenResponse.class), eq(catalogHeaders), any()); @@ -1867,47 +1720,47 @@ public void testCatalogTokenRefreshFailsAndUsesCredentialForRefresh(String oauth assertThat(catalog.tableExists(TableIdentifier.of("ns", "table"))).isFalse(); // call client credentials with no initial auth + Map clientCredentialsRequest = + ImmutableMap.of( + "grant_type", "client_credentials", + "client_id", "catalog", + "client_secret", "secret", + "scope", "catalog"); Mockito.verify(adapter) .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), + reqMatcher( + HTTPMethod.POST, + oauth2ServerUri, + Map.of(), + Map.of(), + clientCredentialsRequest), eq(OAuthTokenResponse.class), - eq(emptyHeaders), + any(), any()); // use the client credential token for config Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // verify the first token exchange - since an exception is thrown, we're performing // retries Mockito.verify(adapter, times(2)) - .execute( - eq(HTTPMethod.POST), + .postForm( eq(oauth2ServerUri), - any(), - Mockito.argThat(firstRefreshRequest::equals), + argThat(firstRefreshRequest::equals), eq(OAuthTokenResponse.class), eq(catalogHeaders), any()); - // here we make sure that the basic auth header is used after token refresh retries // failed Mockito.verify(adapter) - .execute( - eq(HTTPMethod.POST), + .postForm( eq(oauth2ServerUri), - any(), - Mockito.argThat(firstRefreshRequest::equals), + argThat(firstRefreshRequest::equals), eq(OAuthTokenResponse.class), eq(basicHeaders), any()); @@ -1919,12 +1772,10 @@ public void testCatalogTokenRefreshFailsAndUsesCredentialForRefresh(String oauth "Bearer token-exchange-token:sub=client-credentials-token:sub=catalog"); Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/tables/table"), + reqMatcher( + HTTPMethod.HEAD, "v1/namespaces/ns/tables/table", refreshedCatalogHeader), any(), any(), - any(), - eq(refreshedCatalogHeader), any()); }); } @@ -1932,7 +1783,6 @@ public void testCatalogTokenRefreshFailsAndUsesCredentialForRefresh(String oauth @ParameterizedTest @ValueSource(strings = {"v1/oauth/tokens", "https://auth-server.com/token"}) public void testCatalogWithCustomTokenScope(String oauth2ServerUri) { - Map emptyHeaders = ImmutableMap.of(); Map catalogHeaders = ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=catalog"); @@ -1952,14 +1802,7 @@ public void testCatalogWithCustomTokenScope(String oauth2ServerUri) { Mockito.doAnswer(addOneSecondExpiration) .when(adapter) - .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), - eq(OAuthTokenResponse.class), - any(), - any()); + .postForm(eq(oauth2ServerUri), anyMap(), eq(OAuthTokenResponse.class), anyMap(), any()); Map contextCredentials = ImmutableMap.of(); SessionCatalog.SessionContext context = @@ -1986,24 +1829,19 @@ public void testCatalogWithCustomTokenScope(String oauth2ServerUri) { () -> { // call client credentials with no initial auth Mockito.verify(adapter) - .execute( - eq(HTTPMethod.POST), + .postForm( eq(oauth2ServerUri), - any(), - any(), + anyMap(), eq(OAuthTokenResponse.class), - eq(emptyHeaders), + eq(Map.of()), any()); // use the client credential token for config Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); // verify the token exchange uses the right scope @@ -2014,11 +1852,9 @@ public void testCatalogWithCustomTokenScope(String oauth2ServerUri) { "subject_token_type", "urn:ietf:params:oauth:token-type:access_token", "scope", scope); Mockito.verify(adapter) - .execute( - eq(HTTPMethod.POST), + .postForm( eq(oauth2ServerUri), - any(), - Mockito.argThat(firstRefreshRequest::equals), + argThat(firstRefreshRequest::equals), eq(OAuthTokenResponse.class), eq(catalogHeaders), any()); @@ -2047,14 +1883,7 @@ public void testCatalogTokenRefreshDisabledWithToken(String oauth2ServerUri) { Mockito.doAnswer(addOneSecondExpiration) .when(adapter) - .execute( - eq(HTTPMethod.POST), - eq(oauth2ServerUri), - any(), - any(), - eq(OAuthTokenResponse.class), - any(), - any()); + .postForm(eq(oauth2ServerUri), anyMap(), eq(OAuthTokenResponse.class), anyMap(), any()); Map contextCredentials = ImmutableMap.of(); SessionCatalog.SessionContext context = @@ -2076,12 +1905,9 @@ public void testCatalogTokenRefreshDisabledWithToken(String oauth2ServerUri) { Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); } @@ -2122,23 +1948,18 @@ public void testCatalogTokenRefreshDisabledWithCredential(String oauth2ServerUri "scope", "catalog"); Mockito.verify(adapter) - .execute( - eq(HTTPMethod.POST), + .postForm( eq(oauth2ServerUri), - any(), - Mockito.argThat(fetchTokenFromCredential::equals), + argThat(fetchTokenFromCredential::equals), eq(OAuthTokenResponse.class), - eq(ImmutableMap.of()), + anyMap(), any()); Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", catalogHeaders), eq(ConfigResponse.class), - eq(catalogHeaders), + any(), any()); } @@ -2311,20 +2132,14 @@ public void testPaginationForListNamespaces(int numberOfItems) { Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", Map.of(), Map.of()), eq(ConfigResponse.class), any(), any()); Mockito.verify(adapter, times(numberOfItems)) .execute( - eq(HTTPMethod.POST), - eq("v1/namespaces"), - any(), - any(), + reqMatcher(HTTPMethod.POST, "v1/namespaces", Map.of(), Map.of()), eq(CreateNamespaceResponse.class), any(), any()); @@ -2375,20 +2190,18 @@ public void testPaginationForListTables(int numberOfItems) { Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", Map.of(), Map.of()), eq(ConfigResponse.class), any(), any()); Mockito.verify(adapter, times(numberOfItems)) .execute( - eq(HTTPMethod.POST), - eq(String.format("v1/namespaces/%s/tables", namespaceName)), - any(), - any(), + reqMatcher( + HTTPMethod.POST, + String.format("v1/namespaces/%s/tables", namespaceName), + Map.of(), + Map.of()), eq(LoadTableResponse.class), any(), any()); @@ -2436,21 +2249,27 @@ public void testCleanupUncommitedFilesForCleanableFailures() { .build(); Table table = catalog.loadTable(TABLE); - ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateTableRequest.class); Mockito.doThrow(new NotAuthorizedException("not authorized")) .when(adapter) - .post(any(), any(), any(), any(Map.class), any()); + .execute(reqMatcher(HTTPMethod.POST), any(), any(), any()); assertThatThrownBy(() -> catalog.loadTable(TABLE).newFastAppend().appendFile(file).commit()) .isInstanceOf(NotAuthorizedException.class); - verify(adapter, atLeastOnce()) - .post(eq(RESOURCE_PATHS.table(TABLE)), captor.capture(), any(), any(Map.class), any()); // Extract the UpdateTableRequest to determine the path of the manifest list that should be // cleaned up - UpdateTableRequest request = captor.getValue(); - MetadataUpdate.AddSnapshot addSnapshot = (MetadataUpdate.AddSnapshot) request.updates().get(0); - assertThatThrownBy(() -> table.io().newInputFile(addSnapshot.snapshot().manifestListLocation())) - .isInstanceOf(NotFoundException.class); + assertThat(allRequests(adapter)) + .anySatisfy( + req -> { + assertThat(req.method()).isEqualTo(HTTPMethod.POST); + assertThat(req.path()).isEqualTo(RESOURCE_PATHS.table(TABLE)); + assertThat(req.body()).isInstanceOf(UpdateTableRequest.class); + UpdateTableRequest body = (UpdateTableRequest) req.body(); + MetadataUpdate.AddSnapshot addSnapshot = + (MetadataUpdate.AddSnapshot) body.updates().get(0); + assertThatThrownBy( + () -> table.io().newInputFile(addSnapshot.snapshot().manifestListLocation())) + .isInstanceOf(NotFoundException.class); + }); } @Test @@ -2465,21 +2284,26 @@ public void testNoCleanupForNonCleanableExceptions() { catalog.createTable(TABLE, SCHEMA); Table table = catalog.loadTable(TABLE); - ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateTableRequest.class); Mockito.doThrow(new ServiceFailureException("some service failure")) .when(adapter) - .post(any(), any(), any(), any(Map.class), any()); + .execute(reqMatcher(HTTPMethod.POST), any(), any(), any()); assertThatThrownBy(() -> catalog.loadTable(TABLE).newFastAppend().appendFile(FILE_A).commit()) .isInstanceOf(ServiceFailureException.class); - verify(adapter, atLeastOnce()) - .post(eq(RESOURCE_PATHS.table(TABLE)), captor.capture(), any(), any(Map.class), any()); // Extract the UpdateTableRequest to determine the path of the manifest list that should still // exist even though the commit failed - UpdateTableRequest request = captor.getValue(); - MetadataUpdate.AddSnapshot addSnapshot = (MetadataUpdate.AddSnapshot) request.updates().get(0); - assertThat(table.io().newInputFile(addSnapshot.snapshot().manifestListLocation()).exists()) - .isTrue(); + assertThat(allRequests(adapter)) + .anySatisfy( + req -> { + assertThat(req.method()).isEqualTo(HTTPMethod.POST); + assertThat(req.path()).isEqualTo(RESOURCE_PATHS.table(TABLE)); + assertThat(req.body()).isInstanceOf(UpdateTableRequest.class); + UpdateTableRequest body = (UpdateTableRequest) req.body(); + MetadataUpdate.AddSnapshot addSnapshot = + (MetadataUpdate.AddSnapshot) body.updates().get(0); + String manifestListLocation = addSnapshot.snapshot().manifestListLocation(); + assertThat(table.io().newInputFile(manifestListLocation).exists()).isTrue(); + }); } @Test @@ -2493,32 +2317,38 @@ public void testCleanupCleanableExceptionsCreate() { catalog.createTable(TABLE, SCHEMA); TableIdentifier newTable = TableIdentifier.of(TABLE.namespace(), "some_table"); - ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateTableRequest.class); Mockito.doThrow(new NotAuthorizedException("not authorized")) .when(adapter) - .post(eq(RESOURCE_PATHS.table(newTable)), any(), any(), any(Map.class), any()); + .execute(reqMatcher(HTTPMethod.POST, RESOURCE_PATHS.table(newTable)), any(), any(), any()); Transaction createTableTransaction = catalog.newCreateTableTransaction(newTable, SCHEMA); createTableTransaction.newAppend().appendFile(FILE_A).commit(); assertThatThrownBy(createTableTransaction::commitTransaction) .isInstanceOf(NotAuthorizedException.class); - verify(adapter, atLeastOnce()) - .post(eq(RESOURCE_PATHS.table(newTable)), captor.capture(), any(), any(Map.class), any()); - UpdateTableRequest request = captor.getValue(); - Optional appendSnapshot = - request.updates().stream() - .filter(update -> update instanceof MetadataUpdate.AddSnapshot) - .findFirst(); - - assertThat(appendSnapshot).isPresent(); - MetadataUpdate.AddSnapshot addSnapshot = (MetadataUpdate.AddSnapshot) appendSnapshot.get(); - assertThatThrownBy( - () -> - catalog - .loadTable(TABLE) - .io() - .newInputFile(addSnapshot.snapshot().manifestListLocation())) - .isInstanceOf(NotFoundException.class); + + assertThat(allRequests(adapter)) + .anySatisfy( + req -> { + assertThat(req.method()).isEqualTo(HTTPMethod.POST); + assertThat(req.path()).isEqualTo(RESOURCE_PATHS.table(newTable)); + assertThat(req.body()).isInstanceOf(UpdateTableRequest.class); + UpdateTableRequest body = (UpdateTableRequest) req.body(); + Optional appendSnapshot = + body.updates().stream() + .filter(update -> update instanceof MetadataUpdate.AddSnapshot) + .findFirst(); + + assertThat(appendSnapshot).isPresent(); + MetadataUpdate.AddSnapshot addSnapshot = + (MetadataUpdate.AddSnapshot) appendSnapshot.get(); + assertThatThrownBy( + () -> + catalog + .loadTable(TABLE) + .io() + .newInputFile(addSnapshot.snapshot().manifestListLocation())) + .isInstanceOf(NotFoundException.class); + }); } @Test @@ -2534,29 +2364,32 @@ public void testNoCleanupForNonCleanableCreateTransaction() { TableIdentifier newTable = TableIdentifier.of(TABLE.namespace(), "some_table"); Mockito.doThrow(new ServiceFailureException("some service failure")) .when(adapter) - .post(eq(RESOURCE_PATHS.table(newTable)), any(), any(), any(Map.class), any()); - ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateTableRequest.class); + .execute(reqMatcher(HTTPMethod.POST, RESOURCE_PATHS.table(newTable)), any(), any(), any()); + Transaction createTableTransaction = catalog.newCreateTableTransaction(newTable, SCHEMA); createTableTransaction.newAppend().appendFile(FILE_A).commit(); assertThatThrownBy(createTableTransaction::commitTransaction) .isInstanceOf(ServiceFailureException.class); - verify(adapter, atLeastOnce()) - .post(eq(RESOURCE_PATHS.table(newTable)), captor.capture(), any(), any(Map.class), any()); - UpdateTableRequest request = captor.getValue(); - Optional appendSnapshot = - request.updates().stream() - .filter(update -> update instanceof MetadataUpdate.AddSnapshot) - .findFirst(); - assertThat(appendSnapshot).isPresent(); - - MetadataUpdate.AddSnapshot addSnapshot = (MetadataUpdate.AddSnapshot) appendSnapshot.get(); - assertThat( - catalog - .loadTable(TABLE) - .io() - .newInputFile(addSnapshot.snapshot().manifestListLocation()) - .exists()) - .isTrue(); + + assertThat(allRequests(adapter)) + .anySatisfy( + req -> { + assertThat(req.method()).isEqualTo(HTTPMethod.POST); + assertThat(req.path()).isEqualTo(RESOURCE_PATHS.table(newTable)); + assertThat(req.body()).isInstanceOf(UpdateTableRequest.class); + UpdateTableRequest body = (UpdateTableRequest) req.body(); + Optional appendSnapshot = + body.updates().stream() + .filter(update -> update instanceof MetadataUpdate.AddSnapshot) + .findFirst(); + assertThat(appendSnapshot).isPresent(); + + MetadataUpdate.AddSnapshot addSnapshot = + (MetadataUpdate.AddSnapshot) appendSnapshot.get(); + String manifestListLocation = addSnapshot.snapshot().manifestListLocation(); + assertThat(catalog.loadTable(TABLE).io().newInputFile(manifestListLocation).exists()) + .isTrue(); + }); } @Test @@ -2569,32 +2402,35 @@ public void testCleanupCleanableExceptionsReplace() { } catalog.createTable(TABLE, SCHEMA); - ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateTableRequest.class); Mockito.doThrow(new NotAuthorizedException("not authorized")) .when(adapter) - .post(eq(RESOURCE_PATHS.table(TABLE)), any(), any(), any(Map.class), any()); + .execute(reqMatcher(HTTPMethod.POST, RESOURCE_PATHS.table(TABLE)), any(), any(), any()); Transaction replaceTableTransaction = catalog.newReplaceTableTransaction(TABLE, SCHEMA, false); replaceTableTransaction.newAppend().appendFile(FILE_A).commit(); assertThatThrownBy(replaceTableTransaction::commitTransaction) .isInstanceOf(NotAuthorizedException.class); - verify(adapter, atLeastOnce()) - .post(eq(RESOURCE_PATHS.table(TABLE)), captor.capture(), any(), any(Map.class), any()); - UpdateTableRequest request = captor.getValue(); - Optional appendSnapshot = - request.updates().stream() - .filter(update -> update instanceof MetadataUpdate.AddSnapshot) - .findFirst(); - - assertThat(appendSnapshot).isPresent(); - MetadataUpdate.AddSnapshot addSnapshot = (MetadataUpdate.AddSnapshot) appendSnapshot.get(); - assertThatThrownBy( - () -> - catalog - .loadTable(TABLE) - .io() - .newInputFile(addSnapshot.snapshot().manifestListLocation())) - .isInstanceOf(NotFoundException.class); + + assertThat(allRequests(adapter)) + .anySatisfy( + req -> { + assertThat(req.method()).isEqualTo(HTTPMethod.POST); + assertThat(req.path()).isEqualTo(RESOURCE_PATHS.table(TABLE)); + assertThat(req.body()).isInstanceOf(UpdateTableRequest.class); + UpdateTableRequest request = (UpdateTableRequest) req.body(); + Optional appendSnapshot = + request.updates().stream() + .filter(update -> update instanceof MetadataUpdate.AddSnapshot) + .findFirst(); + + assertThat(appendSnapshot).isPresent(); + MetadataUpdate.AddSnapshot addSnapshot = + (MetadataUpdate.AddSnapshot) appendSnapshot.get(); + String manifestListLocation = addSnapshot.snapshot().manifestListLocation(); + assertThatThrownBy( + () -> catalog.loadTable(TABLE).io().newInputFile(manifestListLocation)) + .isInstanceOf(NotFoundException.class); + }); } @Test @@ -2609,29 +2445,32 @@ public void testNoCleanupForNonCleanableReplaceTransaction() { catalog.createTable(TABLE, SCHEMA); Mockito.doThrow(new ServiceFailureException("some service failure")) .when(adapter) - .post(eq(RESOURCE_PATHS.table(TABLE)), any(), any(), any(Map.class), any()); - ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateTableRequest.class); + .execute(reqMatcher(HTTPMethod.POST, RESOURCE_PATHS.table(TABLE)), any(), any(), any()); + Transaction replaceTableTransaction = catalog.newReplaceTableTransaction(TABLE, SCHEMA, false); replaceTableTransaction.newAppend().appendFile(FILE_A).commit(); assertThatThrownBy(replaceTableTransaction::commitTransaction) .isInstanceOf(ServiceFailureException.class); - verify(adapter, atLeastOnce()) - .post(eq(RESOURCE_PATHS.table(TABLE)), captor.capture(), any(), any(Map.class), any()); - UpdateTableRequest request = captor.getValue(); - Optional appendSnapshot = - request.updates().stream() - .filter(update -> update instanceof MetadataUpdate.AddSnapshot) - .findFirst(); - assertThat(appendSnapshot).isPresent(); - - MetadataUpdate.AddSnapshot addSnapshot = (MetadataUpdate.AddSnapshot) appendSnapshot.get(); - assertThat( - catalog - .loadTable(TABLE) - .io() - .newInputFile(addSnapshot.snapshot().manifestListLocation()) - .exists()) - .isTrue(); + + assertThat(allRequests(adapter)) + .anySatisfy( + req -> { + assertThat(req.method()).isEqualTo(HTTPMethod.POST); + assertThat(req.path()).isEqualTo(RESOURCE_PATHS.table(TABLE)); + assertThat(req.body()).isInstanceOf(UpdateTableRequest.class); + UpdateTableRequest request = (UpdateTableRequest) req.body(); + Optional appendSnapshot = + request.updates().stream() + .filter(update -> update instanceof MetadataUpdate.AddSnapshot) + .findFirst(); + assertThat(appendSnapshot).isPresent(); + + MetadataUpdate.AddSnapshot addSnapshot = + (MetadataUpdate.AddSnapshot) appendSnapshot.get(); + String manifestListLocation = addSnapshot.snapshot().manifestListLocation(); + assertThat(catalog.loadTable(TABLE).io().newInputFile(manifestListLocation).exists()) + .isTrue(); + }); } @Test @@ -2645,19 +2484,13 @@ public void testNamespaceExistsViaHEADRequest() { Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", Map.of(), Map.of()), eq(ConfigResponse.class), any(), any()); Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/non-existing"), - any(), - any(), + reqMatcher(HTTPMethod.HEAD, "v1/namespaces/non-existing", Map.of(), Map.of()), any(), any(), any()); @@ -2672,4 +2505,51 @@ private RESTCatalog catalog(RESTCatalogAdapter adapter) { CatalogProperties.FILE_IO_IMPL, "org.apache.iceberg.inmemory.InMemoryFileIO")); return catalog; } + + static HTTPRequest reqMatcher(HTTPMethod method) { + return argThat(req -> req.method() == method); + } + + static HTTPRequest reqMatcher(HTTPMethod method, String path) { + return argThat(req -> req.method() == method && req.path().equals(path)); + } + + static HTTPRequest reqMatcher(HTTPMethod method, String path, Map headers) { + return argThat( + req -> + req.method() == method + && req.path().equals(path) + && req.headers().equals(HTTPHeaders.of(headers))); + } + + static HTTPRequest reqMatcher( + HTTPMethod method, String path, Map headers, Map parameters) { + return argThat( + req -> + req.method() == method + && req.path().equals(path) + && req.headers().equals(HTTPHeaders.of(headers)) + && req.queryParameters().equals(parameters)); + } + + static HTTPRequest reqMatcher( + HTTPMethod method, + String path, + Map headers, + Map parameters, + Object body) { + return argThat( + req -> + req.method() == method + && req.path().equals(path) + && req.headers().equals(HTTPHeaders.of(headers)) + && req.queryParameters().equals(parameters) + && Objects.equals(req.body(), body)); + } + + private static List allRequests(RESTCatalogAdapter adapter) { + ArgumentCaptor captor = ArgumentCaptor.forClass(HTTPRequest.class); + verify(adapter, atLeastOnce()).execute(captor.capture(), any(), any(), any()); + return captor.getAllValues(); + } } diff --git a/core/src/test/java/org/apache/iceberg/rest/TestRESTViewCatalog.java b/core/src/test/java/org/apache/iceberg/rest/TestRESTViewCatalog.java index cfc22ca767f6..8dfbf0df6dd7 100644 --- a/core/src/test/java/org/apache/iceberg/rest/TestRESTViewCatalog.java +++ b/core/src/test/java/org/apache/iceberg/rest/TestRESTViewCatalog.java @@ -18,6 +18,7 @@ */ package org.apache.iceberg.rest; +import static org.apache.iceberg.rest.TestRESTCatalog.reqMatcher; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -38,7 +39,7 @@ import org.apache.iceberg.catalog.TableIdentifier; import org.apache.iceberg.inmemory.InMemoryCatalog; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; -import org.apache.iceberg.rest.RESTCatalogAdapter.HTTPMethod; +import org.apache.iceberg.rest.HTTPRequest.HTTPMethod; import org.apache.iceberg.rest.responses.ConfigResponse; import org.apache.iceberg.rest.responses.ErrorResponse; import org.apache.iceberg.rest.responses.ListTablesResponse; @@ -82,17 +83,13 @@ public void createCatalog() throws Exception { new RESTCatalogAdapter(backendCatalog) { @Override public T execute( - HTTPMethod method, - String path, - Map queryParams, - Object body, + HTTPRequest request, Class responseType, - Map headers, - Consumer errorHandler) { - Object request = roundTripSerialize(body, "request"); - T response = - super.execute( - method, path, queryParams, request, responseType, headers, errorHandler); + Consumer errorHandler, + Consumer> responseHeaders) { + Object body = roundTripSerialize(request.body(), "request"); + HTTPRequest req = ImmutableHTTPRequest.builder().from(request).body(body).build(); + T response = super.execute(req, responseType, errorHandler, responseHeaders); T responseAfterSerialization = roundTripSerialize(response, "response"); return responseAfterSerialization; } @@ -187,21 +184,11 @@ public void testPaginationForListViews(int numberOfItems) { assertThat(views).hasSize(numberOfItems); Mockito.verify(adapter) - .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), - eq(ConfigResponse.class), - any(), - any()); + .execute(reqMatcher(HTTPMethod.GET, "v1/config"), eq(ConfigResponse.class), any(), any()); Mockito.verify(adapter, times(numberOfItems)) .execute( - eq(HTTPMethod.POST), - eq(String.format("v1/namespaces/%s/views", namespaceName)), - any(), - any(), + reqMatcher(HTTPMethod.POST, String.format("v1/namespaces/%s/views", namespaceName)), eq(LoadViewResponse.class), any(), any()); @@ -244,19 +231,13 @@ public void viewExistsViaHEADRequest() { Mockito.verify(adapter) .execute( - eq(HTTPMethod.GET), - eq("v1/config"), - any(), - any(), + reqMatcher(HTTPMethod.GET, "v1/config", Map.of(), Map.of()), eq(ConfigResponse.class), any(), any()); Mockito.verify(adapter) .execute( - eq(HTTPMethod.HEAD), - eq("v1/namespaces/ns/views/view"), - any(), - any(), + reqMatcher(HTTPMethod.HEAD, "v1/namespaces/ns/views/view", Map.of(), Map.of()), any(), any(), any()); From a513a59fcc09129f2edc1f916f3d1d1f51ad1165 Mon Sep 17 00:00:00 2001 From: Alexandre Dutra Date: Mon, 20 Jan 2025 16:46:57 +0100 Subject: [PATCH 2/2] review --- .../apache/iceberg/rest/TestRESTCatalog.java | 27 ++++++++++++------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/core/src/test/java/org/apache/iceberg/rest/TestRESTCatalog.java b/core/src/test/java/org/apache/iceberg/rest/TestRESTCatalog.java index 43a94c401eb4..6a5f22075c6e 100644 --- a/core/src/test/java/org/apache/iceberg/rest/TestRESTCatalog.java +++ b/core/src/test/java/org/apache/iceberg/rest/TestRESTCatalog.java @@ -351,6 +351,7 @@ public void testCatalogBasicBearerToken() { @Test public void testCatalogCredentialNoOauth2ServerUri() { + Map emptyHeaders = ImmutableMap.of(); Map catalogHeaders = ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=catalog"); @@ -366,7 +367,7 @@ public void testCatalogCredentialNoOauth2ServerUri() { // no token or credential for catalog token exchange Mockito.verify(adapter) .execute( - reqMatcher(HTTPMethod.POST, "v1/oauth/tokens", Map.of()), + reqMatcher(HTTPMethod.POST, "v1/oauth/tokens", emptyHeaders), eq(OAuthTokenResponse.class), any(), any()); @@ -389,6 +390,7 @@ public void testCatalogCredentialNoOauth2ServerUri() { @ParameterizedTest @ValueSource(strings = {"v1/oauth/tokens", "https://auth-server.com/token"}) public void testCatalogCredential(String oauth2ServerUri) { + Map emptyHeaders = ImmutableMap.of(); Map catalogHeaders = ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=catalog"); @@ -411,7 +413,7 @@ public void testCatalogCredential(String oauth2ServerUri) { // no token or credential for catalog token exchange Mockito.verify(adapter) .execute( - reqMatcher(HTTPMethod.POST, oauth2ServerUri, Map.of()), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, emptyHeaders), eq(OAuthTokenResponse.class), any(), any()); @@ -486,6 +488,7 @@ public void testCatalogBearerTokenWithClientCredential(String oauth2ServerUri) { @ParameterizedTest @ValueSource(strings = {"v1/oauth/tokens", "https://auth-server.com/token"}) public void testCatalogCredentialWithClientCredential(String oauth2ServerUri) { + Map emptyHeaders = ImmutableMap.of(); Map contextHeaders = ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=user"); Map catalogHeaders = @@ -516,7 +519,7 @@ public void testCatalogCredentialWithClientCredential(String oauth2ServerUri) { // call client credentials with no initial auth Mockito.verify(adapter) .execute( - reqMatcher(HTTPMethod.POST, oauth2ServerUri, Map.of()), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, emptyHeaders), eq(OAuthTokenResponse.class), any(), any()); @@ -1283,6 +1286,7 @@ public void testTableAuth( @ParameterizedTest @ValueSource(strings = {"v1/oauth/tokens", "https://auth-server.com/token"}) public void testCatalogTokenRefresh(String oauth2ServerUri) { + Map emptyHeaders = ImmutableMap.of(); Map catalogHeaders = ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=catalog"); @@ -1327,7 +1331,7 @@ public void testCatalogTokenRefresh(String oauth2ServerUri) { // call client credentials with no initial auth Mockito.verify(adapter) .execute( - reqMatcher(HTTPMethod.POST, oauth2ServerUri, Map.of()), + reqMatcher(HTTPMethod.POST, oauth2ServerUri, emptyHeaders), eq(OAuthTokenResponse.class), any(), any()); @@ -1388,6 +1392,7 @@ public void testCatalogTokenRefresh(String oauth2ServerUri) { @ParameterizedTest @ValueSource(strings = {"v1/oauth/tokens", "https://auth-server.com/token"}) public void testCatalogRefreshedTokenIsUsed(String oauth2ServerUri) { + Map emptyHeaders = ImmutableMap.of(); Map catalogHeaders = ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=catalog"); @@ -1444,7 +1449,7 @@ public void testCatalogRefreshedTokenIsUsed(String oauth2ServerUri) { reqMatcher( HTTPMethod.POST, oauth2ServerUri, - Map.of(), + emptyHeaders, Map.of(), clientCredentialsRequest), eq(OAuthTokenResponse.class), @@ -1517,6 +1522,7 @@ public void testCatalogExpiredBearerTokenIsRefreshedWithCredential(String oauth2 // expires at epoch second = 1 String token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyLCJleHAiOjF9.gQADTbdEv-rpDWKSkGLbmafyB5UUjTdm9B_1izpuZ6E"; + Map emptyHeaders = ImmutableMap.of(); Map catalogHeaders = ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=catalog"); @@ -1552,7 +1558,7 @@ public void testCatalogExpiredBearerTokenIsRefreshedWithCredential(String oauth2 Mockito.verify(adapter) .execute( reqMatcher( - HTTPMethod.POST, oauth2ServerUri, Map.of(), Map.of(), clientCredentialsRequest), + HTTPMethod.POST, oauth2ServerUri, emptyHeaders, Map.of(), clientCredentialsRequest), eq(OAuthTokenResponse.class), any(), any()); @@ -1652,6 +1658,7 @@ public void testCatalogValidBearerTokenIsNotRefreshed() { @ParameterizedTest @ValueSource(strings = {"v1/oauth/tokens", "https://auth-server.com/token"}) public void testCatalogTokenRefreshFailsAndUsesCredentialForRefresh(String oauth2ServerUri) { + Map emptyHeaders = ImmutableMap.of(); Map catalogHeaders = ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=catalog"); @@ -1731,7 +1738,7 @@ public void testCatalogTokenRefreshFailsAndUsesCredentialForRefresh(String oauth reqMatcher( HTTPMethod.POST, oauth2ServerUri, - Map.of(), + emptyHeaders, Map.of(), clientCredentialsRequest), eq(OAuthTokenResponse.class), @@ -1755,6 +1762,7 @@ public void testCatalogTokenRefreshFailsAndUsesCredentialForRefresh(String oauth eq(OAuthTokenResponse.class), eq(catalogHeaders), any()); + // here we make sure that the basic auth header is used after token refresh retries // failed Mockito.verify(adapter) @@ -1783,6 +1791,7 @@ public void testCatalogTokenRefreshFailsAndUsesCredentialForRefresh(String oauth @ParameterizedTest @ValueSource(strings = {"v1/oauth/tokens", "https://auth-server.com/token"}) public void testCatalogWithCustomTokenScope(String oauth2ServerUri) { + Map emptyHeaders = ImmutableMap.of(); Map catalogHeaders = ImmutableMap.of("Authorization", "Bearer client-credentials-token:sub=catalog"); @@ -1833,7 +1842,7 @@ public void testCatalogWithCustomTokenScope(String oauth2ServerUri) { eq(oauth2ServerUri), anyMap(), eq(OAuthTokenResponse.class), - eq(Map.of()), + eq(emptyHeaders), any()); // use the client credential token for config @@ -1952,7 +1961,7 @@ public void testCatalogTokenRefreshDisabledWithCredential(String oauth2ServerUri eq(oauth2ServerUri), argThat(fetchTokenFromCredential::equals), eq(OAuthTokenResponse.class), - anyMap(), + eq(ImmutableMap.of()), any()); Mockito.verify(adapter)