Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Model Library Format export format #7533

Merged
merged 12 commits into from
Mar 10, 2021
1 change: 1 addition & 0 deletions python/tvm/micro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .debugger import GdbRemoteDebugger
from .micro_library import MicroLibrary
from .micro_binary import MicroBinary
from .model_library_format import export_model_library_format, UnsupportedInModelLibraryFormatError
from .session import (
create_local_graph_runtime,
create_local_debug_runtime,
Expand Down
171 changes: 171 additions & 0 deletions python/tvm/micro/model_library_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Defines functions for exporting to Model Library Format."""

import datetime
import json
import os
import re
import tarfile

from ..contrib import utils
from ..relay.backend import graph_runtime_factory
from ..relay import param_dict


class UnsupportedInModelLibraryFormatError(Exception):
"""Raised when export_model_library_format does not support the given Module tree."""


def _populate_codegen_dir(mod, codegen_dir: str):
"""Populate the codegen sub-directory as part of a Model Library Format export.

Parameters
----------
mod : tvm.runtime.Module
Module which should be written to codegen_dir.
codegen_dir : str
Path to the codegen directory on disk.
"""
dso_modules = mod._collect_dso_modules()
dso_module_handles = [m.handle.value for m in dso_modules]
non_dso_modules = mod._collect_from_import_tree(lambda m: m not in dso_modules)
if non_dso_modules:
raise UnsupportedInModelLibraryFormatError(
f"Don't know how to export non-c or non-llvm modules; found: {non_dso_modules!r}"
)

mod_indices = {"lib": 0, "src": 0}
host_codegen_dir = os.path.join(codegen_dir, "host")
for dso_mod in dso_modules:
if dso_mod.type_key == "c":
index = mod_indices["src"]
mod_indices["src"] += 1
parent_dir = os.path.join(host_codegen_dir, "src")
file_name = os.path.join(parent_dir, f"lib{index}.c")
elif dso_mod.type_key == "llvm":
index = mod_indices["lib"]
mod_indices["lib"] += 1
parent_dir = os.path.join(host_codegen_dir, "lib")
file_name = os.path.join(parent_dir, f"lib{index}.o")
else:
assert (
False
), f"do not expect module with type_key={mod.type_key} from _collect_dso_modules"

if not os.path.exists(parent_dir):
os.makedirs(parent_dir)
dso_mod.save(file_name)


def _build_memory_map(graph_json):
"""Build a simpler memory map from graph JSON.

Parameters
----------
graph_json : str
String representation of the graph_json created from tvm.relay.build().

Returns
-------
list :
A list with one entry per storage id describing that memory.
"""
graph = json.loads(graph_json)

seen_storage_ids = set()
memory_map = []
for node_id, storage_id in enumerate(graph["attrs"]["storage_id"][1]):
if storage_id in seen_storage_ids:
continue

seen_storage_ids.add(storage_id)
num_elements = 1
for dim in graph["attrs"]["shape"][1][storage_id]:
num_elements *= dim

dltype = graph["attrs"]["dltype"][1][storage_id]
m = re.match(r"^[a-zA-Z]+([0-9]+)$", dltype)
assert m, f"Exported graph contains unknown dltype {dltype}"

elem_bits = int(m.group(1))

map_entry = {
"storage_id": storage_id,
"size_bytes": (num_elements * elem_bits + 7) // 8,
}
if node_id in graph["arg_nodes"]:
map_entry["input_binding"] = graph["nodes"][node_id]["name"]

memory_map.append(map_entry)

return memory_map


def export_model_library_format(mod: graph_runtime_factory.GraphRuntimeFactoryModule, file_name):
"""Export the build artifact in Model Library Format.

This function creates a .tar archive containing the build artifacts in a standardized
layout. It's intended to allow downstream automation to build TVM artifacts against the C
runtime.

Parameters
----------
mod : tvm.relay.backend.graph_runtime_factory.GraphRuntimeFactoryModule
The return value of tvm.relay.build, which will be exported into Model Library Format.
file_name : str
Path to the .tar archive to generate.
"""
tempdir = utils.tempdir()
metadata = {
"version": 1,
"model_name": mod.libmod_name,
"export_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%SZ"),
"memory": _build_memory_map(mod.graph_json),
"target": {int(k): str(v) for k, v in mod.target.items()},
"runtimes": ["graph"],
}
with open(tempdir.relpath("metadata.json"), "w") as json_f:
json.dump(metadata, json_f, indent=2, sort_keys=True)

codegen_dir_path = tempdir.relpath("codegen")
os.mkdir(codegen_dir_path)
_populate_codegen_dir(mod.lib, codegen_dir_path)

parameters_dir_path = tempdir.relpath("parameters")
os.mkdir(parameters_dir_path)
param_filename = os.path.join(parameters_dir_path, f"{mod.libmod_name}.params")
with open(param_filename, "wb") as f:
f.write(param_dict.save_param_dict(mod.params))

with open(tempdir.relpath("relay.txt"), "w") as f:
f.write(str(mod.ir_mod))

graph_config_dir_path = tempdir.relpath(os.path.join("runtime-config", "graph"))
os.makedirs(graph_config_dir_path)
with open(os.path.join(graph_config_dir_path, "graph.json"), "w") as f:
f.write(mod.graph_json)

with tarfile.open(file_name, "w") as tar_f:

def reset(tarinfo):
tarinfo.uid = tarinfo.gid = 0
tarinfo.uname = tarinfo.gname = "root"
return tarinfo

tar_f.add(tempdir.temp_dir, arcname=".", filter=reset)
12 changes: 8 additions & 4 deletions python/tvm/relay/backend/graph_runtime_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
# under the License.
"""Graph runtime factory."""
import warnings
from tvm._ffi.base import string_types
from tvm._ffi.registry import get_global_func
from tvm.runtime import ndarray
from ..._ffi.base import string_types
from ..._ffi.registry import get_global_func
from ...runtime import ndarray


class GraphRuntimeFactoryModule:
Expand All @@ -31,6 +31,8 @@ class GraphRuntimeFactoryModule:
The graph to be deployed in json format output by graph compiler.
The graph can contain operator(tvm_op) that points to the name of
PackedFunc in the libmod.
target : tvm.Target
The Target used to build this module.
libmod : tvm.Module
The module of the corresponding function
libmod_name: str
Expand All @@ -39,13 +41,15 @@ class GraphRuntimeFactoryModule:
The parameters of module
"""

def __init__(self, graph_json_str, libmod, libmod_name, params):
def __init__(self, ir_mod, target, graph_json_str, libmod, libmod_name, params):
assert isinstance(graph_json_str, string_types)
fcreate = get_global_func("tvm.graph_runtime_factory.create")
args = []
for k, v in params.items():
args.append(k)
args.append(ndarray.array(v))
self.ir_mod = ir_mod
self.target = target
self.module = fcreate(graph_json_str, libmod, libmod_name, *args)
self.graph_json = graph_json_str
self.lib = libmod
Expand Down
20 changes: 11 additions & 9 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,14 @@ def _build_module_no_factory(mod, target=None, target_host=None, params=None, mo
return build(mod, target, target_host, params, mod_name).module


def build(mod, target=None, target_host=None, params=None, mod_name="default"):
def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"):
# fmt: off
# pylint: disable=line-too-long
"""Helper function that builds a Relay function to run on TVM graph runtime.

Parameters
----------
mod : :py:class:`~tvm.IRModule`
ir_mod : :py:class:`~tvm.IRModule`
The IR module to build. Using relay.Function is deprecated.

target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context name) to str/tvm.target.Target, optional
Expand Down Expand Up @@ -251,13 +251,13 @@ def build(mod, target=None, target_host=None, params=None, mod_name="default"):
"""
# pylint: enable=line-too-long
# fmt: on
if not isinstance(mod, (IRModule, _function.Function)):
if not isinstance(ir_mod, (IRModule, _function.Function)):
raise ValueError("Type of input parameter mod must be tvm.IRModule")

if isinstance(mod, _function.Function):
if isinstance(ir_mod, _function.Function):
if params:
mod = bind_params_by_name(mod, params)
mod = IRModule.from_expr(mod)
ir_mod = bind_params_by_name(ir_mod, params)
ir_mod = IRModule.from_expr(ir_mod)
warnings.warn(
"Please use input parameter mod (tvm.IRModule) "
"instead of deprecated parameter mod (tvm.relay.function.Function)",
Expand All @@ -280,9 +280,11 @@ def build(mod, target=None, target_host=None, params=None, mod_name="default"):

with tophub_context:
bld_mod = BuildModule()
graph_json, mod, params = bld_mod.build(mod, target, target_host, params)
mod = _graph_runtime_factory.GraphRuntimeFactoryModule(graph_json, mod, mod_name, params)
return mod
graph_json, runtime_mod, params = bld_mod.build(ir_mod, target, target_host, params)
runtime_mod = _graph_runtime_factory.GraphRuntimeFactoryModule(
ir_mod, target, graph_json, runtime_mod, mod_name, params
)
return runtime_mod


def optimize(mod, target=None, params=None):
Expand Down
26 changes: 21 additions & 5 deletions python/tvm/runtime/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ def __getitem__(self, name):
raise ValueError("Can only take string as function name")
return self.get_function(name)

def __eq__(self, other):
return self.handle.value == other.handle.value

def __call__(self, *args):
if self._entry:
return self._entry(*args)
Expand Down Expand Up @@ -233,24 +236,37 @@ def evaluator(*args):
except NameError:
raise NameError("time_evaluate is only supported when RPC is enabled")

def _collect_dso_modules(self):
"""Helper function to collect dso modules, then return it."""
def _collect_from_import_tree(self, filter_func):
"""Helper function to collect modules from the tree matching a filter_func, then return it.

Parameters
----------
filter_func : Callable[[Module], bool]
A function which is invoked for each Module discovered in the import tree (including
self).

Returns
-------
list[Module] :
A list of matching Module.
"""
visited, stack, dso_modules = set(), [], []
# append root module
visited.add(self)
stack.append(self)
while stack:
module = stack.pop()
if module._dso_exportable():
if filter_func(module):
dso_modules.append(module)
for m in module.imported_modules:
if m not in visited:
visited.add(m)
stack.append(m)
return dso_modules

def _dso_exportable(self):
return self.type_key == "llvm" or self.type_key == "c"
def _collect_dso_modules(self):
is_dso_exportable = lambda m: (m.type_key == "llvm" or m.type_key == "c")
return self._collect_from_import_tree(is_dso_exportable)

def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=None, **kwargs):
"""Export the module and its imported device code one library.
Expand Down
3 changes: 2 additions & 1 deletion src/runtime/graph/graph_runtime_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime_factory.create").set_body([](TVMArgs args
"graph_runtime_factory.create needs at least 3, "
"but it has "
<< args.num_args;
// The argument order is graph_json, module, module_name, params.
// The argument order is graph_json, module, module_name, param0_name, param0_tensor,
// [param1_name, param1_tensor], ...
ICHECK_EQ((args.size() - 3) % 2, 0);
std::unordered_map<std::string, tvm::runtime::NDArray> params;
for (size_t i = 3; i < static_cast<size_t>(args.size()); i += 2) {
Expand Down
Loading