Skip to content

Commit

Permalink
update tensorflow/python/
Browse files Browse the repository at this point in the history
  • Loading branch information
ScXfjiang committed Oct 1, 2024
1 parent 1086598 commit 069a82c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
8 changes: 7 additions & 1 deletion tensorflow/python/lib/core/ndarray_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,13 @@ Status PyArray_TYPE_to_TF_DataType(PyArrayObject* array,
} else if (pyarray_type == custom_dtypes.float8_e4m3fn) {
*out_tf_datatype = TF_FLOAT8_E4M3FN;
break;
} else if (pyarray_type == custom_dtypes.int4) {
} else if (pyarray_type == custom_dtypes.float8_e5m2fnuz) {
*out_tf_datatype = TF_FLOAT8_E5M2FNUZ;
break;
} else if (pyarray_type == custom_dtypes.float8_e4m3fnuz) {
*out_tf_datatype = TF_FLOAT8_E4M3FNUZ;
break;
}else if (pyarray_type == custom_dtypes.int4) {
*out_tf_datatype = TF_INT4;
break;
} else if (pyarray_type == custom_dtypes.uint4) {
Expand Down
6 changes: 6 additions & 0 deletions tensorflow/python/lib/core/ndarray_tensor_bridge.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,12 @@ Status TF_DataType_to_PyArray_TYPE(TF_DataType tf_datatype,
case TF_FLOAT8_E4M3FN:
*out_pyarray_type = custom_dtypes.float8_e4m3fn;
break;
case TF_FLOAT8_E5M2FNUZ:
*out_pyarray_type = custom_dtypes.float8_e5m2fnuz;
break;
case TF_FLOAT8_E4M3FNUZ:
*out_pyarray_type = custom_dtypes.float8_e4m3fnuz;
break;
case TF_INT4:
*out_pyarray_type = custom_dtypes.int4;
break;
Expand Down

0 comments on commit 069a82c

Please sign in to comment.