diff --git a/java-bigquery/google-cloud-bigquery-jdbc/src/main/java/com/google/cloud/bigquery/jdbc/BigQueryConnection.java b/java-bigquery/google-cloud-bigquery-jdbc/src/main/java/com/google/cloud/bigquery/jdbc/BigQueryConnection.java index 3c64e3a073f8..382defc9e5cb 100644 --- a/java-bigquery/google-cloud-bigquery-jdbc/src/main/java/com/google/cloud/bigquery/jdbc/BigQueryConnection.java +++ b/java-bigquery/google-cloud-bigquery-jdbc/src/main/java/com/google/cloud/bigquery/jdbc/BigQueryConnection.java @@ -61,6 +61,8 @@ import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; /** @@ -143,6 +145,8 @@ public class BigQueryConnection extends BigQueryNoOpsConnection { DatabaseMetaData databaseMetaData; Boolean reqGoogleDriveScope; private boolean isReadOnlyTokenUsed = false; + private int queryTaskThreadCount; + private ExecutorService queryTaskExecutor; BigQueryConnection(String url) throws IOException { this(url, DataSource.fromUrl(url)); @@ -268,6 +272,11 @@ public class BigQueryConnection extends BigQueryNoOpsConnection { this.headerProvider = createHeaderProvider(); this.bigQuery = getBigQueryConnection(); + + this.queryTaskThreadCount = ds.getQueryTaskThreadCount(); + this.queryTaskExecutor = + Executors.newFixedThreadPool( + this.queryTaskThreadCount, new BigQueryThreadFactory("BigQuery-query-task-")); } } @@ -630,6 +639,10 @@ int getMetadataFetchThreadCount() { return this.metadataFetchThreadCount; } + public ExecutorService getQueryTaskExecutor() { + return this.queryTaskExecutor; + } + boolean isEnableWriteAPI() { return enableWriteAPI; } @@ -881,6 +894,10 @@ private void closeImpl() throws SQLException { statement.close(); } this.openStatements.clear(); + + if (this.queryTaskExecutor != null) { + this.queryTaskExecutor.shutdown(); + } } catch (ConcurrentModificationException ex) { throw new BigQueryJdbcException("Concurrent modification during close", ex); } catch (InterruptedException e) { diff --git a/java-bigquery/google-cloud-bigquery-jdbc/src/main/java/com/google/cloud/bigquery/jdbc/BigQueryJdbcUrlUtility.java b/java-bigquery/google-cloud-bigquery-jdbc/src/main/java/com/google/cloud/bigquery/jdbc/BigQueryJdbcUrlUtility.java index 89a2b8b5cb8c..9a0a2ab88287 100644 --- a/java-bigquery/google-cloud-bigquery-jdbc/src/main/java/com/google/cloud/bigquery/jdbc/BigQueryJdbcUrlUtility.java +++ b/java-bigquery/google-cloud-bigquery-jdbc/src/main/java/com/google/cloud/bigquery/jdbc/BigQueryJdbcUrlUtility.java @@ -142,6 +142,8 @@ protected boolean removeEldestEntry(Map.Entry> eldes Pattern.CASE_INSENSITIVE); static final String METADATA_FETCH_THREAD_COUNT_PROPERTY_NAME = "MetaDataFetchThreadCount"; static final int DEFAULT_METADATA_FETCH_THREAD_COUNT_VALUE = 32; + static final String QUERY_TASK_THREAD_COUNT_PROPERTY_NAME = "QueryTaskThreadCount"; + static final int DEFAULT_QUERY_TASK_THREAD_COUNT_VALUE = 16; static final String RETRY_TIMEOUT_IN_SECS_PROPERTY_NAME = "Timeout"; static final long DEFAULT_RETRY_TIMEOUT_IN_SECS_VALUE = 0L; static final String JOB_TIMEOUT_PROPERTY_NAME = "JobTimeout"; @@ -540,6 +542,12 @@ protected boolean removeEldestEntry(Map.Entry> eldes "The number of threads used to call a DatabaseMetaData method.") .setDefaultValue(String.valueOf(DEFAULT_METADATA_FETCH_THREAD_COUNT_VALUE)) .build(), + BigQueryConnectionProperty.newBuilder() + .setName(QUERY_TASK_THREAD_COUNT_PROPERTY_NAME) + .setDescription( + "The number of background threads used for executing queries parallel tasks.") + .setDefaultValue(String.valueOf(DEFAULT_QUERY_TASK_THREAD_COUNT_VALUE)) + .build(), BigQueryConnectionProperty.newBuilder() .setName(ENABLE_WRITE_API_PROPERTY_NAME) .setDescription( diff --git a/java-bigquery/google-cloud-bigquery-jdbc/src/main/java/com/google/cloud/bigquery/jdbc/BigQueryStatement.java b/java-bigquery/google-cloud-bigquery-jdbc/src/main/java/com/google/cloud/bigquery/jdbc/BigQueryStatement.java index e2dab7b31678..caeb01798848 100644 --- a/java-bigquery/google-cloud-bigquery-jdbc/src/main/java/com/google/cloud/bigquery/jdbc/BigQueryStatement.java +++ b/java-bigquery/google-cloud-bigquery-jdbc/src/main/java/com/google/cloud/bigquery/jdbc/BigQueryStatement.java @@ -72,8 +72,8 @@ import java.util.Random; import java.util.UUID; import java.util.concurrent.BlockingQueue; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.ThreadFactory; import java.util.logging.Level; @@ -86,11 +86,6 @@ * @see ResultSet */ public class BigQueryStatement extends BigQueryNoOpsStatement { - - // TODO (obada): Update this after benchmarking - private static final int MAX_PROCESS_QUERY_THREADS_CNT = 50; - protected static ExecutorService queryTaskExecutor = - Executors.newFixedThreadPool(MAX_PROCESS_QUERY_THREADS_CNT); private final BigQueryJdbcCustomLogger LOG = new BigQueryJdbcCustomLogger(this.toString()); private static final String DEFAULT_DATASET_NAME = "_google_jdbc"; private static final String DEFAULT_TABLE_NAME = "temp_table_"; @@ -610,15 +605,20 @@ void runQuery(String query, QueryJobConfiguration jobConfiguration) try { resetStatementFields(); + + final QueryJobConfiguration finalJobConfiguration = jobConfiguration; + Future statementTypeFuture = + connection.getQueryTaskExecutor().submit(() -> getStatementType(finalJobConfiguration)); + ExecuteResult executeResult = executeJob(jobConfiguration); - StatementType statementType = - executeResult.job == null - ? getStatementType(jobConfiguration) - : ((QueryStatistics) executeResult.job.getStatistics()).getStatementType(); + + StatementType statementType = statementTypeFuture.get(); SqlType queryType = getQueryType(jobConfiguration, statementType); handleQueryResult(query, executeResult.tableResult, queryType); } catch (InterruptedException ex) { - throw new BigQueryJdbcRuntimeException("Interrupted during runQuery", ex); + throw new BigQueryJdbcRuntimeException(ex); + } catch (ExecutionException e) { + throw new BigQueryJdbcException(e.getCause()); } catch (BigQueryException ex) { if (ex.getMessage().contains("Syntax error")) { throw new BigQueryJdbcSqlSyntaxErrorException("BigQueryException during runQuery", ex); @@ -849,7 +849,8 @@ Thread populateArrowBufferedQueue( com.google.api.gax.rpc.ServerStream stream = bqReadClient.readRowsCallable().call(readRowsRequest); for (ReadRowsResponse response : stream) { - if (Thread.currentThread().isInterrupted() || queryTaskExecutor.isShutdown()) { + if (Thread.currentThread().isInterrupted() + || connection.getQueryTaskExecutor().isShutdown()) { break; } @@ -1062,7 +1063,8 @@ Thread runNextPageTaskAsync( try { while (currentPageToken != null) { // do not process further pages and shutdown - if (Thread.currentThread().isInterrupted() || queryTaskExecutor.isShutdown()) { + if (Thread.currentThread().isInterrupted() + || connection.getQueryTaskExecutor().isShutdown()) { LOG.warning( "%s Interrupted @ runNextPageTaskAsync", Thread.currentThread().getName()); break; @@ -1093,7 +1095,8 @@ Thread runNextPageTaskAsync( // completes Uninterruptibles.putUninterruptibly(rpcResponseQueue, Tuple.of(null, false)); } - // We cannot do queryTaskExecutor.shutdownNow() here as populate buffer method may not + // We cannot do connection.getQueryTaskExecutor().shutdownNow() here as populate buffer + // method may not // have finished processing the records and even that will be interrupted }; @@ -1137,7 +1140,7 @@ Thread parseAndPopulateRpcDataAsync( } if (Thread.currentThread().isInterrupted() - || queryTaskExecutor.isShutdown() + || connection.getQueryTaskExecutor().isShutdown() || fieldValueLists == null) { // do not process further pages and shutdown (outerloop) break; @@ -1147,7 +1150,8 @@ Thread parseAndPopulateRpcDataAsync( long results = 0; for (FieldValueList fieldValueList : fieldValueLists) { - if (Thread.currentThread().isInterrupted() || queryTaskExecutor.isShutdown()) { + if (Thread.currentThread().isInterrupted() + || connection.getQueryTaskExecutor().isShutdown()) { // do not process further pages and shutdown (inner loop) break; } diff --git a/java-bigquery/google-cloud-bigquery-jdbc/src/main/java/com/google/cloud/bigquery/jdbc/DataSource.java b/java-bigquery/google-cloud-bigquery-jdbc/src/main/java/com/google/cloud/bigquery/jdbc/DataSource.java index 1c3344b9b6d1..6e90d1d67030 100644 --- a/java-bigquery/google-cloud-bigquery-jdbc/src/main/java/com/google/cloud/bigquery/jdbc/DataSource.java +++ b/java-bigquery/google-cloud-bigquery-jdbc/src/main/java/com/google/cloud/bigquery/jdbc/DataSource.java @@ -85,6 +85,7 @@ public class DataSource implements javax.sql.DataSource { private Boolean filterTablesOnDefaultDataset; private Integer requestGoogleDriveScope; private Integer metadataFetchThreadCount; + private Integer queryTaskThreadCount; private String sslTrustStorePath; private String sslTrustStorePassword; private Map labels; @@ -247,6 +248,9 @@ public class DataSource implements javax.sql.DataSource { .put( BigQueryJdbcUrlUtility.METADATA_FETCH_THREAD_COUNT_PROPERTY_NAME, (ds, val) -> ds.setMetadataFetchThreadCount(Integer.parseInt(val))) + .put( + BigQueryJdbcUrlUtility.QUERY_TASK_THREAD_COUNT_PROPERTY_NAME, + (ds, val) -> ds.setQueryTaskThreadCount(Integer.parseInt(val))) .put( BigQueryJdbcUrlUtility.SSL_TRUST_STORE_PROPERTY_NAME, DataSource::setSSLTrustStorePath) @@ -558,6 +562,11 @@ private Properties createProperties() { BigQueryJdbcUrlUtility.METADATA_FETCH_THREAD_COUNT_PROPERTY_NAME, String.valueOf(this.metadataFetchThreadCount)); } + if (this.queryTaskThreadCount != null) { + connectionProperties.setProperty( + BigQueryJdbcUrlUtility.QUERY_TASK_THREAD_COUNT_PROPERTY_NAME, + String.valueOf(this.queryTaskThreadCount)); + } if (this.sslTrustStorePath != null) { connectionProperties.setProperty( BigQueryJdbcUrlUtility.SSL_TRUST_STORE_PROPERTY_NAME, @@ -1051,6 +1060,16 @@ public void setMetadataFetchThreadCount(Integer metadataFetchThreadCount) { this.metadataFetchThreadCount = metadataFetchThreadCount; } + public Integer getQueryTaskThreadCount() { + return queryTaskThreadCount != null + ? queryTaskThreadCount + : BigQueryJdbcUrlUtility.DEFAULT_QUERY_TASK_THREAD_COUNT_VALUE; + } + + public void setQueryTaskThreadCount(Integer queryTaskThreadCount) { + this.queryTaskThreadCount = queryTaskThreadCount; + } + public String getSSLTrustStorePath() { return sslTrustStorePath; } diff --git a/java-bigquery/google-cloud-bigquery-jdbc/src/test/java/com/google/cloud/bigquery/jdbc/BigQueryConnectionTest.java b/java-bigquery/google-cloud-bigquery-jdbc/src/test/java/com/google/cloud/bigquery/jdbc/BigQueryConnectionTest.java index dd6ceb0deceb..89ff7cb5adfc 100644 --- a/java-bigquery/google-cloud-bigquery-jdbc/src/test/java/com/google/cloud/bigquery/jdbc/BigQueryConnectionTest.java +++ b/java-bigquery/google-cloud-bigquery-jdbc/src/test/java/com/google/cloud/bigquery/jdbc/BigQueryConnectionTest.java @@ -30,6 +30,7 @@ import java.io.InputStream; import java.sql.SQLException; import java.util.Properties; +import java.util.concurrent.ThreadPoolExecutor; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -365,6 +366,36 @@ public void testMetaDataFetchThreadCountProperty() throws SQLException, IOExcept } } + @Test + public void testQueryTaskThreadCountProperty() throws SQLException, IOException { + // Test Case 1: Should use the default value when the property is not provided. + String urlDefault = + "jdbc:bigquery://https://www.googleapis.com/bigquery/v2:443;" + + "OAuthType=2;ProjectId=MyBigQueryProject;" + + "OAuthAccessToken=redactedToken;OAuthClientId=redactedToken;" + + "OAuthClientSecret=redactedToken;"; + try (BigQueryConnection connectionDefault = new BigQueryConnection(urlDefault)) { + assertEquals( + 16, + ((ThreadPoolExecutor) connectionDefault.getQueryTaskExecutor()).getCorePoolSize(), + "Should use the default value of 4 when the property is not provided"); + } + + // Test Case 2: Should use the custom value when a valid integer is provided. + String urlCustom = + "jdbc:bigquery://https://www.googleapis.com/bigquery/v2:443;" + + "OAuthType=2;ProjectId=MyBigQueryProject;" + + "OAuthAccessToken=redactedToken;OAuthClientId=redactedToken;" + + "OAuthClientSecret=redactedToken;" + + "QueryTaskThreadCount=16;"; + try (BigQueryConnection connectionCustom = new BigQueryConnection(urlCustom)) { + assertEquals( + 16, + ((ThreadPoolExecutor) connectionCustom.getQueryTaskExecutor()).getCorePoolSize(), + "Should use the custom value when a valid integer is provided"); + } + } + @Test public void testBigQueryReadClientKeepAliveSettings() throws SQLException, IOException { String url = diff --git a/java-bigquery/google-cloud-bigquery-jdbc/src/test/java/com/google/cloud/bigquery/jdbc/BigQueryStatementTest.java b/java-bigquery/google-cloud-bigquery-jdbc/src/test/java/com/google/cloud/bigquery/jdbc/BigQueryStatementTest.java index 9fef90c69a4d..daeee2a134ea 100644 --- a/java-bigquery/google-cloud-bigquery-jdbc/src/test/java/com/google/cloud/bigquery/jdbc/BigQueryStatementTest.java +++ b/java-bigquery/google-cloud-bigquery-jdbc/src/test/java/com/google/cloud/bigquery/jdbc/BigQueryStatementTest.java @@ -60,11 +60,14 @@ import java.util.Map; import java.util.UUID; import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.BitVector; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; @@ -86,6 +89,8 @@ public class BigQueryStatementTest { private BigQueryStatement bigQueryStatement; + private ExecutorService queryTaskExecutor; + private final String query = "select * from test"; private final String jobIdVal = UUID.randomUUID().toString(); @@ -128,6 +133,7 @@ private Job getJobMock( @BeforeEach public void setUp() throws IOException, SQLException { + queryTaskExecutor = Executors.newFixedThreadPool(1); bigQueryConnection = mock(BigQueryConnection.class); rpcFactoryMock = mock(BigQueryRpcFactory.class); bigquery = mock(BigQuery.class); @@ -135,6 +141,8 @@ public void setUp() throws IOException, SQLException { storageReadClient = mock(BigQueryReadClient.class); jobId = JobId.newBuilder().setJob(jobIdVal).build(); + doReturn(queryTaskExecutor).when(bigQueryConnection).getQueryTaskExecutor(); + doReturn(bigquery).when(bigQueryConnection).getBigQuery(); doReturn(10L).when(bigQueryConnection).getJobTimeoutInSeconds(); doReturn(10L).when(bigQueryConnection).getMaxBytesBilled(); @@ -150,7 +158,13 @@ public void setUp() throws IOException, SQLException { .setSerializedSchema(serializeSchema(vectorSchemaRoot.getSchema())) .build(); // bigQueryConnection.addOpenStatements(bigQueryStatement); + } + @AfterEach + public void tearDown() { + if (queryTaskExecutor != null) { + queryTaskExecutor.shutdown(); + } } private VectorSchemaRoot getTestVectorSchemaRoot() { @@ -215,6 +229,7 @@ public void testExecSlowQueryPath() throws SQLException, InterruptedException { Job job = getJobMock(tableResult, queryJobConfiguration, StatementType.SELECT); doReturn(job).when(bigquery).queryWithTimeout(any(), any(), any()); + doReturn(job).when(bigquery).create(any(JobInfo.class)); doReturn(jobIdWrapper) .when(bigQueryStatementSpy) @@ -300,14 +315,16 @@ public void setQueryTimeoutTest() throws Exception { Job job = getJobMock(result, jobConfiguration, StatementType.SELECT); doReturn(job).when(bigquery).queryWithTimeout(any(), any(), any()); + doReturn(job).when(bigquery).create(any(JobInfo.class)); doReturn(jsonResultSet).when(bigQueryStatementSpy).processJsonResultSet(result); - ArgumentCaptor captor = + ArgumentCaptor queryCaptor = ArgumentCaptor.forClass(QueryJobConfiguration.class); bigQueryStatementSpy.runQuery(query, jobConfiguration); - verify(bigquery).queryWithTimeout(captor.capture(), any(), any()); - QueryJobConfiguration jobConfig = captor.getValue(); + verify(bigquery, Mockito.times(1)).create(any(JobInfo.class)); + verify(bigquery, Mockito.times(1)).queryWithTimeout(queryCaptor.capture(), any(), any()); + QueryJobConfiguration jobConfig = queryCaptor.getValue(); assertEquals(3000L, jobConfig.getJobTimeoutMs().longValue()); } @@ -399,13 +416,23 @@ public void testJoblessQuery() throws SQLException, InterruptedException { doReturn(jobMock) .when(bigquery) .queryWithTimeout(any(QueryJobConfiguration.class), any(), any()); + doReturn(jobMock).when(bigquery).create(any(JobInfo.class)); doReturn(mock(BigQueryJsonResultSet.class)) .when(jobfulStatementSpy) .processJsonResultSet(tableResultJobfulMock); jobfulStatementSpy.executeQuery("SELECT 1"); - verify(bigquery).queryWithTimeout(any(QueryJobConfiguration.class), any(), any()); + ArgumentCaptor jobfulCaptor = ArgumentCaptor.forClass(JobInfo.class); + verify(bigquery, Mockito.times(1)).create(jobfulCaptor.capture()); + assertTrue( + jobfulCaptor.getAllValues().stream() + .anyMatch( + jobInfo -> + Boolean.TRUE.equals( + ((QueryJobConfiguration) jobInfo.getConfiguration()).dryRun()))); + verify(bigquery, Mockito.times(1)) + .queryWithTimeout(any(QueryJobConfiguration.class), any(), any()); } @Test