diff options
author | epapoutsellis <epapoutsellis@gmail.com> | 2019-06-26 19:23:35 +0100 |
---|---|---|
committer | epapoutsellis <epapoutsellis@gmail.com> | 2019-06-26 19:23:35 +0100 |
commit | 62ae90eff4ef83c2384dcf856b593b1d117c7e49 (patch) | |
tree | 26f379aeab4712d656b80681cace853557f266dd /Wrappers/Python | |
parent | b25bec22c9fc71d7416a799ca5a3fdd87f76d654 (diff) | |
download | framework-62ae90eff4ef83c2384dcf856b593b1d117c7e49.tar.gz framework-62ae90eff4ef83c2384dcf856b593b1d117c7e49.tar.bz2 framework-62ae90eff4ef83c2384dcf856b593b1d117c7e49.tar.xz framework-62ae90eff4ef83c2384dcf856b593b1d117c7e49.zip |
fix cgls
Diffstat (limited to 'Wrappers/Python')
3 files changed, 26 insertions, 29 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py b/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py index 6cf4f06..1695a73 100755 --- a/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py @@ -161,12 +161,12 @@ class CGLS(Algorithm): def should_stop(self): + self.update_objective() flag = (self.norms <= self.norms0 * self.tolerance) or (self.normx * self.tolerance >= 1); #if self.gamma<=self.tolerance: if flag == 1 or self.max_iteration_stop_cryterion(): - print('Tolerance is reached: Iter: {}'.format(self.iteration)) - self.update_objective() + print('Tolerance is reached: Iter: {}'.format(self.iteration)) return True diff --git a/Wrappers/Python/demos/CGLS_examples/CGLS_Tomography.py b/Wrappers/Python/demos/CGLS_examples/CGLS_Tomography.py index e8f4063..2ac050f 100644 --- a/Wrappers/Python/demos/CGLS_examples/CGLS_Tomography.py +++ b/Wrappers/Python/demos/CGLS_examples/CGLS_Tomography.py @@ -50,7 +50,7 @@ import os # Load Shepp-Logan phantom model = 1 # select a model number from the library -N = 128 # set dimension of the phantom +N = 64 # set dimension of the phantom path = os.path.dirname(tomophantom.__file__) path_library2D = os.path.join(path, "Phantom2DLibrary.dat") phantom_2D = TomoP2D.Model(model, N, path_library2D) @@ -75,7 +75,6 @@ sin = Aop.direct(data) np.random.seed(10) noisy_data = AcquisitionData( sin.as_array() + np.random.normal(0,1,ag.shape)) -#noisy_data = AcquisitionData( sin.as_array() ) # Show Ground Truth and Noisy Data plt.figure(figsize=(10,10)) @@ -89,7 +88,6 @@ plt.title('Noisy Data') plt.colorbar() plt.show() - # Setup and run the simple CGLS algorithm x_init = ig.allocate() @@ -98,7 +96,6 @@ cgls1.max_iteration = 20 cgls1.update_objective_interval = 5 cgls1.run(20, verbose = True) - # Setup and run the regularised CGLS algorithm (Tikhonov with Identity) x_init = ig.allocate() diff --git a/Wrappers/Python/demos/CompareAlgorithms/CGLS_FISTA_PDHG_LeastSquares.py b/Wrappers/Python/demos/CompareAlgorithms/CGLS_FISTA_PDHG_LeastSquares.py index 568df38..09e350b 100644 --- a/Wrappers/Python/demos/CompareAlgorithms/CGLS_FISTA_PDHG_LeastSquares.py +++ b/Wrappers/Python/demos/CompareAlgorithms/CGLS_FISTA_PDHG_LeastSquares.py @@ -20,8 +20,8 @@ #========================================================================= """ -Compare solutions of FISTA & PDHG - & CGLS & Astra Built-in algorithms for Least Squares +Compare solutions of FISTA & PDHG & CGLS + & Astra Built-in algorithms for Least Squares Problem: min_x || A x - g ||_{2}^{2} @@ -32,7 +32,7 @@ Problem: min_x || A x - g ||_{2}^{2} """ -from ccpi.framework import ImageData, AcquisitionData, ImageGeometry, AcquisitionGeometry +from ccpi.framework import ImageData, ImageGeometry, AcquisitionGeometry import numpy as np import numpy @@ -71,10 +71,9 @@ else: dev = 'cpu' Aop = AstraProjectorSimple(ig, ag, dev) -sin = Aop.direct(data) +sinogram = Aop.direct(data) + -np.random.seed(10) -noisy_data = AcquisitionData( sin.as_array() + np.random.normal(0,1,ag.shape)) ############################################################################### # Setup and run Astra CGLS algorithm @@ -82,10 +81,10 @@ vol_geom = astra.create_vol_geom(N, N) proj_geom = astra.create_proj_geom('parallel', 1.0, detectors, angles) proj_id = astra.create_projector('line', proj_geom, vol_geom) -# Create a sinogram from a phantom -sinogram_id = astra.data2d.create('-sino', proj_geom, noisy_data.as_array()) +# Create a sinogram id +sinogram_id = astra.data2d.create('-sino', proj_geom, sinogram.as_array()) -# Create a data object for the reconstruction +# Create a data id rec_id = astra.data2d.create('-vol', vol_geom) cgls_astra = astra.astra_dict('CGLS') @@ -96,7 +95,7 @@ cgls_astra['ProjectorId'] = proj_id # Create the algorithm object from the configuration structure alg_id = astra.algorithm.create(cgls_astra) -astra.algorithm.run(alg_id, 25) +astra.algorithm.run(alg_id, 500) recon_cgls_astra = ImageData(astra.data2d.get(rec_id)) @@ -105,41 +104,42 @@ recon_cgls_astra = ImageData(astra.data2d.get(rec_id)) # Setup and run the simple CGLS algorithm x_init = ig.allocate() -cgls = CGLS(x_init = x_init, operator = Aop, data = noisy_data) -cgls.max_iteration = 25 -cgls.update_objective_interval = 5 -cgls.run(25, verbose = True) +cgls = CGLS(x_init = x_init, operator = Aop, data = sinogram) +cgls.max_iteration = 500 +cgls.update_objective_interval = 100 +cgls.run(500, verbose = True) #%% ############################################################################### # Setup and run the PDHG algorithm operator = Aop -f = L2NormSquared(b = noisy_data) +f = L2NormSquared(b = sinogram) g = ZeroFunction() ## Compute operator Norm normK = operator.norm() ## Primal & dual stepsizes -sigma = 1 +sigma = 0.02 tau = 1/(sigma*normK**2) + pdhg = PDHG(f=f,g=g,operator=operator, tau=tau, sigma=sigma, memopt=True) -pdhg.max_iteration = 400 -pdhg.update_objective_interval = 20 -pdhg.run(400, verbose=True) +pdhg.max_iteration = 1000 +pdhg.update_objective_interval = 100 +pdhg.run(1000, verbose=True) #%% ############################################################################### # Setup and run the FISTA algorithm -fidelity = FunctionOperatorComposition(L2NormSquared(b=noisy_data), Aop) +fidelity = FunctionOperatorComposition(L2NormSquared(b=sinogram), Aop) regularizer = ZeroFunction() fista = FISTA(x_init=x_init , f=fidelity, g=regularizer) -fista.max_iteration = 20 -fista.update_objective_interval = 5 -fista.run(20, verbose = True) +fista.max_iteration = 500 +fista.update_objective_interval = 100 +fista.run(500, verbose = True) #%% Show results |