Skip to content

Commit

Permalink
Fix float64 dtype error
Browse files Browse the repository at this point in the history
  • Loading branch information
brkirch committed Sep 5, 2023
1 parent 3d17b11 commit 5a35f9d
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions modules/sd_disable_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):

if param.is_meta:
dtype = sd_param.dtype if sd_param is not None else param.dtype
if dtype == torch.float64 and device.type == 'mps':
dtype = torch.float32
module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)

for name in module._buffers:
Expand Down

0 comments on commit 5a35f9d

Please sign in to comment.