diff --git a/fx2ait/fx2ait/tensor_spec.py b/fx2ait/fx2ait/tensor_spec.py index 1db415571..3c4703594 100644 --- a/fx2ait/fx2ait/tensor_spec.py +++ b/fx2ait/fx2ait/tensor_spec.py @@ -470,7 +470,12 @@ def from_input_list_with_batch_size_jagged_tensor( @classmethod # pyre-ignore [2]: Parameter `sample_input` must have a type other than `Any` - def find_batch_size_dim(cls, inputs: Any) -> []: + def find_batch_size_dim( + cls, + inputs: Any, + can_non_first_dim_be_dynamic: bool = True, + can_dim_value_one_be_dynamic: bool = True, + ) -> []: if isinstance(inputs, torch.Tensor) or len(inputs) <= 1: return [0] shapes = [i.shape for i in inputs] @@ -484,7 +489,9 @@ def find_batch_size_dim(cls, inputs: Any) -> []: # Dedup shape value for single tensor first_dims.add(shape[0]) seen_dims = set() - for i, dim in enumerate(shape): + valid_len = len(shape) if can_non_first_dim_be_dynamic else 1 + for i in range(valid_len): + dim = shape[i] if dim not in seen_dims: frequency_map[dim] = frequency_map.get(dim, 0) + 1 position_scores[dim] = position_scores.get(dim, 0) + i @@ -501,7 +508,18 @@ def find_batch_size_dim(cls, inputs: Any) -> []: frequency_map.items(), key=lambda x: (-x[1], position_scores[x[0]]), ) - batch_size = sorted_frequency[0][0] + if len(sorted_frequency) > 1: + if not can_dim_value_one_be_dynamic and sorted_frequency[0][0] == 1: + # It's often that dim value one indicates a non-dynamic dimension. + # If the user says so, we pick the second most frequent value. + batch_size = sorted_frequency[1][0] + else: + batch_size = sorted_frequency[0][0] + else: + if not can_dim_value_one_be_dynamic and sorted_frequency[0][0] == 1: + batch_size = -1 + else: + batch_size = sorted_frequency[0][0] else: # no dims to sort: no batch_size batch_size = -1 @@ -511,6 +529,8 @@ def find_batch_size_dim(cls, inputs: Any) -> []: # Default batch size dim = -1, indicate no batch_size dim = -1 for index, val in enumerate(i.shape): + if not can_non_first_dim_be_dynamic and index > 0: + break if val == batch_size: dim = index break