diff options
author | algol <dkazanc@hotmail.com> | 2018-02-06 15:17:45 +0000 |
---|---|---|
committer | Edoardo Pasca <edo.paskino@gmail.com> | 2018-02-06 16:24:46 +0000 |
commit | 60be5543cd748a64f085c600d0fd62926ded2f62 (patch) | |
tree | 25830f774530d43487f6f139de24e375d5736ca8 /Wrappers/Python | |
parent | c98c688fecda489f7396ccd66a1487e9c0f6989f (diff) | |
download | regularization-60be5543cd748a64f085c600d0fd62926ded2f62.tar.gz regularization-60be5543cd748a64f085c600d0fd62926ded2f62.tar.bz2 regularization-60be5543cd748a64f085c600d0fd62926ded2f62.tar.xz regularization-60be5543cd748a64f085c600d0fd62926ded2f62.zip |
Updated demo for CPU regularizers, bug fixed in fista_module.cpp for FGP_TV
Diffstat (limited to 'Wrappers/Python')
-rw-r--r-- | Wrappers/Python/demo/test_cpu_regularizers.py | 66 | ||||
-rw-r--r-- | Wrappers/Python/src/fista_module.cpp | 16 |
2 files changed, 50 insertions, 32 deletions
diff --git a/Wrappers/Python/demo/test_cpu_regularizers.py b/Wrappers/Python/demo/test_cpu_regularizers.py index 9713baa..03d650f 100644 --- a/Wrappers/Python/demo/test_cpu_regularizers.py +++ b/Wrappers/Python/demo/test_cpu_regularizers.py @@ -31,6 +31,11 @@ def nrmse(im1, im2): max_val = max(np.max(im1), np.max(im2)) min_val = min(np.min(im1), np.min(im2)) return 1 - (rmse / (max_val - min_val)) + +def rmse(im1, im2): + a, b = im1.shape + rmse = np.sqrt(np.sum((im1 - im2) ** 2) / float(a * b)) + return rmse ############################################################################### def printParametersToString(pars): txt = r'' @@ -63,7 +68,9 @@ filename = os.path.join(".." , ".." , ".." , "data" ,"lena_gray_512.tif") Im = plt.imread(filename) Im = np.asarray(Im, dtype='float32') -perc = 0.15 +Im = Im/255 + +perc = 0.075 u0 = Im + np.random.normal(loc = Im , scale = perc * Im , size = np.shape(Im)) @@ -76,7 +83,7 @@ fig = plt.figure() a=fig.add_subplot(2,3,1) a.set_title('noise') -imgplot = plt.imshow(u0#,cmap="gray" +imgplot = plt.imshow(u0,cmap="gray" ) reg_output = [] @@ -89,10 +96,10 @@ reg_output = [] start_time = timeit.default_timer() pars = {'algorithm' : SplitBregman_TV , \ 'input' : u0, - 'regularization_parameter':10. , \ -'number_of_iterations' :35 ,\ -'tolerance_constant':0.0001 , \ -'TV_penalty': 0 + 'regularization_parameter':15. , \ + 'number_of_iterations' :40 ,\ + 'tolerance_constant':0.0001 , \ + 'TV_penalty': 0 } out = SplitBregman_TV (pars['input'], pars['regularization_parameter'], @@ -100,6 +107,8 @@ out = SplitBregman_TV (pars['input'], pars['regularization_parameter'], pars['tolerance_constant'], pars['TV_penalty']) splitbregman = out[0] +rms = rmse(Im, splitbregman) +pars['rmse'] = rms txtstr = printParametersToString(pars) txtstr += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) print (txtstr) @@ -114,7 +123,7 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) a.text(0.05, 0.95, txtstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) imgplot = plt.imshow(splitbregman,\ - #cmap="gray" + cmap="gray" ) ###################### FGP_TV ######################################### @@ -122,9 +131,9 @@ imgplot = plt.imshow(splitbregman,\ start_time = timeit.default_timer() pars = {'algorithm' : FGP_TV , \ 'input' : u0, - 'regularization_parameter':5e-4, \ - 'number_of_iterations' :10 ,\ - 'tolerance_constant':0.001,\ + 'regularization_parameter':0.05, \ + 'number_of_iterations' :200 ,\ + 'tolerance_constant':1e-4,\ 'TV_penalty': 0 } @@ -135,6 +144,9 @@ out = FGP_TV (pars['input'], pars['TV_penalty']) fgp = out[0] +rms = rmse(Im, fgp) +pars['rmse'] = rms + txtstr = printParametersToString(pars) txtstr += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) print (txtstr) @@ -146,7 +158,7 @@ a=fig.add_subplot(2,3,3) props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords imgplot = plt.imshow(fgp, \ - #cmap="gray" + cmap="gray" ) # place a text box in upper left in axes coords a.text(0.05, 0.95, txtstr, transform=a.transAxes, fontsize=14, @@ -158,10 +170,10 @@ start_time = timeit.default_timer() pars = {'algorithm': LLT_model , \ 'input' : u0, - 'regularization_parameter': 25,\ - 'time_step':0.0003, \ - 'number_of_iterations' :300,\ - 'tolerance_constant':0.001,\ + 'regularization_parameter': 5,\ + 'time_step':0.00035, \ + 'number_of_iterations' :350,\ + 'tolerance_constant':0.0001,\ 'restrictive_Z_smoothing': 0 } out = LLT_model(pars['input'], @@ -172,6 +184,9 @@ out = LLT_model(pars['input'], pars['restrictive_Z_smoothing'] ) llt = out[0] +rms = rmse(Im, out[0]) +pars['rmse'] = rms + txtstr = printParametersToString(pars) txtstr += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) print (txtstr) @@ -183,7 +198,7 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) a.text(0.05, 0.95, txtstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) imgplot = plt.imshow(llt,\ - #cmap="gray" + cmap="gray" ) @@ -192,6 +207,7 @@ imgplot = plt.imshow(llt,\ # # Im = double(imread('lena_gray_256.tif'))/255; % loading image # # u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise # # ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); + start_time = timeit.default_timer() pars = {'algorithm': PatchBased_Regul , \ @@ -199,7 +215,7 @@ pars = {'algorithm': PatchBased_Regul , \ 'regularization_parameter': 0.05,\ 'searching_window_ratio':3, \ 'similarity_window_ratio':1,\ - 'PB_filtering_parameter': 0.08 + 'PB_filtering_parameter': 0.06 } out = PatchBased_Regul(pars['input'], pars['regularization_parameter'], @@ -207,6 +223,9 @@ out = PatchBased_Regul(pars['input'], pars['similarity_window_ratio'] , pars['PB_filtering_parameter']) pbr = out[0] +rms = rmse(Im, out[0]) +pars['rmse'] = rms + txtstr = printParametersToString(pars) txtstr += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) print (txtstr) @@ -219,8 +238,7 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords a.text(0.05, 0.95, txtstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) -imgplot = plt.imshow(pbr #,cmap="gray" - ) +imgplot = plt.imshow(pbr ,cmap="gray") # ###################### TGV_PD ######################################### @@ -233,7 +251,7 @@ start_time = timeit.default_timer() pars = {'algorithm': TGV_PD , \ 'input' : u0,\ - 'regularization_parameter':0.05,\ + 'regularization_parameter':0.07,\ 'first_order_term': 1.3,\ 'second_order_term': 1, \ 'number_of_iterations': 550 @@ -244,6 +262,9 @@ out = TGV_PD(pars['input'], pars['second_order_term'] , pars['number_of_iterations']) tgv = out[0] +rms = rmse(Im, out[0]) +pars['rmse'] = rms + txtstr = printParametersToString(pars) txtstr += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) print (txtstr) @@ -254,10 +275,7 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords a.text(0.05, 0.95, txtstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) -imgplot = plt.imshow(tgv #, cmap="gray") - ) - - +imgplot = plt.imshow(tgv, cmap="gray") plt.show() ################################################################################ diff --git a/Wrappers/Python/src/fista_module.cpp b/Wrappers/Python/src/fista_module.cpp index cef3ecc..e311570 100644 --- a/Wrappers/Python/src/fista_module.cpp +++ b/Wrappers/Python/src/fista_module.cpp @@ -412,13 +412,13 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me /*Taking a step towards minus of the gradient*/ Grad_func2D(P1, P2, D, R1, R2, lambda, dimX, dimY); - - - - - /*updating R and t*/ - tkp1 = (1.0f + sqrt(1.0f + 4.0f*tk*tk))*0.5f; - Rupd_func2D(P1, P1_old, P2, P2_old, R1, R2, tkp1, tk, dimX, dimY); + + /* projection step */ + Proj_func2D(P1, P2, methTV, dimX, dimY); + + /*updating R and t*/ + tkp1 = (1.0f + sqrt(1.0f + 4.0f*tk*tk))*0.5f; + Rupd_func2D(P1, P1_old, P2, P2_old, R1, R2, tkp1, tk, dimX, dimY); /* calculate norm */ re = 0.0f; re1 = 0.0f; @@ -429,7 +429,7 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me } re = sqrt(re) / sqrt(re1); if (re < epsil) count++; - if (count > 3) { + if (count > 4) { Obj_func2D(A, D, P1, P2, lambda, dimX, dimY); funcval = 0.0f; for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2); |