diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientBase.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientBase.java index d13931767ef2f..38ded579a9ef8 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientBase.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientBase.java @@ -22,6 +22,7 @@ import com.azure.core.util.ClientOptions; import com.azure.core.util.Configuration; import com.azure.core.util.CoreUtils; +import com.azure.core.util.SharedExecutorService; import com.azure.core.util.UserAgentUtil; import com.azure.core.util.builder.ClientBuilderUtil; import com.azure.core.util.logging.ClientLogger; @@ -280,6 +281,8 @@ ConfidentialClientApplication getConfidentialClient(boolean enableCae) { if (options.getExecutorService() != null) { applicationBuilder.executorService(options.getExecutorService()); + } else { + applicationBuilder.executorService(SharedExecutorService.getInstance()); } TokenCachePersistenceOptions tokenCachePersistenceOptions = options.getTokenCacheOptions(); @@ -341,6 +344,8 @@ PublicClientApplication getPublicClient(boolean sharedTokenCacheCredential, bool if (options.getExecutorService() != null) { builder.executorService(options.getExecutorService()); + } else { + builder.executorService(SharedExecutorService.getInstance()); } if (enableCae) { @@ -457,6 +462,8 @@ ConfidentialClientApplication getManagedIdentityConfidentialClient() { if (options.getExecutorService() != null) { applicationBuilder.executorService(options.getExecutorService()); + } else { + applicationBuilder.executorService(SharedExecutorService.getInstance()); } return applicationBuilder.build(); @@ -495,6 +502,8 @@ ManagedIdentityApplication getManagedIdentityMsalApplication() { if (options.getExecutorService() != null) { miBuilder.executorService(options.getExecutorService()); + } else { + miBuilder.executorService(SharedExecutorService.getInstance()); } return miBuilder.build(); @@ -537,6 +546,8 @@ ConfidentialClientApplication getWorkloadIdentityConfidentialClient() { if (options.getExecutorService() != null) { applicationBuilder.executorService(options.getExecutorService()); + } else { + applicationBuilder.executorService(SharedExecutorService.getInstance()); } return applicationBuilder.build(); diff --git a/sdk/monitor/azure-monitor-ingestion/src/main/java/com/azure/monitor/ingestion/LogsIngestionClient.java b/sdk/monitor/azure-monitor-ingestion/src/main/java/com/azure/monitor/ingestion/LogsIngestionClient.java index 25f56ec715845..1bc0d05cc07cb 100644 --- a/sdk/monitor/azure-monitor-ingestion/src/main/java/com/azure/monitor/ingestion/LogsIngestionClient.java +++ b/sdk/monitor/azure-monitor-ingestion/src/main/java/com/azure/monitor/ingestion/LogsIngestionClient.java @@ -17,6 +17,7 @@ import com.azure.core.http.rest.Response; import com.azure.core.util.BinaryData; import com.azure.core.util.Context; +import com.azure.core.util.SharedExecutorService; import com.azure.core.util.logging.ClientLogger; import com.azure.monitor.ingestion.implementation.Batcher; import com.azure.monitor.ingestion.implementation.IngestionUsingDataCollectionRulesClient; @@ -29,16 +30,13 @@ import java.util.List; import java.util.Objects; import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; import java.util.function.Consumer; import java.util.stream.Collectors; import java.util.stream.Stream; import static com.azure.monitor.ingestion.implementation.Utils.GZIP; -import static com.azure.monitor.ingestion.implementation.Utils.createThreadPool; import static com.azure.monitor.ingestion.implementation.Utils.getConcurrency; import static com.azure.monitor.ingestion.implementation.Utils.gzipRequest; -import static com.azure.monitor.ingestion.implementation.Utils.registerShutdownHook; /** *

This class provides a synchronous client for uploading custom logs to an Azure Monitor Log Analytics workspace. @@ -99,10 +97,6 @@ public final class LogsIngestionClient implements AutoCloseable { private static final ClientLogger LOGGER = new ClientLogger(LogsIngestionClient.class); private final IngestionUsingDataCollectionRulesClient client; - // dynamic thread pool that scales up and down on demand. - private final ExecutorService threadPool; - private final Thread shutdownHook; - /** * Creates a {@link LogsIngestionClient} that sends requests to the data collection endpoint. * @@ -110,8 +104,6 @@ public final class LogsIngestionClient implements AutoCloseable { */ LogsIngestionClient(IngestionUsingDataCollectionRulesClient client) { this.client = client; - this.threadPool = createThreadPool(); - this.shutdownHook = registerShutdownHook(this.threadPool, 5); } /** @@ -251,7 +243,7 @@ private Stream submit(Stream } try { - return threadPool.submit(() -> responseStream).get(); + return SharedExecutorService.getInstance().submit(() -> responseStream).get(); } catch (InterruptedException | ExecutionException e) { throw LOGGER.logExceptionAsError(new RuntimeException(e)); } @@ -335,7 +327,5 @@ public Response uploadWithResponse(String ruleId, String streamName, Binar @Override public void close() { - threadPool.shutdown(); - Runtime.getRuntime().removeShutdownHook(shutdownHook); } } diff --git a/sdk/monitor/azure-monitor-ingestion/src/main/java/com/azure/monitor/ingestion/implementation/Utils.java b/sdk/monitor/azure-monitor-ingestion/src/main/java/com/azure/monitor/ingestion/implementation/Utils.java index 54f11b7985baa..4049568e9ab19 100644 --- a/sdk/monitor/azure-monitor-ingestion/src/main/java/com/azure/monitor/ingestion/implementation/Utils.java +++ b/sdk/monitor/azure-monitor-ingestion/src/main/java/com/azure/monitor/ingestion/implementation/Utils.java @@ -9,10 +9,6 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.UncheckedIOException; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.SynchronousQueue; -import java.util.concurrent.ThreadPoolExecutor; -import java.util.concurrent.TimeUnit; import java.util.zip.GZIPOutputStream; public final class Utils { @@ -20,8 +16,6 @@ public final class Utils { public static final String GZIP = "gzip"; private static final ClientLogger LOGGER = new ClientLogger(Utils.class); - // similarly to Schedulers.DEFAULT_BOUNDED_ELASTIC_SIZE, just puts a limit depending on logical processors count. - private static final int MAX_CONCURRENCY = 10 * Runtime.getRuntime().availableProcessors(); private Utils() { } @@ -50,41 +44,4 @@ public static int getConcurrency(LogsUploadOptions options) { return 1; } - - /** - * Creates cached (that supports scaling) thread pool with shutdown hook to do best-effort graceful termination within timeout. - * - * @return {@link ExecutorService} instance. - */ - public static ExecutorService createThreadPool() { - return new ThreadPoolExecutor(0, MAX_CONCURRENCY, 60L, TimeUnit.SECONDS, new SynchronousQueue<>()); - } - - /** - * Registers {@link ExecutorService} shutdown hook which will be called when JVM terminates. - * First, stops accepting new tasks, then awaits their completion for - * half of timeout, cancels remaining tasks and waits another half of timeout for them to get cancelled. - * - * @param threadPool Thread pool to shut down. - * @param timeoutSec Timeout in seconds to wait for tasks to complete or terminate after JVM starting to shut down. - * @return hook thread instance that can be used to unregister hook. - */ - public static Thread registerShutdownHook(ExecutorService threadPool, int timeoutSec) { - // based on https://docs.oracle.com/javase/7/docs/api/java/util/concurrent/ExecutorService.html - long halfTimeoutNanos = TimeUnit.SECONDS.toNanos(timeoutSec) / 2; - Thread hook = new Thread(() -> { - try { - threadPool.shutdown(); - if (!threadPool.awaitTermination(halfTimeoutNanos, TimeUnit.NANOSECONDS)) { - threadPool.shutdownNow(); - threadPool.awaitTermination(halfTimeoutNanos, TimeUnit.NANOSECONDS); - } - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - threadPool.shutdownNow(); - } - }); - Runtime.getRuntime().addShutdownHook(hook); - return hook; - } } diff --git a/sdk/monitor/azure-monitor-ingestion/src/test/java/com/azure/monitor/ingestion/implementation/ConcurrencyLimitingSpliteratorTest.java b/sdk/monitor/azure-monitor-ingestion/src/test/java/com/azure/monitor/ingestion/implementation/ConcurrencyLimitingSpliteratorTest.java index ea478eed406e0..1a6eb0dc276e7 100644 --- a/sdk/monitor/azure-monitor-ingestion/src/test/java/com/azure/monitor/ingestion/implementation/ConcurrencyLimitingSpliteratorTest.java +++ b/sdk/monitor/azure-monitor-ingestion/src/test/java/com/azure/monitor/ingestion/implementation/ConcurrencyLimitingSpliteratorTest.java @@ -3,6 +3,7 @@ package com.azure.monitor.ingestion.implementation; +import com.azure.core.util.SharedExecutorService; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.parallel.Execution; import org.junit.jupiter.api.parallel.ExecutionMode; @@ -13,8 +14,6 @@ import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; @@ -31,7 +30,6 @@ @Execution(ExecutionMode.SAME_THREAD) public class ConcurrencyLimitingSpliteratorTest { private static final int TEST_TIMEOUT_SEC = 30; - private static final ExecutorService TEST_THREAD_POOL = Executors.newCachedThreadPool(); @Test public void invalidParams() { @@ -53,7 +51,7 @@ public void concurrentCalls(int concurrency) throws ExecutionException, Interrup int effectiveConcurrency = Math.min(list.size(), concurrency); CountDownLatch latch = new CountDownLatch(effectiveConcurrency); - List processed = TEST_THREAD_POOL.submit(() -> stream.map(r -> { + List processed = SharedExecutorService.getInstance().submit(() -> stream.map(r -> { latch.countDown(); try { Thread.sleep(10); @@ -78,7 +76,7 @@ public void concurrencyHigherThanItemsCount() throws ExecutionException, Interru AtomicInteger parallel = new AtomicInteger(0); AtomicInteger maxParallel = new AtomicInteger(0); - List processed = TEST_THREAD_POOL.submit(() -> stream.map(r -> { + List processed = SharedExecutorService.getInstance().submit(() -> stream.map(r -> { int cur = parallel.incrementAndGet(); int curMax = maxParallel.get(); while (cur > curMax && !maxParallel.compareAndSet(curMax, cur)) { diff --git a/sdk/monitor/azure-monitor-ingestion/src/test/java/com/azure/monitor/ingestion/implementation/UtilsTest.java b/sdk/monitor/azure-monitor-ingestion/src/test/java/com/azure/monitor/ingestion/implementation/UtilsTest.java deleted file mode 100644 index 41e1b80a36daf..0000000000000 --- a/sdk/monitor/azure-monitor-ingestion/src/test/java/com/azure/monitor/ingestion/implementation/UtilsTest.java +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package com.azure.monitor.ingestion.implementation; - -import org.junit.jupiter.api.Test; - -import java.util.List; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.Future; -import java.util.concurrent.RejectedExecutionException; -import java.util.stream.Collectors; -import java.util.stream.IntStream; -import java.util.stream.Stream; - -import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; - -public class UtilsTest { - @Test - public void shutdownHookTerminatesPool() throws InterruptedException, ExecutionException { - int timeoutSec = 2; - ExecutorService threadPool = Executors.newFixedThreadPool(1); - Thread hook = Utils.registerShutdownHook(threadPool, timeoutSec); - - Stream stream = IntStream.of(100, 4000).boxed().parallel().map(this::task); - - Future> tasks = threadPool.submit(() -> stream.collect(Collectors.toList())); - - hook.run(); - - assertTrue(threadPool.isShutdown()); - assertTrue(threadPool.isTerminated()); - assertArrayEquals(new Integer[] { 100, -1 }, tasks.get().toArray()); - - assertThrows(RejectedExecutionException.class, - () -> threadPool.submit(() -> stream.collect(Collectors.toList()))); - } - - @Test - public void shutdownHookRegistered() { - int timeoutSec = 2; - ExecutorService threadPool = Executors.newFixedThreadPool(1); - Thread hook = Utils.registerShutdownHook(threadPool, timeoutSec); - assertTrue(Runtime.getRuntime().removeShutdownHook(hook)); - } - - private int task(int sleepMs) { - try { - Thread.sleep(sleepMs); - return sleepMs; - } catch (InterruptedException e) { - return -1; - } - } -} diff --git a/sdk/search/azure-search-documents/src/main/java/com/azure/search/documents/implementation/batching/SearchIndexingPublisher.java b/sdk/search/azure-search-documents/src/main/java/com/azure/search/documents/implementation/batching/SearchIndexingPublisher.java index 72df7ab125f32..df837f5e10c54 100644 --- a/sdk/search/azure-search-documents/src/main/java/com/azure/search/documents/implementation/batching/SearchIndexingPublisher.java +++ b/sdk/search/azure-search-documents/src/main/java/com/azure/search/documents/implementation/batching/SearchIndexingPublisher.java @@ -7,6 +7,7 @@ import com.azure.core.http.rest.Response; import com.azure.core.util.Context; import com.azure.core.util.CoreUtils; +import com.azure.core.util.SharedExecutorService; import com.azure.core.util.logging.ClientLogger; import com.azure.core.util.serializer.JsonSerializer; import com.azure.search.documents.implementation.SearchIndexClientImpl; @@ -29,8 +30,6 @@ import java.util.List; import java.util.Objects; import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; @@ -54,7 +53,6 @@ */ public final class SearchIndexingPublisher { private static final ClientLogger LOGGER = new ClientLogger(SearchIndexingPublisher.class); - private static final ExecutorService EXECUTOR = getThreadPoolWithShutdownHook(); private final SearchIndexClientImpl restClient; private final JsonSerializer serializer; @@ -154,7 +152,8 @@ public void flush(boolean awaitLock, boolean isClose, Duration timeout, Context private void flushLoop(boolean isClosed, Duration timeout, Context context) { if (timeout != null && !timeout.isNegative() && !timeout.isZero()) { final AtomicReference>> batchActions = new AtomicReference<>(); - Future future = EXECUTOR.submit(() -> flushLoopHelper(isClosed, context, batchActions)); + Future future = SharedExecutorService.getInstance() + .submit(() -> flushLoopHelper(isClosed, context, batchActions)); try { CoreUtils.getResultWithTimeout(future, timeout); @@ -361,8 +360,4 @@ private static void sleep(long millis) { } catch (InterruptedException ignored) { } } - - private static ExecutorService getThreadPoolWithShutdownHook() { - return CoreUtils.addShutdownHookSafely(Executors.newCachedThreadPool(), Duration.ofSeconds(5)); - } } diff --git a/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/specialized/HttpFaultInjectingTests.java b/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/specialized/HttpFaultInjectingTests.java index fd4f775729127..9a83cfa09f0c8 100644 --- a/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/specialized/HttpFaultInjectingTests.java +++ b/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/specialized/HttpFaultInjectingTests.java @@ -16,6 +16,7 @@ import com.azure.core.util.Context; import com.azure.core.util.CoreUtils; import com.azure.core.util.HttpClientOptions; +import com.azure.core.util.SharedExecutorService; import com.azure.core.util.UrlBuilder; import com.azure.core.util.logging.ClientLogger; import com.azure.storage.blob.BlobClient; @@ -49,8 +50,7 @@ import java.util.Locale; import java.util.Set; import java.util.concurrent.Callable; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -130,8 +130,8 @@ public void downloadToFileWithFaultInjection() throws IOException, InterruptedEx StandardOpenOption.TRUNCATE_EXISTING, // If the file already exists and it is opened for WRITE access, then its length is truncated to 0. StandardOpenOption.READ, StandardOpenOption.WRITE)); - ExecutorService executorService = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors()); - executorService.invokeAll(files.stream().map(it -> (Callable) () -> { + CountDownLatch countDownLatch = new CountDownLatch(500); + SharedExecutorService.getInstance().invokeAll(files.stream().map(it -> (Callable) () -> { try { downloadClient.downloadToFileWithResponse(new BlobDownloadToFileOptions(it.getAbsolutePath()) .setOpenOptions(overwriteOptions) @@ -148,13 +148,14 @@ public void downloadToFileWithFaultInjection() throws IOException, InterruptedEx LOGGER.atWarning() .addKeyValue("downloadFile", it.getAbsolutePath()) .log("Failed to complete download.", ex); + } finally { + countDownLatch.countDown(); } return null; }).collect(Collectors.toList())); - executorService.shutdown(); - executorService.awaitTermination(10, TimeUnit.MINUTES); + countDownLatch.await(10, TimeUnit.MINUTES); assertTrue(successCount.get() >= 450); // cleanup diff --git a/sdk/tables/azure-data-tables/src/main/java/com/azure/data/tables/TableClient.java b/sdk/tables/azure-data-tables/src/main/java/com/azure/data/tables/TableClient.java index d310af33af2ad..1ea280a59ac6d 100644 --- a/sdk/tables/azure-data-tables/src/main/java/com/azure/data/tables/TableClient.java +++ b/sdk/tables/azure-data-tables/src/main/java/com/azure/data/tables/TableClient.java @@ -65,17 +65,15 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeoutException; import java.util.function.BiConsumer; import java.util.function.Supplier; import java.util.stream.Collectors; -import static com.azure.core.util.CoreUtils.getResultWithTimeout; import static com.azure.data.tables.implementation.TableUtils.callIterableWithOptionalTimeout; import static com.azure.data.tables.implementation.TableUtils.callWithOptionalTimeout; -import static com.azure.data.tables.implementation.TableUtils.hasTimeout; import static com.azure.data.tables.implementation.TableUtils.mapThrowableToTableServiceException; +import static com.azure.data.tables.implementation.TableUtils.requestWithOptionalTimeout; import static com.azure.data.tables.implementation.TableUtils.toTableServiceError; /** @@ -289,8 +287,6 @@ */ @ServiceClient(builder = TableClientBuilder.class) public final class TableClient { - - private static final ExecutorService THREAD_POOL = TableUtils.getThreadPoolWithShutdownHook(); private final ClientLogger logger = new ClientLogger(TableClient.class); private final String tableName; private final AzureTableImpl tablesImplementation; @@ -473,7 +469,7 @@ public Response createTableWithResponse(Duration timeout, Context con ResponseFormat.RETURN_NO_CONTENT, null, context), TableItemAccessHelper.createItem(new TableResponseProperties().setTableName(tableName))); - return callWithOptionalTimeout(callable, THREAD_POOL, timeout, logger); + return callWithOptionalTimeout(callable, timeout, logger); } /** @@ -524,8 +520,7 @@ public Response deleteTableWithResponse(Duration timeout, Context context) .deleteWithResponse(tableName, null, context), null); try { - return hasTimeout(timeout) - ? getResultWithTimeout(THREAD_POOL.submit(callable::get), timeout) : callable.get(); + return requestWithOptionalTimeout(callable, timeout); } catch (InterruptedException | ExecutionException | TimeoutException ex) { throw logger.logExceptionAsError(new RuntimeException(ex)); } catch (RuntimeException ex) { @@ -620,7 +615,7 @@ public Response createEntityWithResponse(TableEntity entity, Duration time return new SimpleResponse<>(response.getRequest(), response.getStatusCode(), response.getHeaders(), null); }; - return callWithOptionalTimeout(callable, THREAD_POOL, timeout, logger); + return callWithOptionalTimeout(callable, timeout, logger); } /** @@ -715,7 +710,7 @@ public Response upsertEntityWithResponse(TableEntity entity, TableEntityUp } }; - return callWithOptionalTimeout(callable, THREAD_POOL, timeout, logger); + return callWithOptionalTimeout(callable, timeout, logger); } /** @@ -855,7 +850,7 @@ public Response updateEntityWithResponse(TableEntity entity, TableEntityUp } }; - return callWithOptionalTimeout(callable, THREAD_POOL, timeout, logger); + return callWithOptionalTimeout(callable, timeout, logger); } /** @@ -963,8 +958,7 @@ private Response deleteEntityWithResponse(String partitionKey, String rowK null, null, null, context); try { - return hasTimeout(timeout) - ? getResultWithTimeout(THREAD_POOL.submit(callable::get), timeout) : callable.get(); + return requestWithOptionalTimeout(callable, timeout); } catch (InterruptedException | ExecutionException | TimeoutException ex) { throw logger.logExceptionAsError(new RuntimeException(ex)); } catch (RuntimeException ex) { @@ -1051,7 +1045,7 @@ public PagedIterable listEntities(ListEntitiesOptions options, Dura () -> listEntitiesFirstPage(context, options, TableEntity.class), token -> listEntitiesNextPage(token, context, options, TableEntity.class)); - return callIterableWithOptionalTimeout(callable, THREAD_POOL, timeout, logger); + return callIterableWithOptionalTimeout(callable, timeout, logger); } private PagedResponse listEntitiesFirstPage(Context context, ListEntitiesOptions options, @@ -1220,7 +1214,7 @@ public Response getEntityWithResponse(String partitionKey, String r EntityHelper.convertToSubclass(entity, TableEntity.class, logger)); }; - return callWithOptionalTimeout(callable, THREAD_POOL, timeout, logger); + return callWithOptionalTimeout(callable, timeout, logger); } /** @@ -1295,7 +1289,7 @@ public Response getAccessPoliciesWithResponse(Duration time new TableAccessPolicies(response.getValue() == null ? null : response.getValue().items())); }; - return callWithOptionalTimeout(callable, THREAD_POOL, timeout, logger); + return callWithOptionalTimeout(callable, timeout, logger); } @@ -1405,7 +1399,7 @@ public Response setAccessPoliciesWithResponse(List return new SimpleResponse<>(response, response.getValue()); }; - return callWithOptionalTimeout(callable, THREAD_POOL, timeout, logger); + return callWithOptionalTimeout(callable, timeout, logger); } @@ -1655,8 +1649,7 @@ public Response submitTransactionWithResponse(List createTableWithResponse(String tableName, Duration timeout, Context context) { Supplier> callable = () -> createTableWithResponse(tableName, context); - return callWithOptionalTimeout(callable, THREAD_POOL, timeout, logger); + return callWithOptionalTimeout(callable, timeout, logger); } Response createTableWithResponse(String tableName, Context context) { @@ -477,7 +473,7 @@ public TableClient createTableIfNotExists(String tableName) { public Response createTableIfNotExistsWithResponse(String tableName, Duration timeout, Context context) { Supplier> callable = () -> createTableIfNotExistsWithResponse(tableName, context); - Response returnedResponse = callWithOptionalTimeout(callable, THREAD_POOL, timeout, logger, true); + Response returnedResponse = callWithOptionalTimeout(callable, timeout, logger, true); return returnedResponse.getValue() == null ? new SimpleResponse<>(returnedResponse.getRequest(), returnedResponse.getStatusCode(), returnedResponse.getHeaders(), getTableClient(tableName)) : returnedResponse; } @@ -542,8 +538,7 @@ public void deleteTable(String tableName) { public Response deleteTableWithResponse(String tableName, Duration timeout, Context context) { Supplier> callable = () -> deleteTableWithResponse(tableName, context); try { - return hasTimeout(timeout) - ? getResultWithTimeout(THREAD_POOL.submit(callable::get), timeout) : callable.get(); + return requestWithOptionalTimeout(callable, timeout); } catch (InterruptedException | ExecutionException | TimeoutException e) { throw logger.logExceptionAsError(new RuntimeException(e)); } catch (RuntimeException e) { @@ -618,7 +613,7 @@ public PagedIterable listTables() { @ServiceMethod(returns = ReturnType.COLLECTION) public PagedIterable listTables(ListTablesOptions options, Duration timeout, Context context) { Supplier> callable = () -> listTables(options, context); - return callIterableWithOptionalTimeout(callable, THREAD_POOL, timeout, logger); + return callIterableWithOptionalTimeout(callable, timeout, logger); } private PagedIterable listTables(ListTablesOptions options, Context context) { @@ -718,7 +713,7 @@ public TableServiceProperties getProperties() { @ServiceMethod(returns = ReturnType.SINGLE) public Response getPropertiesWithResponse(Duration timeout, Context context) { Supplier> callable = () -> getPropertiesWithResponse(context); - return callWithOptionalTimeout(callable, THREAD_POOL, timeout, logger); + return callWithOptionalTimeout(callable, timeout, logger); } Response getPropertiesWithResponse(Context context) { @@ -811,7 +806,7 @@ public void setProperties(TableServiceProperties tableServiceProperties) { public Response setPropertiesWithResponse(TableServiceProperties tableServiceProperties, Duration timeout, Context context) { Supplier> callable = () -> setPropertiesWithResponse(tableServiceProperties, context); - return callWithOptionalTimeout(callable, THREAD_POOL, timeout, logger); + return callWithOptionalTimeout(callable, timeout, logger); } Response setPropertiesWithResponse(TableServiceProperties tableServiceProperties, Context context) { @@ -874,7 +869,7 @@ public TableServiceStatistics getStatistics() { @ServiceMethod(returns = ReturnType.SINGLE) public Response getStatisticsWithResponse(Duration timeout, Context context) { Supplier> callable = () -> getStatisticsWithResponse(context); - return callWithOptionalTimeout(callable, THREAD_POOL, timeout, logger); + return callWithOptionalTimeout(callable, timeout, logger); } diff --git a/sdk/tables/azure-data-tables/src/main/java/com/azure/data/tables/implementation/TableUtils.java b/sdk/tables/azure-data-tables/src/main/java/com/azure/data/tables/implementation/TableUtils.java index be7391d25cfe2..ad714bb602ab7 100644 --- a/sdk/tables/azure-data-tables/src/main/java/com/azure/data/tables/implementation/TableUtils.java +++ b/sdk/tables/azure-data-tables/src/main/java/com/azure/data/tables/implementation/TableUtils.java @@ -8,6 +8,7 @@ import com.azure.core.http.rest.Response; import com.azure.core.http.rest.SimpleResponse; import com.azure.core.util.CoreUtils; +import com.azure.core.util.SharedExecutorService; import com.azure.core.util.logging.ClientLogger; import com.azure.data.tables.implementation.models.TableServiceErrorException; import com.azure.data.tables.implementation.models.TableServiceJsonError; @@ -28,8 +29,6 @@ import java.util.Map; import java.util.TreeMap; import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; import java.util.concurrent.TimeoutException; import java.util.function.Function; import java.util.function.Supplier; @@ -43,7 +42,6 @@ public final class TableUtils { private static final String UTF8_CHARSET = "UTF-8"; private static final String DELIMITER_CONTINUATION_TOKEN = ";"; - private static final long THREADPOOL_SHUTDOWN_HOOK_TIMEOUT_SECONDS = 5; private TableUtils() { throw new UnsupportedOperationException("Cannot instantiate TablesUtils"); @@ -131,23 +129,6 @@ public static Flux applyOptionalTimeout(Flux publisher, Duration timeo return timeout == null ? publisher : publisher.timeout(timeout); } - /** - * Blocks an asynchronous response with an optional timeout. - * - * @param response Asynchronous response to block. - * @param timeout Optional timeout. - * @param Return type of the asynchronous response. - * @return The value of the asynchronous response. - * @throws RuntimeException If the asynchronous response doesn't complete before the timeout expires. - */ - public static T blockWithOptionalTimeout(Mono response, Duration timeout) { - if (timeout == null) { - return response.block(); - } else { - return response.block(timeout); - } - } - /** * Deserializes a given {@link Response HTTP response} including headers to a given class. * @@ -157,7 +138,8 @@ public static T blockWithOptionalTimeout(Mono response, Duration timeout) * @param The class of the exception to swallow. * @return A {@link Mono} that contains the deserialized response. */ - public static Mono> swallowExceptionForStatusCode(int statusCode, E httpResponseException, ClientLogger logger) { + public static Mono> swallowExceptionForStatusCode(int statusCode, + E httpResponseException, ClientLogger logger) { HttpResponse httpResponse = httpResponseException.getResponse(); if (httpResponse.getStatusCode() == statusCode) { @@ -168,10 +150,6 @@ public static Mono> swallowExce return monoError(logger, httpResponseException); } - public static boolean hasTimeout(Duration timeout) { - return timeout != null && !timeout.isZero() && !timeout.isNegative(); - } - /** * Parses the query string into a key-value pair map that maintains key, query parameter key, order. The value is * stored as a parsed array (ex. key=[val1, val2, val3] instead of key=val1,val2,val3). @@ -278,7 +256,7 @@ public static String urlEncode(final String stringToEncode) { return null; } - if (stringToEncode.length() == 0) { + if (stringToEncode.isEmpty()) { return ""; } @@ -318,11 +296,6 @@ private static String encode(final String stringToEncode) { } } - public static ExecutorService getThreadPoolWithShutdownHook() { - return CoreUtils.addShutdownHookSafely(Executors.newCachedThreadPool(), - Duration.ofSeconds(THREADPOOL_SHUTDOWN_HOOK_TIMEOUT_SECONDS)); - } - // Single quotes in OData queries should be escaped by using two consecutive single quotes characters. // Source: http://docs.oasis-open.org/odata/odata/v4.01/odata-v4.01-part2-url-conventions.html#sec_URLSyntax. public static String escapeSingleQuotes(String input) { @@ -361,13 +334,15 @@ public static String[] getKeysFromToken(String token) { return keys; } - public static Response callWithOptionalTimeout(Supplier> callable, ExecutorService threadPool, Duration timeout, ClientLogger logger) { - return callWithOptionalTimeout(callable, threadPool, timeout, logger, false); + public static Response callWithOptionalTimeout(Supplier> callable, Duration timeout, + ClientLogger logger) { + return callWithOptionalTimeout(callable, timeout, logger, false); } - public static Response callWithOptionalTimeout(Supplier> callable, ExecutorService threadPool, Duration timeout, ClientLogger logger, boolean skip409Logging) { + public static Response callWithOptionalTimeout(Supplier> callable, Duration timeout, + ClientLogger logger, boolean skip409Logging) { try { - return callHandler(callable, threadPool, timeout, logger); + return callHandler(callable, timeout, logger); } catch (Throwable thrown) { Throwable exception = mapThrowableToTableServiceException(thrown); if (exception instanceof TableServiceException) { @@ -383,19 +358,37 @@ public static Response callWithOptionalTimeout(Supplier> call } } - public static PagedIterable callIterableWithOptionalTimeout(Supplier> callable, ExecutorService threadPool, Duration timeout, ClientLogger logger) { + public static PagedIterable callIterableWithOptionalTimeout(Supplier> callable, + Duration timeout, ClientLogger logger) { try { - return callHandler(callable, threadPool, timeout, logger); + return callHandler(callable, timeout, logger); } catch (Exception thrown) { Throwable exception = mapThrowableToTableServiceException(thrown); throw logger.logExceptionAsError((RuntimeException) exception); } } - private static T callHandler(Supplier callable, ExecutorService threadPool, Duration timeout, ClientLogger logger) throws Exception { + public static T requestWithOptionalTimeout(Supplier request, Duration timeout) + throws ExecutionException, InterruptedException, TimeoutException { + return hasTimeout(timeout) + ? getResultWithTimeout(SharedExecutorService.getInstance().submit(request::get), timeout) + : request.get(); + } + + /** + * Checks whether the timeout exists (is not null and has a positive duration). + * + * @param timeout The timeout to check. + * @return Whether the timeout exists (is not null and has a positive duration). + */ + private static boolean hasTimeout(Duration timeout) { + return timeout != null && (timeout.getSeconds() | timeout.getNano()) > 0; + } + + private static T callHandler(Supplier callable, Duration timeout, ClientLogger logger) throws Exception { try { return hasTimeout(timeout) - ? getResultWithTimeout(threadPool.submit(callable::get), timeout) + ? getResultWithTimeout(SharedExecutorService.getInstance().submit(callable::get), timeout) : callable.get(); } catch (ExecutionException | InterruptedException | TimeoutException ex) {