From 5b13760c01bdbb2ab15be072c582adb2a8792f23 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Tue, 21 Jan 2025 14:46:20 -0800 Subject: [PATCH] Spark 3.3: Backport support for default values (#11988) --- .../org/apache/iceberg/spark/SparkUtil.java | 66 +++ .../iceberg/spark/data/SparkAvroReader.java | 12 + .../spark/data/SparkParquetReaders.java | 26 +- .../spark/data/SparkPlannedAvroReader.java | 192 ++++++++ .../iceberg/spark/data/SparkValueReaders.java | 38 ++ .../VectorizedSparkParquetReaders.java | 9 +- .../iceberg/spark/source/BaseReader.java | 62 +-- .../iceberg/spark/source/BaseRowReader.java | 4 +- .../iceberg/spark/data/AvroDataTest.java | 329 +++++++++++++- .../iceberg/spark/data/TestHelpers.java | 110 ++++- .../spark/data/TestSparkAvroEnums.java | 2 +- .../spark/data/TestSparkAvroReader.java | 30 +- .../spark/data/TestSparkOrcReader.java | 15 +- .../spark/data/TestSparkParquetReader.java | 79 ++-- .../data/TestSparkRecordOrcReaderWriter.java | 32 +- ...rquetDictionaryEncodedVectorizedReads.java | 27 +- ...allbackToPlainEncodingVectorizedReads.java | 8 +- .../TestParquetVectorizedReads.java | 162 +++---- .../spark/source/DataFrameWriteTestBase.java | 140 ++++++ .../iceberg/spark/source/ScanTestBase.java | 126 ++++++ .../spark/source/TestAvroDataFrameWrite.java | 33 ++ .../iceberg/spark/source/TestAvroScan.java | 65 +-- .../spark/source/TestDataFrameWrites.java | 421 ------------------ .../spark/source/TestORCDataFrameWrite.java | 33 ++ .../source/TestParquetDataFrameWrite.java | 33 ++ .../iceberg/spark/source/TestParquetScan.java | 105 ++--- .../source/TestParquetVectorizedScan.java | 26 ++ 27 files changed, 1369 insertions(+), 816 deletions(-) create mode 100644 spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkPlannedAvroReader.java create mode 100644 spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/DataFrameWriteTestBase.java create mode 100644 spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/ScanTestBase.java create mode 100644 spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestAvroDataFrameWrite.java delete mode 100644 spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWrites.java create mode 100644 spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestORCDataFrameWrite.java create mode 100644 spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestParquetDataFrameWrite.java create mode 100644 spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestParquetVectorizedScan.java diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java index a1fc9b23b1d0..fb083f768615 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java @@ -18,6 +18,8 @@ */ package org.apache.iceberg.spark; +import java.math.BigDecimal; +import java.nio.ByteBuffer; import java.sql.Date; import java.sql.Timestamp; import java.util.List; @@ -25,28 +27,36 @@ import java.util.function.BiFunction; import java.util.function.Function; import java.util.stream.Collectors; +import org.apache.avro.generic.GenericData; +import org.apache.avro.util.Utf8; import org.apache.hadoop.conf.Configuration; import org.apache.iceberg.PartitionField; import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; import org.apache.iceberg.relocated.com.google.common.base.Joiner; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.transforms.Transform; import org.apache.iceberg.transforms.UnknownTransform; +import org.apache.iceberg.types.Type; import org.apache.iceberg.types.TypeUtil; import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.ByteBuffers; import org.apache.iceberg.util.Pair; import org.apache.spark.sql.RuntimeConfig; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.catalyst.expressions.BoundReference; import org.apache.spark.sql.catalyst.expressions.EqualTo; import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.catalyst.expressions.Literal; import org.apache.spark.sql.connector.expressions.NamedReference; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Decimal; import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; import org.joda.time.DateTime; public class SparkUtil { @@ -282,4 +292,60 @@ public static String toColumnName(NamedReference ref) { public static boolean caseSensitive(SparkSession spark) { return Boolean.parseBoolean(spark.conf().get("spark.sql.caseSensitive")); } + + /** + * Converts a value to pass into Spark from Iceberg's internal object model. + * + * @param type an Iceberg type + * @param value a value that is an instance of {@link Type.TypeID#javaClass()} + * @return the value converted for Spark + */ + public static Object internalToSpark(Type type, Object value) { + if (value == null) { + return null; + } + + switch (type.typeId()) { + case DECIMAL: + return Decimal.apply((BigDecimal) value); + case UUID: + case STRING: + if (value instanceof Utf8) { + Utf8 utf8 = (Utf8) value; + return UTF8String.fromBytes(utf8.getBytes(), 0, utf8.getByteLength()); + } + return UTF8String.fromString(value.toString()); + case FIXED: + if (value instanceof byte[]) { + return value; + } else if (value instanceof GenericData.Fixed) { + return ((GenericData.Fixed) value).bytes(); + } + return ByteBuffers.toByteArray((ByteBuffer) value); + case BINARY: + return ByteBuffers.toByteArray((ByteBuffer) value); + case STRUCT: + Types.StructType structType = (Types.StructType) type; + + if (structType.fields().isEmpty()) { + return new GenericInternalRow(); + } + + List fields = structType.fields(); + Object[] values = new Object[fields.size()]; + StructLike struct = (StructLike) value; + + for (int index = 0; index < fields.size(); index++) { + Types.NestedField field = fields.get(index); + Type fieldType = field.type(); + values[index] = + internalToSpark(fieldType, struct.get(index, fieldType.typeId().javaClass())); + } + + return new GenericInternalRow(values); + default: + } + + return value; + } } diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroReader.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroReader.java index 4622d2928ac4..7d92d963a9f4 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroReader.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroReader.java @@ -37,16 +37,28 @@ import org.apache.iceberg.types.Types; import org.apache.spark.sql.catalyst.InternalRow; +/** + * @deprecated will be removed in 1.8.0; use SparkPlannedAvroReader instead. + */ +@Deprecated public class SparkAvroReader implements DatumReader, SupportsRowPosition { private final Schema readSchema; private final ValueReader reader; private Schema fileSchema = null; + /** + * @deprecated will be removed in 1.8.0; use SparkPlannedAvroReader instead. + */ + @Deprecated public SparkAvroReader(org.apache.iceberg.Schema expectedSchema, Schema readSchema) { this(expectedSchema, readSchema, ImmutableMap.of()); } + /** + * @deprecated will be removed in 1.8.0; use SparkPlannedAvroReader instead. + */ + @Deprecated @SuppressWarnings("unchecked") public SparkAvroReader( org.apache.iceberg.Schema expectedSchema, Schema readSchema, Map constants) { diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetReaders.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetReaders.java index af16d9bbc290..3ce54d2d9ffa 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetReaders.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetReaders.java @@ -44,6 +44,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkUtil; import org.apache.iceberg.types.Type.TypeID; import org.apache.iceberg.types.Types; import org.apache.iceberg.util.UUIDUtil; @@ -165,6 +166,7 @@ public ParquetValueReader struct( int defaultMaxDefinitionLevel = type.getMaxDefinitionLevel(currentPath()); for (Types.NestedField field : expectedFields) { int id = field.fieldId(); + ParquetValueReader reader = readersById.get(id); if (idToConstant.containsKey(id)) { // containsKey is used because the constant may be null int fieldMaxDefinitionLevel = @@ -178,15 +180,21 @@ public ParquetValueReader struct( } else if (id == MetadataColumns.IS_DELETED.fieldId()) { reorderedFields.add(ParquetValueReaders.constant(false)); types.add(null); + } else if (reader != null) { + reorderedFields.add(reader); + types.add(typesById.get(id)); + } else if (field.initialDefault() != null) { + reorderedFields.add( + ParquetValueReaders.constant( + SparkUtil.internalToSpark(field.type(), field.initialDefault()), + maxDefinitionLevelsById.getOrDefault(id, defaultMaxDefinitionLevel))); + types.add(typesById.get(id)); + } else if (field.isOptional()) { + reorderedFields.add(ParquetValueReaders.nulls()); + types.add(null); } else { - ParquetValueReader reader = readersById.get(id); - if (reader != null) { - reorderedFields.add(reader); - types.add(typesById.get(id)); - } else { - reorderedFields.add(ParquetValueReaders.nulls()); - types.add(null); - } + throw new IllegalArgumentException( + String.format("Missing required field: %s", field.name())); } } @@ -250,7 +258,7 @@ public ParquetValueReader primitive( if (expected != null && expected.typeId() == Types.LongType.get().typeId()) { return new IntAsLongReader(desc); } else { - return new UnboxedReader(desc); + return new UnboxedReader<>(desc); } case DATE: case INT_64: diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkPlannedAvroReader.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkPlannedAvroReader.java new file mode 100644 index 000000000000..7bcd8881c10b --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkPlannedAvroReader.java @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; +import org.apache.avro.LogicalType; +import org.apache.avro.LogicalTypes; +import org.apache.avro.Schema; +import org.apache.avro.io.DatumReader; +import org.apache.avro.io.Decoder; +import org.apache.iceberg.avro.AvroWithPartnerVisitor; +import org.apache.iceberg.avro.SupportsRowPosition; +import org.apache.iceberg.avro.ValueReader; +import org.apache.iceberg.avro.ValueReaders; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkUtil; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.Pair; +import org.apache.spark.sql.catalyst.InternalRow; + +public class SparkPlannedAvroReader implements DatumReader, SupportsRowPosition { + + private final Types.StructType expectedType; + private final Map idToConstant; + private ValueReader reader; + + public static SparkPlannedAvroReader create(org.apache.iceberg.Schema schema) { + return create(schema, ImmutableMap.of()); + } + + public static SparkPlannedAvroReader create( + org.apache.iceberg.Schema schema, Map constants) { + return new SparkPlannedAvroReader(schema, constants); + } + + private SparkPlannedAvroReader( + org.apache.iceberg.Schema expectedSchema, Map constants) { + this.expectedType = expectedSchema.asStruct(); + this.idToConstant = constants; + } + + @Override + @SuppressWarnings("unchecked") + public void setSchema(Schema fileSchema) { + this.reader = + (ValueReader) + AvroWithPartnerVisitor.visit( + expectedType, + fileSchema, + new ReadBuilder(idToConstant), + AvroWithPartnerVisitor.FieldIDAccessors.get()); + } + + @Override + public InternalRow read(InternalRow reuse, Decoder decoder) throws IOException { + return reader.read(decoder, reuse); + } + + @Override + public void setRowPositionSupplier(Supplier posSupplier) { + if (reader instanceof SupportsRowPosition) { + ((SupportsRowPosition) reader).setRowPositionSupplier(posSupplier); + } + } + + private static class ReadBuilder extends AvroWithPartnerVisitor> { + private final Map idToConstant; + + private ReadBuilder(Map idToConstant) { + this.idToConstant = idToConstant; + } + + @Override + public ValueReader record(Type partner, Schema record, List> fieldReaders) { + if (partner == null) { + return ValueReaders.skipStruct(fieldReaders); + } + + Types.StructType expected = partner.asStructType(); + List>> readPlan = + ValueReaders.buildReadPlan( + expected, record, fieldReaders, idToConstant, SparkUtil::internalToSpark); + + // TODO: should this pass expected so that struct.get can reuse containers? + return SparkValueReaders.struct(readPlan, expected.fields().size()); + } + + @Override + public ValueReader union(Type partner, Schema union, List> options) { + return ValueReaders.union(options); + } + + @Override + public ValueReader array(Type partner, Schema array, ValueReader elementReader) { + return SparkValueReaders.array(elementReader); + } + + @Override + public ValueReader arrayMap( + Type partner, Schema map, ValueReader keyReader, ValueReader valueReader) { + return SparkValueReaders.arrayMap(keyReader, valueReader); + } + + @Override + public ValueReader map(Type partner, Schema map, ValueReader valueReader) { + return SparkValueReaders.map(SparkValueReaders.strings(), valueReader); + } + + @Override + public ValueReader primitive(Type partner, Schema primitive) { + LogicalType logicalType = primitive.getLogicalType(); + if (logicalType != null) { + switch (logicalType.getName()) { + case "date": + // Spark uses the same representation + return ValueReaders.ints(); + + case "timestamp-millis": + // adjust to microseconds + ValueReader longs = ValueReaders.longs(); + return (ValueReader) (decoder, ignored) -> longs.read(decoder, null) * 1000L; + + case "timestamp-micros": + // Spark uses the same representation + return ValueReaders.longs(); + + case "decimal": + return SparkValueReaders.decimal( + ValueReaders.decimalBytesReader(primitive), + ((LogicalTypes.Decimal) logicalType).getScale()); + + case "uuid": + return SparkValueReaders.uuids(); + + default: + throw new IllegalArgumentException("Unknown logical type: " + logicalType); + } + } + + switch (primitive.getType()) { + case NULL: + return ValueReaders.nulls(); + case BOOLEAN: + return ValueReaders.booleans(); + case INT: + if (partner != null && partner.typeId() == Type.TypeID.LONG) { + return ValueReaders.intsAsLongs(); + } + return ValueReaders.ints(); + case LONG: + return ValueReaders.longs(); + case FLOAT: + if (partner != null && partner.typeId() == Type.TypeID.DOUBLE) { + return ValueReaders.floatsAsDoubles(); + } + return ValueReaders.floats(); + case DOUBLE: + return ValueReaders.doubles(); + case STRING: + return SparkValueReaders.strings(); + case FIXED: + return ValueReaders.fixed(primitive.getFixedSize()); + case BYTES: + return ValueReaders.bytes(); + case ENUM: + return SparkValueReaders.enums(primitive.getEnumSymbols()); + default: + throw new IllegalArgumentException("Unsupported type: " + primitive); + } + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java index 3cbf38d88bf4..7e65535f5ecb 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java @@ -32,6 +32,7 @@ import org.apache.iceberg.avro.ValueReaders; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.Pair; import org.apache.iceberg.util.UUIDUtil; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; @@ -74,6 +75,11 @@ static ValueReader map(ValueReader keyReader, ValueReader< return new MapReader(keyReader, valueReader); } + static ValueReader struct( + List>> readPlan, int numFields) { + return new PlannedStructReader(readPlan, numFields); + } + static ValueReader struct( List> readers, Types.StructType struct, Map idToConstant) { return new StructReader(readers, struct, idToConstant); @@ -249,6 +255,38 @@ public ArrayBasedMapData read(Decoder decoder, Object reuse) throws IOException } } + static class PlannedStructReader extends ValueReaders.PlannedStructReader { + private final int numFields; + + protected PlannedStructReader(List>> readPlan, int numFields) { + super(readPlan); + this.numFields = numFields; + } + + @Override + protected InternalRow reuseOrCreate(Object reuse) { + if (reuse instanceof GenericInternalRow + && ((GenericInternalRow) reuse).numFields() == numFields) { + return (InternalRow) reuse; + } + return new GenericInternalRow(numFields); + } + + @Override + protected Object get(InternalRow struct, int pos) { + return null; + } + + @Override + protected void set(InternalRow struct, int pos, Object value) { + if (value != null) { + struct.update(pos, value); + } else { + struct.setNullAt(pos); + } + } + } + static class StructReader extends ValueReaders.StructReader { private final int numFields; diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkParquetReaders.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkParquetReaders.java index e47152c79398..636ad3be7dcc 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkParquetReaders.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkParquetReaders.java @@ -27,6 +27,7 @@ import org.apache.iceberg.data.DeleteFilter; import org.apache.iceberg.parquet.TypeWithSchemaVisitor; import org.apache.iceberg.parquet.VectorizedReader; +import org.apache.iceberg.spark.SparkUtil; import org.apache.parquet.schema.MessageType; import org.apache.spark.sql.catalyst.InternalRow; import org.slf4j.Logger; @@ -112,7 +113,13 @@ private static class ReaderBuilder extends VectorizedReaderBuilder { Map idToConstant, Function>, VectorizedReader> readerFactory, DeleteFilter deleteFilter) { - super(expectedSchema, parquetSchema, setArrowValidityVector, idToConstant, readerFactory); + super( + expectedSchema, + parquetSchema, + setArrowValidityVector, + idToConstant, + readerFactory, + SparkUtil::internalToSpark); this.deleteFilter = deleteFilter; } diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/BaseReader.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/BaseReader.java index 40d907e12c08..70b37558fa54 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/BaseReader.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/BaseReader.java @@ -20,15 +20,11 @@ import java.io.Closeable; import java.io.IOException; -import java.math.BigDecimal; -import java.nio.ByteBuffer; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.stream.Collectors; import java.util.stream.Stream; -import org.apache.avro.generic.GenericData; -import org.apache.avro.util.Utf8; import org.apache.iceberg.ContentFile; import org.apache.iceberg.ContentScanTask; import org.apache.iceberg.DeleteFile; @@ -51,16 +47,11 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.spark.SparkSchemaUtil; -import org.apache.iceberg.types.Type; -import org.apache.iceberg.types.Types.NestedField; +import org.apache.iceberg.spark.SparkUtil; import org.apache.iceberg.types.Types.StructType; -import org.apache.iceberg.util.ByteBuffers; import org.apache.iceberg.util.PartitionUtil; import org.apache.spark.rdd.InputFileBlockHolder; import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.types.Decimal; -import org.apache.spark.unsafe.types.UTF8String; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -201,59 +192,12 @@ private EncryptedInputFile toEncryptedInputFile(ContentFile file) { protected Map constantsMap(ContentScanTask task, Schema readSchema) { if (readSchema.findField(MetadataColumns.PARTITION_COLUMN_ID) != null) { StructType partitionType = Partitioning.partitionType(table); - return PartitionUtil.constantsMap(task, partitionType, BaseReader::convertConstant); + return PartitionUtil.constantsMap(task, partitionType, SparkUtil::internalToSpark); } else { - return PartitionUtil.constantsMap(task, BaseReader::convertConstant); + return PartitionUtil.constantsMap(task, SparkUtil::internalToSpark); } } - protected static Object convertConstant(Type type, Object value) { - if (value == null) { - return null; - } - - switch (type.typeId()) { - case DECIMAL: - return Decimal.apply((BigDecimal) value); - case STRING: - if (value instanceof Utf8) { - Utf8 utf8 = (Utf8) value; - return UTF8String.fromBytes(utf8.getBytes(), 0, utf8.getByteLength()); - } - return UTF8String.fromString(value.toString()); - case FIXED: - if (value instanceof byte[]) { - return value; - } else if (value instanceof GenericData.Fixed) { - return ((GenericData.Fixed) value).bytes(); - } - return ByteBuffers.toByteArray((ByteBuffer) value); - case BINARY: - return ByteBuffers.toByteArray((ByteBuffer) value); - case STRUCT: - StructType structType = (StructType) type; - - if (structType.fields().isEmpty()) { - return new GenericInternalRow(); - } - - List fields = structType.fields(); - Object[] values = new Object[fields.size()]; - StructLike struct = (StructLike) value; - - for (int index = 0; index < fields.size(); index++) { - NestedField field = fields.get(index); - Type fieldType = field.type(); - values[index] = - convertConstant(fieldType, struct.get(index, fieldType.typeId().javaClass())); - } - - return new GenericInternalRow(values); - default: - } - return value; - } - protected class SparkDeleteFilter extends DeleteFilter { private final InternalRowWrapper asStructLike; diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/BaseRowReader.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/BaseRowReader.java index 927084caea1c..2d51992dd96a 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/BaseRowReader.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/BaseRowReader.java @@ -32,9 +32,9 @@ import org.apache.iceberg.orc.ORC; import org.apache.iceberg.parquet.Parquet; import org.apache.iceberg.relocated.com.google.common.collect.Sets; -import org.apache.iceberg.spark.data.SparkAvroReader; import org.apache.iceberg.spark.data.SparkOrcReader; import org.apache.iceberg.spark.data.SparkParquetReaders; +import org.apache.iceberg.spark.data.SparkPlannedAvroReader; import org.apache.iceberg.types.TypeUtil; import org.apache.spark.sql.catalyst.InternalRow; @@ -77,7 +77,7 @@ private CloseableIterable newAvroIterable( .reuseContainers() .project(projection) .split(start, length) - .createReaderFunc(readSchema -> new SparkAvroReader(projection, readSchema, idToConstant)) + .createResolvingReader(schema -> SparkPlannedAvroReader.create(schema, idToConstant)) .withNameMapping(nameMapping()) .build(); } diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/AvroDataTest.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/AvroDataTest.java index db0d7336f161..d7ecef758c47 100644 --- a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/AvroDataTest.java +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/AvroDataTest.java @@ -20,30 +20,56 @@ import static org.apache.iceberg.types.Types.NestedField.optional; import static org.apache.iceberg.types.Types.NestedField.required; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import java.io.IOException; +import java.io.UnsupportedEncodingException; +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.nio.file.Path; import java.util.Map; +import java.util.UUID; import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Stream; import org.apache.iceberg.Schema; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.types.Type; import org.apache.iceberg.types.TypeUtil; import org.apache.iceberg.types.Types; import org.apache.iceberg.types.Types.ListType; import org.apache.iceberg.types.Types.LongType; import org.apache.iceberg.types.Types.MapType; import org.apache.iceberg.types.Types.StructType; +import org.apache.iceberg.util.DateTimeUtil; import org.apache.spark.sql.internal.SQLConf; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.assertj.core.api.Assumptions; +import org.assertj.core.api.Condition; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; public abstract class AvroDataTest { protected abstract void writeAndValidate(Schema schema) throws IOException; + protected void writeAndValidate(Schema writeSchema, Schema expectedSchema) throws IOException { + throw new UnsupportedEncodingException( + "Cannot run test, writeAndValidate(Schema, Schema) is not implemented"); + } + + protected boolean supportsDefaultValues() { + return false; + } + + protected boolean supportsNestedTypes() { + return true; + } + protected static final StructType SUPPORTED_PRIMITIVES = StructType.of( required(100, "id", LongType.get()), @@ -65,7 +91,7 @@ public abstract class AvroDataTest { required(117, "dec_38_10", Types.DecimalType.of(38, 10)) // Spark's maximum precision ); - @Rule public TemporaryFolder temp = new TemporaryFolder(); + @TempDir protected Path temp; @Test public void testSimpleStruct() throws IOException { @@ -90,12 +116,16 @@ public void testStructWithOptionalFields() throws IOException { @Test public void testNestedStruct() throws IOException { + Assumptions.assumeThat(supportsNestedTypes()).isTrue(); + writeAndValidate( TypeUtil.assignIncreasingFreshIds(new Schema(required(1, "struct", SUPPORTED_PRIMITIVES)))); } @Test public void testArray() throws IOException { + Assumptions.assumeThat(supportsNestedTypes()).isTrue(); + Schema schema = new Schema( required(0, "id", LongType.get()), @@ -106,6 +136,8 @@ public void testArray() throws IOException { @Test public void testArrayOfStructs() throws IOException { + Assumptions.assumeThat(supportsNestedTypes()).isTrue(); + Schema schema = TypeUtil.assignIncreasingFreshIds( new Schema( @@ -117,6 +149,8 @@ public void testArrayOfStructs() throws IOException { @Test public void testMap() throws IOException { + Assumptions.assumeThat(supportsNestedTypes()).isTrue(); + Schema schema = new Schema( required(0, "id", LongType.get()), @@ -130,6 +164,8 @@ public void testMap() throws IOException { @Test public void testNumericMapKey() throws IOException { + Assumptions.assumeThat(supportsNestedTypes()).isTrue(); + Schema schema = new Schema( required(0, "id", LongType.get()), @@ -141,6 +177,8 @@ public void testNumericMapKey() throws IOException { @Test public void testComplexMapKey() throws IOException { + Assumptions.assumeThat(supportsNestedTypes()).isTrue(); + Schema schema = new Schema( required(0, "id", LongType.get()), @@ -160,6 +198,8 @@ public void testComplexMapKey() throws IOException { @Test public void testMapOfStructs() throws IOException { + Assumptions.assumeThat(supportsNestedTypes()).isTrue(); + Schema schema = TypeUtil.assignIncreasingFreshIds( new Schema( @@ -174,6 +214,8 @@ public void testMapOfStructs() throws IOException { @Test public void testMixedTypes() throws IOException { + Assumptions.assumeThat(supportsNestedTypes()).isTrue(); + StructType structType = StructType.of( required(0, "id", LongType.get()), @@ -244,6 +286,285 @@ public void testTimestampWithoutZone() throws IOException { }); } + @Test + public void testMissingRequiredWithoutDefault() { + Assumptions.assumeThat(supportsDefaultValues()).isTrue(); + + Schema writeSchema = new Schema(required(1, "id", Types.LongType.get())); + + Schema expectedSchema = + new Schema( + required(1, "id", Types.LongType.get()), + Types.NestedField.required("missing_str") + .withId(6) + .ofType(Types.StringType.get()) + .withDoc("Missing required field with no default") + .build()); + + assertThatThrownBy(() -> writeAndValidate(writeSchema, expectedSchema)) + .has( + new Condition<>( + t -> + IllegalArgumentException.class.isInstance(t) + || IllegalArgumentException.class.isInstance(t.getCause()), + "Expecting a throwable or cause that is an instance of IllegalArgumentException")) + .hasMessageContaining("Missing required field: missing_str"); + } + + @Test + public void testDefaultValues() throws IOException { + Assumptions.assumeThat(supportsDefaultValues()).isTrue(); + + Schema writeSchema = + new Schema( + required(1, "id", Types.LongType.get()), + Types.NestedField.optional("data") + .withId(2) + .ofType(Types.StringType.get()) + .withInitialDefault("wrong!") + .withDoc("Should not produce default value") + .build()); + + Schema expectedSchema = + new Schema( + required(1, "id", Types.LongType.get()), + Types.NestedField.optional("data") + .withId(2) + .ofType(Types.StringType.get()) + .withInitialDefault("wrong!") + .build(), + Types.NestedField.required("missing_str") + .withId(6) + .ofType(Types.StringType.get()) + .withInitialDefault("orange") + .build(), + Types.NestedField.optional("missing_int") + .withId(7) + .ofType(Types.IntegerType.get()) + .withInitialDefault(34) + .build()); + + writeAndValidate(writeSchema, expectedSchema); + } + + @Test + public void testNullDefaultValue() throws IOException { + Assumptions.assumeThat(supportsDefaultValues()).isTrue(); + + Schema writeSchema = + new Schema( + required(1, "id", Types.LongType.get()), + Types.NestedField.optional("data") + .withId(2) + .ofType(Types.StringType.get()) + .withInitialDefault("wrong!") + .withDoc("Should not produce default value") + .build()); + + Schema expectedSchema = + new Schema( + required(1, "id", Types.LongType.get()), + Types.NestedField.optional("data") + .withId(2) + .ofType(Types.StringType.get()) + .withInitialDefault("wrong!") + .build(), + Types.NestedField.optional("missing_date") + .withId(3) + .ofType(Types.DateType.get()) + .build()); + + writeAndValidate(writeSchema, expectedSchema); + } + + @Test + public void testNestedDefaultValue() throws IOException { + Assumptions.assumeThat(supportsDefaultValues()).isTrue(); + Assumptions.assumeThat(supportsNestedTypes()).isTrue(); + + Schema writeSchema = + new Schema( + required(1, "id", Types.LongType.get()), + Types.NestedField.optional("data") + .withId(2) + .ofType(Types.StringType.get()) + .withInitialDefault("wrong!") + .withDoc("Should not produce default value") + .build(), + Types.NestedField.optional("nested") + .withId(3) + .ofType(Types.StructType.of(required(4, "inner", Types.StringType.get()))) + .withDoc("Used to test nested field defaults") + .build()); + + Schema expectedSchema = + new Schema( + required(1, "id", Types.LongType.get()), + Types.NestedField.optional("data") + .withId(2) + .ofType(Types.StringType.get()) + .withInitialDefault("wrong!") + .build(), + Types.NestedField.optional("nested") + .withId(3) + .ofType( + Types.StructType.of( + required(4, "inner", Types.StringType.get()), + Types.NestedField.optional("missing_inner_float") + .withId(5) + .ofType(Types.FloatType.get()) + .withInitialDefault(-0.0F) + .build())) + .withDoc("Used to test nested field defaults") + .build()); + + writeAndValidate(writeSchema, expectedSchema); + } + + @Test + public void testMapNestedDefaultValue() throws IOException { + Assumptions.assumeThat(supportsDefaultValues()).isTrue(); + Assumptions.assumeThat(supportsNestedTypes()).isTrue(); + + Schema writeSchema = + new Schema( + required(1, "id", Types.LongType.get()), + Types.NestedField.optional("data") + .withId(2) + .ofType(Types.StringType.get()) + .withInitialDefault("wrong!") + .withDoc("Should not produce default value") + .build(), + Types.NestedField.optional("nested_map") + .withId(3) + .ofType( + Types.MapType.ofOptional( + 4, + 5, + Types.StringType.get(), + Types.StructType.of(required(6, "value_str", Types.StringType.get())))) + .withDoc("Used to test nested map value field defaults") + .build()); + + Schema expectedSchema = + new Schema( + required(1, "id", Types.LongType.get()), + Types.NestedField.optional("data") + .withId(2) + .ofType(Types.StringType.get()) + .withInitialDefault("wrong!") + .build(), + Types.NestedField.optional("nested_map") + .withId(3) + .ofType( + Types.MapType.ofOptional( + 4, + 5, + Types.StringType.get(), + Types.StructType.of( + required(6, "value_str", Types.StringType.get()), + Types.NestedField.optional("value_int") + .withId(7) + .ofType(Types.IntegerType.get()) + .withInitialDefault(34) + .build()))) + .withDoc("Used to test nested field defaults") + .build()); + + writeAndValidate(writeSchema, expectedSchema); + } + + @Test + public void testListNestedDefaultValue() throws IOException { + Assumptions.assumeThat(supportsDefaultValues()).isTrue(); + Assumptions.assumeThat(supportsNestedTypes()).isTrue(); + + Schema writeSchema = + new Schema( + required(1, "id", Types.LongType.get()), + Types.NestedField.optional("data") + .withId(2) + .ofType(Types.StringType.get()) + .withInitialDefault("wrong!") + .withDoc("Should not produce default value") + .build(), + Types.NestedField.optional("nested_list") + .withId(3) + .ofType( + Types.ListType.ofOptional( + 4, Types.StructType.of(required(5, "element_str", Types.StringType.get())))) + .withDoc("Used to test nested field defaults") + .build()); + + Schema expectedSchema = + new Schema( + required(1, "id", Types.LongType.get()), + Types.NestedField.optional("data") + .withId(2) + .ofType(Types.StringType.get()) + .withInitialDefault("wrong!") + .build(), + Types.NestedField.optional("nested_list") + .withId(3) + .ofType( + Types.ListType.ofOptional( + 4, + Types.StructType.of( + required(5, "element_str", Types.StringType.get()), + Types.NestedField.optional("element_int") + .withId(7) + .ofType(Types.IntegerType.get()) + .withInitialDefault(34) + .build()))) + .withDoc("Used to test nested field defaults") + .build()); + + writeAndValidate(writeSchema, expectedSchema); + } + + private static Stream primitiveTypesAndDefaults() { + return Stream.of( + Arguments.of(Types.BooleanType.get(), false), + Arguments.of(Types.IntegerType.get(), 34), + Arguments.of(Types.LongType.get(), 4900000000L), + Arguments.of(Types.FloatType.get(), 12.21F), + Arguments.of(Types.DoubleType.get(), -0.0D), + Arguments.of(Types.DateType.get(), DateTimeUtil.isoDateToDays("2024-12-17")), + // Arguments.of(Types.TimeType.get(), DateTimeUtil.isoTimeToMicros("23:59:59.999999")), + Arguments.of( + Types.TimestampType.withZone(), + DateTimeUtil.isoTimestamptzToMicros("2024-12-17T23:59:59.999999+00:00")), + // Arguments.of( + // Types.TimestampType.withoutZone(), + // DateTimeUtil.isoTimestampToMicros("2024-12-17T23:59:59.999999")), + Arguments.of(Types.StringType.get(), "iceberg"), + Arguments.of(Types.UUIDType.get(), UUID.randomUUID()), + Arguments.of( + Types.FixedType.ofLength(4), ByteBuffer.wrap(new byte[] {0x0a, 0x0b, 0x0c, 0x0d})), + Arguments.of(Types.BinaryType.get(), ByteBuffer.wrap(new byte[] {0x0a, 0x0b})), + Arguments.of(Types.DecimalType.of(9, 2), new BigDecimal("12.34"))); + } + + @ParameterizedTest + @MethodSource("primitiveTypesAndDefaults") + public void testPrimitiveTypeDefaultValues(Type.PrimitiveType type, Object defaultValue) + throws IOException { + Assumptions.assumeThat(supportsDefaultValues()).isTrue(); + + Schema writeSchema = new Schema(required(1, "id", Types.LongType.get())); + + Schema readSchema = + new Schema( + required(1, "id", Types.LongType.get()), + Types.NestedField.optional("col_with_default") + .withId(2) + .ofType(type) + .withInitialDefault(defaultValue) + .build()); + + writeAndValidate(writeSchema, readSchema); + } + protected void withSQLConf(Map conf, Action action) throws IOException { SQLConf sqlConf = SQLConf.get(); diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java index 4bd3531ab954..b3313412d6c9 100644 --- a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java @@ -41,6 +41,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; import java.util.stream.StreamSupport; +import org.apache.avro.Schema.Field; import org.apache.avro.generic.GenericData; import org.apache.avro.generic.GenericData.Record; import org.apache.iceberg.DataFile; @@ -58,6 +59,7 @@ import org.apache.iceberg.spark.SparkSchemaUtil; import org.apache.iceberg.types.Type; import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.ByteBuffers; import org.apache.orc.storage.serde2.io.DateWritable; import org.apache.spark.sql.Column; import org.apache.spark.sql.Dataset; @@ -75,6 +77,8 @@ import org.apache.spark.sql.types.MapType; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.TimestampNTZType; +import org.apache.spark.sql.types.TimestampType$; import org.apache.spark.sql.vectorized.ColumnarBatch; import org.apache.spark.unsafe.types.UTF8String; import org.junit.Assert; @@ -91,11 +95,20 @@ public static void assertEqualsSafe(Types.StructType struct, List recs, public static void assertEqualsSafe(Types.StructType struct, Record rec, Row row) { List fields = struct.fields(); - for (int i = 0; i < fields.size(); i += 1) { - Type fieldType = fields.get(i).type(); - - Object expectedValue = rec.get(i); - Object actualValue = row.get(i); + for (int readPos = 0; readPos < fields.size(); readPos += 1) { + Types.NestedField field = fields.get(readPos); + Field writeField = rec.getSchema().getField(field.name()); + + Type fieldType = field.type(); + Object actualValue = row.get(readPos); + + Object expectedValue; + if (writeField != null) { + int writePos = writeField.pos(); + expectedValue = rec.get(writePos); + } else { + expectedValue = field.initialDefault(); + } assertEqualsSafe(fieldType, expectedValue, actualValue); } @@ -104,13 +117,25 @@ public static void assertEqualsSafe(Types.StructType struct, Record rec, Row row public static void assertEqualsBatch( Types.StructType struct, Iterator expected, ColumnarBatch batch) { for (int rowId = 0; rowId < batch.numRows(); rowId++) { - List fields = struct.fields(); InternalRow row = batch.getRow(rowId); Record rec = expected.next(); - for (int i = 0; i < fields.size(); i += 1) { - Type fieldType = fields.get(i).type(); - Object expectedValue = rec.get(i); - Object actualValue = row.isNullAt(i) ? null : row.get(i, convert(fieldType)); + + List fields = struct.fields(); + for (int readPos = 0; readPos < fields.size(); readPos += 1) { + Types.NestedField field = fields.get(readPos); + Field writeField = rec.getSchema().getField(field.name()); + + Type fieldType = field.type(); + Object actualValue = row.isNullAt(readPos) ? null : row.get(readPos, convert(fieldType)); + + Object expectedValue; + if (writeField != null) { + int writePos = writeField.pos(); + expectedValue = rec.get(writePos); + } else { + expectedValue = field.initialDefault(); + } + assertEqualsUnsafe(fieldType, expectedValue, actualValue); } } @@ -189,10 +214,21 @@ private static void assertEqualsSafe(Type type, Object expected, Object actual) Assert.assertEquals("UUID string representation should match", expected.toString(), actual); break; case FIXED: - assertThat(expected).as("Should expect a Fixed").isInstanceOf(GenericData.Fixed.class); + // generated data is written using Avro or Parquet/Avro so generated rows use + // GenericData.Fixed, but default values are converted from Iceberg's internal + // representation so the expected value may be either class. + byte[] expectedBytes; + if (expected instanceof ByteBuffer) { + expectedBytes = ByteBuffers.toByteArray((ByteBuffer) expected); + } else if (expected instanceof GenericData.Fixed) { + expectedBytes = ((GenericData.Fixed) expected).bytes(); + } else { + throw new IllegalStateException( + "Invalid expected value, not byte[] or Fixed: " + expected); + } + assertThat(actual).as("Should be a byte[]").isInstanceOf(byte[].class); - Assert.assertArrayEquals( - "Bytes should match", ((GenericData.Fixed) expected).bytes(), (byte[]) actual); + assertThat(actual).as("Bytes should match").isEqualTo(expectedBytes); break; case BINARY: assertThat(expected).as("Should expect a ByteBuffer").isInstanceOf(ByteBuffer.class); @@ -214,7 +250,7 @@ private static void assertEqualsSafe(Type type, Object expected, Object actual) assertThat(expected).as("Should expect a Collection").isInstanceOf(Collection.class); assertThat(actual).as("Should be a Seq").isInstanceOf(Seq.class); List asList = seqAsJavaListConverter((Seq) actual).asJava(); - assertEqualsSafe(type.asNestedType().asListType(), (Collection) expected, asList); + assertEqualsSafe(type.asNestedType().asListType(), (Collection) expected, asList); break; case MAP: assertThat(expected).as("Should expect a Collection").isInstanceOf(Map.class); @@ -231,11 +267,20 @@ private static void assertEqualsSafe(Type type, Object expected, Object actual) public static void assertEqualsUnsafe(Types.StructType struct, Record rec, InternalRow row) { List fields = struct.fields(); - for (int i = 0; i < fields.size(); i += 1) { - Type fieldType = fields.get(i).type(); - - Object expectedValue = rec.get(i); - Object actualValue = row.isNullAt(i) ? null : row.get(i, convert(fieldType)); + for (int readPos = 0; readPos < fields.size(); readPos += 1) { + Types.NestedField field = fields.get(readPos); + Field writeField = rec.getSchema().getField(field.name()); + + Type fieldType = field.type(); + Object actualValue = row.isNullAt(readPos) ? null : row.get(readPos, convert(fieldType)); + + Object expectedValue; + if (writeField != null) { + int writePos = writeField.pos(); + expectedValue = rec.get(writePos); + } else { + expectedValue = field.initialDefault(); + } assertEqualsUnsafe(fieldType, expectedValue, actualValue); } @@ -314,10 +359,21 @@ private static void assertEqualsUnsafe(Type type, Object expected, Object actual "UUID string representation should match", expected.toString(), actual.toString()); break; case FIXED: - assertThat(expected).as("Should expect a Fixed").isInstanceOf(GenericData.Fixed.class); + // generated data is written using Avro or Parquet/Avro so generated rows use + // GenericData.Fixed, but default values are converted from Iceberg's internal + // representation so the expected value may be either class. + byte[] expectedBytes; + if (expected instanceof ByteBuffer) { + expectedBytes = ByteBuffers.toByteArray((ByteBuffer) expected); + } else if (expected instanceof GenericData.Fixed) { + expectedBytes = ((GenericData.Fixed) expected).bytes(); + } else { + throw new IllegalStateException( + "Invalid expected value, not byte[] or Fixed: " + expected); + } + assertThat(actual).as("Should be a byte[]").isInstanceOf(byte[].class); - Assert.assertArrayEquals( - "Bytes should match", ((GenericData.Fixed) expected).bytes(), (byte[]) actual); + assertThat(actual).as("Bytes should match").isEqualTo(expectedBytes); break; case BINARY: assertThat(expected).as("Should expect a ByteBuffer").isInstanceOf(ByteBuffer.class); @@ -341,12 +397,12 @@ private static void assertEqualsUnsafe(Type type, Object expected, Object actual assertThat(expected).as("Should expect a Collection").isInstanceOf(Collection.class); assertThat(actual).as("Should be an ArrayData").isInstanceOf(ArrayData.class); assertEqualsUnsafe( - type.asNestedType().asListType(), (Collection) expected, (ArrayData) actual); + type.asNestedType().asListType(), (Collection) expected, (ArrayData) actual); break; case MAP: assertThat(expected).as("Should expect a Map").isInstanceOf(Map.class); assertThat(actual).as("Should be an ArrayBasedMapData").isInstanceOf(MapData.class); - assertEqualsUnsafe(type.asNestedType().asMapType(), (Map) expected, (MapData) actual); + assertEqualsUnsafe(type.asNestedType().asMapType(), (Map) expected, (MapData) actual); break; case TIME: default: @@ -703,6 +759,12 @@ private static void assertEquals( for (int i = 0; i < actual.numFields(); i += 1) { StructField field = struct.fields()[i]; DataType type = field.dataType(); + // ColumnarRow.get doesn't support TimestampNTZType, causing tests to fail. the representation + // is identical to TimestampType so this uses that type to validate. + if (type instanceof TimestampNTZType) { + type = TimestampType$.MODULE$; + } + assertEquals( context + "." + field.name(), type, diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroEnums.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroEnums.java index 6f05a9ed7c1f..1f4e798a4ae7 100644 --- a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroEnums.java +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroEnums.java @@ -78,7 +78,7 @@ public void writeAndValidateEnums() throws IOException { List rows; try (AvroIterable reader = Avro.read(Files.localInput(testFile)) - .createReaderFunc(SparkAvroReader::new) + .createResolvingReader(SparkPlannedAvroReader::create) .project(schema) .build()) { rows = Lists.newArrayList(reader); diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroReader.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroReader.java index 6d1ef3db3657..922af5b26a89 100644 --- a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroReader.java +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroReader.java @@ -20,29 +20,32 @@ import static org.apache.iceberg.spark.data.TestHelpers.assertEqualsUnsafe; -import java.io.File; import java.io.IOException; import java.util.List; import org.apache.avro.generic.GenericData.Record; -import org.apache.iceberg.Files; import org.apache.iceberg.Schema; import org.apache.iceberg.avro.Avro; import org.apache.iceberg.avro.AvroIterable; +import org.apache.iceberg.inmemory.InMemoryOutputFile; import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.io.OutputFile; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.spark.sql.catalyst.InternalRow; -import org.junit.Assert; public class TestSparkAvroReader extends AvroDataTest { @Override protected void writeAndValidate(Schema schema) throws IOException { - List expected = RandomData.generateList(schema, 100, 0L); + writeAndValidate(schema, schema); + } + + @Override + protected void writeAndValidate(Schema writeSchema, Schema expectedSchema) throws IOException { + List expected = RandomData.generateList(writeSchema, 100, 0L); - File testFile = temp.newFile(); - Assert.assertTrue("Delete should succeed", testFile.delete()); + OutputFile outputFile = new InMemoryOutputFile(); try (FileAppender writer = - Avro.write(Files.localOutput(testFile)).schema(schema).named("test").build()) { + Avro.write(outputFile).schema(writeSchema).named("test").build()) { for (Record rec : expected) { writer.add(rec); } @@ -50,15 +53,20 @@ protected void writeAndValidate(Schema schema) throws IOException { List rows; try (AvroIterable reader = - Avro.read(Files.localInput(testFile)) - .createReaderFunc(SparkAvroReader::new) - .project(schema) + Avro.read(outputFile.toInputFile()) + .createResolvingReader(SparkPlannedAvroReader::create) + .project(expectedSchema) .build()) { rows = Lists.newArrayList(reader); } for (int i = 0; i < expected.size(); i += 1) { - assertEqualsUnsafe(schema.asStruct(), expected.get(i), rows.get(i)); + assertEqualsUnsafe(expectedSchema.asStruct(), expected.get(i), rows.get(i)); } } + + @Override + protected boolean supportsDefaultValues() { + return true; + } } diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReader.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReader.java index b23fe729a187..a032e504c2b7 100644 --- a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReader.java +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReader.java @@ -20,6 +20,7 @@ import static org.apache.iceberg.spark.data.TestHelpers.assertEquals; import static org.apache.iceberg.types.Types.NestedField.required; +import static org.assertj.core.api.Assertions.assertThat; import java.io.File; import java.io.IOException; @@ -37,8 +38,7 @@ import org.apache.iceberg.types.Types; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.vectorized.ColumnarBatch; -import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Test; public class TestSparkOrcReader extends AvroDataTest { @Override @@ -62,8 +62,7 @@ public void writeAndValidateRepeatingRecords() throws IOException { private void writeAndValidateRecords(Schema schema, Iterable expected) throws IOException { - final File testFile = temp.newFile(); - Assert.assertTrue("Delete should succeed", testFile.delete()); + final File testFile = temp.resolve("test").toFile(); try (FileAppender writer = ORC.write(Files.localOutput(testFile)) @@ -81,10 +80,10 @@ private void writeAndValidateRecords(Schema schema, Iterable expect final Iterator actualRows = reader.iterator(); final Iterator expectedRows = expected.iterator(); while (expectedRows.hasNext()) { - Assert.assertTrue("Should have expected number of rows", actualRows.hasNext()); + assertThat(actualRows.hasNext()).as("Should have expected number of rows").isTrue(); assertEquals(schema, expectedRows.next(), actualRows.next()); } - Assert.assertFalse("Should not have extra rows", actualRows.hasNext()); + assertThat(actualRows.hasNext()).as("Should not have extra rows").isFalse(); } try (CloseableIterable reader = @@ -97,10 +96,10 @@ private void writeAndValidateRecords(Schema schema, Iterable expect final Iterator actualRows = batchesToRows(reader.iterator()); final Iterator expectedRows = expected.iterator(); while (expectedRows.hasNext()) { - Assert.assertTrue("Should have expected number of rows", actualRows.hasNext()); + assertThat(actualRows.hasNext()).as("Should have expected number of rows").isTrue(); assertEquals(schema, expectedRows.next(), actualRows.next()); } - Assert.assertFalse("Should not have extra rows", actualRows.hasNext()); + assertThat(actualRows.hasNext()).as("Should not have extra rows").isFalse(); } } diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetReader.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetReader.java index 77ad638542f5..ca2d010b7537 100644 --- a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetReader.java +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetReader.java @@ -21,8 +21,9 @@ import static org.apache.iceberg.spark.data.TestHelpers.assertEqualsUnsafe; import static org.apache.iceberg.types.Types.NestedField.required; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assumptions.assumeThat; -import java.io.File; import java.io.IOException; import java.util.Iterator; import java.util.List; @@ -39,9 +40,11 @@ import org.apache.iceberg.data.IcebergGenerics; import org.apache.iceberg.data.Record; import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.inmemory.InMemoryOutputFile; import org.apache.iceberg.io.CloseableIterable; import org.apache.iceberg.io.FileAppender; import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.io.OutputFile; import org.apache.iceberg.parquet.Parquet; import org.apache.iceberg.parquet.ParquetUtil; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; @@ -57,44 +60,53 @@ import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -import org.junit.Assert; -import org.junit.Assume; -import org.junit.Test; +import org.junit.jupiter.api.Test; public class TestSparkParquetReader extends AvroDataTest { @Override protected void writeAndValidate(Schema schema) throws IOException { - Assume.assumeTrue( - "Parquet Avro cannot write non-string map keys", - null - == TypeUtil.find( - schema, - type -> type.isMapType() && type.asMapType().keyType() != Types.StringType.get())); + writeAndValidate(schema, schema); + } + + @Override + protected void writeAndValidate(Schema writeSchema, Schema expectedSchema) throws IOException { + assumeThat( + null + == TypeUtil.find( + writeSchema, + type -> + type.isMapType() && type.asMapType().keyType() != Types.StringType.get())) + .as("Parquet Avro cannot write non-string map keys") + .isTrue(); - List expected = RandomData.generateList(schema, 100, 0L); + List expected = RandomData.generateList(writeSchema, 100, 0L); - File testFile = temp.newFile(); - Assert.assertTrue("Delete should succeed", testFile.delete()); + OutputFile outputFile = new InMemoryOutputFile(); try (FileAppender writer = - Parquet.write(Files.localOutput(testFile)).schema(schema).named("test").build()) { + Parquet.write(outputFile).schema(writeSchema).named("test").build()) { writer.addAll(expected); } try (CloseableIterable reader = - Parquet.read(Files.localInput(testFile)) - .project(schema) - .createReaderFunc(type -> SparkParquetReaders.buildReader(schema, type)) + Parquet.read(outputFile.toInputFile()) + .project(expectedSchema) + .createReaderFunc(type -> SparkParquetReaders.buildReader(expectedSchema, type)) .build()) { Iterator rows = reader.iterator(); for (GenericData.Record record : expected) { - Assert.assertTrue("Should have expected number of rows", rows.hasNext()); - assertEqualsUnsafe(schema.asStruct(), record, rows.next()); + assertThat(rows.hasNext()).as("Should have expected number of rows").isTrue(); + assertEqualsUnsafe(expectedSchema.asStruct(), record, rows.next()); } - Assert.assertFalse("Should not have extra rows", rows.hasNext()); + assertThat(rows.hasNext()).as("Should not have extra rows").isFalse(); } } + @Override + protected boolean supportsDefaultValues() { + return true; + } + protected List rowsFromFile(InputFile inputFile, Schema schema) throws IOException { try (CloseableIterable reader = Parquet.read(inputFile) @@ -112,7 +124,7 @@ protected Table tableFromInputFile(InputFile inputFile, Schema schema) throws IO schema, PartitionSpec.unpartitioned(), ImmutableMap.of(), - temp.newFolder().getCanonicalPath()); + temp.resolve("table").toFile().getCanonicalPath()); table .newAppend() @@ -130,8 +142,7 @@ protected Table tableFromInputFile(InputFile inputFile, Schema schema) throws IO @Test public void testInt96TimestampProducedBySparkIsReadCorrectly() throws IOException { - String outputFilePath = - String.format("%s/%s", temp.getRoot().getAbsolutePath(), "parquet_int96.parquet"); + String outputFilePath = temp.resolve("parquet_int96.parquet").toString(); HadoopOutputFile outputFile = HadoopOutputFile.fromPath( new org.apache.hadoop.fs.Path(outputFilePath), new Configuration()); @@ -157,7 +168,7 @@ public void testInt96TimestampProducedBySparkIsReadCorrectly() throws IOExceptio InputFile parquetInputFile = Files.localInput(outputFilePath); List readRows = rowsFromFile(parquetInputFile, schema); - Assert.assertEquals(rows.size(), readRows.size()); + assertThat(rows.size()).isEqualTo(readRows.size()); assertThat(readRows).isEqualTo(rows); // Now we try to import that file as an Iceberg table to make sure Iceberg can read @@ -165,7 +176,7 @@ public void testInt96TimestampProducedBySparkIsReadCorrectly() throws IOExceptio Table int96Table = tableFromInputFile(parquetInputFile, schema); List tableRecords = Lists.newArrayList(IcebergGenerics.read(int96Table).build()); - Assert.assertEquals(rows.size(), tableRecords.size()); + assertThat(rows.size()).isEqualTo(tableRecords.size()); for (int i = 0; i < tableRecords.size(); i++) { GenericsHelpers.assertEqualsUnsafe(schema.asStruct(), tableRecords.get(i), rows.get(i)); @@ -203,4 +214,22 @@ protected WriteSupport getWriteSupport(Configuration configuration) return new org.apache.spark.sql.execution.datasources.parquet.ParquetWriteSupport(); } } + + @Test + public void testMissingRequiredWithoutDefault() { + Schema writeSchema = new Schema(required(1, "id", Types.LongType.get())); + + Schema expectedSchema = + new Schema( + required(1, "id", Types.LongType.get()), + Types.NestedField.required("missing_str") + .withId(6) + .ofType(Types.StringType.get()) + .withDoc("Missing required field with no default") + .build()); + + assertThatThrownBy(() -> writeAndValidate(writeSchema, expectedSchema)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Missing required field: missing_str"); + } } diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkRecordOrcReaderWriter.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkRecordOrcReaderWriter.java index d10e7f5a19e3..e644c7ead4c1 100644 --- a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkRecordOrcReaderWriter.java +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkRecordOrcReaderWriter.java @@ -19,6 +19,7 @@ package org.apache.iceberg.spark.data; import static org.apache.iceberg.types.Types.NestedField.required; +import static org.assertj.core.api.Assertions.assertThat; import java.io.File; import java.io.IOException; @@ -38,15 +39,13 @@ import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.types.Types; import org.apache.spark.sql.catalyst.InternalRow; -import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Test; public class TestSparkRecordOrcReaderWriter extends AvroDataTest { private static final int NUM_RECORDS = 200; private void writeAndValidate(Schema schema, List expectedRecords) throws IOException { - final File originalFile = temp.newFile(); - Assert.assertTrue("Delete should succeed", originalFile.delete()); + final File originalFile = temp.resolve("original").toFile(); // Write few generic records into the original test file. try (FileAppender writer = @@ -68,8 +67,7 @@ private void writeAndValidate(Schema schema, List expectedRecords) throw assertEqualsUnsafe(schema.asStruct(), expectedRecords, reader, expectedRecords.size()); } - final File anotherFile = temp.newFile(); - Assert.assertTrue("Delete should succeed", anotherFile.delete()); + final File anotherFile = temp.resolve("another").toFile(); // Write those spark InternalRows into a new file again. try (FileAppender writer = @@ -130,12 +128,14 @@ private static void assertRecordEquals( Iterator expectedIter = expected.iterator(); Iterator actualIter = actual.iterator(); for (int i = 0; i < size; i += 1) { - Assert.assertTrue("Expected iterator should have more rows", expectedIter.hasNext()); - Assert.assertTrue("Actual iterator should have more rows", actualIter.hasNext()); - Assert.assertEquals("Should have same rows.", expectedIter.next(), actualIter.next()); + assertThat(expectedIter.hasNext()).as("Expected iterator should have more rows").isTrue(); + assertThat(actualIter.hasNext()).as("Actual iterator should have more rows").isTrue(); + assertThat(expectedIter.next()).isEqualTo(actualIter.next()); } - Assert.assertFalse("Expected iterator should not have any extra rows.", expectedIter.hasNext()); - Assert.assertFalse("Actual iterator should not have any extra rows.", actualIter.hasNext()); + assertThat(expectedIter.hasNext()) + .as("Expected iterator should not have any extra rows") + .isFalse(); + assertThat(actualIter.hasNext()).as("Actual iterator should not have any extra rows").isFalse(); } private static void assertEqualsUnsafe( @@ -143,11 +143,13 @@ private static void assertEqualsUnsafe( Iterator expectedIter = expected.iterator(); Iterator actualIter = actual.iterator(); for (int i = 0; i < size; i += 1) { - Assert.assertTrue("Expected iterator should have more rows", expectedIter.hasNext()); - Assert.assertTrue("Actual iterator should have more rows", actualIter.hasNext()); + assertThat(expectedIter.hasNext()).as("Expected iterator should have more rows").isTrue(); + assertThat(actualIter.hasNext()).as("Actual iterator should have more rows").isTrue(); GenericsHelpers.assertEqualsUnsafe(struct, expectedIter.next(), actualIter.next()); } - Assert.assertFalse("Expected iterator should not have any extra rows.", expectedIter.hasNext()); - Assert.assertFalse("Actual iterator should not have any extra rows.", actualIter.hasNext()); + assertThat(expectedIter.hasNext()) + .as("Expected iterator should not have any extra rows") + .isFalse(); + assertThat(actualIter.hasNext()).as("Actual iterator should not have any extra rows").isFalse(); } } diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryEncodedVectorizedReads.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryEncodedVectorizedReads.java index 93080e17db35..a198ca7cdad2 100644 --- a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryEncodedVectorizedReads.java +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryEncodedVectorizedReads.java @@ -23,6 +23,7 @@ import java.io.File; import java.io.IOException; import org.apache.avro.generic.GenericData; +import org.apache.iceberg.Files; import org.apache.iceberg.Schema; import org.apache.iceberg.io.FileAppender; import org.apache.iceberg.parquet.Parquet; @@ -32,9 +33,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Iterables; import org.apache.iceberg.spark.data.RandomData; -import org.junit.Assert; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Test; public class TestParquetDictionaryEncodedVectorizedReads extends TestParquetVectorizedReads { @@ -52,32 +51,32 @@ Iterable generateData( @Test @Override - @Ignore // Ignored since this code path is already tested in TestParquetVectorizedReads - public void testVectorizedReadsWithNewContainers() throws IOException {} + public void testVectorizedReadsWithNewContainers() throws IOException { + // Disabled since this code path is already tested in TestParquetVectorizedReads + } @Test public void testMixedDictionaryNonDictionaryReads() throws IOException { Schema schema = new Schema(SUPPORTED_PRIMITIVES.fields()); - File dictionaryEncodedFile = temp.newFile(); - Assert.assertTrue("Delete should succeed", dictionaryEncodedFile.delete()); + File dictionaryEncodedFile = temp.resolve("dictionary.parquet").toFile(); Iterable dictionaryEncodableData = RandomData.generateDictionaryEncodableData( schema, 10000, 0L, RandomData.DEFAULT_NULL_PERCENTAGE); - try (FileAppender writer = parquetWriter(schema, dictionaryEncodedFile)) { + try (FileAppender writer = + parquetWriter(schema, Files.localOutput(dictionaryEncodedFile))) { writer.addAll(dictionaryEncodableData); } - File plainEncodingFile = temp.newFile(); - Assert.assertTrue("Delete should succeed", plainEncodingFile.delete()); + File plainEncodingFile = temp.resolve("plain.parquet").toFile(); Iterable nonDictionaryData = RandomData.generate(schema, 10000, 0L, RandomData.DEFAULT_NULL_PERCENTAGE); - try (FileAppender writer = parquetWriter(schema, plainEncodingFile)) { + try (FileAppender writer = + parquetWriter(schema, Files.localOutput(plainEncodingFile))) { writer.addAll(nonDictionaryData); } int rowGroupSize = PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT; - File mixedFile = temp.newFile(); - Assert.assertTrue("Delete should succeed", mixedFile.delete()); + File mixedFile = temp.resolve("mixed.parquet").toFile(); Parquet.concat( ImmutableList.of(dictionaryEncodedFile, plainEncodingFile, dictionaryEncodedFile), mixedFile, @@ -88,7 +87,7 @@ public void testMixedDictionaryNonDictionaryReads() throws IOException { schema, 30000, FluentIterable.concat(dictionaryEncodableData, nonDictionaryData, dictionaryEncodableData), - mixedFile, + Files.localInput(mixedFile), true, BATCH_SIZE); } diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryFallbackToPlainEncodingVectorizedReads.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryFallbackToPlainEncodingVectorizedReads.java index f8b2040c4512..03a96b474713 100644 --- a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryFallbackToPlainEncodingVectorizedReads.java +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryFallbackToPlainEncodingVectorizedReads.java @@ -18,13 +18,12 @@ */ package org.apache.iceberg.spark.data.parquet.vectorized; -import java.io.File; import java.io.IOException; import org.apache.avro.generic.GenericData; -import org.apache.iceberg.Files; import org.apache.iceberg.Schema; import org.apache.iceberg.TableProperties; import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.io.OutputFile; import org.apache.iceberg.parquet.Parquet; import org.apache.iceberg.relocated.com.google.common.base.Function; import org.apache.iceberg.relocated.com.google.common.collect.Iterables; @@ -54,8 +53,9 @@ Iterable generateData( } @Override - FileAppender parquetWriter(Schema schema, File testFile) throws IOException { - return Parquet.write(Files.localOutput(testFile)) + FileAppender parquetWriter(Schema schema, OutputFile outputFile) + throws IOException { + return Parquet.write(outputFile) .schema(schema) .named("test") .set(TableProperties.PARQUET_DICT_SIZE_BYTES, "512000") diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetVectorizedReads.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetVectorizedReads.java index e3854bfeb529..5b00c4f86671 100644 --- a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetVectorizedReads.java +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetVectorizedReads.java @@ -22,16 +22,17 @@ import static org.apache.iceberg.types.Types.NestedField.required; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; import java.security.SecureRandom; import java.util.Iterator; import org.apache.avro.generic.GenericData; -import org.apache.iceberg.Files; import org.apache.iceberg.Schema; +import org.apache.iceberg.inmemory.InMemoryOutputFile; import org.apache.iceberg.io.CloseableIterable; import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.io.OutputFile; import org.apache.iceberg.parquet.Parquet; import org.apache.iceberg.relocated.com.google.common.base.Function; import org.apache.iceberg.relocated.com.google.common.base.Strings; @@ -51,7 +52,6 @@ import org.apache.spark.sql.vectorized.ColumnarBatch; import org.junit.Assert; import org.junit.Assume; -import org.junit.Ignore; import org.junit.Test; public class TestParquetVectorizedReads extends AvroDataTest { @@ -64,18 +64,42 @@ public class TestParquetVectorizedReads extends AvroDataTest { @Override protected void writeAndValidate(Schema schema) throws IOException { - writeAndValidate(schema, getNumRows(), 0L, RandomData.DEFAULT_NULL_PERCENTAGE, true); + writeAndValidate(schema, schema); + } + + @Override + protected void writeAndValidate(Schema writeSchema, Schema expectedSchema) throws IOException { + writeAndValidate( + writeSchema, + expectedSchema, + getNumRows(), + 29714278L, + RandomData.DEFAULT_NULL_PERCENTAGE, + true, + BATCH_SIZE, + IDENTITY); + } + + @Override + protected boolean supportsDefaultValues() { + return true; + } + + @Override + protected boolean supportsNestedTypes() { + return false; } private void writeAndValidate( Schema schema, int numRecords, long seed, float nullPercentage, boolean reuseContainers) throws IOException { writeAndValidate( - schema, numRecords, seed, nullPercentage, reuseContainers, BATCH_SIZE, IDENTITY); + schema, schema, numRecords, seed, nullPercentage, reuseContainers, BATCH_SIZE, IDENTITY); } private void writeAndValidate( - Schema schema, + Schema writeSchema, + Schema expectedSchema, int numRecords, long seed, float nullPercentage, @@ -88,28 +112,35 @@ private void writeAndValidate( "Parquet Avro cannot write non-string map keys", null == TypeUtil.find( - schema, + writeSchema, type -> type.isMapType() && type.asMapType().keyType() != Types.StringType.get())); Iterable expected = - generateData(schema, numRecords, seed, nullPercentage, transform); + generateData(writeSchema, numRecords, seed, nullPercentage, transform); - // write a test parquet file using iceberg writer - File testFile = temp.newFile(); - Assert.assertTrue("Delete should succeed", testFile.delete()); + OutputFile outputFile = new InMemoryOutputFile(); - try (FileAppender writer = parquetWriter(schema, testFile)) { + try (FileAppender writer = parquetWriter(writeSchema, outputFile)) { writer.addAll(expected); } - assertRecordsMatch(schema, numRecords, expected, testFile, reuseContainers, batchSize); + assertRecordsMatch( + expectedSchema, numRecords, expected, outputFile.toInputFile(), reuseContainers, batchSize); // With encryption - testFile.delete(); - try (FileAppender writer = encryptedParquetWriter(schema, testFile)) { + OutputFile encryptedOutputFile = new InMemoryOutputFile(); + try (FileAppender writer = + encryptedParquetWriter(writeSchema, encryptedOutputFile)) { writer.addAll(expected); } - assertRecordsMatch(schema, numRecords, expected, testFile, reuseContainers, batchSize, true); + assertRecordsMatch( + expectedSchema, + numRecords, + expected, + encryptedOutputFile.toInputFile(), + reuseContainers, + batchSize, + true); } protected int getNumRows() { @@ -127,16 +158,17 @@ Iterable generateData( return transform == IDENTITY ? data : Iterables.transform(data, transform); } - FileAppender parquetWriter(Schema schema, File testFile) throws IOException { - return Parquet.write(Files.localOutput(testFile)).schema(schema).named("test").build(); + FileAppender parquetWriter(Schema schema, OutputFile outputFile) + throws IOException { + return Parquet.write(outputFile).schema(schema).named("test").build(); } - FileAppender encryptedParquetWriter(Schema schema, File testFile) + FileAppender encryptedParquetWriter(Schema schema, OutputFile outputFile) throws IOException { SecureRandom rand = new SecureRandom(); rand.nextBytes(FILE_DEK.array()); rand.nextBytes(AAD_PREFIX.array()); - return Parquet.write(Files.localOutput(testFile)) + return Parquet.write(outputFile) .schema(schema) .withFileEncryptionKey(FILE_DEK) .withAADPrefix(AAD_PREFIX) @@ -144,21 +176,21 @@ FileAppender encryptedParquetWriter(Schema schema, File test .build(); } - FileAppender parquetV2Writer(Schema schema, File testFile) + FileAppender parquetV2Writer(Schema schema, OutputFile outputFile) throws IOException { - return Parquet.write(Files.localOutput(testFile)) + return Parquet.write(outputFile) .schema(schema) .named("test") .writerVersion(ParquetProperties.WriterVersion.PARQUET_2_0) .build(); } - FileAppender encryptedParquetV2Writer(Schema schema, File testFile) + FileAppender encryptedParquetV2Writer(Schema schema, OutputFile outputFile) throws IOException { SecureRandom rand = new SecureRandom(); rand.nextBytes(FILE_DEK.array()); rand.nextBytes(AAD_PREFIX.array()); - return Parquet.write(Files.localOutput(testFile)) + return Parquet.write(outputFile) .schema(schema) .withFileEncryptionKey(FILE_DEK) .withAADPrefix(AAD_PREFIX) @@ -171,24 +203,25 @@ void assertRecordsMatch( Schema schema, int expectedSize, Iterable expected, - File testFile, + InputFile inputFile, boolean reuseContainers, int batchSize) throws IOException { - assertRecordsMatch(schema, expectedSize, expected, testFile, reuseContainers, batchSize, false); + assertRecordsMatch( + schema, expectedSize, expected, inputFile, reuseContainers, batchSize, false); } void assertRecordsMatch( Schema schema, int expectedSize, Iterable expected, - File testFile, + InputFile inputFile, boolean reuseContainers, int batchSize, boolean encrypted) throws IOException { Parquet.ReadBuilder readBuilder = - Parquet.read(Files.localInput(testFile)) + Parquet.read(inputFile) .project(schema) .recordsPerBatch(batchSize) .createBatchedReaderFunc( @@ -217,41 +250,6 @@ void assertRecordsMatch( } } - @Override - @Test - @Ignore - public void testArray() {} - - @Override - @Test - @Ignore - public void testArrayOfStructs() {} - - @Override - @Test - @Ignore - public void testMap() {} - - @Override - @Test - @Ignore - public void testNumericMapKey() {} - - @Override - @Test - @Ignore - public void testComplexMapKey() {} - - @Override - @Test - @Ignore - public void testMapOfStructs() {} - - @Override - @Test - @Ignore - public void testMixedTypes() {} - @Test @Override public void testNestedStruct() { @@ -303,10 +301,13 @@ public void testVectorizedReadsWithNewContainers() throws IOException { public void testVectorizedReadsWithReallocatedArrowBuffers() throws IOException { // With a batch size of 2, 256 bytes are allocated in the VarCharVector. By adding strings of // length 512, the vector will need to be reallocated for storing the batch. - writeAndValidate( + Schema schema = new Schema( Lists.newArrayList( - SUPPORTED_PRIMITIVES.field("id"), SUPPORTED_PRIMITIVES.field("data"))), + SUPPORTED_PRIMITIVES.field("id"), SUPPORTED_PRIMITIVES.field("data"))); + writeAndValidate( + schema, + schema, 10, 0L, RandomData.DEFAULT_NULL_PERCENTAGE, @@ -331,11 +332,10 @@ public void testReadsForTypePromotedColumns() throws Exception { optional(102, "float_data", Types.FloatType.get()), optional(103, "decimal_data", Types.DecimalType.of(10, 5))); - File dataFile = temp.newFile(); - Assert.assertTrue("Delete should succeed", dataFile.delete()); + OutputFile outputFile = new InMemoryOutputFile(); Iterable data = generateData(writeSchema, 30000, 0L, RandomData.DEFAULT_NULL_PERCENTAGE, IDENTITY); - try (FileAppender writer = parquetWriter(writeSchema, dataFile)) { + try (FileAppender writer = parquetWriter(writeSchema, outputFile)) { writer.addAll(data); } @@ -346,7 +346,7 @@ public void testReadsForTypePromotedColumns() throws Exception { optional(102, "float_data", Types.DoubleType.get()), optional(103, "decimal_data", Types.DecimalType.of(25, 5))); - assertRecordsMatch(readSchema, 30000, data, dataFile, true, BATCH_SIZE); + assertRecordsMatch(readSchema, 30000, data, outputFile.toInputFile(), true, BATCH_SIZE); } @Test @@ -360,22 +360,23 @@ public void testSupportedReadsForParquetV2() throws Exception { optional(103, "double_data", Types.DoubleType.get()), optional(104, "decimal_data", Types.DecimalType.of(25, 5))); - File dataFile = temp.newFile(); - Assert.assertTrue("Delete should succeed", dataFile.delete()); + OutputFile outputFile = new InMemoryOutputFile(); Iterable data = generateData(schema, 30000, 0L, RandomData.DEFAULT_NULL_PERCENTAGE, IDENTITY); - try (FileAppender writer = parquetV2Writer(schema, dataFile)) { + try (FileAppender writer = parquetV2Writer(schema, outputFile)) { writer.addAll(data); } - assertRecordsMatch(schema, 30000, data, dataFile, true, BATCH_SIZE); + assertRecordsMatch(schema, 30000, data, outputFile.toInputFile(), true, BATCH_SIZE); // With encryption - dataFile.delete(); - try (FileAppender writer = encryptedParquetV2Writer(schema, dataFile)) { + OutputFile encryptedOutputFile = new InMemoryOutputFile(); + try (FileAppender writer = + encryptedParquetV2Writer(schema, encryptedOutputFile)) { writer.addAll(data); } - assertRecordsMatch(schema, 30000, data, dataFile, true, BATCH_SIZE, true); + assertRecordsMatch( + schema, 30000, data, encryptedOutputFile.toInputFile(), true, BATCH_SIZE, true); } @Test @@ -383,14 +384,15 @@ public void testUnsupportedReadsForParquetV2() throws Exception { // Longs, ints, string types etc use delta encoding and which are not supported for vectorized // reads Schema schema = new Schema(SUPPORTED_PRIMITIVES.fields()); - File dataFile = temp.newFile(); - Assert.assertTrue("Delete should succeed", dataFile.delete()); + OutputFile outputFile = new InMemoryOutputFile(); Iterable data = generateData(schema, 30000, 0L, RandomData.DEFAULT_NULL_PERCENTAGE, IDENTITY); - try (FileAppender writer = parquetV2Writer(schema, dataFile)) { + try (FileAppender writer = parquetV2Writer(schema, outputFile)) { writer.addAll(data); } - assertThatThrownBy(() -> assertRecordsMatch(schema, 30000, data, dataFile, true, BATCH_SIZE)) + assertThatThrownBy( + () -> + assertRecordsMatch(schema, 30000, data, outputFile.toInputFile(), true, BATCH_SIZE)) .as("Vectorized reads not supported") .isInstanceOf(UnsupportedOperationException.class) .hasMessageContaining("Cannot support vectorized reads for column"); diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/DataFrameWriteTestBase.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/DataFrameWriteTestBase.java new file mode 100644 index 000000000000..756370eec0da --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/DataFrameWriteTestBase.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.spark.SparkSchemaUtil.convert; +import static org.apache.iceberg.spark.data.TestHelpers.assertEqualsUnsafe; +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.Iterator; +import java.util.List; +import org.apache.avro.generic.GenericData; +import org.apache.iceberg.Files; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.Tables; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.avro.AvroIterable; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.SparkPlannedAvroReader; +import org.apache.iceberg.types.Types; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.sql.DataFrameWriter; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +public abstract class DataFrameWriteTestBase extends ScanTestBase { + @TempDir private Path temp; + + @Override + protected boolean supportsDefaultValues() { + // disable default value tests because this tests the write path + return false; + } + + @Override + protected void writeRecords(Table table, List records) throws IOException { + Schema tableSchema = table.schema(); // use the table schema because ids are reassigned + + Dataset df = createDataset(records, tableSchema); + DataFrameWriter writer = df.write().format("iceberg").mode("append"); + + writer.save(table.location()); + + // refresh the in-memory table state to pick up Spark's write + table.refresh(); + } + + private Dataset createDataset(List records, Schema schema) + throws IOException { + // this uses the SparkAvroReader to create a DataFrame from the list of records + // it assumes that SparkAvroReader is correct + File testFile = File.createTempFile("junit", null, temp.toFile()); + assertThat(testFile.delete()).as("Delete should succeed").isTrue(); + + try (FileAppender writer = + Avro.write(Files.localOutput(testFile)).schema(schema).named("test").build()) { + for (GenericData.Record rec : records) { + writer.add(rec); + } + } + + List rows; + try (AvroIterable reader = + Avro.read(Files.localInput(testFile)) + .createResolvingReader(SparkPlannedAvroReader::create) + .project(schema) + .build()) { + rows = Lists.newArrayList(reader); + } + + // verify that the dataframe matches + assertThat(rows.size()).isEqualTo(records.size()); + Iterator recordIter = records.iterator(); + for (InternalRow row : rows) { + assertEqualsUnsafe(schema.asStruct(), recordIter.next(), row); + } + + JavaRDD rdd = sc.parallelize(rows); + return spark.internalCreateDataFrame(JavaRDD.toRDD(rdd), convert(schema), false); + } + + @Test + public void testAlternateLocation() throws IOException { + Schema schema = new Schema(required(1, "id", Types.LongType.get())); + + File location = temp.resolve("table_location").toFile(); + File altLocation = temp.resolve("alt_location").toFile(); + + Tables tables = new HadoopTables(spark.sessionState().newHadoopConf()); + Table table = tables.create(schema, PartitionSpec.unpartitioned(), location.toString()); + + // override the table's data location + table + .updateProperties() + .set(TableProperties.WRITE_DATA_LOCATION, altLocation.getAbsolutePath()) + .commit(); + + writeRecords(table, RandomData.generateList(table.schema(), 100, 87112L)); + + table + .currentSnapshot() + .addedDataFiles(table.io()) + .forEach( + dataFile -> + assertThat(dataFile.location()) + .as( + String.format( + "File should have the parent directory %s, but has: %s.", + altLocation, dataFile.location())) + .startsWith(altLocation + "/")); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/ScanTestBase.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/ScanTestBase.java new file mode 100644 index 000000000000..3a269740b709 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/ScanTestBase.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.List; +import org.apache.avro.generic.GenericData; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.hadoop.HadoopTables; +import org.apache.iceberg.spark.data.AvroDataTest; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.types.TypeUtil; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.io.TempDir; + +/** An AvroDataScan test that validates data by reading through Spark */ +public abstract class ScanTestBase extends AvroDataTest { + private static final Configuration CONF = new Configuration(); + + protected static SparkSession spark = null; + protected static JavaSparkContext sc = null; + + @BeforeAll + public static void startSpark() { + ScanTestBase.spark = SparkSession.builder().master("local[2]").getOrCreate(); + ScanTestBase.sc = JavaSparkContext.fromSparkContext(spark.sparkContext()); + } + + @AfterAll + public static void stopSpark() { + SparkSession currentSpark = ScanTestBase.spark; + ScanTestBase.spark = null; + ScanTestBase.sc = null; + currentSpark.stop(); + } + + @TempDir private Path temp; + + protected void configureTable(Table table) {} + + protected abstract void writeRecords(Table table, List records) + throws IOException; + + @Override + protected void writeAndValidate(Schema schema) throws IOException { + writeAndValidate(schema, schema); + } + + @Override + protected void writeAndValidate(Schema writeSchema, Schema expectedSchema) throws IOException { + File parent = temp.resolve("scan_test").toFile(); + File location = new File(parent, "test"); + + HadoopTables tables = new HadoopTables(CONF); + Table table = tables.create(writeSchema, PartitionSpec.unpartitioned(), location.toString()); + + // Important: use the table's schema for the rest of the test + // When tables are created, the column ids are reassigned. + List expected = RandomData.generateList(table.schema(), 100, 1L); + + writeRecords(table, expected); + + // update the table schema to the expected schema + if (!expectedSchema.sameSchema(table.schema())) { + Schema expectedSchemaWithTableIds = + TypeUtil.reassignOrRefreshIds(expectedSchema, table.schema()); + int highestFieldId = + Math.max(table.schema().highestFieldId(), expectedSchema.highestFieldId()); + + // don't use the table API because tests cover incompatible update cases + TableOperations ops = ((BaseTable) table).operations(); + TableMetadata builder = + TableMetadata.buildFrom(ops.current()) + .upgradeFormatVersion(3) + .setCurrentSchema(expectedSchemaWithTableIds, highestFieldId) + .build(); + ops.commit(ops.current(), builder); + } + + Dataset df = spark.read().format("iceberg").load(table.location()); + + List rows = df.collectAsList(); + assertThat(rows).as("Should contain 100 rows").hasSize(100); + + for (int i = 0; i < expected.size(); i += 1) { + TestHelpers.assertEqualsSafe(table.schema().asStruct(), expected.get(i), rows.get(i)); + } + } + + @Override + protected boolean supportsDefaultValues() { + return true; + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestAvroDataFrameWrite.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestAvroDataFrameWrite.java new file mode 100644 index 000000000000..110428d0a20c --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestAvroDataFrameWrite.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; + +public class TestAvroDataFrameWrite extends DataFrameWriteTestBase { + @Override + protected void configureTable(Table table) { + table + .updateProperties() + .set(TableProperties.DEFAULT_FILE_FORMAT, FileFormat.AVRO.toString()) + .commit(); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestAvroScan.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestAvroScan.java index 9491adde4605..3e67282644b5 100644 --- a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestAvroScan.java +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestAvroScan.java @@ -25,87 +25,34 @@ import java.util.List; import java.util.UUID; import org.apache.avro.generic.GenericData.Record; -import org.apache.hadoop.conf.Configuration; import org.apache.iceberg.DataFile; import org.apache.iceberg.DataFiles; import org.apache.iceberg.FileFormat; import org.apache.iceberg.PartitionSpec; -import org.apache.iceberg.Schema; import org.apache.iceberg.Table; import org.apache.iceberg.avro.Avro; -import org.apache.iceberg.hadoop.HadoopTables; import org.apache.iceberg.io.FileAppender; -import org.apache.iceberg.spark.data.AvroDataTest; -import org.apache.iceberg.spark.data.RandomData; -import org.apache.iceberg.spark.data.TestHelpers; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.SparkSession; -import org.junit.AfterClass; -import org.junit.Assert; -import org.junit.BeforeClass; -import org.junit.Rule; -import org.junit.rules.TemporaryFolder; - -public class TestAvroScan extends AvroDataTest { - private static final Configuration CONF = new Configuration(); - - @Rule public TemporaryFolder temp = new TemporaryFolder(); - - private static SparkSession spark = null; - - @BeforeClass - public static void startSpark() { - TestAvroScan.spark = SparkSession.builder().master("local[2]").getOrCreate(); - } - - @AfterClass - public static void stopSpark() { - SparkSession currentSpark = TestAvroScan.spark; - TestAvroScan.spark = null; - currentSpark.stop(); - } +public class TestAvroScan extends ScanTestBase { @Override - protected void writeAndValidate(Schema schema) throws IOException { - File parent = temp.newFolder("avro"); - File location = new File(parent, "test"); - File dataFolder = new File(location, "data"); - dataFolder.mkdirs(); + protected void writeRecords(Table table, List records) throws IOException { + File dataFolder = new File(table.location(), "data"); File avroFile = new File(dataFolder, FileFormat.AVRO.addExtension(UUID.randomUUID().toString())); - HadoopTables tables = new HadoopTables(CONF); - Table table = tables.create(schema, PartitionSpec.unpartitioned(), location.toString()); - - // Important: use the table's schema for the rest of the test - // When tables are created, the column ids are reassigned. - Schema tableSchema = table.schema(); - - List expected = RandomData.generateList(tableSchema, 100, 1L); - try (FileAppender writer = - Avro.write(localOutput(avroFile)).schema(tableSchema).build()) { - writer.addAll(expected); + Avro.write(localOutput(avroFile)).schema(table.schema()).build()) { + writer.addAll(records); } DataFile file = DataFiles.builder(PartitionSpec.unpartitioned()) - .withRecordCount(100) .withFileSizeInBytes(avroFile.length()) + .withRecordCount(records.size()) .withPath(avroFile.toString()) .build(); table.newAppend().appendFile(file).commit(); - - Dataset df = spark.read().format("iceberg").load(location.toString()); - - List rows = df.collectAsList(); - Assert.assertEquals("Should contain 100 rows", 100, rows.size()); - - for (int i = 0; i < expected.size(); i += 1) { - TestHelpers.assertEqualsSafe(tableSchema.asStruct(), expected.get(i), rows.get(i)); - } } } diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWrites.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWrites.java deleted file mode 100644 index b5c055925bf7..000000000000 --- a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWrites.java +++ /dev/null @@ -1,421 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.iceberg.spark.source; - -import static org.apache.iceberg.spark.SparkSchemaUtil.convert; -import static org.apache.iceberg.spark.data.TestHelpers.assertEqualsSafe; -import static org.apache.iceberg.spark.data.TestHelpers.assertEqualsUnsafe; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -import java.io.File; -import java.io.IOException; -import java.net.URI; -import java.util.Arrays; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Random; -import org.apache.avro.generic.GenericData.Record; -import org.apache.hadoop.conf.Configuration; -import org.apache.iceberg.Files; -import org.apache.iceberg.PartitionSpec; -import org.apache.iceberg.Schema; -import org.apache.iceberg.Snapshot; -import org.apache.iceberg.Table; -import org.apache.iceberg.TableProperties; -import org.apache.iceberg.avro.Avro; -import org.apache.iceberg.avro.AvroIterable; -import org.apache.iceberg.hadoop.HadoopTables; -import org.apache.iceberg.io.FileAppender; -import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; -import org.apache.iceberg.relocated.com.google.common.collect.Lists; -import org.apache.iceberg.spark.SparkSQLProperties; -import org.apache.iceberg.spark.SparkSchemaUtil; -import org.apache.iceberg.spark.SparkWriteOptions; -import org.apache.iceberg.spark.data.AvroDataTest; -import org.apache.iceberg.spark.data.RandomData; -import org.apache.iceberg.spark.data.SparkAvroReader; -import org.apache.iceberg.types.Types; -import org.apache.spark.SparkException; -import org.apache.spark.TaskContext; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.MapPartitionsFunction; -import org.apache.spark.sql.DataFrameWriter; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.SaveMode; -import org.apache.spark.sql.SparkSession; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.encoders.RowEncoder; -import org.junit.AfterClass; -import org.junit.Assert; -import org.junit.Assume; -import org.junit.BeforeClass; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; - -@RunWith(Parameterized.class) -public class TestDataFrameWrites extends AvroDataTest { - private static final Configuration CONF = new Configuration(); - - private final String format; - - @Parameterized.Parameters(name = "format = {0}") - public static Object[] parameters() { - return new Object[] {"parquet", "avro", "orc"}; - } - - public TestDataFrameWrites(String format) { - this.format = format; - } - - private static SparkSession spark = null; - private static JavaSparkContext sc = null; - - private Map tableProperties; - - private final org.apache.spark.sql.types.StructType sparkSchema = - new org.apache.spark.sql.types.StructType( - new org.apache.spark.sql.types.StructField[] { - new org.apache.spark.sql.types.StructField( - "optionalField", - org.apache.spark.sql.types.DataTypes.StringType, - true, - org.apache.spark.sql.types.Metadata.empty()), - new org.apache.spark.sql.types.StructField( - "requiredField", - org.apache.spark.sql.types.DataTypes.StringType, - false, - org.apache.spark.sql.types.Metadata.empty()) - }); - - private final Schema icebergSchema = - new Schema( - Types.NestedField.optional(1, "optionalField", Types.StringType.get()), - Types.NestedField.required(2, "requiredField", Types.StringType.get())); - - private final List data0 = - Arrays.asList( - "{\"optionalField\": \"a1\", \"requiredField\": \"bid_001\"}", - "{\"optionalField\": \"a2\", \"requiredField\": \"bid_002\"}"); - private final List data1 = - Arrays.asList( - "{\"optionalField\": \"d1\", \"requiredField\": \"bid_101\"}", - "{\"optionalField\": \"d2\", \"requiredField\": \"bid_102\"}", - "{\"optionalField\": \"d3\", \"requiredField\": \"bid_103\"}", - "{\"optionalField\": \"d4\", \"requiredField\": \"bid_104\"}"); - - @BeforeClass - public static void startSpark() { - TestDataFrameWrites.spark = SparkSession.builder().master("local[2]").getOrCreate(); - TestDataFrameWrites.sc = JavaSparkContext.fromSparkContext(spark.sparkContext()); - } - - @AfterClass - public static void stopSpark() { - SparkSession currentSpark = TestDataFrameWrites.spark; - TestDataFrameWrites.spark = null; - TestDataFrameWrites.sc = null; - currentSpark.stop(); - } - - @Override - protected void writeAndValidate(Schema schema) throws IOException { - File location = createTableFolder(); - Table table = createTable(schema, location); - writeAndValidateWithLocations(table, location, new File(location, "data")); - } - - @Test - public void testWriteWithCustomDataLocation() throws IOException { - File location = createTableFolder(); - File tablePropertyDataLocation = temp.newFolder("test-table-property-data-dir"); - Table table = createTable(new Schema(SUPPORTED_PRIMITIVES.fields()), location); - table - .updateProperties() - .set(TableProperties.WRITE_DATA_LOCATION, tablePropertyDataLocation.getAbsolutePath()) - .commit(); - writeAndValidateWithLocations(table, location, tablePropertyDataLocation); - } - - private File createTableFolder() throws IOException { - File parent = temp.newFolder("parquet"); - File location = new File(parent, "test"); - Assert.assertTrue("Mkdir should succeed", location.mkdirs()); - return location; - } - - private Table createTable(Schema schema, File location) { - HadoopTables tables = new HadoopTables(CONF); - return tables.create(schema, PartitionSpec.unpartitioned(), location.toString()); - } - - private void writeAndValidateWithLocations(Table table, File location, File expectedDataDir) - throws IOException { - Schema tableSchema = table.schema(); // use the table schema because ids are reassigned - - table.updateProperties().set(TableProperties.DEFAULT_FILE_FORMAT, format).commit(); - - Iterable expected = RandomData.generate(tableSchema, 100, 0L); - writeData(expected, tableSchema, location.toString()); - - table.refresh(); - - List actual = readTable(location.toString()); - - Iterator expectedIter = expected.iterator(); - Iterator actualIter = actual.iterator(); - while (expectedIter.hasNext() && actualIter.hasNext()) { - assertEqualsSafe(tableSchema.asStruct(), expectedIter.next(), actualIter.next()); - } - Assert.assertEquals( - "Both iterators should be exhausted", expectedIter.hasNext(), actualIter.hasNext()); - - table - .currentSnapshot() - .addedDataFiles(table.io()) - .forEach( - dataFile -> - Assert.assertTrue( - String.format( - "File should have the parent directory %s, but has: %s.", - expectedDataDir.getAbsolutePath(), dataFile.location()), - URI.create(dataFile.location()) - .getPath() - .startsWith(expectedDataDir.getAbsolutePath()))); - } - - private List readTable(String location) { - Dataset result = spark.read().format("iceberg").load(location); - - return result.collectAsList(); - } - - private void writeData(Iterable records, Schema schema, String location) - throws IOException { - Dataset df = createDataset(records, schema); - DataFrameWriter writer = df.write().format("iceberg").mode("append"); - writer.save(location); - } - - private void writeDataWithFailOnPartition( - Iterable records, Schema schema, String location) throws IOException, SparkException { - final int numPartitions = 10; - final int partitionToFail = new Random().nextInt(numPartitions); - MapPartitionsFunction failOnFirstPartitionFunc = - input -> { - int partitionId = TaskContext.getPartitionId(); - - if (partitionId == partitionToFail) { - throw new SparkException( - String.format("Intended exception in partition %d !", partitionId)); - } - return input; - }; - - Dataset df = - createDataset(records, schema) - .repartition(numPartitions) - .mapPartitions(failOnFirstPartitionFunc, RowEncoder.apply(convert(schema))); - // This trick is needed because Spark 3 handles decimal overflow in RowEncoder which "changes" - // nullability of the column to "true" regardless of original nullability. - // Setting "check-nullability" option to "false" doesn't help as it fails at Spark analyzer. - Dataset convertedDf = df.sqlContext().createDataFrame(df.rdd(), convert(schema)); - DataFrameWriter writer = convertedDf.write().format("iceberg").mode("append"); - writer.save(location); - } - - private Dataset createDataset(Iterable records, Schema schema) throws IOException { - // this uses the SparkAvroReader to create a DataFrame from the list of records - // it assumes that SparkAvroReader is correct - File testFile = temp.newFile(); - Assert.assertTrue("Delete should succeed", testFile.delete()); - - try (FileAppender writer = - Avro.write(Files.localOutput(testFile)).schema(schema).named("test").build()) { - for (Record rec : records) { - writer.add(rec); - } - } - - // make sure the dataframe matches the records before moving on - List rows = Lists.newArrayList(); - try (AvroIterable reader = - Avro.read(Files.localInput(testFile)) - .createReaderFunc(SparkAvroReader::new) - .project(schema) - .build()) { - - Iterator recordIter = records.iterator(); - Iterator readIter = reader.iterator(); - while (recordIter.hasNext() && readIter.hasNext()) { - InternalRow row = readIter.next(); - assertEqualsUnsafe(schema.asStruct(), recordIter.next(), row); - rows.add(row); - } - Assert.assertEquals( - "Both iterators should be exhausted", recordIter.hasNext(), readIter.hasNext()); - } - - JavaRDD rdd = sc.parallelize(rows); - return spark.internalCreateDataFrame(JavaRDD.toRDD(rdd), convert(schema), false); - } - - @Test - public void testNullableWithWriteOption() throws IOException { - Assume.assumeTrue( - "Spark 3 rejects writing nulls to a required column", spark.version().startsWith("2")); - - File location = new File(temp.newFolder("parquet"), "test"); - String sourcePath = String.format("%s/nullable_poc/sourceFolder/", location); - String targetPath = String.format("%s/nullable_poc/targetFolder/", location); - - tableProperties = ImmutableMap.of(TableProperties.WRITE_DATA_LOCATION, targetPath); - - // read this and append to iceberg dataset - spark - .read() - .schema(sparkSchema) - .json(JavaSparkContext.fromSparkContext(spark.sparkContext()).parallelize(data1)) - .write() - .parquet(sourcePath); - - // this is our iceberg dataset to which we will append data - new HadoopTables(spark.sessionState().newHadoopConf()) - .create( - icebergSchema, - PartitionSpec.builderFor(icebergSchema).identity("requiredField").build(), - tableProperties, - targetPath); - - // this is the initial data inside the iceberg dataset - spark - .read() - .schema(sparkSchema) - .json(JavaSparkContext.fromSparkContext(spark.sparkContext()).parallelize(data0)) - .write() - .format("iceberg") - .mode(SaveMode.Append) - .save(targetPath); - - // read from parquet and append to iceberg w/ nullability check disabled - spark - .read() - .schema(SparkSchemaUtil.convert(icebergSchema)) - .parquet(sourcePath) - .write() - .format("iceberg") - .option(SparkWriteOptions.CHECK_NULLABILITY, false) - .mode(SaveMode.Append) - .save(targetPath); - - // read all data - List rows = spark.read().format("iceberg").load(targetPath).collectAsList(); - Assert.assertEquals("Should contain 6 rows", 6, rows.size()); - } - - @Test - public void testNullableWithSparkSqlOption() throws IOException { - Assume.assumeTrue( - "Spark 3 rejects writing nulls to a required column", spark.version().startsWith("2")); - - File location = new File(temp.newFolder("parquet"), "test"); - String sourcePath = String.format("%s/nullable_poc/sourceFolder/", location); - String targetPath = String.format("%s/nullable_poc/targetFolder/", location); - - tableProperties = ImmutableMap.of(TableProperties.WRITE_DATA_LOCATION, targetPath); - - // read this and append to iceberg dataset - spark - .read() - .schema(sparkSchema) - .json(JavaSparkContext.fromSparkContext(spark.sparkContext()).parallelize(data1)) - .write() - .parquet(sourcePath); - - SparkSession newSparkSession = - SparkSession.builder() - .master("local[2]") - .appName("NullableTest") - .config(SparkSQLProperties.CHECK_NULLABILITY, false) - .getOrCreate(); - - // this is our iceberg dataset to which we will append data - new HadoopTables(newSparkSession.sessionState().newHadoopConf()) - .create( - icebergSchema, - PartitionSpec.builderFor(icebergSchema).identity("requiredField").build(), - tableProperties, - targetPath); - - // this is the initial data inside the iceberg dataset - newSparkSession - .read() - .schema(sparkSchema) - .json(JavaSparkContext.fromSparkContext(spark.sparkContext()).parallelize(data0)) - .write() - .format("iceberg") - .mode(SaveMode.Append) - .save(targetPath); - - // read from parquet and append to iceberg - newSparkSession - .read() - .schema(SparkSchemaUtil.convert(icebergSchema)) - .parquet(sourcePath) - .write() - .format("iceberg") - .mode(SaveMode.Append) - .save(targetPath); - - // read all data - List rows = newSparkSession.read().format("iceberg").load(targetPath).collectAsList(); - Assert.assertEquals("Should contain 6 rows", 6, rows.size()); - } - - @Test - public void testFaultToleranceOnWrite() throws IOException { - File location = createTableFolder(); - Schema schema = new Schema(SUPPORTED_PRIMITIVES.fields()); - Table table = createTable(schema, location); - - Iterable records = RandomData.generate(schema, 100, 0L); - writeData(records, schema, location.toString()); - - table.refresh(); - - Snapshot snapshotBeforeFailingWrite = table.currentSnapshot(); - List resultBeforeFailingWrite = readTable(location.toString()); - - Iterable records2 = RandomData.generate(schema, 100, 0L); - - assertThatThrownBy(() -> writeDataWithFailOnPartition(records2, schema, location.toString())) - .isInstanceOf(SparkException.class); - - table.refresh(); - - Snapshot snapshotAfterFailingWrite = table.currentSnapshot(); - List resultAfterFailingWrite = readTable(location.toString()); - - Assert.assertEquals(snapshotAfterFailingWrite, snapshotBeforeFailingWrite); - Assert.assertEquals(resultAfterFailingWrite, resultBeforeFailingWrite); - } -} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestORCDataFrameWrite.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestORCDataFrameWrite.java new file mode 100644 index 000000000000..35be6423ee23 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestORCDataFrameWrite.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; + +public class TestORCDataFrameWrite extends DataFrameWriteTestBase { + @Override + protected void configureTable(Table table) { + table + .updateProperties() + .set(TableProperties.DEFAULT_FILE_FORMAT, FileFormat.ORC.toString()) + .commit(); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestParquetDataFrameWrite.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestParquetDataFrameWrite.java new file mode 100644 index 000000000000..90a9ac48a486 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestParquetDataFrameWrite.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; + +public class TestParquetDataFrameWrite extends DataFrameWriteTestBase { + @Override + protected void configureTable(Table table) { + table + .updateProperties() + .set(TableProperties.DEFAULT_FILE_FORMAT, FileFormat.PARQUET.toString()) + .commit(); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestParquetScan.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestParquetScan.java index f585ed360f95..6b9ec85b7f0b 100644 --- a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestParquetScan.java +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestParquetScan.java @@ -19,13 +19,13 @@ package org.apache.iceberg.spark.source; import static org.apache.iceberg.Files.localOutput; +import static org.assertj.core.api.Assumptions.assumeThat; import java.io.File; import java.io.IOException; import java.util.List; import java.util.UUID; import org.apache.avro.generic.GenericData; -import org.apache.hadoop.conf.Configuration; import org.apache.iceberg.DataFile; import org.apache.iceberg.DataFiles; import org.apache.iceberg.FileFormat; @@ -33,108 +33,55 @@ import org.apache.iceberg.Schema; import org.apache.iceberg.Table; import org.apache.iceberg.TableProperties; -import org.apache.iceberg.hadoop.HadoopTables; import org.apache.iceberg.io.FileAppender; import org.apache.iceberg.parquet.Parquet; -import org.apache.iceberg.spark.data.AvroDataTest; -import org.apache.iceberg.spark.data.RandomData; -import org.apache.iceberg.spark.data.TestHelpers; import org.apache.iceberg.types.TypeUtil; import org.apache.iceberg.types.Types; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.SparkSession; -import org.junit.AfterClass; -import org.junit.Assert; -import org.junit.Assume; -import org.junit.BeforeClass; -import org.junit.Rule; -import org.junit.rules.TemporaryFolder; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -@RunWith(Parameterized.class) -public class TestParquetScan extends AvroDataTest { - private static final Configuration CONF = new Configuration(); - - private static SparkSession spark = null; - - @BeforeClass - public static void startSpark() { - TestParquetScan.spark = SparkSession.builder().master("local[2]").getOrCreate(); - } - - @AfterClass - public static void stopSpark() { - SparkSession currentSpark = TestParquetScan.spark; - TestParquetScan.spark = null; - currentSpark.stop(); - } - - @Rule public TemporaryFolder temp = new TemporaryFolder(); - - @Parameterized.Parameters(name = "vectorized = {0}") - public static Object[] parameters() { - return new Object[] {false, true}; +public class TestParquetScan extends ScanTestBase { + protected boolean vectorized() { + return false; } - private final boolean vectorized; - - public TestParquetScan(boolean vectorized) { - this.vectorized = vectorized; + @Override + protected void configureTable(Table table) { + table + .updateProperties() + .set(TableProperties.PARQUET_VECTORIZATION_ENABLED, String.valueOf(vectorized())) + .commit(); } @Override - protected void writeAndValidate(Schema schema) throws IOException { - Assume.assumeTrue( - "Cannot handle non-string map keys in parquet-avro", - null - == TypeUtil.find( - schema, - type -> type.isMapType() && type.asMapType().keyType() != Types.StringType.get())); - - File parent = temp.newFolder("parquet"); - File location = new File(parent, "test"); - File dataFolder = new File(location, "data"); - dataFolder.mkdirs(); + protected void writeRecords(Table table, List records) throws IOException { + File dataFolder = new File(table.location(), "data"); File parquetFile = new File(dataFolder, FileFormat.PARQUET.addExtension(UUID.randomUUID().toString())); - HadoopTables tables = new HadoopTables(CONF); - Table table = tables.create(schema, PartitionSpec.unpartitioned(), location.toString()); - - // Important: use the table's schema for the rest of the test - // When tables are created, the column ids are reassigned. - Schema tableSchema = table.schema(); - - List expected = RandomData.generateList(tableSchema, 100, 1L); - try (FileAppender writer = - Parquet.write(localOutput(parquetFile)).schema(tableSchema).build()) { - writer.addAll(expected); + Parquet.write(localOutput(parquetFile)).schema(table.schema()).build()) { + writer.addAll(records); } DataFile file = DataFiles.builder(PartitionSpec.unpartitioned()) .withFileSizeInBytes(parquetFile.length()) .withPath(parquetFile.toString()) - .withRecordCount(100) + .withRecordCount(records.size()) .build(); table.newAppend().appendFile(file).commit(); - table - .updateProperties() - .set(TableProperties.PARQUET_VECTORIZATION_ENABLED, String.valueOf(vectorized)) - .commit(); - - Dataset df = spark.read().format("iceberg").load(location.toString()); - - List rows = df.collectAsList(); - Assert.assertEquals("Should contain 100 rows", 100, rows.size()); + } - for (int i = 0; i < expected.size(); i += 1) { - TestHelpers.assertEqualsSafe(tableSchema.asStruct(), expected.get(i), rows.get(i)); - } + @Override + protected void writeAndValidate(Schema writeSchema, Schema expectedSchema) throws IOException { + assumeThat( + TypeUtil.find( + writeSchema, + type -> type.isMapType() && type.asMapType().keyType() != Types.StringType.get())) + .as("Cannot handle non-string map keys in parquet-avro") + .isNull(); + + super.writeAndValidate(writeSchema, expectedSchema); } } diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestParquetVectorizedScan.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestParquetVectorizedScan.java new file mode 100644 index 000000000000..a6b5166b3a4e --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestParquetVectorizedScan.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +public class TestParquetVectorizedScan extends TestParquetScan { + @Override + protected boolean vectorized() { + return true; + } +}