Skip to content
This repository has been archived by the owner on Mar 12, 2024. It is now read-only.

[EXE-1528] Send jdbcOptions as MergedParameter to driver #10

Merged
merged 9 commits into from
Apr 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import org.apache.spark.sql._
import org.apache.spark.sql.types.StructType
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Matchers}

import io.github.spark_redshift_community.spark.redshift.Parameters.MergedParameters

import scala.util.Random


Expand Down Expand Up @@ -62,6 +64,16 @@ trait IntegrationSuiteBase
s"$AWS_REDSHIFT_JDBC_URL?user=$AWS_REDSHIFT_USER&password=$AWS_REDSHIFT_PASSWORD&ssl=true"
}

protected def param: MergedParameters = {
MergedParameters(
Map(
"url" -> jdbcUrlNoUserPassword,
"user" -> AWS_REDSHIFT_USER,
"password" -> AWS_REDSHIFT_PASSWORD
)
)
}

protected def jdbcUrlNoUserPassword: String = {
s"$AWS_REDSHIFT_JDBC_URL?ssl=true"
}
Expand Down Expand Up @@ -91,7 +103,7 @@ trait IntegrationSuiteBase
sc.hadoopConfiguration.set("fs.s3n.awsSecretAccessKey", AWS_SECRET_ACCESS_KEY)
sc.hadoopConfiguration.set("fs.s3a.access.key", AWS_ACCESS_KEY_ID)
sc.hadoopConfiguration.set("fs.s3a.secret.key", AWS_SECRET_ACCESS_KEY)
conn = DefaultJDBCWrapper.getConnector(None, jdbcUrl, None)
conn = DefaultJDBCWrapper.getConnector(param)
}

override def afterAll(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class PostgresDriverIntegrationSuite extends IntegrationSuiteBase {

// TODO (luca|issue #9) Fix tests when using postgresql driver
ignore("postgresql driver takes precedence for jdbc:postgresql:// URIs") {
val conn = DefaultJDBCWrapper.getConnector(None, jdbcUrl, None)
val conn = DefaultJDBCWrapper.getConnector(param)
try {
assert(conn.getClass.getName === "org.postgresql.jdbc4.Jdbc4Connection")
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class DefaultSource(
}

def tableExists: Boolean = {
val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl, params.credentials)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious, did you look into the history of why they ignore all the other params?

Even on latest master, only use these 3 params and ignore the others /~https://github.com/spark-redshift-community/spark-redshift/blob/master/src/main/scala/io/github/spark_redshift_community/spark/redshift/RedshiftRelation.scala#L111

Copy link
Collaborator Author

@ShaoFuWu ShaoFuWu Apr 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like they were fixing the same invalid character 8 years ago like we were last week 😅 , the way they fix is passing user/password as part of the options in this pr

val conn = jdbcWrapper.getConnector(params)
try {
jdbcWrapper.tableExists(conn, table.toString)
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@

package io.github.spark_redshift_community.spark.redshift

import java.sql.{ResultSet, PreparedStatement, Connection, Driver, DriverManager, ResultSetMetaData, SQLException}
import io.github.spark_redshift_community.spark.redshift.Parameters.MergedParameters

import java.sql.{Connection, Driver, DriverManager, PreparedStatement, ResultSet, ResultSetMetaData, SQLException}
import java.util.Properties
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.{ThreadFactory, Executors}

import java.util.concurrent.{Executors, ThreadFactory}
import scala.collection.JavaConverters._
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.concurrent.duration.Duration
import scala.util.Try
import scala.util.control.NonFatal

import org.apache.spark.SPARK_VERSION
import org.apache.spark.sql.execution.datasources.jdbc.DriverRegistry
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -196,10 +196,9 @@ private[redshift] class JDBCWrapper {
* discover the appropriate driver class.
* @param url the JDBC url to connect to.
*/
def getConnector(
userProvidedDriverClass: Option[String],
url: String,
credentials: Option[(String, String)]) : Connection = {
def getConnector(mergedParameters: MergedParameters) : Connection = {
val url = mergedParameters.jdbcUrl
val userProvidedDriverClass = mergedParameters.jdbcDriver
val subprotocol = url.stripPrefix("jdbc:").split(":")(0)
val driverClass: String = getDriverClass(subprotocol, userProvidedDriverClass)
DriverRegistry.register(driverClass)
Expand All @@ -225,9 +224,8 @@ private[redshift] class JDBCWrapper {
throw new IllegalArgumentException(s"Did not find registered driver with class $driverClass")
}
val properties = new Properties()
credentials.foreach { case(user, password) =>
properties.setProperty("user", user)
properties.setProperty("password", password)
mergedParameters.parameters.foreach { case (key, value) =>
properties.setProperty(key, value)
}
driver.connect(url, properties)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ private[redshift] case class RedshiftRelation(
userSchema.getOrElse {
val tableNameOrSubquery =
params.query.map(q => s"($q)").orElse(params.table.map(_.toString)).get
val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl, params.credentials)
val conn = jdbcWrapper.getConnector(params)
try {
jdbcWrapper.resolveTable(conn, tableNameOrSubquery)
} finally {
Expand All @@ -89,7 +89,7 @@ private[redshift] case class RedshiftRelation(
}

private def executeCountQuery(query: String): RDD[InternalRow] = {
val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl, params.credentials)
val conn = jdbcWrapper.getConnector(params)
val queryWithTag = RedshiftPushDownSqlStatement.appendTagsToQuery(jdbcOptions, query)
try {
val results = jdbcWrapper.executeQueryInterruptibly(conn.prepareStatement(queryWithTag))
Expand All @@ -116,7 +116,7 @@ private[redshift] case class RedshiftRelation(
// Unload data from Redshift into a temporary directory in S3:
val tempDir = params.createPerQueryTempDir()
val unloadSql = buildUnloadStmt(query, tempDir, creds, params.sseKmsKey)
val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl, params.credentials)
val conn = jdbcWrapper.getConnector(params)
try {
jdbcWrapper.executeInterruptibly(conn.prepareStatement(unloadSql))
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package io.github.spark_redshift_community.spark.redshift

import java.net.URI
import java.sql.{Connection, Date, SQLException, Timestamp}

import com.amazonaws.auth.AWSCredentialsProvider
import com.amazonaws.services.s3.AmazonS3Client
import io.github.spark_redshift_community.spark.redshift.Parameters.MergedParameters
Expand All @@ -29,7 +28,6 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode}
import org.slf4j.LoggerFactory

import scala.collection.mutable
import scala.util.control.NonFatal

/**
Expand Down Expand Up @@ -403,7 +401,7 @@ private[redshift] class RedshiftWriter(
tempDir = params.createPerQueryTempDir(),
tempFormat = params.tempFormat,
nullString = params.nullString)
val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl, params.credentials)
val conn = jdbcWrapper.getConnector(params)
conn.setAutoCommit(false)
try {
val table: TableName = params.table.get
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

package io.github.spark_redshift_community.spark.redshift

import java.sql.{Connection, PreparedStatement, ResultSet, SQLException}
import io.github.spark_redshift_community.spark.redshift.Parameters.MergedParameters

import java.sql.{Connection, PreparedStatement, ResultSet, SQLException}
import org.apache.spark.sql.types.StructType
import org.mockito.Matchers._
import org.mockito.Mockito._
Expand Down Expand Up @@ -69,7 +70,7 @@ class MockRedshift(
doAnswer(new Answer[Connection] {
override def answer(invocation: InvocationOnMock): Connection = createMockConnection()
}).when(jdbcWrapper)
.getConnector(any[Option[String]](), same(jdbcUrl), any[Option[(String, String)]]())
.getConnector(any[MergedParameters]())

doAnswer(new Answer[Boolean] {
override def answer(invocation: InvocationOnMock): Boolean = {
Expand Down
2 changes: 1 addition & 1 deletion version.sbt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version in ThisBuild := "5.0.7-aiq6"
version in ThisBuild := "5.0.7-aiq14"