Skip to content

Commit

Permalink
tensorflow/compiler/
Browse files Browse the repository at this point in the history
  • Loading branch information
ScXfjiang committed Oct 1, 2024
1 parent 18fc45a commit b4b5462
Show file tree
Hide file tree
Showing 12 changed files with 136 additions and 54 deletions.
31 changes: 24 additions & 7 deletions tensorflow/compiler/jit/xla_gpu_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_platform_info.h"
#include "tensorflow/compiler/tf2xla/layout_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "xla/stream_executor/gpu/gpu_init.h"
#include "xla/stream_executor/platform_manager.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/lib/core/status.h"
#include "xla/stream_executor/gpu/gpu_init.h"
#include "xla/stream_executor/platform_manager.h"

namespace tensorflow {

Expand Down Expand Up @@ -155,11 +155,28 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory);

// Kernel registrations

constexpr std::array<DataType, 20> kAllXlaGpuTypes = {
{DT_UINT8, DT_QUINT8, DT_UINT16, DT_INT8, DT_QINT8,
DT_INT16, DT_INT32, DT_QINT32, DT_INT64, DT_HALF,
DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BOOL,
DT_BFLOAT16, DT_FLOAT8_E5M2, DT_FLOAT8_E4M3FN, DT_INT4, DT_UINT4}};
constexpr std::array<DataType, 22> kAllXlaGpuTypes = {{DT_UINT8,
DT_QUINT8,
DT_UINT16,
DT_INT8,
DT_QINT8,
DT_INT16,
DT_INT32,
DT_QINT32,
DT_INT64,
DT_HALF,
DT_FLOAT,
DT_DOUBLE,
DT_COMPLEX64,
DT_COMPLEX128,
DT_BOOL,
DT_BFLOAT16,
DT_FLOAT8_E5M2,
DT_FLOAT8_E4M3FN,
DT_FLOAT8_E5M2FNUZ,
DT_FLOAT8_E4M3FNUZ,
DT_INT4,
DT_UINT4}};

REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes);
REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/quantization/stablehlo/utils/bfloat16_type.h"

#include <gtest/gtest.h>

#include <memory>

#include <gtest/gtest.h>
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "tensorflow/compiler/mlir/register_common_dialects.h"

namespace mlir::quant::stablehlo {
Expand All @@ -36,6 +37,7 @@ TEST(IsLargeFloatTypeTest, scalars) {
auto context = CreateContext();

EXPECT_FALSE(IsLargeFloatType(Float8E4M3FNType::get(context.get())));
EXPECT_FALSE(IsLargeFloatType(Float8E4M3FNUZType::get(context.get())));
EXPECT_FALSE(IsLargeFloatType(Float16Type::get(context.get())));
EXPECT_FALSE(IsLargeFloatType(BFloat16Type::get(context.get())));
EXPECT_TRUE(IsLargeFloatType(Float32Type::get(context.get())));
Expand All @@ -52,6 +54,8 @@ TEST(IsLargeFloatTypeTest, tensors) {

EXPECT_FALSE(IsLargeFloatType(
RankedTensorType::get({2, 2}, Float8E4M3FNType::get(context.get()))));
EXPECT_FALSE(IsLargeFloatType(
RankedTensorType::get({2, 2}, Float8E4M3FNUZType::get(context.get()))));
EXPECT_FALSE(IsLargeFloatType(
RankedTensorType::get({2, 2}, Float16Type::get(context.get()))));
EXPECT_FALSE(IsLargeFloatType(
Expand All @@ -76,6 +80,8 @@ TEST(ToBfloat16TypeTest, scalars) {

EXPECT_EQ(ToBfloat16Type(Float8E4M3FNType::get(context.get())),
Float8E4M3FNType::get(context.get()));
EXPECT_EQ(ToBfloat16Type(Float8E4M3FNUZType::get(context.get())),
Float8E4M3FNUZType::get(context.get()));
EXPECT_EQ(ToBfloat16Type(Float16Type::get(context.get())),
Float16Type::get(context.get()));
EXPECT_EQ(ToBfloat16Type(BFloat16Type::get(context.get())),
Expand All @@ -102,6 +108,10 @@ TEST(ToBfloat16TypeTest, tensors) {
ToBfloat16Type(
RankedTensorType::get({2, 2}, Float8E4M3FNType::get(context.get()))),
RankedTensorType::get({2, 2}, Float8E4M3FNType::get(context.get())));
EXPECT_EQ(
ToBfloat16Type(RankedTensorType::get(
{2, 2}, Float8E4M3FNUZType::get(context.get()))),
RankedTensorType::get({2, 2}, Float8E4M3FNUZType::get(context.get())));
EXPECT_EQ(ToBfloat16Type(
RankedTensorType::get({2, 2}, Float16Type::get(context.get()))),
RankedTensorType::get({2, 2}, Float16Type::get(context.get())));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,31 @@ limitations under the License.

#include "tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils.h"

#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include <cstdint>
#include <memory>
#include <string>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/register_common_dialects.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/tsl/framework/numeric_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/ir/types/dialect.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/tsl/framework/numeric_types.h"

namespace mlir::quant::tensorflow {
namespace {
Expand Down Expand Up @@ -182,6 +183,7 @@ TEST(IsTFQintTypeTest, ValidTFQintTypeSucceeds) {

EXPECT_FALSE(IsTFQintType(TF::Int8RefType::get(context.get())));
EXPECT_FALSE(IsTFQintType(TF::Float8E5M2RefType::get(context.get())));
EXPECT_FALSE(IsTFQintType(TF::Float8E5M2FNUZRefType::get(context.get())));
}

TEST(GetIntTypeFromTFQintTest, ChecksIntTypesFromTFQint) {
Expand Down
8 changes: 7 additions & 1 deletion tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,8 @@ def TF_Float64Ref : TF_TensorFlowType<"DoubleRef", "f64ref">;
def TF_Bfloat16Ref : TF_TensorFlowType<"Bfloat16Ref", "bf16ref">;
def TF_Float8E4M3FNRef : TF_TensorFlowType<"Float8E4M3FNRef", "float8e4m3fnref">;
def TF_Float8E5M2Ref : TF_TensorFlowType<"Float8E5M2Ref", "float8e5m2ref">;
def TF_Float8E4M3FNUZRef : TF_TensorFlowType<"Float8E4M3FNUZRef", "float8e4m3fnuzref">;
def TF_Float8E5M2FNUZRef : TF_TensorFlowType<"Float8E5M2FNUZRef", "float8e5m2fnuzref">;

// Complex reference types
def TF_Complex64Ref : TF_TensorFlowType<"Complex64Ref", "complex64ref">;
Expand Down Expand Up @@ -443,12 +445,14 @@ def TF_Float64 : AnyTypeOf<[F64, TF_Float64Ref], "64-bit float">;
def TF_Bfloat16 : AnyTypeOf<[BF16, TF_Bfloat16Ref], "bfloat16">;
def TF_Float8E4M3FN : AnyTypeOf<[F8E4M3FN, TF_Float8E4M3FNRef], "float8e4m3fn">;
def TF_Float8E5M2 : AnyTypeOf<[F8E5M2, TF_Float8E5M2Ref], "float8e5m2">;
def TF_Float8E4M3FNUZ : AnyTypeOf<[F8E4M3FNUZ, TF_Float8E4M3FNUZRef], "float8e4m3fnuz">;
def TF_Float8E5M2FNUZ : AnyTypeOf<[F8E5M2FNUZ, TF_Float8E5M2FNUZRef], "float8e5m2fnuz">;

def TF_F32OrF64 : AnyTypeOf<[TF_Float32, TF_Float64], "32/64-bit float">;

def TF_Float : AnyTypeOf<
[TF_Float16, TF_Float32, TF_Float64, TF_Bfloat16, TF_Float8E4M3FN,
TF_Float8E5M2],
TF_Float8E5M2, TF_Float8E4M3FNUZ, TF_Float8E5M2FNUZ],
"floating-point">;

// Tensor types
Expand All @@ -460,6 +464,8 @@ def TF_Float64Tensor : TensorOf<[TF_Float64]>;
def TF_Bfloat16Tensor : TensorOf<[TF_Bfloat16]>;
def TF_Float8E4M3FNTensor : TensorOf<[TF_Float8E4M3FN]>;
def TF_Float8E5M2Tensor : TensorOf<[TF_Float8E5M2]>;
def TF_Float8E4M3FNUZTensor : TensorOf<[TF_Float8E4M3FNUZ]>;
def TF_Float8E5M2FNUZTensor : TensorOf<[TF_Float8E5M2FNUZ]>;

//===----------------------------------------------------------------------===//
// Complex types (including corresponding reference types)
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/mlir/tensorflow/ir/tf_types.def
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ HANDLE_TF_REF_TYPE(HalfRef, HALF_REF, "halfref")
HANDLE_TF_REF_TYPE(ResourceRef, RESOURCE_REF, "resourceref")
HANDLE_TF_REF_TYPE(Float8E4M3FNRef, FLOAT8_E4M3FN_REF, "float8e4m3fnref")
HANDLE_TF_REF_TYPE(Float8E5M2Ref, FLOAT8_E5M2_REF, "float8e5m2ref")
HANDLE_TF_REF_TYPE(Float8E4M3FNUZRef, FLOAT8_E4M3FNUZ_REF, "float8e4m3fnuzref")
HANDLE_TF_REF_TYPE(Float8E5M2FNUZRef, FLOAT8_E5M2FNUZ_REF, "float8e5m2fnuzref")

#ifndef HANDLE_LAST_TF_TYPE
#define HANDLE_LAST_TF_TYPE(class, enumerant, name) \
Expand Down
22 changes: 16 additions & 6 deletions tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ limitations under the License.
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/Support/DebugStringHelper.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
Expand Down Expand Up @@ -149,6 +149,8 @@ absl::StatusOr<ElementsAttr> ConvertTensor(const Tensor& input_tensor,
case DT_HALF:
case DT_FLOAT8_E5M2:
case DT_FLOAT8_E4M3FN:
case DT_FLOAT8_E5M2FNUZ:
case DT_FLOAT8_E4M3FNUZ:
return ConvertTensorOfCustomFloatType(input_tensor, type);
case DT_STRING:
return ConvertStringTensor(input_tensor, type);
Expand Down Expand Up @@ -466,6 +468,14 @@ Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) {
ConvertFloat8ElementsAttr<tsl::float8_e4m3fn>(
dense_attr, output->mutable_float8_val());
break;
case DT_FLOAT8_E5M2FNUZ:
ConvertFloat8ElementsAttr<tsl::float8_e5m2fnuz>(
dense_attr, output->mutable_float8_val());
break;
case DT_FLOAT8_E4M3FNUZ:
ConvertFloat8ElementsAttr<tsl::float8_e4m3fnuz>(
dense_attr, output->mutable_float8_val());
break;
case tensorflow::DT_INT4:
ConvertIntElementsAttr<int, tsl::int4>(dense_attr,
output->mutable_int_val(),
Expand Down
20 changes: 13 additions & 7 deletions tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,22 @@ limitations under the License.
#include <cstring>
#include <initializer_list>

#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h"
#include "xla/test.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/types.h"
#include "tsl/platform/ml_dtypes.h"
#include "xla/test.h"

namespace tensorflow {
namespace {
Expand All @@ -42,7 +42,7 @@ using ::testing::Eq;
using ::testing::IsFalse;
using ::testing::IsTrue;

static void RegisterDialects(mlir::MLIRContext &context) {
static void RegisterDialects(mlir::MLIRContext& context) {
context.loadDialect<mlir::TF::TensorFlowDialect>();
}

Expand Down Expand Up @@ -148,6 +148,12 @@ TEST_F(ConvertTensorTest, Simple) {
ASSERT_NO_FATAL_FAILURE(VerifyConversion<tsl::float8_e4m3fn>(
{tsl::float8_e4m3fn{1.0}, tsl::float8_e4m3fn{-1.0}}, DT_FLOAT8_E4M3FN,
mlir::FloatType::getFloat8E4M3FN(&context)));
ASSERT_NO_FATAL_FAILURE(VerifyConversion<tsl::float8_e5m2fnuz>(
{tsl::float8_e5m2fnuz{1.0}, tsl::float8_e5m2fnuz{-1.0}},
DT_FLOAT8_E5M2FNUZ, mlir::FloatType::getFloat8E5M2FNUZ(&context)));
ASSERT_NO_FATAL_FAILURE(VerifyConversion<tsl::float8_e4m3fnuz>(
{tsl::float8_e4m3fnuz{1.0}, tsl::float8_e4m3fnuz{-1.0}},
DT_FLOAT8_E4M3FNUZ, mlir::FloatType::getFloat8E4M3FNUZ(&context)));

ASSERT_NO_FATAL_FAILURE(VerifyConversion<int4>(
{static_cast<int4>(1), static_cast<int4>(-1)}, DT_INT4,
Expand Down
14 changes: 13 additions & 1 deletion tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ Status ConvertDataType(DataType dtype, Builder builder, Type* type) {
case tensorflow::DT_FLOAT8_E5M2:
*type = builder.getFloat8E5M2Type();
return absl::OkStatus();
case tensorflow::DT_FLOAT8_E4M3FNUZ:
*type = builder.getFloat8E4M3FNUZType();
return absl::OkStatus();
case tensorflow::DT_FLOAT8_E5M2FNUZ:
*type = builder.getFloat8E5M2FNUZType();
return absl::OkStatus();
case DT_INT4:
*type = builder.getIntegerType(4, /*isSigned=*/true);
return absl::OkStatus();
Expand Down Expand Up @@ -125,7 +131,13 @@ Status ConvertScalarTypeToDataType(Type type, DataType* dtype) {
} else if (type.isFloat8E5M2()) {
*dtype = DT_FLOAT8_E5M2;
return absl::OkStatus();
} else if (auto itype = mlir::dyn_cast<mlir::IntegerType>(type)) {
} else if (type.isFloat8E4M3FNUZ()) {
*dtype = DT_FLOAT8_E4M3FNUZ;
return absl::OkStatus();
} else if (type.isFloat8E5M2FNUZ()) {
*dtype = DT_FLOAT8_E5M2FNUZ;
return absl::OkStatus();
}else if (auto itype = mlir::dyn_cast<mlir::IntegerType>(type)) {
switch (itype.getWidth()) {
case 1:
*dtype = DT_BOOL;
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/tests/const_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def testConst(self):
dtypes.float64,
dtypes.float8_e5m2,
dtypes.float8_e4m3fn,
dtypes.float8_e5m2fnuz,
dtypes.float8_e4m3fnuz,
}
for dtype in types:
with self.subTest(dtype=dtype):
Expand Down
13 changes: 7 additions & 6 deletions tensorflow/compiler/tests/unary_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def quantize_and_dequantize_v2_round_half_up(x):
-128.0 / 127,
1,
],
dtype=dtype))
dtype=dtype))

def quantize_and_dequantize_v2_round_half_to_even(x):
return array_ops.quantize_and_dequantize(
Expand All @@ -601,7 +601,7 @@ def quantize_and_dequantize_v2_round_half_to_even(x):
-128.0 / 127,
1,
],
dtype=dtype))
dtype=dtype))

def testComplexOps(self):
for dtype in self.complex_types:
Expand Down Expand Up @@ -891,7 +891,8 @@ def testCastFp8(self):
# TODO(b/271327511): Fix issue where casts to FP8 very rarely result in
# NaN on Mac
self.skipTest("Casts to FP8 sometimes result in NaN on Mac")
fp8_types = {dtypes.float8_e5m2, dtypes.float8_e4m3fn}
fp8_types = {dtypes.float8_e5m2, dtypes.float8_e4m3fn,
dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz}
other_types = {
dtypes.bool, dtypes.float32, dtypes.float64, dtypes.complex64,
dtypes.int32, dtypes.int64, dtypes.uint32, dtypes.uint64
Expand Down Expand Up @@ -1021,7 +1022,7 @@ def invert_twice(x):
expected=np.array([1, 2, 0], dtype=np_dtype))

def testRank(self):
rank_op = lambda x: array_ops.rank_internal(x, optimize=False)
def rank_op(x): return array_ops.rank_internal(x, optimize=False)
for dtype in self.numeric_types:
self._assertOpOutputMatchesExpected(
rank_op, dtype(7), expected=np.int32(0))
Expand All @@ -1037,7 +1038,7 @@ def testRank(self):
expected=np.int32(2))

def testShape(self):
shape_op = lambda x: array_ops.shape_internal(x, optimize=False)
def shape_op(x): return array_ops.shape_internal(x, optimize=False)
for dtype in self.numeric_types:
self._assertOpOutputMatchesExpected(
shape_op, dtype(7), expected=np.array([], dtype=np.int32))
Expand All @@ -1059,7 +1060,7 @@ def testShape(self):
expected=np.array([3, 1], dtype=np.int32))

def testSize(self):
size_op = lambda x: array_ops.size_internal(x, optimize=False)
def size_op(x): return array_ops.size_internal(x, optimize=False)
for dtype in self.numeric_types:
self._assertOpOutputMatchesExpected(
size_op, dtype(7), expected=np.int32(1))
Expand Down
Loading

0 comments on commit b4b5462

Please sign in to comment.