diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java index 4071c676adb0..a6886e439450 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java @@ -36,9 +36,12 @@ import com.google.common.base.Supplier; import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.errorprone.annotations.CanIgnoreReturnValue; import java.io.IOException; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.logging.Level; import java.util.logging.Logger; @@ -63,6 +66,7 @@ public class VertexAI implements AutoCloseable { private final String location; private final String apiEndpoint; private final Transport transport; + private final HeaderProvider headerProvider; private final CredentialsProvider credentialsProvider; private final transient Supplier predictionClientSupplier; @@ -85,6 +89,7 @@ public VertexAI(String projectId, String location) { location, Transport.GRPC, ImmutableList.of(), + /* customHeaders= */ ImmutableMap.of(), /* credentials= */ Optional.empty(), /* apiEndpoint= */ Optional.empty(), /* predictionClientSupplierOpt= */ Optional.empty(), @@ -108,6 +113,7 @@ public VertexAI() { null, Transport.GRPC, ImmutableList.of(), + /* customHeaders= */ ImmutableMap.of(), /* credentials= */ Optional.empty(), /* apiEndpoint= */ Optional.empty(), /* predictionClientSupplierOpt= */ Optional.empty(), @@ -119,6 +125,7 @@ private VertexAI( String location, Transport transport, List scopes, + Map customHeaders, Optional credentials, Optional apiEndpoint, Optional> predictionClientSupplierOpt, @@ -131,6 +138,15 @@ private VertexAI( this.location = Strings.isNullOrEmpty(location) ? inferLocation() : location; this.transport = transport; + String sdkHeader = + String.format( + "%s/%s", + Constants.USER_AGENT_HEADER, + GaxProperties.getLibraryVersion(PredictionServiceSettings.class)); + Map headers = new HashMap<>(customHeaders); + headers.compute("user-agent", (k, v) -> v == null ? sdkHeader : sdkHeader + " " + v); + this.headerProvider = FixedHeaderProvider.create(headers); + if (credentials.isPresent()) { this.credentialsProvider = FixedCredentialsProvider.create(credentials.get()); } else { @@ -160,6 +176,7 @@ public static class Builder { private String location; private Transport transport = Transport.GRPC; private ImmutableList scopes = ImmutableList.of(); + private ImmutableMap customHeaders = ImmutableMap.of(); private Optional credentials = Optional.empty(); private Optional apiEndpoint = Optional.empty(); @@ -174,6 +191,7 @@ public VertexAI build() { location, transport, scopes, + customHeaders, credentials, apiEndpoint, Optional.ofNullable(predictionClientSupplier), @@ -240,6 +258,14 @@ public Builder setScopes(List scopes) { this.scopes = ImmutableList.copyOf(scopes); return this; } + + @CanIgnoreReturnValue + public Builder setCustomHeaders(Map customHeaders) { + checkNotNull(customHeaders, "customHeaders can't be null"); + + this.customHeaders = ImmutableMap.copyOf(customHeaders); + return this; + } } /** @@ -278,6 +304,15 @@ public String getApiEndpoint() { return apiEndpoint; } + /** + * Returns the headers to use when making API calls. + * + * @return a map of headers to use when making API calls. + */ + public Map getHeaders() { + return headerProvider.getHeaders(); + } + /** * Returns the default credentials to use when making API calls. * diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/VertexAITest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/VertexAITest.java index 4d7bfff73fc1..7baddcf7b088 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/VertexAITest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/VertexAITest.java @@ -22,12 +22,15 @@ import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.when; +import com.google.api.gax.core.GaxProperties; import com.google.api.gax.core.GoogleCredentialsProvider; import com.google.auth.oauth2.GoogleCredentials; import com.google.cloud.vertexai.api.PredictionServiceClient; import com.google.cloud.vertexai.api.PredictionServiceSettings; import com.google.common.collect.ImmutableList; import java.io.IOException; +import java.util.HashMap; +import java.util.Map; import java.util.Optional; import org.junit.Rule; import org.junit.Test; @@ -397,4 +400,59 @@ public void testInstantiateVertexAI_builderWithTransport_shouldContainRightField assertThat(vertexAi.getTransport()).isEqualTo(Transport.REST); assertThat(vertexAi.getApiEndpoint()).isEqualTo(TEST_DEFAULT_ENDPOINT); } + + @Test + public void testInstantiateVertexAI_builderWithCustomHeaders_shouldContainRightFields() + throws IOException { + Map customHeaders = new HashMap<>(); + customHeaders.put("test_key", "test_value"); + + vertexAi = + new VertexAI.Builder() + .setProjectId(TEST_PROJECT) + .setLocation(TEST_LOCATION) + .setCustomHeaders(customHeaders) + .build(); + + assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT); + assertThat(vertexAi.getLocation()).isEqualTo(TEST_LOCATION); + // headers should include both the sdk header and the custom headers + Map expectedHeaders = new HashMap<>(customHeaders); + expectedHeaders.put( + "user-agent", + String.format( + "%s/%s", + Constants.USER_AGENT_HEADER, + GaxProperties.getLibraryVersion(PredictionServiceSettings.class))); + assertThat(vertexAi.getHeaders()).isEqualTo(expectedHeaders); + } + + @Test + public void + testInstantiateVertexAI_builderWithCustomHeadersWithSdkReservedKey_shouldContainRightFields() + throws IOException { + Map customHeadersWithSdkReservedKey = new HashMap<>(); + customHeadersWithSdkReservedKey.put("user-agent", "test_value"); + + vertexAi = + new VertexAI.Builder() + .setProjectId(TEST_PROJECT) + .setLocation(TEST_LOCATION) + .setCustomHeaders(customHeadersWithSdkReservedKey) + .build(); + + assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT); + assertThat(vertexAi.getLocation()).isEqualTo(TEST_LOCATION); + // headers should include sdk reserved key with value of both the sdk header and the custom + // headers + Map expectedHeaders = new HashMap<>(); + expectedHeaders.put( + "user-agent", + String.format( + "%s/%s %s", + Constants.USER_AGENT_HEADER, + GaxProperties.getLibraryVersion(PredictionServiceSettings.class), + "test_value")); + assertThat(vertexAi.getHeaders()).isEqualTo(expectedHeaders); + } }