-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
31 lines (23 loc) · 803 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import taichi as ti
from auto_graph import auto_graph
@auto_graph
def fool_graph(arr: ti.types.ndarray(dtype=ti.f32, ndim=1)):
x, y = 2, 3
dt = x + y
dt_arr = ti.ndarray(dtype=ti.f32, shape=arr.shape[0])
kernel_delta(dt, dt_arr)
kernel_update(arr, dt_arr)
kernel_update(arr, dt_arr)
@ti.kernel
def kernel_delta(delta: ti.i32, arr: ti.types.ndarray(dtype=ti.f32, ndim=1)):
for i in ti.grouped(arr):
arr[i] = arr[i] + delta
@ti.kernel
def kernel_update(arr: ti.types.ndarray(dtype=ti.f32, ndim=1),
delta: ti.types.ndarray(dtype=ti.f32, ndim=1)):
for i in ti.grouped(arr):
arr[i] = arr[i] + delta[i]
if __name__ == '__main__':
ti.init(arch=ti.vulkan)
fool_graph.compile()
fool_graph.archive(ti.vulkan, "auto_graph.tcm")