Skip to content

Commit

Permalink
added the possibility to specify the number of CPUs used by emcee
Browse files Browse the repository at this point in the history
  • Loading branch information
gaetanoscandariato committed Nov 18, 2024
1 parent f9ea4dc commit 1e9c861
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
6 changes: 3 additions & 3 deletions SLOPpy/subroutines/bayesian_emcee.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def emcee_lines_fit_functions(model_case,
lines_center,
jitter_index,
priors_dict,
theta_start, boundaries, ndim, nwalkers, ngen, nsteps, nthin):
theta_start, boundaries, ndim, nwalkers, ngen, nsteps, nthin, ncpus):

os.environ["OMP_NUM_THREADS"] = "1"

Expand Down Expand Up @@ -173,7 +173,7 @@ def emcee_lines_fit_functions(model_case,

try:
#from pyde.de import DiffEvol
from pytransit.utils.de import DiffEvol
from pytransit.utils.de import DiffEvol
use_pyde = True
except ImportError:
print(' Warnign: PyDE is not installed, random initialization point')
Expand Down Expand Up @@ -231,7 +231,7 @@ def emcee_lines_fit_functions(model_case,
point_start[0, :] = theta_start

start = time.time()
with Pool() as pool:
with Pool(processes=ncpus) as pool:

sampler = emcee.EnsembleSampler(nwalkers,
ndim,
Expand Down
6 changes: 4 additions & 2 deletions SLOPpy/transmission_binned_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,7 @@ def compute_transmission_binned_mcmc(config_in, lines_label, reference='planetRF
nthin = sampler_pams.get('n_thin', 50)
nsteps = sampler_pams.get('n_steps', 20000)
nburnin = sampler_pams.get('n_burnin', 10000)
ncpus = sampler_pams.get('n_cpus', 10)
ndata = np.size(wave_meshgrid)

if pams_dict.get('rp_factor', False):
Expand Down Expand Up @@ -809,7 +810,7 @@ def compute_transmission_binned_mcmc(config_in, lines_label, reference='planetRF
lines_center,
jitter_index,
prior_dict,
theta_start, boundaries, ndim, nwalkers, ngen, nsteps, nthin)
theta_start, boundaries, ndim, nwalkers, ngen, nsteps, nthin, ncpus)

flat_chain, flat_lnprob, chain_med, chain_MAP, lnprob_med, lnprob_MAP = \
emcee_flatten_median(population, sampler_chain,
Expand Down Expand Up @@ -1130,6 +1131,7 @@ def compute_transmission_binned_mcmc(config_in, lines_label, reference='planetRF
nthin = sampler_pams.get('n_thin', 50)
nsteps = sampler_pams.get('n_steps', 20000)
nburnin = sampler_pams.get('n_burnin', 10000)
nburnin = sampler_pams.get('n_cpus', 10)
ndata = np.size(all_wave_meshgrid)

if pams_dict.get('rp_factor', False):
Expand Down Expand Up @@ -1158,7 +1160,7 @@ def compute_transmission_binned_mcmc(config_in, lines_label, reference='planetRF
lines_center,
all_jitter_index,
prior_dict,
theta_start, boundaries, ndim, nwalkers, ngen, nsteps, nthin)
theta_start, boundaries, ndim, nwalkers, ngen, nsteps, nthin, ncpus)

flat_chain, flat_lnprob, chain_med, chain_MAP, lnprob_med, lnprob_MAP = \
emcee_flatten_median(population, sampler_chain,
Expand Down
5 changes: 3 additions & 2 deletions SLOPpy/transmission_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,7 @@ def compute_transmission_mcmc(config_in, lines_label, reference='planetRF', pca_
nthin = sampler_pams.get('n_thin', 50)
nsteps = sampler_pams.get('n_steps', 20000)
nburnin = sampler_pams.get('n_burnin', 5000)
ncpus = sampler_pams.get('n_cpus', 10)
ndata = np.size(wave_array)

if pams_dict.get('rp_factor', False):
Expand Down Expand Up @@ -720,7 +721,7 @@ def compute_transmission_mcmc(config_in, lines_label, reference='planetRF', pca_
lines_center,
jitter_index,
prior_dict,
theta_start, boundaries, ndim, nwalkers, ngen, nsteps, nthin)
theta_start, boundaries, ndim, nwalkers, ngen, nsteps, nthin, ncpus)

flat_chain, flat_lnprob, chain_med, chain_MAP, lnprob_med, lnprob_MAP = \
emcee_flatten_median(population, sampler_chain,
Expand Down Expand Up @@ -1006,7 +1007,7 @@ def compute_transmission_mcmc(config_in, lines_label, reference='planetRF', pca_
lines_center,
all_jitter_index,
prior_dict,
theta_start, boundaries, ndim, nwalkers, ngen, nsteps, nthin)
theta_start, boundaries, ndim, nwalkers, ngen, nsteps, nthin, ncpus)

flat_chain, flat_lnprob, chain_med, chain_MAP, lnprob_med, lnprob_MAP = \
emcee_flatten_median(population, sampler_chain,
Expand Down

0 comments on commit 1e9c861

Please sign in to comment.