Commit a95e36a4 authored by Matthieu Tristram's avatar Matthieu Tristram
Browse files

Duplicate plot_correlation

parent c3886afb
...@@ -207,9 +207,8 @@ def cov2cor( mat): ...@@ -207,9 +207,8 @@ def cov2cor( mat):
######################################################################## ########################################################################
#CONVERGENCE #CONVERGENCE
######################################################################## ########################################################################
def GelmanRubin( chains, gap=10000, length_max=None, new=False): def GelmanRubin( chains, gap=10000, length_min=1000, length_max=None, new=False):
length_min = 1000 #plot GelmanRubin test skipping first parameter=chi2
nchain = len(chains) nchain = len(chains)
nsamples = min( [len(c[0]) for c in chains]) nsamples = min( [len(c[0]) for c in chains])
...@@ -227,11 +226,11 @@ def GelmanRubin( chains, gap=10000, length_max=None, new=False): ...@@ -227,11 +226,11 @@ def GelmanRubin( chains, gap=10000, length_max=None, new=False):
if new: if new:
#Gelman-Rubin stat on the last "gap" samples for each iteration #Gelman-Rubin stat on the last "gap" samples for each iteration
n = gap n = gap
tmp = [data[:,isamp:isamp+gap] for data in chains] tmp = [data[1:,isamp:isamp+gap] for data in chains]
else: else:
#Gelman-Runbin stat on the last half samples for each iteration #Gelman-Runbin stat on the last half samples for each iteration
n = (isamp - length_min)/2 n = (isamp - length_min)/2
tmp = [data[:,length_min+n:isamp] for data in chains] tmp = [data[1:,length_min+n:isamp] for data in chains]
#within Chain #within Chain
mchain = np.mean( tmp,2) mchain = np.mean( tmp,2)
...@@ -788,17 +787,20 @@ def hist1d(x, bins=50, weights=None, smooth=0., normmax=True, **kwargs): ...@@ -788,17 +787,20 @@ def hist1d(x, bins=50, weights=None, smooth=0., normmax=True, **kwargs):
def plot_correlation( chain, par, **kwargs): def MCcorrelation( chain, par, **kwargs):
data = [chain[p] for p in par]
matC = np.corrcoef( data)
plot_correlation( matC, par, **kwargs)
def plot_correlation( matC, par, **kwargs):
vmin = kwargs.pop("vmin", -1) vmin = kwargs.pop("vmin", -1)
vmax = kwargs.pop("vmax", 1) vmax = kwargs.pop("vmax", 1)
data = [chain[p] for p in par]
plt.figure( figsize=(14,10)) plt.figure( figsize=(14,10))
#covmat = np.triu(np.corrcoef( chain.values()[1:]), k=1) #covmat = np.triu(np.corrcoef( chain.values()[1:]), k=1)
matC = np.corrcoef( data)
covmat = np.tril( matC, k=-1) covmat = np.tril( matC, k=-1)
covmat[covmat==0] = np.NaN covmat[covmat==0] = np.NaN
plt.imshow( covmat, vmin=vmin, vmax=vmax, **kwargs) plt.imshow( covmat, vmin=vmin, vmax=vmax, **kwargs)
...@@ -809,6 +811,8 @@ def plot_correlation( chain, par, **kwargs): ...@@ -809,6 +811,8 @@ def plot_correlation( chain, par, **kwargs):
# text( i, i+0.25, parname.get(par[i+1],par[i+1]),rotation=-90) # text( i, i+0.25, parname.get(par[i+1],par[i+1]),rotation=-90)
plt.text( i-0.25, i+0.25, parname.get(par[i],par[i]),rotation='vertical',verticalalignment='bottom') plt.text( i-0.25, i+0.25, parname.get(par[i],par[i]),rotation='vertical',verticalalignment='bottom')
######################################################################## ########################################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment