summaryrefslogtreecommitdiffstats
path: root/Wrappers/Python
diff options
context:
space:
mode:
authoralgol <dkazanc@hotmail.com>2018-02-06 15:17:45 +0000
committerEdoardo Pasca <edo.paskino@gmail.com>2018-02-06 16:24:46 +0000
commit60be5543cd748a64f085c600d0fd62926ded2f62 (patch)
tree25830f774530d43487f6f139de24e375d5736ca8 /Wrappers/Python
parentc98c688fecda489f7396ccd66a1487e9c0f6989f (diff)
downloadregularization-60be5543cd748a64f085c600d0fd62926ded2f62.tar.gz
regularization-60be5543cd748a64f085c600d0fd62926ded2f62.tar.bz2
regularization-60be5543cd748a64f085c600d0fd62926ded2f62.tar.xz
regularization-60be5543cd748a64f085c600d0fd62926ded2f62.zip
Updated demo for CPU regularizers, bug fixed in fista_module.cpp for FGP_TV
Diffstat (limited to 'Wrappers/Python')
-rw-r--r--Wrappers/Python/demo/test_cpu_regularizers.py66
-rw-r--r--Wrappers/Python/src/fista_module.cpp16
2 files changed, 50 insertions, 32 deletions
diff --git a/Wrappers/Python/demo/test_cpu_regularizers.py b/Wrappers/Python/demo/test_cpu_regularizers.py
index 9713baa..03d650f 100644
--- a/Wrappers/Python/demo/test_cpu_regularizers.py
+++ b/Wrappers/Python/demo/test_cpu_regularizers.py
@@ -31,6 +31,11 @@ def nrmse(im1, im2):
max_val = max(np.max(im1), np.max(im2))
min_val = min(np.min(im1), np.min(im2))
return 1 - (rmse / (max_val - min_val))
+
+def rmse(im1, im2):
+ a, b = im1.shape
+ rmse = np.sqrt(np.sum((im1 - im2) ** 2) / float(a * b))
+ return rmse
###############################################################################
def printParametersToString(pars):
txt = r''
@@ -63,7 +68,9 @@ filename = os.path.join(".." , ".." , ".." , "data" ,"lena_gray_512.tif")
Im = plt.imread(filename)
Im = np.asarray(Im, dtype='float32')
-perc = 0.15
+Im = Im/255
+
+perc = 0.075
u0 = Im + np.random.normal(loc = Im ,
scale = perc * Im ,
size = np.shape(Im))
@@ -76,7 +83,7 @@ fig = plt.figure()
a=fig.add_subplot(2,3,1)
a.set_title('noise')
-imgplot = plt.imshow(u0#,cmap="gray"
+imgplot = plt.imshow(u0,cmap="gray"
)
reg_output = []
@@ -89,10 +96,10 @@ reg_output = []
start_time = timeit.default_timer()
pars = {'algorithm' : SplitBregman_TV , \
'input' : u0,
- 'regularization_parameter':10. , \
-'number_of_iterations' :35 ,\
-'tolerance_constant':0.0001 , \
-'TV_penalty': 0
+ 'regularization_parameter':15. , \
+ 'number_of_iterations' :40 ,\
+ 'tolerance_constant':0.0001 , \
+ 'TV_penalty': 0
}
out = SplitBregman_TV (pars['input'], pars['regularization_parameter'],
@@ -100,6 +107,8 @@ out = SplitBregman_TV (pars['input'], pars['regularization_parameter'],
pars['tolerance_constant'],
pars['TV_penalty'])
splitbregman = out[0]
+rms = rmse(Im, splitbregman)
+pars['rmse'] = rms
txtstr = printParametersToString(pars)
txtstr += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time)
print (txtstr)
@@ -114,7 +123,7 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
a.text(0.05, 0.95, txtstr, transform=a.transAxes, fontsize=14,
verticalalignment='top', bbox=props)
imgplot = plt.imshow(splitbregman,\
- #cmap="gray"
+ cmap="gray"
)
###################### FGP_TV #########################################
@@ -122,9 +131,9 @@ imgplot = plt.imshow(splitbregman,\
start_time = timeit.default_timer()
pars = {'algorithm' : FGP_TV , \
'input' : u0,
- 'regularization_parameter':5e-4, \
- 'number_of_iterations' :10 ,\
- 'tolerance_constant':0.001,\
+ 'regularization_parameter':0.05, \
+ 'number_of_iterations' :200 ,\
+ 'tolerance_constant':1e-4,\
'TV_penalty': 0
}
@@ -135,6 +144,9 @@ out = FGP_TV (pars['input'],
pars['TV_penalty'])
fgp = out[0]
+rms = rmse(Im, fgp)
+pars['rmse'] = rms
+
txtstr = printParametersToString(pars)
txtstr += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time)
print (txtstr)
@@ -146,7 +158,7 @@ a=fig.add_subplot(2,3,3)
props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
# place a text box in upper left in axes coords
imgplot = plt.imshow(fgp, \
- #cmap="gray"
+ cmap="gray"
)
# place a text box in upper left in axes coords
a.text(0.05, 0.95, txtstr, transform=a.transAxes, fontsize=14,
@@ -158,10 +170,10 @@ start_time = timeit.default_timer()
pars = {'algorithm': LLT_model , \
'input' : u0,
- 'regularization_parameter': 25,\
- 'time_step':0.0003, \
- 'number_of_iterations' :300,\
- 'tolerance_constant':0.001,\
+ 'regularization_parameter': 5,\
+ 'time_step':0.00035, \
+ 'number_of_iterations' :350,\
+ 'tolerance_constant':0.0001,\
'restrictive_Z_smoothing': 0
}
out = LLT_model(pars['input'],
@@ -172,6 +184,9 @@ out = LLT_model(pars['input'],
pars['restrictive_Z_smoothing'] )
llt = out[0]
+rms = rmse(Im, out[0])
+pars['rmse'] = rms
+
txtstr = printParametersToString(pars)
txtstr += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time)
print (txtstr)
@@ -183,7 +198,7 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
a.text(0.05, 0.95, txtstr, transform=a.transAxes, fontsize=14,
verticalalignment='top', bbox=props)
imgplot = plt.imshow(llt,\
- #cmap="gray"
+ cmap="gray"
)
@@ -192,6 +207,7 @@ imgplot = plt.imshow(llt,\
# # Im = double(imread('lena_gray_256.tif'))/255; % loading image
# # u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise
# # ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05);
+
start_time = timeit.default_timer()
pars = {'algorithm': PatchBased_Regul , \
@@ -199,7 +215,7 @@ pars = {'algorithm': PatchBased_Regul , \
'regularization_parameter': 0.05,\
'searching_window_ratio':3, \
'similarity_window_ratio':1,\
- 'PB_filtering_parameter': 0.08
+ 'PB_filtering_parameter': 0.06
}
out = PatchBased_Regul(pars['input'],
pars['regularization_parameter'],
@@ -207,6 +223,9 @@ out = PatchBased_Regul(pars['input'],
pars['similarity_window_ratio'] ,
pars['PB_filtering_parameter'])
pbr = out[0]
+rms = rmse(Im, out[0])
+pars['rmse'] = rms
+
txtstr = printParametersToString(pars)
txtstr += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time)
print (txtstr)
@@ -219,8 +238,7 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
# place a text box in upper left in axes coords
a.text(0.05, 0.95, txtstr, transform=a.transAxes, fontsize=14,
verticalalignment='top', bbox=props)
-imgplot = plt.imshow(pbr #,cmap="gray"
- )
+imgplot = plt.imshow(pbr ,cmap="gray")
# ###################### TGV_PD #########################################
@@ -233,7 +251,7 @@ start_time = timeit.default_timer()
pars = {'algorithm': TGV_PD , \
'input' : u0,\
- 'regularization_parameter':0.05,\
+ 'regularization_parameter':0.07,\
'first_order_term': 1.3,\
'second_order_term': 1, \
'number_of_iterations': 550
@@ -244,6 +262,9 @@ out = TGV_PD(pars['input'],
pars['second_order_term'] ,
pars['number_of_iterations'])
tgv = out[0]
+rms = rmse(Im, out[0])
+pars['rmse'] = rms
+
txtstr = printParametersToString(pars)
txtstr += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time)
print (txtstr)
@@ -254,10 +275,7 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
# place a text box in upper left in axes coords
a.text(0.05, 0.95, txtstr, transform=a.transAxes, fontsize=14,
verticalalignment='top', bbox=props)
-imgplot = plt.imshow(tgv #, cmap="gray")
- )
-
-
+imgplot = plt.imshow(tgv, cmap="gray")
plt.show()
################################################################################
diff --git a/Wrappers/Python/src/fista_module.cpp b/Wrappers/Python/src/fista_module.cpp
index cef3ecc..e311570 100644
--- a/Wrappers/Python/src/fista_module.cpp
+++ b/Wrappers/Python/src/fista_module.cpp
@@ -412,13 +412,13 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me
/*Taking a step towards minus of the gradient*/
Grad_func2D(P1, P2, D, R1, R2, lambda, dimX, dimY);
-
-
-
-
- /*updating R and t*/
- tkp1 = (1.0f + sqrt(1.0f + 4.0f*tk*tk))*0.5f;
- Rupd_func2D(P1, P1_old, P2, P2_old, R1, R2, tkp1, tk, dimX, dimY);
+
+ /* projection step */
+ Proj_func2D(P1, P2, methTV, dimX, dimY);
+
+ /*updating R and t*/
+ tkp1 = (1.0f + sqrt(1.0f + 4.0f*tk*tk))*0.5f;
+ Rupd_func2D(P1, P1_old, P2, P2_old, R1, R2, tkp1, tk, dimX, dimY);
/* calculate norm */
re = 0.0f; re1 = 0.0f;
@@ -429,7 +429,7 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me
}
re = sqrt(re) / sqrt(re1);
if (re < epsil) count++;
- if (count > 3) {
+ if (count > 4) {
Obj_func2D(A, D, P1, P2, lambda, dimX, dimY);
funcval = 0.0f;
for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2);