From b4b54621e201deb9cd253b2f8bea33b2e3013c5c Mon Sep 17 00:00:00 2001 From: scxfjiang Date: Tue, 1 Oct 2024 14:44:53 +0000 Subject: [PATCH] tensorflow/compiler/ --- tensorflow/compiler/jit/xla_gpu_device.cc | 31 ++++++++++++++---- .../stablehlo/utils/bfloat16_type_test.cc | 14 ++++++-- .../stablehlo/utils/tf_type_utils_test.cc | 24 +++++++------- .../compiler/mlir/tensorflow/ir/tf_op_base.td | 8 ++++- .../compiler/mlir/tensorflow/ir/tf_types.def | 2 ++ .../mlir/tensorflow/utils/convert_tensor.cc | 22 +++++++++---- .../tensorflow/utils/convert_tensor_test.cc | 20 ++++++++---- .../mlir/tensorflow/utils/convert_type.cc | 14 +++++++- tensorflow/compiler/tests/const_test.py | 2 ++ tensorflow/compiler/tests/unary_ops_test.py | 13 ++++---- tensorflow/compiler/tf2xla/type_util.cc | 8 +++++ tensorflow/compiler/tf2xla/xla_op_registry.h | 32 +++++++++++-------- 12 files changed, 136 insertions(+), 54 deletions(-) diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index a16415ececc035..f8e9a429fd15b9 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -30,10 +30,10 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_platform_info.h" #include "tensorflow/compiler/tf2xla/layout_util.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/stream_executor/gpu/gpu_init.h" -#include "xla/stream_executor/platform_manager.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/lib/core/status.h" +#include "xla/stream_executor/gpu/gpu_init.h" +#include "xla/stream_executor/platform_manager.h" namespace tensorflow { @@ -155,11 +155,28 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory); // Kernel registrations -constexpr std::array kAllXlaGpuTypes = { - {DT_UINT8, DT_QUINT8, DT_UINT16, DT_INT8, DT_QINT8, - DT_INT16, DT_INT32, DT_QINT32, DT_INT64, DT_HALF, - DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BOOL, - DT_BFLOAT16, DT_FLOAT8_E5M2, DT_FLOAT8_E4M3FN, DT_INT4, DT_UINT4}}; +constexpr std::array kAllXlaGpuTypes = {{DT_UINT8, + DT_QUINT8, + DT_UINT16, + DT_INT8, + DT_QINT8, + DT_INT16, + DT_INT32, + DT_QINT32, + DT_INT64, + DT_HALF, + DT_FLOAT, + DT_DOUBLE, + DT_COMPLEX64, + DT_COMPLEX128, + DT_BOOL, + DT_BFLOAT16, + DT_FLOAT8_E5M2, + DT_FLOAT8_E4M3FN, + DT_FLOAT8_E5M2FNUZ, + DT_FLOAT8_E4M3FNUZ, + DT_INT4, + DT_UINT4}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes); REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/utils/bfloat16_type_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/utils/bfloat16_type_test.cc index 45fb47565ea9e3..7cfd0a38d930b2 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/utils/bfloat16_type_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/utils/bfloat16_type_test.cc @@ -14,11 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/stablehlo/utils/bfloat16_type.h" +#include + #include -#include #include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/compiler/mlir/register_common_dialects.h" namespace mlir::quant::stablehlo { @@ -36,6 +37,7 @@ TEST(IsLargeFloatTypeTest, scalars) { auto context = CreateContext(); EXPECT_FALSE(IsLargeFloatType(Float8E4M3FNType::get(context.get()))); + EXPECT_FALSE(IsLargeFloatType(Float8E4M3FNUZType::get(context.get()))); EXPECT_FALSE(IsLargeFloatType(Float16Type::get(context.get()))); EXPECT_FALSE(IsLargeFloatType(BFloat16Type::get(context.get()))); EXPECT_TRUE(IsLargeFloatType(Float32Type::get(context.get()))); @@ -52,6 +54,8 @@ TEST(IsLargeFloatTypeTest, tensors) { EXPECT_FALSE(IsLargeFloatType( RankedTensorType::get({2, 2}, Float8E4M3FNType::get(context.get())))); + EXPECT_FALSE(IsLargeFloatType( + RankedTensorType::get({2, 2}, Float8E4M3FNUZType::get(context.get())))); EXPECT_FALSE(IsLargeFloatType( RankedTensorType::get({2, 2}, Float16Type::get(context.get())))); EXPECT_FALSE(IsLargeFloatType( @@ -76,6 +80,8 @@ TEST(ToBfloat16TypeTest, scalars) { EXPECT_EQ(ToBfloat16Type(Float8E4M3FNType::get(context.get())), Float8E4M3FNType::get(context.get())); + EXPECT_EQ(ToBfloat16Type(Float8E4M3FNUZType::get(context.get())), + Float8E4M3FNUZType::get(context.get())); EXPECT_EQ(ToBfloat16Type(Float16Type::get(context.get())), Float16Type::get(context.get())); EXPECT_EQ(ToBfloat16Type(BFloat16Type::get(context.get())), @@ -102,6 +108,10 @@ TEST(ToBfloat16TypeTest, tensors) { ToBfloat16Type( RankedTensorType::get({2, 2}, Float8E4M3FNType::get(context.get()))), RankedTensorType::get({2, 2}, Float8E4M3FNType::get(context.get()))); + EXPECT_EQ( + ToBfloat16Type(RankedTensorType::get( + {2, 2}, Float8E4M3FNUZType::get(context.get()))), + RankedTensorType::get({2, 2}, Float8E4M3FNUZType::get(context.get()))); EXPECT_EQ(ToBfloat16Type( RankedTensorType::get({2, 2}, Float16Type::get(context.get()))), RankedTensorType::get({2, 2}, Float16Type::get(context.get()))); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils_test.cc index 62abb400ca5b34..733979b0ebc825 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils_test.cc @@ -15,30 +15,31 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils.h" +#include +#include + #include #include #include -#include -#include #include "llvm/Support/Casting.h" -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/register_common_dialects.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/tsl/framework/numeric_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/ir/types/dialect.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/tsl/framework/numeric_types.h" namespace mlir::quant::tensorflow { namespace { @@ -182,6 +183,7 @@ TEST(IsTFQintTypeTest, ValidTFQintTypeSucceeds) { EXPECT_FALSE(IsTFQintType(TF::Int8RefType::get(context.get()))); EXPECT_FALSE(IsTFQintType(TF::Float8E5M2RefType::get(context.get()))); + EXPECT_FALSE(IsTFQintType(TF::Float8E5M2FNUZRefType::get(context.get()))); } TEST(GetIntTypeFromTFQintTest, ChecksIntTypesFromTFQint) { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td index a5bb0051cc8fe4..88901ddd5aa2ad 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td @@ -327,6 +327,8 @@ def TF_Float64Ref : TF_TensorFlowType<"DoubleRef", "f64ref">; def TF_Bfloat16Ref : TF_TensorFlowType<"Bfloat16Ref", "bf16ref">; def TF_Float8E4M3FNRef : TF_TensorFlowType<"Float8E4M3FNRef", "float8e4m3fnref">; def TF_Float8E5M2Ref : TF_TensorFlowType<"Float8E5M2Ref", "float8e5m2ref">; +def TF_Float8E4M3FNUZRef : TF_TensorFlowType<"Float8E4M3FNUZRef", "float8e4m3fnuzref">; +def TF_Float8E5M2FNUZRef : TF_TensorFlowType<"Float8E5M2FNUZRef", "float8e5m2fnuzref">; // Complex reference types def TF_Complex64Ref : TF_TensorFlowType<"Complex64Ref", "complex64ref">; @@ -443,12 +445,14 @@ def TF_Float64 : AnyTypeOf<[F64, TF_Float64Ref], "64-bit float">; def TF_Bfloat16 : AnyTypeOf<[BF16, TF_Bfloat16Ref], "bfloat16">; def TF_Float8E4M3FN : AnyTypeOf<[F8E4M3FN, TF_Float8E4M3FNRef], "float8e4m3fn">; def TF_Float8E5M2 : AnyTypeOf<[F8E5M2, TF_Float8E5M2Ref], "float8e5m2">; +def TF_Float8E4M3FNUZ : AnyTypeOf<[F8E4M3FNUZ, TF_Float8E4M3FNUZRef], "float8e4m3fnuz">; +def TF_Float8E5M2FNUZ : AnyTypeOf<[F8E5M2FNUZ, TF_Float8E5M2FNUZRef], "float8e5m2fnuz">; def TF_F32OrF64 : AnyTypeOf<[TF_Float32, TF_Float64], "32/64-bit float">; def TF_Float : AnyTypeOf< [TF_Float16, TF_Float32, TF_Float64, TF_Bfloat16, TF_Float8E4M3FN, - TF_Float8E5M2], + TF_Float8E5M2, TF_Float8E4M3FNUZ, TF_Float8E5M2FNUZ], "floating-point">; // Tensor types @@ -460,6 +464,8 @@ def TF_Float64Tensor : TensorOf<[TF_Float64]>; def TF_Bfloat16Tensor : TensorOf<[TF_Bfloat16]>; def TF_Float8E4M3FNTensor : TensorOf<[TF_Float8E4M3FN]>; def TF_Float8E5M2Tensor : TensorOf<[TF_Float8E5M2]>; +def TF_Float8E4M3FNUZTensor : TensorOf<[TF_Float8E4M3FNUZ]>; +def TF_Float8E5M2FNUZTensor : TensorOf<[TF_Float8E5M2FNUZ]>; //===----------------------------------------------------------------------===// // Complex types (including corresponding reference types) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.def b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.def index 17daa6afdcaf4b..80caf094cfe7ea 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.def +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.def @@ -68,6 +68,8 @@ HANDLE_TF_REF_TYPE(HalfRef, HALF_REF, "halfref") HANDLE_TF_REF_TYPE(ResourceRef, RESOURCE_REF, "resourceref") HANDLE_TF_REF_TYPE(Float8E4M3FNRef, FLOAT8_E4M3FN_REF, "float8e4m3fnref") HANDLE_TF_REF_TYPE(Float8E5M2Ref, FLOAT8_E5M2_REF, "float8e5m2ref") +HANDLE_TF_REF_TYPE(Float8E4M3FNUZRef, FLOAT8_E4M3FNUZ_REF, "float8e4m3fnuzref") +HANDLE_TF_REF_TYPE(Float8E5M2FNUZRef, FLOAT8_E5M2FNUZ_REF, "float8e5m2fnuzref") #ifndef HANDLE_LAST_TF_TYPE #define HANDLE_LAST_TF_TYPE(class, enumerant, name) \ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc index b9fef486428977..c23590ab2bd397 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc @@ -28,13 +28,13 @@ limitations under the License. #include "llvm/ADT/APFloat.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" @@ -149,6 +149,8 @@ absl::StatusOr ConvertTensor(const Tensor& input_tensor, case DT_HALF: case DT_FLOAT8_E5M2: case DT_FLOAT8_E4M3FN: + case DT_FLOAT8_E5M2FNUZ: + case DT_FLOAT8_E4M3FNUZ: return ConvertTensorOfCustomFloatType(input_tensor, type); case DT_STRING: return ConvertStringTensor(input_tensor, type); @@ -466,6 +468,14 @@ Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) { ConvertFloat8ElementsAttr( dense_attr, output->mutable_float8_val()); break; + case DT_FLOAT8_E5M2FNUZ: + ConvertFloat8ElementsAttr( + dense_attr, output->mutable_float8_val()); + break; + case DT_FLOAT8_E4M3FNUZ: + ConvertFloat8ElementsAttr( + dense_attr, output->mutable_float8_val()); + break; case tensorflow::DT_INT4: ConvertIntElementsAttr(dense_attr, output->mutable_int_val(), diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc index 3feed8904fab0e..dbbfc48a434e98 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc @@ -18,15 +18,14 @@ limitations under the License. #include #include -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/Dialect.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" -#include "xla/test.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/tensor_util.h" #include "tensorflow/core/framework/types.pb.h" @@ -34,6 +33,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/types.h" #include "tsl/platform/ml_dtypes.h" +#include "xla/test.h" namespace tensorflow { namespace { @@ -42,7 +42,7 @@ using ::testing::Eq; using ::testing::IsFalse; using ::testing::IsTrue; -static void RegisterDialects(mlir::MLIRContext &context) { +static void RegisterDialects(mlir::MLIRContext& context) { context.loadDialect(); } @@ -148,6 +148,12 @@ TEST_F(ConvertTensorTest, Simple) { ASSERT_NO_FATAL_FAILURE(VerifyConversion( {tsl::float8_e4m3fn{1.0}, tsl::float8_e4m3fn{-1.0}}, DT_FLOAT8_E4M3FN, mlir::FloatType::getFloat8E4M3FN(&context))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {tsl::float8_e5m2fnuz{1.0}, tsl::float8_e5m2fnuz{-1.0}}, + DT_FLOAT8_E5M2FNUZ, mlir::FloatType::getFloat8E5M2FNUZ(&context))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {tsl::float8_e4m3fnuz{1.0}, tsl::float8_e4m3fnuz{-1.0}}, + DT_FLOAT8_E4M3FNUZ, mlir::FloatType::getFloat8E4M3FNUZ(&context))); ASSERT_NO_FATAL_FAILURE(VerifyConversion( {static_cast(1), static_cast(-1)}, DT_INT4, diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc index e3404d613c9f83..6ab13690705172 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc @@ -88,6 +88,12 @@ Status ConvertDataType(DataType dtype, Builder builder, Type* type) { case tensorflow::DT_FLOAT8_E5M2: *type = builder.getFloat8E5M2Type(); return absl::OkStatus(); + case tensorflow::DT_FLOAT8_E4M3FNUZ: + *type = builder.getFloat8E4M3FNUZType(); + return absl::OkStatus(); + case tensorflow::DT_FLOAT8_E5M2FNUZ: + *type = builder.getFloat8E5M2FNUZType(); + return absl::OkStatus(); case DT_INT4: *type = builder.getIntegerType(4, /*isSigned=*/true); return absl::OkStatus(); @@ -125,7 +131,13 @@ Status ConvertScalarTypeToDataType(Type type, DataType* dtype) { } else if (type.isFloat8E5M2()) { *dtype = DT_FLOAT8_E5M2; return absl::OkStatus(); - } else if (auto itype = mlir::dyn_cast(type)) { + } else if (type.isFloat8E4M3FNUZ()) { + *dtype = DT_FLOAT8_E4M3FNUZ; + return absl::OkStatus(); + } else if (type.isFloat8E5M2FNUZ()) { + *dtype = DT_FLOAT8_E5M2FNUZ; + return absl::OkStatus(); + }else if (auto itype = mlir::dyn_cast(type)) { switch (itype.getWidth()) { case 1: *dtype = DT_BOOL; diff --git a/tensorflow/compiler/tests/const_test.py b/tensorflow/compiler/tests/const_test.py index bb1f3e23a7306e..5eee656370921e 100644 --- a/tensorflow/compiler/tests/const_test.py +++ b/tensorflow/compiler/tests/const_test.py @@ -48,6 +48,8 @@ def testConst(self): dtypes.float64, dtypes.float8_e5m2, dtypes.float8_e4m3fn, + dtypes.float8_e5m2fnuz, + dtypes.float8_e4m3fnuz, } for dtype in types: with self.subTest(dtype=dtype): diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 99b997561b41c3..189b97c3cc9c95 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -577,7 +577,7 @@ def quantize_and_dequantize_v2_round_half_up(x): -128.0 / 127, 1, ], - dtype=dtype)) + dtype=dtype)) def quantize_and_dequantize_v2_round_half_to_even(x): return array_ops.quantize_and_dequantize( @@ -601,7 +601,7 @@ def quantize_and_dequantize_v2_round_half_to_even(x): -128.0 / 127, 1, ], - dtype=dtype)) + dtype=dtype)) def testComplexOps(self): for dtype in self.complex_types: @@ -891,7 +891,8 @@ def testCastFp8(self): # TODO(b/271327511): Fix issue where casts to FP8 very rarely result in # NaN on Mac self.skipTest("Casts to FP8 sometimes result in NaN on Mac") - fp8_types = {dtypes.float8_e5m2, dtypes.float8_e4m3fn} + fp8_types = {dtypes.float8_e5m2, dtypes.float8_e4m3fn, + dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz} other_types = { dtypes.bool, dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.int32, dtypes.int64, dtypes.uint32, dtypes.uint64 @@ -1021,7 +1022,7 @@ def invert_twice(x): expected=np.array([1, 2, 0], dtype=np_dtype)) def testRank(self): - rank_op = lambda x: array_ops.rank_internal(x, optimize=False) + def rank_op(x): return array_ops.rank_internal(x, optimize=False) for dtype in self.numeric_types: self._assertOpOutputMatchesExpected( rank_op, dtype(7), expected=np.int32(0)) @@ -1037,7 +1038,7 @@ def testRank(self): expected=np.int32(2)) def testShape(self): - shape_op = lambda x: array_ops.shape_internal(x, optimize=False) + def shape_op(x): return array_ops.shape_internal(x, optimize=False) for dtype in self.numeric_types: self._assertOpOutputMatchesExpected( shape_op, dtype(7), expected=np.array([], dtype=np.int32)) @@ -1059,7 +1060,7 @@ def testShape(self): expected=np.array([3, 1], dtype=np.int32)) def testSize(self): - size_op = lambda x: array_ops.size_internal(x, optimize=False) + def size_op(x): return array_ops.size_internal(x, optimize=False) for dtype in self.numeric_types: self._assertOpOutputMatchesExpected( size_op, dtype(7), expected=np.int32(1)) diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc index 655a2c3cdec160..3f9f4647d35094 100644 --- a/tensorflow/compiler/tf2xla/type_util.cc +++ b/tensorflow/compiler/tf2xla/type_util.cc @@ -70,6 +70,12 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { case tensorflow::DT_FLOAT8_E4M3FN: *type = xla::F8E4M3FN; return absl::OkStatus(); + case tensorflow::DT_FLOAT8_E5M2FNUZ: + *type = xla::F8E5M2FNUZ; + return absl::OkStatus(); + case tensorflow::DT_FLOAT8_E4M3FNUZ: + *type = xla::F8E4M3FNUZ; + return absl::OkStatus(); case tensorflow::DT_BFLOAT16: *type = xla::BF16; return absl::OkStatus(); @@ -102,6 +108,8 @@ absl::StatusOr EncodePrimitiveTypeAsDataType( {xla::PRED, DT_BOOL}, {xla::F8E5M2, DT_FLOAT8_E5M2}, {xla::F8E4M3FN, DT_FLOAT8_E4M3FN}, + {xla::F8E5M2FNUZ, DT_FLOAT8_E5M2FNUZ}, + {xla::F8E4M3FNUZ, DT_FLOAT8_E4M3FNUZ}, {xla::BF16, DT_BFLOAT16}, {xla::F16, DT_HALF}, {xla::F32, DT_FLOAT}, diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 333a9168f3deda..8f20e9b71f427c 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -56,19 +56,25 @@ constexpr std::array kNumericTypes = { DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BFLOAT16}}; -constexpr std::array kCpuAllTypes = { - {DT_UINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, - DT_INT8, DT_QINT8, DT_INT16, DT_INT32, DT_QINT32, - DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, - DT_COMPLEX128, DT_BOOL, DT_BFLOAT16, DT_FLOAT8_E5M2, DT_FLOAT8_E4M3FN, - DT_INT4, DT_UINT4}}; - -constexpr std::array kGpuAllTypes = { - {DT_UINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, - DT_INT8, DT_QINT8, DT_INT16, DT_INT32, DT_QINT32, - DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, - DT_COMPLEX128, DT_BOOL, DT_BFLOAT16, DT_FLOAT8_E5M2, DT_FLOAT8_E4M3FN, - DT_INT4, DT_UINT4}}; +constexpr std::array kCpuAllTypes = { + {DT_UINT8, DT_QUINT8, DT_UINT16, + DT_UINT32, DT_UINT64, DT_INT8, + DT_QINT8, DT_INT16, DT_INT32, + DT_QINT32, DT_INT64, DT_HALF, + DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, + DT_COMPLEX128, DT_BOOL, DT_BFLOAT16, + DT_FLOAT8_E5M2, DT_FLOAT8_E4M3FN, DT_FLOAT8_E5M2FNUZ, + DT_FLOAT8_E4M3FNUZ, DT_INT4, DT_UINT4}}; + +constexpr std::array kGpuAllTypes = { + {DT_UINT8, DT_QUINT8, DT_UINT16, + DT_UINT32, DT_UINT64, DT_INT8, + DT_QINT8, DT_INT16, DT_INT32, + DT_QINT32, DT_INT64, DT_HALF, + DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, + DT_COMPLEX128, DT_BOOL, DT_BFLOAT16, + DT_FLOAT8_E5M2, DT_FLOAT8_E4M3FN, DT_FLOAT8_E5M2FNUZ, + DT_FLOAT8_E4M3FNUZ, DT_INT4, DT_UINT4}}; // Class that manages registrations of operators and devices for the XLA JIT. // Not thread-safe.