-
Notifications
You must be signed in to change notification settings - Fork 107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move kernel functions to their algorithm-class folder. #501
Conversation
c300d43
to
8041a3f
Compare
Codecov Report
@@ Coverage Diff @@
## main #501 +/- ##
==========================================
+ Coverage 99.28% 99.29% +0.01%
==========================================
Files 47 47
Lines 1948 1978 +30
==========================================
+ Hits 1934 1964 +30
Misses 14 14
|
not sure what is up with the building of the docs by readthedocs |
Love the overall direction this is going, some high level comments:
|
tests/mcmc/test_sampling.py
Outdated
@@ -449,6 +452,7 @@ def normal_logprob(self, x): | |||
def test_univariate_normal( | |||
self, algorithm, initial_position, parameters, num_sampling_steps, burnin | |||
): | |||
algorithm = eval(algorithm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I understand this change, seems not needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because blackjax.hmc
(or any other kernel constructor) is now a function, it creates an object with a specific location in memory when evaluated. If this object is created outside the test function test_univariate_normal
then calls to the chex variant function inference_loop
have a different call sign for each of the devices/cores (because each call sign has a different memory location), throwing an error when collecting tests. The error looks like ERROR collecting gw1. Different tests were collected between gw0 and gw1
.
tests/smc/test_tempered_smc.py
Outdated
@@ -68,8 +68,8 @@ def logprior_fn(x): | |||
iterates = [] | |||
results = [] # type: List[TemperedSMCState] | |||
|
|||
hmc_kernel = blackjax.hmc.kernel() | |||
hmc_init = blackjax.hmc.init | |||
hmc_kernel = blackjax.mcmc.hmc.kernel() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, in the top level __init__.py
you imported the symbol, so you dont need to make this change
Fixed this temporarily by allowing the build to pass even if there are warnings (we'll need to figure these out later) in #502. |
Ok, so that's why I actually added these classes :/ Not a big fan of passing names as strings. |
Similar issue for when we are writing test: #501 (comment) |
We still have
Agree with both. Our other option is to have the user pass both the |
are we turning our backs to Marx and staying on a class structure? Python does seem to be, inherently, an object oriented language...
|
ed3f6cf
to
cf45d3b
Compare
Back to classes (but still have the function commits if we want to go back to that), also I've made naming on sampling algorithms consistent #280 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this.
Since we need to keep the class implementation for high level API, the devil's advocate is that keeping them in the same file (like what we have originally) makes the implementation pattern clear and easy to reference / compare.
Also to consider, the |
blackjax/mcmc/random_walk.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@junpenglao this file is a good example of various algorithms in one script. Since both additive_step_random_walk
and irmh
are special cases of rmh
it makes sense to keep them all together, wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Only one minor comment: any reason why you are putting the top level API for blackjax.adaptation
before def base(...)
?
Didn't notice this, probably best to keep the set up consistent... |
* Move kernel function constructors to their respective algorithm-class folder * base classes for adaptation algorithms
Closes #492 before people start building new algorithms with the previous structure.
Besides loosing the aliases, the other major breaking change is that for adaptation algorithms the user can no longer pass a class as the
algorithm
to adapt with. Hence, the user now needs to pass a string with thealgorithm_name
for the adaptation kernel to invoke thekernel
andinit
method. This applies towindow_adaptation
andpathfinder_adaptation
.