summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorepapoutsellis <epapoutsellis@gmail.com>2019-06-26 19:23:35 +0100
committerepapoutsellis <epapoutsellis@gmail.com>2019-06-26 19:23:35 +0100
commit62ae90eff4ef83c2384dcf856b593b1d117c7e49 (patch)
tree26f379aeab4712d656b80681cace853557f266dd
parentb25bec22c9fc71d7416a799ca5a3fdd87f76d654 (diff)
downloadframework-62ae90eff4ef83c2384dcf856b593b1d117c7e49.tar.gz
framework-62ae90eff4ef83c2384dcf856b593b1d117c7e49.tar.bz2
framework-62ae90eff4ef83c2384dcf856b593b1d117c7e49.tar.xz
framework-62ae90eff4ef83c2384dcf856b593b1d117c7e49.zip
fix cgls
-rwxr-xr-xWrappers/Python/ccpi/optimisation/algorithms/CGLS.py4
-rw-r--r--Wrappers/Python/demos/CGLS_examples/CGLS_Tomography.py5
-rw-r--r--Wrappers/Python/demos/CompareAlgorithms/CGLS_FISTA_PDHG_LeastSquares.py46
3 files changed, 26 insertions, 29 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py b/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py
index 6cf4f06..1695a73 100755
--- a/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py
@@ -161,12 +161,12 @@ class CGLS(Algorithm):
def should_stop(self):
+ self.update_objective()
flag = (self.norms <= self.norms0 * self.tolerance) or (self.normx * self.tolerance >= 1);
#if self.gamma<=self.tolerance:
if flag == 1 or self.max_iteration_stop_cryterion():
- print('Tolerance is reached: Iter: {}'.format(self.iteration))
- self.update_objective()
+ print('Tolerance is reached: Iter: {}'.format(self.iteration))
return True
diff --git a/Wrappers/Python/demos/CGLS_examples/CGLS_Tomography.py b/Wrappers/Python/demos/CGLS_examples/CGLS_Tomography.py
index e8f4063..2ac050f 100644
--- a/Wrappers/Python/demos/CGLS_examples/CGLS_Tomography.py
+++ b/Wrappers/Python/demos/CGLS_examples/CGLS_Tomography.py
@@ -50,7 +50,7 @@ import os
# Load Shepp-Logan phantom
model = 1 # select a model number from the library
-N = 128 # set dimension of the phantom
+N = 64 # set dimension of the phantom
path = os.path.dirname(tomophantom.__file__)
path_library2D = os.path.join(path, "Phantom2DLibrary.dat")
phantom_2D = TomoP2D.Model(model, N, path_library2D)
@@ -75,7 +75,6 @@ sin = Aop.direct(data)
np.random.seed(10)
noisy_data = AcquisitionData( sin.as_array() + np.random.normal(0,1,ag.shape))
-#noisy_data = AcquisitionData( sin.as_array() )
# Show Ground Truth and Noisy Data
plt.figure(figsize=(10,10))
@@ -89,7 +88,6 @@ plt.title('Noisy Data')
plt.colorbar()
plt.show()
-
# Setup and run the simple CGLS algorithm
x_init = ig.allocate()
@@ -98,7 +96,6 @@ cgls1.max_iteration = 20
cgls1.update_objective_interval = 5
cgls1.run(20, verbose = True)
-
# Setup and run the regularised CGLS algorithm (Tikhonov with Identity)
x_init = ig.allocate()
diff --git a/Wrappers/Python/demos/CompareAlgorithms/CGLS_FISTA_PDHG_LeastSquares.py b/Wrappers/Python/demos/CompareAlgorithms/CGLS_FISTA_PDHG_LeastSquares.py
index 568df38..09e350b 100644
--- a/Wrappers/Python/demos/CompareAlgorithms/CGLS_FISTA_PDHG_LeastSquares.py
+++ b/Wrappers/Python/demos/CompareAlgorithms/CGLS_FISTA_PDHG_LeastSquares.py
@@ -20,8 +20,8 @@
#=========================================================================
"""
-Compare solutions of FISTA & PDHG
- & CGLS & Astra Built-in algorithms for Least Squares
+Compare solutions of FISTA & PDHG & CGLS
+ & Astra Built-in algorithms for Least Squares
Problem: min_x || A x - g ||_{2}^{2}
@@ -32,7 +32,7 @@ Problem: min_x || A x - g ||_{2}^{2}
"""
-from ccpi.framework import ImageData, AcquisitionData, ImageGeometry, AcquisitionGeometry
+from ccpi.framework import ImageData, ImageGeometry, AcquisitionGeometry
import numpy as np
import numpy
@@ -71,10 +71,9 @@ else:
dev = 'cpu'
Aop = AstraProjectorSimple(ig, ag, dev)
-sin = Aop.direct(data)
+sinogram = Aop.direct(data)
+
-np.random.seed(10)
-noisy_data = AcquisitionData( sin.as_array() + np.random.normal(0,1,ag.shape))
###############################################################################
# Setup and run Astra CGLS algorithm
@@ -82,10 +81,10 @@ vol_geom = astra.create_vol_geom(N, N)
proj_geom = astra.create_proj_geom('parallel', 1.0, detectors, angles)
proj_id = astra.create_projector('line', proj_geom, vol_geom)
-# Create a sinogram from a phantom
-sinogram_id = astra.data2d.create('-sino', proj_geom, noisy_data.as_array())
+# Create a sinogram id
+sinogram_id = astra.data2d.create('-sino', proj_geom, sinogram.as_array())
-# Create a data object for the reconstruction
+# Create a data id
rec_id = astra.data2d.create('-vol', vol_geom)
cgls_astra = astra.astra_dict('CGLS')
@@ -96,7 +95,7 @@ cgls_astra['ProjectorId'] = proj_id
# Create the algorithm object from the configuration structure
alg_id = astra.algorithm.create(cgls_astra)
-astra.algorithm.run(alg_id, 25)
+astra.algorithm.run(alg_id, 500)
recon_cgls_astra = ImageData(astra.data2d.get(rec_id))
@@ -105,41 +104,42 @@ recon_cgls_astra = ImageData(astra.data2d.get(rec_id))
# Setup and run the simple CGLS algorithm
x_init = ig.allocate()
-cgls = CGLS(x_init = x_init, operator = Aop, data = noisy_data)
-cgls.max_iteration = 25
-cgls.update_objective_interval = 5
-cgls.run(25, verbose = True)
+cgls = CGLS(x_init = x_init, operator = Aop, data = sinogram)
+cgls.max_iteration = 500
+cgls.update_objective_interval = 100
+cgls.run(500, verbose = True)
#%%
###############################################################################
# Setup and run the PDHG algorithm
operator = Aop
-f = L2NormSquared(b = noisy_data)
+f = L2NormSquared(b = sinogram)
g = ZeroFunction()
## Compute operator Norm
normK = operator.norm()
## Primal & dual stepsizes
-sigma = 1
+sigma = 0.02
tau = 1/(sigma*normK**2)
+
pdhg = PDHG(f=f,g=g,operator=operator, tau=tau, sigma=sigma, memopt=True)
-pdhg.max_iteration = 400
-pdhg.update_objective_interval = 20
-pdhg.run(400, verbose=True)
+pdhg.max_iteration = 1000
+pdhg.update_objective_interval = 100
+pdhg.run(1000, verbose=True)
#%%
###############################################################################
# Setup and run the FISTA algorithm
-fidelity = FunctionOperatorComposition(L2NormSquared(b=noisy_data), Aop)
+fidelity = FunctionOperatorComposition(L2NormSquared(b=sinogram), Aop)
regularizer = ZeroFunction()
fista = FISTA(x_init=x_init , f=fidelity, g=regularizer)
-fista.max_iteration = 20
-fista.update_objective_interval = 5
-fista.run(20, verbose = True)
+fista.max_iteration = 500
+fista.update_objective_interval = 100
+fista.run(500, verbose = True)
#%% Show results