summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2017-08-07 17:21:54 +0100
committerEdoardo Pasca <edo.paskino@gmail.com>2017-08-07 17:21:54 +0100
commitdb50cddf2cfe92c652ff16ce51a3bcecca96de68 (patch)
tree33621e652c83d06a33b3757937023cd307646950 /src
parent6589fa197d9f87f7a37f46943aa995d97f50bb46 (diff)
downloadregularization-db50cddf2cfe92c652ff16ce51a3bcecca96de68.tar.gz
regularization-db50cddf2cfe92c652ff16ce51a3bcecca96de68.tar.bz2
regularization-db50cddf2cfe92c652ff16ce51a3bcecca96de68.tar.xz
regularization-db50cddf2cfe92c652ff16ce51a3bcecca96de68.zip
added TGV_PD
Diffstat (limited to 'src')
-rw-r--r--src/Python/setup.py1
-rw-r--r--src/Python/test_regularizers.py195
2 files changed, 168 insertions, 28 deletions
diff --git a/src/Python/setup.py b/src/Python/setup.py
index a4eed14..0468722 100644
--- a/src/Python/setup.py
+++ b/src/Python/setup.py
@@ -53,6 +53,7 @@ setup(
"..\\..\\main_func\\regularizers_CPU\\SplitBregman_TV_core.c",
"..\\..\\main_func\\regularizers_CPU\\LLT_model_core.c",
"..\\..\\main_func\\regularizers_CPU\\PatchBased_Regul_core.c",
+ "..\\..\\main_func\\regularizers_CPU\\TGV_PD_core.c",
"..\\..\\main_func\\regularizers_CPU\\utils.c"
],
include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ),
diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py
index 6abfba4..6a34749 100644
--- a/src/Python/test_regularizers.py
+++ b/src/Python/test_regularizers.py
@@ -47,6 +47,8 @@ class Regularizer():
SplitBregman_TV = regularizers.SplitBregman_TV
FGP_TV = regularizers.FGP_TV
LLT_model = regularizers.LLT_model
+ PatchBased_Regul = regularizers.PatchBased_Regul
+ TGV_PD = regularizers.TGV_PD
# Algorithm
class TotalVariationPenalty(Enum):
@@ -55,13 +57,17 @@ class Regularizer():
# TotalVariationPenalty
def __init__(self , algorithm):
-
+ self.setAlgorithm ( algorithm )
+ # __init__
+
+ def setAlgorithm(self, algorithm):
self.algorithm = algorithm
self.pars = self.parsForAlgorithm(algorithm)
- # __init__
+ # setAlgorithm
def parsForAlgorithm(self, algorithm):
pars = dict()
+
if algorithm == Regularizer.Algorithm.SplitBregman_TV :
pars['algorithm'] = algorithm
pars['input'] = None
@@ -69,6 +75,7 @@ class Regularizer():
pars['number_of_iterations'] = 35
pars['tolerance_constant'] = 0.0001
pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic
+
elif algorithm == Regularizer.Algorithm.FGP_TV :
pars['algorithm'] = algorithm
pars['input'] = None
@@ -76,6 +83,7 @@ class Regularizer():
pars['number_of_iterations'] = 50
pars['tolerance_constant'] = 0.001
pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic
+
elif algorithm == Regularizer.Algorithm.LLT_model:
pars['algorithm'] = algorithm
pars['input'] = None
@@ -85,6 +93,24 @@ class Regularizer():
pars['tolerance_constant'] = None
pars['restrictive_Z_smoothing'] = 0
+ elif algorithm == Regularizer.Algorithm.PatchBased_Regul:
+ pars['algorithm'] = algorithm
+ pars['input'] = None
+ pars['searching_window_ratio'] = None
+ pars['similarity_window_ratio'] = None
+ pars['PB_filtering_parameter'] = None
+ pars['regularization_parameter'] = None
+
+ elif algorithm == Regularizer.Algorithm.TGV_PD:
+ pars['algorithm'] = algorithm
+ pars['input'] = None
+ pars['first_order_term'] = None
+ pars['second_order_term'] = None
+ pars['number_of_iterations'] = None
+ pars['regularization_parameter'] = None
+
+
+
return pars
# parsForAlgorithm
@@ -98,6 +124,8 @@ class Regularizer():
self.pars['regularization_parameter'] = regularization_parameter
#for key, value in self.pars.items():
# print("{0} = {1}".format(key, value))
+ if None in self.pars:
+ raise Exception("Not all parameters have been provided")
if self.algorithm == Regularizer.Algorithm.SplitBregman_TV :
return self.algorithm(input, regularization_parameter,
@@ -112,15 +140,27 @@ class Regularizer():
elif self.algorithm == Regularizer.Algorithm.LLT_model :
#LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher)
# no default
- if None in self.pars:
- raise Exception("Not all parameters have been provided")
- else:
- return self.algorithm(input,
- regularization_parameter,
- self.pars['time_step'] ,
- self.pars['number_of_iterations'],
- self.pars['tolerance_constant'],
- self.pars['restrictive_Z_smoothing'] )
+ return self.algorithm(input,
+ regularization_parameter,
+ self.pars['time_step'] ,
+ self.pars['number_of_iterations'],
+ self.pars['tolerance_constant'],
+ self.pars['restrictive_Z_smoothing'] )
+ elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul :
+ #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher)
+ # no default
+ return self.algorithm(input, regularization_parameter,
+ self.pars['searching_window_ratio'] ,
+ self.pars['similarity_window_ratio'] ,
+ self.pars['PB_filtering_parameter'])
+ elif self.algorithm == Regularizer.Algorithm.TGV_PD :
+ #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher)
+ # no default
+ return self.algorithm(input, regularization_parameter,
+ self.pars['first_order_term'] ,
+ self.pars['second_order_term'] ,
+ self.pars['number_of_iterations'])
+
# __call__
@@ -142,13 +182,40 @@ class Regularizer():
@staticmethod
def LLT_model(input, regularization_parameter , time_step, number_of_iterations,
tolerance_constant, restrictive_Z_smoothing=0):
- reg = Regularizer(Regularizer.Algorithm.FGP_TV)
+ reg = Regularizer(Regularizer.Algorithm.LLT_model)
out = list( reg(input, regularization_parameter, time_step=time_step,
number_of_iterations=number_of_iterations,
tolerance_constant=tolerance_constant,
restrictive_Z_smoothing=restrictive_Z_smoothing) )
out.append(reg.pars)
return out
+
+ @staticmethod
+ def PatchBased_Regul(input, regularization_parameter,
+ searching_window_ratio,
+ similarity_window_ratio,
+ PB_filtering_parameter):
+ reg = Regularizer(Regularizer.Algorithm.PatchBased_Regul)
+ out = list( reg(input,
+ regularization_parameter,
+ searching_window_ratio=searching_window_ratio,
+ similarity_window_ratio=similarity_window_ratio,
+ PB_filtering_parameter=PB_filtering_parameter )
+ )
+ out.append(reg.pars)
+ return out
+
+ @staticmethod
+ def TGV_PD(input, regularization_parameter , first_order_term,
+ second_order_term, number_of_iterations):
+
+ reg = Regularizer(Regularizer.Algorithm.TGV_PD)
+ out = list( reg(input, regularization_parameter,
+ first_order_term=first_order_term,
+ second_order_term=second_order_term,
+ number_of_iterations=number_of_iterations) )
+ out.append(reg.pars)
+ return out
#Example:
@@ -171,17 +238,17 @@ u0 = Im + (perc* np.random.normal(size=np.shape(Im)))
f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1)
u0 = f(u0).astype('float32')
-# plot
+## plot
fig = plt.figure()
-a=fig.add_subplot(2,3,1)
-a.set_title('Original')
-imgplot = plt.imshow(Im)
+#a=fig.add_subplot(3,3,1)
+#a.set_title('Original')
+#imgplot = plt.imshow(Im)
-a=fig.add_subplot(2,3,2)
+a=fig.add_subplot(2,3,1)
a.set_title('noise')
imgplot = plt.imshow(u0)
-
+reg_output = []
##############################################################################
# Call regularizer
@@ -199,8 +266,9 @@ out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., numbe
TV_Penalty=Regularizer.TotalVariationPenalty.l1)
out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. )
pars = out2[2]
+reg_output.append(out2)
-a=fig.add_subplot(2,3,3)
+a=fig.add_subplot(2,3,2)
a.set_title('SplitBregman_TV')
textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s'
textstr = textstr % (pars['regularization_parameter'],
@@ -213,7 +281,7 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
# place a text box in upper left in axes coords
a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
verticalalignment='top', bbox=props)
-imgplot = plt.imshow(out2[0])
+imgplot = plt.imshow(reg_output[-1][0])
###################### FGP_TV #########################################
# u = FGP_TV(single(u0), 0.05, 100, 1e-04);
@@ -221,7 +289,9 @@ out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.05,
number_of_iterations=10)
pars = out2[-1]
-a=fig.add_subplot(2,3,4)
+reg_output.append(out2)
+
+a=fig.add_subplot(2,3,3)
a.set_title('FGP_TV')
textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s'
textstr = textstr % (pars['regularization_parameter'],
@@ -234,18 +304,23 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
# place a text box in upper left in axes coords
a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
verticalalignment='top', bbox=props)
-imgplot = plt.imshow(out2[0])
+imgplot = plt.imshow(reg_output[-1][0])
###################### LLT_model #########################################
# * u0 = Im + .03*randn(size(Im)); % adding noise
# [Den] = LLT_model(single(u0), 10, 0.1, 1);
-out2 = Regularizer.LLT_model(input=u0, regularization_parameter=10.,
- time_step=0.1,
- tolerance_constant=1e-4,
- number_of_iterations=10)
+#Den = LLT_model(single(u0), 25, 0.0003, 300, 0.0001, 0);
+#input, regularization_parameter , time_step, number_of_iterations,
+# tolerance_constant, restrictive_Z_smoothing=0
+out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25,
+ time_step=0.0003,
+ tolerance_constant=0.0001,
+ number_of_iterations=300)
pars = out2[-1]
-a=fig.add_subplot(2,3,5)
+reg_output.append(out2)
+
+a=fig.add_subplot(2,3,4)
a.set_title('LLT_model')
textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\ntime-step=%f'
textstr = textstr % (pars['regularization_parameter'],
@@ -259,7 +334,71 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
# place a text box in upper left in axes coords
a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
verticalalignment='top', bbox=props)
-imgplot = plt.imshow(out2[0])
+imgplot = plt.imshow(reg_output[-1][0])
+
+###################### PatchBased_Regul #########################################
+# Quick 2D denoising example in Matlab:
+# Im = double(imread('lena_gray_256.tif'))/255; % loading image
+# u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise
+# ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05);
+
+out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05,
+ searching_window_ratio=3,
+ similarity_window_ratio=1,
+ PB_filtering_parameter=0.08)
+pars = out2[-1]
+reg_output.append(out2)
+
+a=fig.add_subplot(2,3,5)
+a.set_title('PatchBased_Regul')
+textstr = 'regularization_parameter=%.2f\nsearching_window_ratio=%d\nsimilarity_window_ratio=%.2e\nPB_filtering_parameter=%f'
+textstr = textstr % (pars['regularization_parameter'],
+ pars['searching_window_ratio'],
+ pars['similarity_window_ratio'],
+ pars['PB_filtering_parameter'])
+
+
+
+
+# these are matplotlib.patch.Patch properties
+props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+# place a text box in upper left in axes coords
+a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+ verticalalignment='top', bbox=props)
+imgplot = plt.imshow(reg_output[-1][0])
+
+
+###################### TGV_PD #########################################
+# Quick 2D denoising example in Matlab:
+# Im = double(imread('lena_gray_256.tif'))/255; % loading image
+# u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise
+# u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550);
+
+
+out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05,
+ first_order_term=1.3,
+ second_order_term=1,
+ number_of_iterations=550)
+pars = out2[-1]
+reg_output.append(out2)
+
+a=fig.add_subplot(2,3,6)
+a.set_title('TGV_PD')
+textstr = 'regularization_parameter=%.2f\nfirst_order_term=%.2f\nsecond_order_term=%.2f\nnumber_of_iterations=%d'
+textstr = textstr % (pars['regularization_parameter'],
+ pars['first_order_term'],
+ pars['second_order_term'],
+ pars['number_of_iterations'])
+
+
+
+
+# these are matplotlib.patch.Patch properties
+props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+# place a text box in upper left in axes coords
+a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+ verticalalignment='top', bbox=props)
+imgplot = plt.imshow(reg_output[-1][0])