from logging import raiseExceptions import matplotlib.pyplot as plt import matplotlib import numpy as np from scipy.special import factorial import scipy.stats poisson_cdf = scipy.stats.poisson.cdf def normal(x, mu, sigma): return np.exp(-(x-mu)**2/(2*sigma**2))/np.sqrt(2*np.pi*sigma**2) def poisson(k, mu): return mu**k *np.exp(- mu)/ factorial(k) def tram_time(n_rides=2000, cl=0.95, type='single'): plt.rcParams['figure.figsize'] = [10, 5] entries = np.random.lognormal(2,0.5, n_rides)/3 + 0.5 entries *= (-1) if type == 'ensemble': return np.quantile(entries, q=cl) bins = 200 n, edges = np.histogram(entries, bins=bins) ymax = np.max(n) plt.clf() fig = plt.figure() ax1 = fig.add_subplot(111) ax1.set_xlabel('Time $t$ in Minutes', fontsize=12) ax1.set_xlim([-14, 0.5]) if type == 'double': lower = np.quantile(entries, q=(1-cl)/2) upper = np.quantile(entries, q=cl+(1-cl)/2) ax1.vlines(lower, ymin=0, ymax=ymax, color='green', label='Conf. Interval') ax1.vlines(upper, ymin=0, ymax=ymax, color='green') ax1.text(upper-0.3, ymax*1.03, str(round(upper,2)), color='green', fontsize=12) ax1.text(lower-0.3, ymax*1.03, str(round(lower,2)), color='green', fontsize=12) n, edges, _ = ax1.hist(entries, bins=bins, label='{} Test Ride(s)'.format(n_rides)) mids = edges + (edges[1]-edges[0])/2. arrive = mids[n.argmax()] title_string = 'Time $t$: ${:.2f}^{{+{:.2f}}}_{{-{:.2f}}}$ min. @{:.0f}% CL'.format(arrive, upper-arrive, arrive-lower, 100*cl) elif type == 'single' or type == 'ensemble': lower = np.quantile(entries, q=(1-cl)) ax1.vlines(lower, ymin=0, ymax=ymax, color='green', label='Lower Limit: {:.2f} min. @{:.0f}% CL'.format(lower, 100*cl)) ax1.text(lower-0.3, ymax*1.03, str(round(lower,2)), color='green', fontsize=12) n, edges, _ = ax1.hist(entries, bins=bins, label='{} Test Ride(s)'.format(n_rides)) mids = edges + (edges[1]-edges[0])/2. arrive = mids[n.argmax()] title_string = 'Time $t$: {:.2f} min.'.format(arrive) #ax2 = ax1.twiny() #ax2.set_xlim((np.array([-0.5, 14])-upper)*(-1)) #ax2.set_xlabel('Time to Lecture/Exam in Minutes', fontsize=12) else: raise NameError("Choose either 'single' or 'double' as limit type") ax1.legend(loc='upper center', bbox_to_anchor=(0.5, 1.3), ncol=1, fancybox=True, shadow=True, title=title_string, title_fontsize=15) ax1.set_ylabel('Test Ride(s)', fontsize=12) ax1.set_ylim([0, ymax*1.1]) plt.show() def find_limit(obs, cl=0.95): k_range = np.linspace(obs/5, 5*obs, 10000) quotient = np.abs((1-poisson_cdf(k=obs, mu=k_range))-cl) return round(k_range[np.argmin(quotient)],2) def find_cls_limit(obs, b, cl=0.95): s = np.linspace(0, 5*(b+1), 10000) cl_sb = poisson_cdf(k=obs, mu=s+b) cl_b = poisson_cdf(k=obs, mu=b) cl_s = cl_sb/cl_b arg = np.argmin(np.abs(cl_s - (1-cl))) return round(s[arg], 2) def neyman_construction(columns, cl=0.95): # very ugly n = 26 columns_number=8.5 plt.rcParams['figure.figsize'] = 30, 10 x = np.zeros((n, n)) cols = np.arange(1, (n+2)/2, 0.5) upper = [] for row in range(n): data = poisson(mu=cols, k=row+1) x[:,row] += data for row in range(n): sum = 0 for col in range(n): sum += x[row,col] if sum > 1-cl: upper.append(col) break plt.clf() fig = plt.figure() ax1 = fig.add_subplot(111) xmin = 0.5 xmax = n-0.5 ymin = 0 ymax = n/2 ax1.set_xlim([xmin, (xmax/2)-0.25]) if columns=='one': ax1.set_ylim([columns_number-0.25-0.75, columns_number+0.25-0.75]) ax1.set_yticks([columns_number-0.75]) ax1.set_yticklabels([columns_number]) elif columns=='all': ax1.set_ylim([ymin, ymax-0.5]) ax1.set_yticks(cols[:-1]-0.75) ax1.set_yticklabels(cols[:-1]) else: raise ValueError('Please use "all" or "one" as input for columns') im = ax1.imshow(x, origin='lower', extent = [xmin, xmax+1, ymin, ymax]) ax1.step(np.array(upper)+0.5, (np.arange(len(upper))/2), color='red') cbar = fig.colorbar(im) cbar.set_label(label='Probability Density', size=15) ax1.set_ylabel('Model Parameter $\mu$', fontsize=15) ax1.set_xlabel('Observation k', fontsize=15) for (j,i),label in np.ndenumerate(x): if i>=int(n/2): continue if columns=='one': if j!=(columns_number-1)*2: continue ax1.text(i+1,(j/2)+0.25,round(label,3),ha='center',va='center', color='w') plt.show()