Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backend paddle: allow enable prim to support more brands of chips and hardware devices #1948

Merged
merged 1 commit into from
Feb 25, 2025

Conversation

lijialin03
Copy link
Contributor

@lijialin03 lijialin03 commented Feb 18, 2025

’prim‘ is a function in the paddle framework.

This PR add a part of code which allow users to enable prim mode by using 'PRIM=1'. If enable it, the entire command become 'PRIM=1 DDE_BACKEND=paddle python fractional_diffusion_1d.py'.

Brief introduction of prim:
There are many operators in the framework, some of them are complex(coarse-grained), like ‘log_softmax’, some of them are simple(fine-grained), like 'max'/'sum'/'sub'.
In short, 'prim' represents the process of decomposing the coarse-grained operator fine-grained operators.
Take 'log_softmax' as an example, the running process before decomposing is 'input->op(log_softmax)->output' and after decomposing is 'input->op(max)->op(sub)->op(exp)->op(sum)->op(log)->op(sub)->output'(The backward process will also be modified accordingly).

For details, please see the comments in this PR.

image

@lululxvi
Copy link
Owner

Could you be more specific about the meaning of "high-order (K) operators"?

@lijialin03 lijialin03 force-pushed the develop branch 2 times, most recently from 47f0c71 to 54cfe58 Compare February 24, 2025 03:48
@lijialin03
Copy link
Contributor Author

Could you be more specific about the meaning of "high-order (K) operators"?

Okay, I have updated the description of this PR and explained the prim function in detail.

@lululxvi
Copy link
Owner

I understand the example of log_softmax. But is this true?

In the same way, the fourth-order differential in example EulerBeam will also be decomposed into 'input->op(1st order differential)->op(1st order differential)->op(1st order differential)->op(1st order differential)->output'(The actual process is more complex, here is just a brief example).

In DeepXDE, 4-th order diff is currently implemented via doing first-order diff recursively. Why we still need PRIM for this?

@lululxvi
Copy link
Owner

By the way, is prim always faster?

@HydrogenSulfate
Copy link
Contributor

I understand the example of log_softmax. But is this true?

In the same way, the fourth-order differential in example EulerBeam will also be decomposed into 'input->op(1st order differential)->op(1st order differential)->op(1st order differential)->op(1st order differential)->output'(The actual process is more complex, here is just a brief example).

In DeepXDE, 4-th order diff is currently implemented via doing first-order diff recursively. Why we still need PRIM for this?

Hello Lulu, in Paddle, “prim” mode refers to a system where we use a set of primitive operators, and any complex operators are decomposed into these primitive operators before execution. For instance, log_softmax is decomposed into 5 primitive operators when prim is enabled: max, sub, exp, sum, and log.

These primitive operators each have forward implementations and first-order backward implementations. Together, they form a complete set of operators, meaning that once any complex operator is decomposed into this set, we no longer need to implement first-order(xxx_backward) or higher-order derivatives (xxx_double/triple_backward) for these complex operators. Their computation graphs are replaced by those formed by the decomposed primitive operators. This decomposition process increases the number of operator executions, which may result in slightly slower execution speeds. However, the advantage is that it supports differentiation to any order, provided that the differentiation relationships exist.

For users, the prim mode serves two main purposes. First, it supports differentiation to any order. Second, on certain hardware where higher-order differential operators have not been implemented, the prim mode allows these operators to be decomposed into primitive operators, enabling them to run effectively. Third, we can use paddle compiler(CINN) to optimize the decomposed computation graph, as compilers generally prefer smaller operators over larger ones. This optimization can lead to significantly improved performance. For instance, in the modulus-sym suite, we achieved nearly a 70% increase in iterations per second (IPS). We also plan to conduct testing on DeepXDE to further evaluate performance gains.

@lijialin03 lijialin03 changed the title Backend paddle: allow enable prim to accelerate running Backend paddle: allow enable prim to support more brands of chips and hardware devices Feb 25, 2025
@lululxvi lululxvi merged commit e104988 into lululxvi:master Feb 25, 2025
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants