-
Notifications
You must be signed in to change notification settings - Fork 785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Backend paddle: allow enable prim to support more brands of chips and hardware devices #1948
Conversation
Could you be more specific about the meaning of "high-order (K) operators"? |
47f0c71
to
54cfe58
Compare
Okay, I have updated the description of this PR and explained the prim function in detail. |
I understand the example of log_softmax. But is this true?
In DeepXDE, 4-th order diff is currently implemented via doing first-order diff recursively. Why we still need PRIM for this? |
By the way, is prim always faster? |
Hello Lulu, in Paddle, “prim” mode refers to a system where we use a set of primitive operators, and any complex operators are decomposed into these primitive operators before execution. For instance, These primitive operators each have forward implementations and first-order backward implementations. Together, they form a complete set of operators, meaning that once any complex operator is decomposed into this set, we no longer need to implement first-order( For users, the prim mode serves two main purposes. First, it supports differentiation to any order. Second, on certain hardware where higher-order differential operators have not been implemented, the prim mode allows these operators to be decomposed into primitive operators, enabling them to run effectively. Third, we can use paddle compiler(CINN) to optimize the decomposed computation graph, as compilers generally prefer smaller operators over larger ones. This optimization can lead to significantly improved performance. For instance, in the modulus-sym suite, we achieved nearly a 70% increase in iterations per second (IPS). We also plan to conduct testing on DeepXDE to further evaluate performance gains. |
’prim‘ is a function in the paddle framework.
This PR add a part of code which allow users to enable prim mode by using 'PRIM=1'. If enable it, the entire command become 'PRIM=1 DDE_BACKEND=paddle python fractional_diffusion_1d.py'.
Brief introduction of prim:
There are many operators in the framework, some of them are complex(coarse-grained), like ‘log_softmax’, some of them are simple(fine-grained), like 'max'/'sum'/'sub'.
In short, 'prim' represents the process of decomposing the coarse-grained operator fine-grained operators.
Take 'log_softmax' as an example, the running process before decomposing is 'input->op(log_softmax)->output' and after decomposing is 'input->op(max)->op(sub)->op(exp)->op(sum)->op(log)->op(sub)->output'(The backward process will also be modified accordingly).
For details, please see the comments in this PR.