Error when using PTSampler with profiling = True with multiprocessing
Hello! I've encountered error when trying to execute code from demo-ptmcmc.ipynb
, with the addition of multiprocessing (the code itself is in the end). This error is occurring only when PTSampler
is created with profiling=True
option.
The error message is the following:
Process Process-1:
Traceback (most recent call last):
File "/sps/lisaf/sviatoslav/mcmc_env/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
self.run()
File "/sps/lisaf/sviatoslav/mcmc_env/lib/python3.9/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/sps/lisaf/sviatoslav/mcmc_env/lib/python3.9/site-packages/m3c2/sampler.py", line 785, in _run_proc
y, Ly, Py = self.chains[num_chain].full_step(ims, chains=self.chains)
File "/sps/lisaf/sviatoslav/mcmc_env/lib/python3.9/site-packages/m3c2/chain.py", line 307, in full_step
prop = self.choose_proposal()
File "/sps/lisaf/sviatoslav/mcmc_env/lib/python3.9/site-packages/m3c2/chain.py", line 249, in choose_proposal
self.jumps[str(p)] += 1
KeyError: '<bound method SCAM.SCAM of <m3c2.proposal.SCAM object at 0x7f8585d84160>>'
As I understand, this error originated from the keys of dictionary self.jumps
from Chain
class. In the newest version these keys are defined as str(p)
, p is a proposal. However, when using multiprocessing the addresses of proposals' variables (in code below: SL
and SC
) change and therefore str(p)
is different.
A possible solution (which was used in mc3) is to change keys back to p.__self__.name
.
Full code below:
import numpy as np
import matplotlib.pyplot as plt
import pickle
import m3c2.proposal as proposal
from m3c2.sampler import PTSampler
from m3c2.tools import get_autocorr_len, get_mcmc_stats
import warnings
import multiprocessing as mp
class TestLogLik:
""" A multi dimensional Gaussian
"""
def __init__(self, ndim):
""" Init number of dimensions.
"""
self.ndim = ndim
self.param_dic = [f"p{i}" for i in range(ndim)]
means = np.random.rand(ndim)
cov = 0.5 - np.random.rand(ndim ** 2).reshape((ndim, ndim))
cov = np.triu(cov)
cov += cov.T - np.diag(cov.diagonal())
cov = np.dot(cov, cov)
self.mu = means
self.cov = cov
def loglik(self, x, **kwargs):
""" Return log-likelihood for a given point x.
"""
diff = x - self.mu
return -0.5 * np.dot(diff, np.linalg.solve(self.cov, diff))
def logpi(self, p, **kwargs):
""" Return log-prior for a given point x.
"""
return -7.0
def main():
# Parallel tempering parameters
Tmax = 10
Nchains = 5
# Define likelihood, priors and starting point
ndim = 3
T = TestLogLik(ndim)
priors = np.array([-3,3]*T.ndim).reshape(T.ndim,2)
x0 = [np.random.randn(T.ndim) for n in range(Nchains)]
S = PTSampler(Nchains, T.loglik, T.logpi, T.param_dic, Tmax=Tmax, profiling=True)
S.set_starting_point(x0)
# Define proposals
SL = proposal.Slice(T.param_dic).slice
SC = proposal.SCAM(T.param_dic).SCAM
p_dict = [{SL:40, SC:70}]*Nchains
S.set_proposals(p_dict)
print("str(proposals) in the time of creation", str(SL), str(SC))
print("Available cores:", mp.cpu_count())
# Run mcmc
niter = 100
c = S.run_mcmc(niter, pSwap=0.95, printN=50, multiproc=True, n0_swap=100)
if __name__ == '__main__':
mp.set_start_method('spawn')
warnings.filterwarnings('ignore')
mp.freeze_support()
main()
Edited by Sviatoslav Khukhlaev