summaryrefslogtreecommitdiffstats
path: root/Wrappers
diff options
context:
space:
mode:
authorepapoutsellis <epapoutsellis@gmail.com>2019-06-25 16:48:34 +0100
committerepapoutsellis <epapoutsellis@gmail.com>2019-06-25 16:48:34 +0100
commit7962f16b9f0da905fdb75ee54dc12760fe6994b8 (patch)
tree631a5f06cce98404e38eab45847fe9e178a32776 /Wrappers
parentd30584e78e5507839360278250c55280bbc2a18e (diff)
downloadframework-7962f16b9f0da905fdb75ee54dc12760fe6994b8.tar.gz
framework-7962f16b9f0da905fdb75ee54dc12760fe6994b8.tar.bz2
framework-7962f16b9f0da905fdb75ee54dc12760fe6994b8.tar.xz
framework-7962f16b9f0da905fdb75ee54dc12760fe6994b8.zip
cgls_test
Diffstat (limited to 'Wrappers')
-rwxr-xr-xWrappers/Python/ccpi/optimisation/algorithms/CGLS.py138
-rw-r--r--Wrappers/Python/demos/CGLS_examples/CGLS_Tikhonov.py71
2 files changed, 136 insertions, 73 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py b/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py
index 15acc31..28b19f6 100755
--- a/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py
@@ -46,78 +46,94 @@ class CGLS(Algorithm):
self.set_up(x_init =kwargs['x_init'],
operator=kwargs['operator'],
data =kwargs['data'])
-
+
def set_up(self, x_init, operator , data ):
- self.r = data.copy()
- self.x = x_init * 0
-
- self.operator = operator
- self.d = operator.adjoint(self.r)
-
-
- self.normr2 = self.d.squared_norm()
+ self.x = x_init
+ self.r = data - self.operator.direct(self.x)
+ self.s = self.operator.adjoint(self.r)
- self.s = self.operator.domain_geometry().allocate()
- #if isinstance(self.normr2, Iterable):
- # self.normr2 = sum(self.normr2)
- #self.normr2 = numpy.sqrt(self.normr2)
- #print ("set_up" , self.normr2)
- n = Norm2Sq(operator, self.data)
- self.loss.append(n(x_init))
- self.configured = True
+ self.p = self.s
+ self.norms0 = self.s.norm()
+ self.gamma = self.norms0**2
+ self.normx = self.x.norm()
+ self.xmax = self.normx
+ self.configured = True
+
+# def set_up(self, x_init, operator , data ):
+#
+# self.r = data.copy()
+# self.x = x_init * 0
+#
+# self.operator = operator
+# self.d = operator.adjoint(self.r)
+#
+#
+# self.normr2 = self.d.squared_norm()
+#
+# self.s = self.operator.domain_geometry().allocate()
+# #if isinstance(self.normr2, Iterable):
+# # self.normr2 = sum(self.normr2)
+# #self.normr2 = numpy.sqrt(self.normr2)
+# #print ("set_up" , self.normr2)
+# n = Norm2Sq(operator, self.data)
+# self.loss.append(n(x_init))
+# self.configured = True
def update(self):
self.update_new()
- def update_old(self):
- Ad = self.operator.direct(self.d)
- #norm = (Ad*Ad).sum()
- #if isinstance(norm, Iterable):
- # norm = sum(norm)
- norm = Ad.squared_norm()
- alpha = self.normr2/norm
- self.x += (self.d * alpha)
- self.r -= (Ad * alpha)
- s = self.operator.adjoint(self.r)
-
- normr2_new = s.squared_norm()
- #if isinstance(normr2_new, Iterable):
- # normr2_new = sum(normr2_new)
- #normr2_new = numpy.sqrt(normr2_new)
- #print (normr2_new)
-
- beta = normr2_new/self.normr2
- self.normr2 = normr2_new
- self.d = s + beta*self.d
-
def update_new(self):
-
- Ad = self.operator.direct(self.d)
- norm = Ad.squared_norm()
- if norm == 0.:
- print ('norm = 0, cannot update solution')
- print ("self.d norm", self.d.squared_norm(), self.d.as_array())
- raise StopIteration()
- alpha = self.normr2/norm
- if alpha == 0.:
- print ('alpha = 0, cannot update solution')
- raise StopIteration()
- self.d *= alpha
- Ad *= alpha
- self.r -= Ad
- self.x += self.d
+ self.q = self.operator.direct(self.p)
+ delta = self.q.squared_norm()
+ alpha = self.gamma/delta
+
+ self.x += alpha * self.p
+ self.r -= alpha * self.q
+
+ self.s = self.operator.adjoint(self.r)
+
+ self.norms = self.s.norm()
+ self.gamma1 = self.gamma
+ self.gamma = self.norms**2
+ self.beta = self.gamma/self.gamma1
+ self.p = self.s + self.beta * self.p
+
+ self.normx = self.x.norm()
+ self.xmax = numpy.maximum(self.xmax, self.normx)
- self.operator.adjoint(self.r, out=self.s)
- s = self.s
-
- normr2_new = s.squared_norm()
- beta = normr2_new/self.normr2
- self.normr2 = normr2_new
- self.d *= (beta/alpha)
- self.d += s
+ if self.gamma<=1e-6:
+ raise StopIteration()
+# def update_new(self):
+#
+# Ad = self.operator.direct(self.d)
+# norm = Ad.squared_norm()
+#
+# if norm <= 1e-3:
+# print ('norm = 0, cannot update solution')
+# #print ("self.d norm", self.d.squared_norm(), self.d.as_array())
+# raise StopIteration()
+# alpha = self.normr2/norm
+# if alpha <= 1e-3:
+# print ('alpha = 0, cannot update solution')
+# raise StopIteration()
+# self.d *= alpha
+# Ad *= alpha
+# self.r -= Ad
+#
+# self.x += self.d
+#
+# self.operator.adjoint(self.r, out=self.s)
+# s = self.s
+#
+# normr2_new = s.squared_norm()
+#
+# beta = normr2_new/self.normr2
+# self.normr2 = normr2_new
+# self.d *= (beta/alpha)
+# self.d += s
def update_objective(self):
a = self.r.squared_norm()
diff --git a/Wrappers/Python/demos/CGLS_examples/CGLS_Tikhonov.py b/Wrappers/Python/demos/CGLS_examples/CGLS_Tikhonov.py
index 653e191..63a2254 100644
--- a/Wrappers/Python/demos/CGLS_examples/CGLS_Tikhonov.py
+++ b/Wrappers/Python/demos/CGLS_examples/CGLS_Tikhonov.py
@@ -47,8 +47,8 @@ from ccpi.astra.ops import AstraProjectorSimple
# Load Data
loader = TestData(data_dir=os.path.join(sys.prefix, 'share','ccpi'))
-N = 150
-M = 150
+N = 64
+M = 64
data = loader.load(TestData.SIMPLE_PHANTOM_2D, size=(N,M), scale=(0,1))
ig = data.geometry
@@ -82,20 +82,19 @@ plt.colorbar()
plt.show()
# Setup and run the CGLS algorithm
-#alpha = 50
-#Grad = Gradient(ig)
+alpha = 5
+Grad = Gradient(ig)
#
## Form Tikhonov as a Block CGLS structure
-#op_CGLS = BlockOperator( Aop, alpha * Grad, shape=(2,1))
-#block_data = BlockDataContainer(noisy_data, Grad.range_geometry().allocate())
+op_CGLS = BlockOperator( Aop, alpha * Grad, shape=(2,1))
+block_data = BlockDataContainer(noisy_data, Grad.range_geometry().allocate())
#
-#x_init = ig.allocate()
-#cgls = CGLS(x_init=x_init, operator=op_CGLS, data=block_data)
-#cgls.max_iteration = 1000
-#cgls.update_objective_interval = 200
-#cgls.run(1000,verbose=False)
+x_init = ig.allocate()
+cgls = CGLS(x_init=x_init, operator=op_CGLS, data=block_data)
+cgls.max_iteration = 1000
+cgls.update_objective_interval = 200
+cgls.run(1000,verbose=True)
-#%%
# Show results
plt.figure(figsize=(5,5))
plt.imshow(cgls.get_output().as_array())
@@ -103,4 +102,52 @@ plt.title('CGLS reconstruction')
plt.colorbar()
plt.show()
+#%%
+from ccpi.optimisation.operators import SparseFiniteDiff
+import astra
+import numpy
+
+try:
+ from cvxpy import *
+ cvx_not_installable = True
+except ImportError:
+ cvx_not_installable = False
+
+
+if cvx_not_installable:
+
+
+ ##Construct problem
+ u = Variable(N*M)
+ #q = Variable()
+
+ DY = SparseFiniteDiff(ig, direction=0, bnd_cond='Neumann')
+ DX = SparseFiniteDiff(ig, direction=1, bnd_cond='Neumann')
+
+ regulariser = alpha * sum_squares(norm(vstack([DX.matrix() * vec(u), DY.matrix() * vec(u)]), 2, axis = 0))
+
+ # create matrix representation for Astra operator
+
+ vol_geom = astra.create_vol_geom(N, N)
+ proj_geom = astra.create_proj_geom('parallel', 1.0, detectors, angles)
+
+ proj_id = astra.create_projector('strip', proj_geom, vol_geom)
+
+ matrix_id = astra.projector.matrix(proj_id)
+
+ ProjMat = astra.matrix.get(matrix_id)
+
+ fidelity = sum_squares( ProjMat * u - noisy_data.as_array().ravel())
+
+ solver = SCS
+ obj = Minimize( regulariser + fidelity)
+ prob = Problem(obj, constraints)
+ result = prob.solve(verbose = True, solver = solver)
+
+
+
+
+
+
+