Skip to content

Commit

Permalink
run ruff format on nf4tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
msaroufim committed Feb 18, 2025
1 parent 22bc211 commit 71caddb
Showing 1 changed file with 27 additions and 27 deletions.
54 changes: 27 additions & 27 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,9 @@ def nf4_split(aten_op, args, kwargs=None):
attr_to_chunks = {}
for attr in _INNER_TENSOR_NAMES_FOR_SHARDING:
inner_tensor = getattr(nf4tensor, attr)
assert inner_tensor.numel() % num_chunks == 0, (
f"{attr}.numel() not divisible by {num_chunks}"
)
assert (
inner_tensor.numel() % num_chunks == 0
), f"{attr}.numel() not divisible by {num_chunks}"
chunks = aten_op(inner_tensor, inner_tensor.numel() // num_chunks, **kwargs)
attr_to_chunks[attr] = chunks

Expand Down Expand Up @@ -236,9 +236,9 @@ def nf4_new_zeros(aten_op, args, kwargs=None):
updated_attrs = {}
for attr in _INNER_TENSOR_NAMES_FOR_SHARDING:
inner_tensor = getattr(nf4tensor, attr)
assert inner_tensor.size(0) % ratio == 0, (
f"{attr}.numel() must be divisible by {ratio}"
)
assert (
inner_tensor.size(0) % ratio == 0
), f"{attr}.numel() must be divisible by {ratio}"
inner_tensor = aten_op(inner_tensor, [inner_tensor.size(0) // ratio], **kwargs)
updated_attrs[attr] = inner_tensor
updated_attrs["size"] = new_size
Expand Down Expand Up @@ -473,9 +473,9 @@ def get_block_absmax(input_tensor: torch.Tensor, block_size: int) -> torch.Tenso
torch.Tensor: Tensor of scalers for each block
"""
assert input_tensor.dim() == 1, "Input tensor must be flattened"
assert (input_tensor.numel() % block_size) == 0, (
f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}"
)
assert (
(input_tensor.numel() % block_size) == 0
), f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}"

n_blocks = input_tensor.numel() // block_size
blocks = input_tensor.view(n_blocks, block_size)
Expand Down Expand Up @@ -558,12 +558,12 @@ def from_tensor(
block_size: int,
scaler_block_size: int,
):
assert input_tensor.dim() <= 2, (
f"expect input tensor dim <= 2 but got dim = {input_tensor.dim()}"
)
assert input_tensor.numel() % block_size == 0, (
f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}"
)
assert (
input_tensor.dim() <= 2
), f"expect input tensor dim <= 2 but got dim = {input_tensor.dim()}"
assert (
input_tensor.numel() % block_size == 0
), f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}"
assert input_tensor.is_contiguous, "Input tensor must be contiguous!"
# I think I want do this
# assert not input_tensor.requires_grad, "Input tensor must not require grad"
Expand Down Expand Up @@ -644,19 +644,19 @@ def double_quantize_scalers(
size: (n_scaler_blocks)
"""
assert input_tensor.dim() == 1, "Input tensor must be flattened"
assert (input_tensor.numel() % scaler_block_size) == 0, (
f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}"
)
assert (
(input_tensor.numel() % scaler_block_size) == 0
), f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}"

# First round of quantization
# Produces: A tensor of size (n_blocks) of input_tensor.dtype
scalers_1 = get_block_absmax(input_tensor, block_size)
scalers_1_mean = scalers_1.mean()
scalers_1 = scalers_1 - scalers_1_mean
# Second round of quantization
assert scalers_1.numel() % scaler_block_size == 0, (
f"Number of scalers must be divisible by scaler block size, got {scalers_1.numel()} scaler_block_size {scaler_block_size} "
)
assert (
scalers_1.numel() % scaler_block_size == 0
), f"Number of scalers must be divisible by scaler block size, got {scalers_1.numel()} scaler_block_size {scaler_block_size} "
n_scaler_blocks = scalers_1.numel() // scaler_block_size
scaler_blocks = scalers_1.view(n_scaler_blocks, scaler_block_size)

Expand Down Expand Up @@ -698,9 +698,9 @@ def dequantize_scalers(
"""
assert input_tensor.dim() == 1, "Input tensor must be flattened"
assert (input_tensor.numel() % scaler_block_size) == 0, (
f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}"
)
assert (
(input_tensor.numel() % scaler_block_size) == 0
), f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}"
n_scaler_blocks = input_tensor.numel() // scaler_block_size
input_tensor = input_tensor.view(n_scaler_blocks, scaler_block_size)
dequantized = (input_tensor / quantization_factor.unsqueeze(-1)).flatten().to(
Expand All @@ -716,9 +716,9 @@ def convert_to_norm_float_weight(
flattened_tensor = input_tensor.flatten()
# Since we are using uint8 we will encode 2 entries per byte
numel = input_tensor.numel()
assert numel % 2 == 0, (
"Number of elements must be even just to not have to think about the end"
)
assert (
numel % 2 == 0
), "Number of elements must be even just to not have to think about the end"
# Reshape the flattened tensor into blocks of size self.block_size
blocks = flattened_tensor.view(n_blocks, block_size)

Expand Down

0 comments on commit 71caddb

Please sign in to comment.