Skip to content

Commit

Permalink
[MPS] Allow float16 input to float32 LayerNorm (#96430)
Browse files Browse the repository at this point in the history
Only for forward pass

Subset of #96208

Create constant with scalar using `input_mps_dtype` and use
`reciprocalWithTensor` instead of `divisionWithPrimaryTensor:1.0
secondaryTensor:`

Fixes #96113

Pull Request resolved: #96430
Approved by: /~https://github.com/kulinseth
  • Loading branch information
malfet authored and pytorchmergebot committed Mar 9, 2023
1 parent 457396f commit 075a494
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 13 deletions.
19 changes: 6 additions & 13 deletions aten/src/ATen/native/mps/operations/Normalization.mm
Original file line number Diff line number Diff line change
Expand Up @@ -254,20 +254,16 @@ Check if running mean exists (maybe do this check before making graph)
// Update saved mean and inverse std tensor
MPSGraphTensor *epsilonTensor = [mpsGraph constantWithScalar:(double)epsilon
shape:@[@1]
dataType:MPSDataTypeFloat32];
dataType:input_mps_dtype];

MPSGraphTensor *varianceEps = [mpsGraph additionWithPrimaryTensor:batchVarianceTensor
secondaryTensor:epsilonTensor
name:@"varianceEps"];

MPSGraphTensor *sqrtVariance = [mpsGraph squareRootWithTensor:varianceEps
name:@"sqrtVariance"];
float primary = 1.0f;
MPSGraphTensor *primaryTensor = [mpsGraph constantWithScalar:primary dataType:MPSDataTypeFloat32];

scaledInverseSqrtVariance = [mpsGraph divisionWithPrimaryTensor:primaryTensor
secondaryTensor:sqrtVariance
name:nil];
scaledInverseSqrtVariance = [mpsGraph reciprocalWithTensor:sqrtVariance
name:nil];
// Update saved mean and inverse std tensor
saveMeanTensor = batchMeanTensor;
saveVarTensor = scaledInverseSqrtVariance;
Expand Down Expand Up @@ -678,13 +674,10 @@ string get_mem_string(c10::MemoryFormat memory_format) {

if(train) {
// Use save_mean and save_var
float primary = 1.0f;
MPSGraphTensor *primaryTensor = [mpsGraph constantWithScalar:primary dataType:MPSDataTypeFloat32];
MPSGraphTensor *epsilonTensor = [mpsGraph constantWithScalar:(float)epsilon dataType:MPSDataTypeFloat32];
MPSGraphTensor *epsilonTensor = [mpsGraph constantWithScalar:(float)epsilon dataType:input_mps_dtype];
MPSGraphTensor *revertSaveVarTensor = saveVarTensor;
revertSaveVarTensor = [mpsGraph divisionWithPrimaryTensor: primaryTensor
secondaryTensor: revertSaveVarTensor
name: nil];
revertSaveVarTensor = [mpsGraph reciprocalWithTensor: revertSaveVarTensor
name: nil];
revertSaveVarTensor = [mpsGraph multiplicationWithPrimaryTensor: revertSaveVarTensor
secondaryTensor: revertSaveVarTensor
name: nil];
Expand Down
3 changes: 3 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2008,6 +2008,9 @@ def helper(input_shape, normalized_shape, eps=1e-05, elementwise_affine=True, dt
helper((2, 3, 4, 5), (4, 5), elementwise_affine=elementwise_affine)
helper((2, 3, 4, 5, 6), (4, 5, 6), elementwise_affine=elementwise_affine)

# Regression test for /~https://github.com/pytorch/pytorch/issues/96113
torch.nn.LayerNorm((16,), elementwise_affine=True).to("mps")(torch.randn(1, 2, 16).to("mps", dtype=torch.float16))

def test_instance_norm(self):
def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_running_stats=True, test_module=False):

Expand Down

0 comments on commit 075a494

Please sign in to comment.