Skip to content

Commit

Permalink
Add option to use CSV as an intermediate data format during writes
Browse files Browse the repository at this point in the history
This patch adds new options to allow CSV to be used as the intermediate data format when writing data to Redshift. This can offer large performance benefits because Redshift's Avro reader can be very slow. This patch is based on #165 by emlyn and incorporates changes from me in order to add documentation, make the new option case-insensitive, improve some error messages, and add tests.

Using CSV for writes also allows us to write to tables whose column names are unsupported by Avro, so #84 is partially addressed by this patch.

As a hedge, I've marked this feature as "Experimental" and I'll remove that label after it's been tested in the wild a bit more.

Fixes #73.

Author: Josh Rosen <joshrosen@databricks.com>
Author: Josh Rosen <rosenville@gmail.com>
Author: Emlyn Corrin <Emlyn.Corrin@microsoft.com>
Author: Emlyn Corrin <emlyn@swiftkey.com>

Closes #288 from JoshRosen/use-csv-for-writes.
  • Loading branch information
JoshRosen committed Oct 25, 2016
1 parent d508d3e commit 6cc49da
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 23 deletions.
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,34 @@ for other options).</p>
at the end of the command can be used, but that should cover most possible use cases.</p>
</td>
</tr>
<tr>
<td><tt>tempformat</tt> (Experimental)</td>
<td>No</td>
<td><tt>AVRO</tt></td>
<td>
<p>
The format in which to save temporary files in S3 when writing to Redshift.
Defaults to "AVRO"; the other allowed values are "CSV" and "CSV GZIP" for CSV
and gzipped CSV, respectively.
</p>
<p>
Redshift is significantly faster when loading CSV than when loading Avro files, so
using that <tt>tempformat</tt> may provide a large performance boost when writing
to Redshift.
</p>
</td>
</tr>
<tr>
<td><tt>csvnullstring</tt> (Experimental)</td>
<td>No</td>
<td><tt>@NULL@</tt></td>
<td>
<p>
The String value to write for nulls when using the CSV <tt>tempformat</tt>.
This should be a value which does not appear in your actual data.
</p>
</td>
</tr>
</table>

## Additional configuration options
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,8 @@ trait IntegrationSuiteBase
expectedSchemaAfterLoad: Option[StructType] = None,
saveMode: SaveMode = SaveMode.ErrorIfExists): Unit = {
try {
df.write
.format("com.databricks.spark.redshift")
.option("url", jdbcUrl)
write(df)
.option("dbtable", tableName)
.option("tempdir", tempDir)
.mode(saveMode)
.save()
// Check that the table exists. It appears that creating a table in one connection then
Expand All @@ -215,12 +212,7 @@ trait IntegrationSuiteBase
Thread.sleep(1000)
assert(DefaultJDBCWrapper.tableExists(conn, tableName))
}
val loadedDf = sqlContext.read
.format("com.databricks.spark.redshift")
.option("url", jdbcUrl)
.option("dbtable", tableName)
.option("tempdir", tempDir)
.load()
val loadedDf = read.option("dbtable", tableName).load()
assert(loadedDf.schema === expectedSchemaAfterLoad.getOrElse(df.schema))
checkAnswer(loadedDf, df.collect())
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ import org.apache.spark.sql.types._
/**
* End-to-end tests of functionality which involves writing to Redshift via the connector.
*/
class RedshiftWriteSuite extends IntegrationSuiteBase {
abstract class BaseRedshiftWriteSuite extends IntegrationSuiteBase {

protected val tempformat: String

override protected def write(df: DataFrame): DataFrameWriter[Row] =
super.write(df).option("tempformat", tempformat)

test("roundtrip save and load") {
// This test can be simplified once #98 is fixed.
Expand Down Expand Up @@ -109,3 +114,54 @@ class RedshiftWriteSuite extends IntegrationSuiteBase {
)
}
}

class AvroRedshiftWriteSuite extends BaseRedshiftWriteSuite {
override protected val tempformat: String = "AVRO"

test("informative error message when saving with column names that contain spaces (#84)") {
intercept[IllegalArgumentException] {
testRoundtripSaveAndLoad(
s"error_when_saving_column_name_with_spaces_$randomSuffix",
sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))),
StructType(StructField("column name with spaces", IntegerType) :: Nil)))
}
}
}

class CSVRedshiftWriteSuite extends BaseRedshiftWriteSuite {
override protected val tempformat: String = "CSV"

test("save with column names that contain spaces (#84)") {
testRoundtripSaveAndLoad(
s"save_with_column_names_that_contain_spaces_$randomSuffix",
sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))),
StructType(StructField("column name with spaces", IntegerType) :: Nil)))
}
}

class CSVGZIPRedshiftWriteSuite extends IntegrationSuiteBase {
// Note: we purposely don't inherit from BaseRedshiftWriteSuite because we're only interested in
// testing basic functionality of the GZIP code; the rest of the write path should be unaffected
// by compression here.

override protected def write(df: DataFrame): DataFrameWriter[Row] =
super.write(df).option("tempformat", "CSV GZIP")

test("roundtrip save and load") {
// This test can be simplified once #98 is fixed.
val tableName = s"roundtrip_save_and_load_$randomSuffix"
try {
write(
sqlContext.createDataFrame(sc.parallelize(TestUtils.expectedData), TestUtils.testSchema))
.option("dbtable", tableName)
.mode(SaveMode.ErrorIfExists)
.save()

assert(DefaultJDBCWrapper.tableExists(conn, tableName))
checkAnswer(read.option("dbtable", tableName).load(), TestUtils.expectedData)
} finally {
conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
conn.commit()
}
}
}
22 changes: 22 additions & 0 deletions src/main/scala/com/databricks/spark/redshift/Parameters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,30 @@ private[redshift] object Parameters {
// * distkey has no default, but is optional unless using diststyle KEY
// * jdbcdriver has no default, but is optional

"tempformat" -> "AVRO",
"csvnullstring" -> "@NULL@",
"overwrite" -> "false",
"diststyle" -> "EVEN",
"usestagingtable" -> "true",
"preactions" -> ";",
"postactions" -> ";"
)

val VALID_TEMP_FORMATS = Set("AVRO", "CSV", "CSV GZIP")

/**
* Merge user parameters with the defaults, preferring user parameters if specified
*/
def mergeParameters(userParameters: Map[String, String]): MergedParameters = {
if (!userParameters.contains("tempdir")) {
throw new IllegalArgumentException("'tempdir' is required for all Redshift loads and saves")
}
if (userParameters.contains("tempformat") &&
!VALID_TEMP_FORMATS.contains(userParameters("tempformat").toUpperCase)) {
throw new IllegalArgumentException(
s"""Invalid temp format: ${userParameters("tempformat")}; """ +
s"valid formats are: ${VALID_TEMP_FORMATS.mkString(", ")}")
}
if (!userParameters.contains("url")) {
throw new IllegalArgumentException("A JDBC URL must be provided with 'url' parameter")
}
Expand Down Expand Up @@ -84,6 +94,18 @@ private[redshift] object Parameters {
*/
def rootTempDir: String = parameters("tempdir")

/**
* The format in which to save temporary files in S3. Defaults to "AVRO"; the other allowed
* values are "CSV" and "CSV GZIP" for CSV and gzipped CSV, respectively.
*/
def tempFormat: String = parameters("tempformat").toUpperCase

/**
* The String value to write for nulls when using CSV.
* This should be a value which does not appear in your actual data.
*/
def nullString: String = parameters("csvnullstring")

/**
* Creates a per-query subdirectory in the [[rootTempDir]], with a random UUID.
*/
Expand Down
60 changes: 48 additions & 12 deletions src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,12 @@ private[redshift] class RedshiftWriter(
val credsString: String =
AWSCredentialsUtils.getRedshiftCredentialsString(params, creds.getCredentials)
val fixedUrl = Utils.fixS3Url(manifestUrl)
val format = params.tempFormat match {
case "AVRO" => "AVRO 'auto'"
case csv if csv == "CSV" || csv == "CSV GZIP" => csv + s" NULL AS '${params.nullString}'"
}
s"COPY ${params.table.get} FROM '$fixedUrl' CREDENTIALS '$credsString' FORMAT AS " +
s"AVRO 'auto' manifest ${params.extraCopyOptions}"
s"${format} manifest ${params.extraCopyOptions}"
}

/**
Expand Down Expand Up @@ -205,7 +209,9 @@ private[redshift] class RedshiftWriter(
private def unloadData(
sqlContext: SQLContext,
data: DataFrame,
tempDir: String): Option[String] = {
tempDir: String,
tempFormat: String,
nullString: String): Option[String] = {
// spark-avro does not support Date types. In addition, it converts Timestamps into longs
// (milliseconds since the Unix epoch). Redshift is capable of loading timestamps in
// 'epochmillisecs' format but there's no equivalent format for dates. To work around this, we
Expand Down Expand Up @@ -273,10 +279,20 @@ private[redshift] class RedshiftWriter(
}
)

sqlContext.createDataFrame(convertedRows, convertedSchema)
.write
.format("com.databricks.spark.avro")
.save(tempDir)
val writer = sqlContext.createDataFrame(convertedRows, convertedSchema).write
(tempFormat match {
case "AVRO" =>
writer.format("com.databricks.spark.avro")
case "CSV" =>
writer.format("csv")
.option("escape", "\"")
.option("nullValue", nullString)
case "CSV GZIP" =>
writer.format("csv")
.option("escape", "\"")
.option("nullValue", nullString)
.option("compression", "gzip")
}).save(tempDir)

if (nonEmptyPartitions.value.isEmpty) {
None
Expand All @@ -285,10 +301,7 @@ private[redshift] class RedshiftWriter(
// for a description of the manifest file format. The URLs in this manifest must be absolute
// and complete.

// The saved filenames depend on the spark-avro version. In spark-avro 1.0.0, the write
// path uses SparkContext.saveAsHadoopFile(), which produces filenames of the form
// part-XXXXX.avro. In spark-avro 2.0.0+, the partition filenames are of the form
// part-r-XXXXX-UUID.avro.
// The partition filenames are of the form part-r-XXXXX-UUID.fileExtension.
val fs = FileSystem.get(URI.create(tempDir), sqlContext.sparkContext.hadoopConfiguration)
val partitionIdRegex = "^part-(?:r-)?(\\d+)[^\\d+].*$".r
val filesToLoad: Seq[String] = {
Expand Down Expand Up @@ -317,7 +330,7 @@ private[redshift] class RedshiftWriter(
}

/**
* Write a DataFrame to a Redshift table, using S3 and Avro serialization
* Write a DataFrame to a Redshift table, using S3 and Avro or CSV serialization
*/
def saveToRedshift(
sqlContext: SQLContext,
Expand Down Expand Up @@ -352,13 +365,36 @@ private[redshift] class RedshiftWriter(
}
}

// When using the Avro tempformat, log an informative error message in case any column names
// are unsupported by Avro's schema validation:
if (params.tempFormat == "AVRO") {
for (fieldName <- data.schema.fieldNames) {
// The following logic is based on Avro's Schema.validateName() method:
val firstChar = fieldName.charAt(0)
val isValid = (firstChar.isLetter || firstChar == '_') && fieldName.tail.forall { c =>
c.isLetterOrDigit || c == '_'
}
if (!isValid) {
throw new IllegalArgumentException(
s"The field name '$fieldName' is not supported when using the Avro tempformat. " +
"Try using the CSV tempformat instead. For more details, see " +
"/~https://github.com/databricks/spark-redshift/issues/84")
}
}
}

Utils.assertThatFileSystemIsNotS3BlockFileSystem(
new URI(params.rootTempDir), sqlContext.sparkContext.hadoopConfiguration)

Utils.checkThatBucketHasObjectLifecycleConfiguration(params.rootTempDir, s3ClientFactory(creds))

// Save the table's rows to S3:
val manifestUrl = unloadData(sqlContext, data, params.createPerQueryTempDir())
val manifestUrl = unloadData(
sqlContext,
data,
tempDir = params.createPerQueryTempDir(),
tempFormat = params.tempFormat,
nullString = params.nullString)
val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl, params.credentials)
conn.setAutoCommit(false)
try {
Expand Down
14 changes: 14 additions & 0 deletions src/test/scala/com/databricks/spark/redshift/ParametersSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,18 @@ class ParametersSuite extends FunSuite with Matchers {
"query" -> "select * from test_table",
"url" -> "jdbc:redshift://foo/bar?user=user&password=password"))
}

test("tempformat option is case-insensitive") {
val params = Map(
"tempdir" -> "s3://foo/bar",
"dbtable" -> "test_schema.test_table",
"url" -> "jdbc:redshift://foo/bar?user=user&password=password")

Parameters.mergeParameters(params + ("tempformat" -> "csv"))
Parameters.mergeParameters(params + ("tempformat" -> "CSV"))

intercept[IllegalArgumentException] {
Parameters.mergeParameters(params + ("tempformat" -> "invalid-temp-format"))
}
}
}

0 comments on commit 6cc49da

Please sign in to comment.