Skip to content

Commit

Permalink
Fixed a couple of dynamic shape detection issues (#996)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #996

This diff fixed a couple of issues for dynamic shape detection:

* handle cases where sample tensors may not have any dynamic dimension

* added two lowering configs to guide dynamic shape detection

  (1) can_last_dim_be_dynamic: specifies if the last dimension can be dynamic
  (2) can_value_one_be_dynamic: specifies if dimension value one is allowed
      to appear at any dynamic dimension

Reviewed By: hl475

Differential Revision: D54831007

fbshipit-source-id: cc17db2dd386748fec5e4d491fc58ae975b16b7f
  • Loading branch information
chenyang78 authored and facebook-github-bot committed Mar 19, 2024
1 parent 0390a00 commit d2dc957
Showing 1 changed file with 23 additions and 3 deletions.
26 changes: 23 additions & 3 deletions fx2ait/fx2ait/tensor_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,12 @@ def from_input_list_with_batch_size_jagged_tensor(

@classmethod
# pyre-ignore [2]: Parameter `sample_input` must have a type other than `Any`
def find_batch_size_dim(cls, inputs: Any) -> []:
def find_batch_size_dim(
cls,
inputs: Any,
can_non_first_dim_be_dynamic: bool = True,
can_dim_value_one_be_dynamic: bool = True,
) -> []:
if isinstance(inputs, torch.Tensor) or len(inputs) <= 1:
return [0]
shapes = [i.shape for i in inputs]
Expand All @@ -484,7 +489,9 @@ def find_batch_size_dim(cls, inputs: Any) -> []:
# Dedup shape value for single tensor
first_dims.add(shape[0])
seen_dims = set()
for i, dim in enumerate(shape):
valid_len = len(shape) if can_non_first_dim_be_dynamic else 1
for i in range(valid_len):
dim = shape[i]
if dim not in seen_dims:
frequency_map[dim] = frequency_map.get(dim, 0) + 1
position_scores[dim] = position_scores.get(dim, 0) + i
Expand All @@ -501,7 +508,18 @@ def find_batch_size_dim(cls, inputs: Any) -> []:
frequency_map.items(),
key=lambda x: (-x[1], position_scores[x[0]]),
)
batch_size = sorted_frequency[0][0]
if len(sorted_frequency) > 1:
if not can_dim_value_one_be_dynamic and sorted_frequency[0][0] == 1:
# It's often that dim value one indicates a non-dynamic dimension.
# If the user says so, we pick the second most frequent value.
batch_size = sorted_frequency[1][0]
else:
batch_size = sorted_frequency[0][0]
else:
if not can_dim_value_one_be_dynamic and sorted_frequency[0][0] == 1:
batch_size = -1
else:
batch_size = sorted_frequency[0][0]
else:
# no dims to sort: no batch_size
batch_size = -1
Expand All @@ -511,6 +529,8 @@ def find_batch_size_dim(cls, inputs: Any) -> []:
# Default batch size dim = -1, indicate no batch_size
dim = -1
for index, val in enumerate(i.shape):
if not can_non_first_dim_be_dynamic and index > 0:
break
if val == batch_size:
dim = index
break
Expand Down

0 comments on commit d2dc957

Please sign in to comment.