-
Notifications
You must be signed in to change notification settings - Fork 950
Allow different dtypes in binary math ops #1432
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: 0 of 1 approvals obtained (waiting on @dsmilkov, @nsthorat, @caisq, and @tafsiri)
src/tensor_util.ts, line 23 at r2 (raw file):
makeTypesMatch
optional: Maybe name the function makeTwoTypesMatch
to be specific.
src/tensor_util.ts, line 24 at r2 (raw file):
const dtype = upcastType(a.dtype, b.dtype);
Does it make sense to have a shortcut path like
if (a.dtype === b.dtype) {
return [a, b];
}
for efficiency? That path should be hit in a majority of the cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: 0 of 1 approvals obtained (waiting on @dsmilkov, @nsthorat, and @tafsiri)
src/ops/binary_ops.ts, line 58 at r2 (raw file):
let $a = convertToTensor(a, 'a', 'add'); let $b = convertToTensor(b, 'b', 'add'); [$a, $b] = makeTypesMatch($a, $b);
Optional: Since this three-line pattern occurs so frequently, you may create a new function that can be called like:
const [$a, $b] = convertTensorsAndMatchDTypes(a, 'a', b, 'b', 'add');
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. PTAL.
Reviewable status: 0 of 1 approvals obtained (waiting on @caisq, @nsthorat, and @tafsiri)
src/tensor_util.ts, line 23 at r2 (raw file):
Previously, caisq (Shanqing Cai) wrote…
makeTypesMatch
optional: Maybe name the function
makeTwoTypesMatch
to be specific.
Acknowledged.
src/tensor_util.ts, line 24 at r2 (raw file):
Previously, caisq (Shanqing Cai) wrote…
const dtype = upcastType(a.dtype, b.dtype);
Does it make sense to have a shortcut path like
if (a.dtype === b.dtype) { return [a, b]; }for efficiency? That path should be hit in a majority of the cases.
Thanks. Casting a tensor to the same dtype is a no-op, but it does result in more function calls in the stack, thus adding your suggestion.
src/ops/binary_ops.ts, line 58 at r2 (raw file):
Previously, caisq (Shanqing Cai) wrote…
let $a = convertToTensor(a, 'a', 'add'); let $b = convertToTensor(b, 'b', 'add'); [$a, $b] = makeTypesMatch($a, $b);
Optional: Since this three-line pattern occurs so frequently, you may create a new function that can be called like:
const [$a, $b] = convertTensorsAndMatchDTypes(a, 'a', b, 'b', 'add');
Acknowledged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 3 of 5 files at r1, 5 of 7 files at r2, 1 of 1 files at r3.
Reviewable status: 0 of 1 approvals obtained (waiting on @caisq, @nsthorat, and @tafsiri)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for doing this!
Reviewable status: complete! 1 of 1 approvals obtained (waiting on @caisq, @dsmilkov, and @tafsiri)
src/ops/binary_ops.ts, line 58 at r2 (raw file):
Previously, dsmilkov (Daniel Smilkov) wrote…
Acknowledged.
I still think it's a good idea :) It'll save 100+ LOC and reduce likelihood of programming error in the future. I'd be interested in hearing your reason for not adopting it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: complete! 2 of 1 approvals obtained (waiting on @caisq and @tafsiri)
src/ops/binary_ops.ts, line 58 at r2 (raw file):
Previously, caisq (Shanqing Cai) wrote…
I still think it's a good idea :) It'll save 100+ LOC and reduce likelihood of programming error in the future. I'd be interested in hearing your reason for not adopting it.
Want to keep the number of util methods small - the more specialized util methods, the less contributors know about them and likely forget to use them, and makes code harder to read. The cost is 32 LOC (16 instances * 2 lines saves), so not worth the effort yet, IMO.
Allow users to provide different dtypes in binary arithmetic ops (add/sub/mul/div/...) and matmul, just like in numpy.
The dtype of the result is upcasted i.e. matMul(float32, int32) => float32
This will result in release patch 0.14.1, which will fix the breakage in 0.14.0 caused by #1408 due to improved dtype inference where tensor(new Int32Array()) is inferred to be int32, and was float32.
Fixes tensorflow/tfjs#934, tensorflow/tfjs#966
This change is