Skip to content

Commit

Permalink
to revert
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Jan 16, 2025
1 parent cf9a51d commit 765e723
Showing 1 changed file with 57 additions and 1 deletion.
58 changes: 57 additions & 1 deletion optimum/exporters/tflite/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,65 @@ def _create_dummy_input_generator_classes(self) -> List["DummyInputGenerator"]:
self._validate_mandatory_axes()
return [cls_(self.task, self._normalized_config, **self._axes) for cls_ in self.DUMMY_INPUT_GENERATOR_CLASSES]


@property
def values_override(self) -> Optional[Dict[str, Any]]:
"""
Dictionary of keys to override in the model's config before exporting.
Returns:
`Optional[Dict[str, Any]]`: A dictionary specifying the configuration items to override.
"""
if hasattr(self._config, "use_cache"):
return {"use_cache": False}

return None

@property
@abstractmethod
def inputs(self) -> List[str]:
"""
List containing the names of the inputs the exported model should take.
Returns:
`List[str]`: A list of input names.
"""
raise NotImplementedError()

@property
def outputs(self) -> List[str]:
"""
List containing the names of the outputs the exported model should have.
Returns:
`List[str]`: A list of output names.
"""
return self._TASK_TO_COMMON_OUTPUTS[self.task]


def generate_dummy_inputs(self) -> Dict[str, "tf.Tensor"]:
return super().generate_dummy_inputs(framework="tf")
"""
Generates dummy inputs that the exported model should be able to process.
This method is actually used to determine the input specs that are needed for the export.
Returns:
`Dict[str, tf.Tensor]`: A dictionary mapping input names to dummy tensors.
"""
dummy_inputs_generators = self._create_dummy_input_generator_classes()
dummy_inputs = {}

for input_name in self.inputs:
input_was_inserted = False
for dummy_input_gen in dummy_inputs_generators:
if dummy_input_gen.supports_input(input_name):
dummy_inputs[input_name] = dummy_input_gen.generate(input_name, framework="tf")
input_was_inserted = True
break
if not input_was_inserted:
raise RuntimeError(
f'Could not generate dummy inputs for "{input_name}". Try adding a proper dummy input generator '
"to the model TFLite config."
)

return dummy_inputs


@property
def inputs_specs(self) -> List["TensorSpec"]:
"""
Expand Down

0 comments on commit 765e723

Please sign in to comment.