Skip to content

Commit

Permalink
[prim] generate static prim api
Browse files Browse the repository at this point in the history
  • Loading branch information
cxxly committed Feb 11, 2023
1 parent fd0d4fa commit 4a3d5d7
Show file tree
Hide file tree
Showing 17 changed files with 659 additions and 200 deletions.
6 changes: 5 additions & 1 deletion paddle/fluid/operators/generator/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import itertools
import re
from typing import Dict, List
from typing import Dict, List, Sequence

from type_mapping import (
attr_types_map,
Expand Down Expand Up @@ -80,6 +80,10 @@ def to_sr_output_type(s):
return sr_output_types_map[s]


def filter_intermediate(items: Sequence):
return tuple([item for item in items if not item.get('intermediate')])


# -------------- transform argument names from yaml to opmaker ------------
def to_opmaker_name(s):
if s.endswith("_grad"):
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/operators/generator/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ def is_scalar(s):
return re.match(r"Scalar(\(\w+\))*", s) is not None


def is_intarray(s):
return s == 'IntArray'


def is_datatype(s):
return s == 'DataType'


def is_initializer_list(s):
return s == "{}"

Expand All @@ -63,3 +71,7 @@ def supports_no_need_buffer(op):
if input["no_need_buffer"]:
return True
return False


def is_tensor_list(s):
return s == 'Tensor[]'
5 changes: 5 additions & 0 deletions paddle/fluid/prim/api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@ add_subdirectory(manual_prim)
add_subdirectory(generated_prim)

if(NOT (NOT WITH_PYTHON AND ON_INFER))
cc_library(eager_prim_api DEPS generated_eager_prim_api manual_eager_prim_api)
cc_library(static_prim_api DEPS generated_static_prim_api
manual_static_prim_api)
cc_library(
prim_api
SRCS all.cc
DEPS static_utils static_prim_api eager_prim_api eager_api)
else()
cc_library(static_prim_api DEPS generated_static_prim_api
manual_static_prim_api)
cc_library(
prim_api
SRCS all.cc
Expand Down
25 changes: 25 additions & 0 deletions paddle/fluid/prim/api/api.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
- unsqueeze
- pow
- exp
- scale
- multiply
- matmul
- expand
- divide
- sum
- add
- abs
- assign
- concat
- elementwise_pow
- floor
- gather_nd
- log
- max
- maximum
- minimum
- prod
- roll
- scatter
- scatter_nd_add
- tile
53 changes: 45 additions & 8 deletions paddle/fluid/prim/api/auto_code_generated/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,73 @@ set(api_yaml_path
set(legacy_api_yaml_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_ops.parsed.yaml"
)
set(api_compat_yaml_path
"${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml")
set(api_prim_yaml_path "${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/api.yaml")
set(api_version_yaml_path
"${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/op_version.yaml")
set(tmp_eager_prim_api_cc_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/tmp_eager_prim_api.cc"
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/eager_prim_api.cc.tmp"
)
set(tmp_static_prim_api_cc_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/static_prim_api.cc.tmp"
)
set(tmp_prim_api_h_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/tmp_prim_generated_api.h"
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/prim_generated_api.h.tmp"
)
set(eager_prim_api_cc_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/eager_prim_api.cc"
)
set(static_prim_api_cc_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/static_prim_api.cc"
)
set(prim_api_h_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/prim_generated_api.h"
)
set(prim_api_gen_file
${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/auto_code_generated/prim_gen.py)
set(static_prim_api_template_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/auto_code_generated/template/static_prim_api.cc.tpl"
)
set(eager_prim_api_gen_file
${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/auto_code_generated/eager_gen.py)
set(static_prim_api_gen_file
${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/auto_code_generated/static_gen.py
)

message("prim api Code gen")
message("Eager prim api code generator")
execute_process(
WORKING_DIRECTORY
${CMAKE_SOURCE_DIR}/paddle/fluid/prim/api/auto_code_generated
COMMAND
${PYTHON_EXECUTABLE} ${prim_api_gen_file} --api_yaml_path
${PYTHON_EXECUTABLE} ${eager_prim_api_gen_file} --api_yaml_path
${legacy_api_yaml_path} ${api_yaml_path} --prim_api_header_path
${tmp_prim_api_h_path} --eager_prim_api_source_path
${tmp_eager_prim_api_cc_path}
${tmp_eager_prim_api_cc_path} --api_prim_yaml_path ${api_prim_yaml_path}
RESULT_VARIABLE _result)
if(${_result})
message(FATAL_ERROR "prim api genrate failed, exiting.")
message(FATAL_ERROR "Eager prim api generate failed, exiting.")
endif()
execute_process(COMMAND ${CMAKE_COMMAND} -E copy_if_different
${tmp_prim_api_h_path} ${prim_api_h_path})
execute_process(COMMAND ${CMAKE_COMMAND} -E copy_if_different
${tmp_eager_prim_api_cc_path} ${eager_prim_api_cc_path})
message("copy tmp_xxx_prim_api to xxx_prim_api")

message("Static prim api code generator")
execute_process(
WORKING_DIRECTORY
${CMAKE_SOURCE_DIR}/paddle/fluid/prim/api/auto_code_generated
COMMAND
${PYTHON_EXECUTABLE} ${static_prim_api_gen_file} --api_phi_yaml_path
${api_yaml_path} --api_phi_legacy_yaml_path ${legacy_api_yaml_path}
--api_compat_yaml_path ${api_compat_yaml_path} --api_version_yaml_path
${api_version_yaml_path} --api_prim_yaml_path ${api_prim_yaml_path}
--template_path ${static_prim_api_template_path} --output_path
${tmp_static_prim_api_cc_path}
RESULT_VARIABLE _result)
if(${_result})
message(FATAL_ERROR "Static prim api generate failed, exiting.")
endif()
execute_process(
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_static_prim_api_cc_path}
${static_prim_api_cc_path})
message("copy tmp_xxx_prim_api to xxx_prim_api")
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def api_namespace():
)


def generate_api(api_yaml_path, header_file_path, eager_prim_source_file_path):
def generate_api(
api_yaml_path, header_file_path, eager_prim_source_file_path, api_prim_path
):
apis = []

for each_api_yaml in api_yaml_path:
Expand All @@ -76,8 +78,11 @@ def generate_api(api_yaml_path, header_file_path, eager_prim_source_file_path):
eager_prim_source_file.write(eager_source_include())
eager_prim_source_file.write(namespace[0])

with open(api_prim_path, 'rt') as f:
api_prims = yaml.safe_load(f)

for api in apis:
prim_api = EagerPrimAPI(api)
prim_api = EagerPrimAPI(api, api_prims)
if prim_api.is_prim_api:
header_file.write(prim_api.gene_prim_api_declaration())
eager_prim_source_file.write(prim_api.gene_eager_prim_api_code())
Expand Down Expand Up @@ -112,16 +117,24 @@ def main():
default='paddle/fluid/prim/api/generated_prim/eager_prim_api.cc',
)

parser.add_argument(
'--api_prim_yaml_path',
help='Primitive API list yaml file.',
default='paddle/fluid/prim/api/auto_code_generated/api.yaml',
)

options = parser.parse_args()

api_yaml_path = options.api_yaml_path
prim_api_header_file_path = options.prim_api_header_path
eager_prim_api_source_file_path = options.eager_prim_api_source_path
api_prim_yaml_path = options.api_prim_yaml_path

generate_api(
api_yaml_path,
prim_api_header_file_path,
eager_prim_api_source_file_path,
api_prim_yaml_path,
)


Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/prim/api/auto_code_generated/prim_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@


class BaseAPI:
def __init__(self, api_item_yaml):
def __init__(self, api_item_yaml, prims=tuple()):
# self.api = api_item_yaml['op']
self.api = api_item_yaml['name']

self.is_prim_api = False
if api_item_yaml['name'] in white_ops_list:
if api_item_yaml['name'] in prims:
self.is_prim_api = True

#######################################
Expand Down Expand Up @@ -253,8 +253,8 @@ def parse_output(self, outputs_list):


class EagerPrimAPI(BaseAPI):
def __init__(self, api_item_yaml):
super().__init__(api_item_yaml)
def __init__(self, api_item_yaml, prims=tuple()):
super().__init__(api_item_yaml, prims)

def get_api__func_name(self):
api_func_name = self.api
Expand Down
Loading

0 comments on commit 4a3d5d7

Please sign in to comment.