summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2017-10-20 17:04:26 +0100
committerEdoardo Pasca <edo.paskino@gmail.com>2017-10-20 17:04:26 +0100
commit903175ed67f7645fa35edf4623b27999d6cb990f (patch)
tree17da596ffa3e4d139af15fb9454719ceac875734
parentd2ce1b74b4ecad5cdecb29207181e09ef0f6013a (diff)
downloadregularization-903175ed67f7645fa35edf4623b27999d6cb990f.tar.gz
regularization-903175ed67f7645fa35edf4623b27999d6cb990f.tar.bz2
regularization-903175ed67f7645fa35edf4623b27999d6cb990f.tar.xz
regularization-903175ed67f7645fa35edf4623b27999d6cb990f.zip
Further development
-rw-r--r--src/Python/ccpi/fista/FISTAReconstructor.py24
-rw-r--r--src/Python/test_reconstructor-os.py112
2 files changed, 81 insertions, 55 deletions
diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py
index fda9cf0..85bfac5 100644
--- a/src/Python/ccpi/fista/FISTAReconstructor.py
+++ b/src/Python/ccpi/fista/FISTAReconstructor.py
@@ -583,3 +583,27 @@ class FISTAReconstructor():
string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n'
print (string.format(i,Resid_error[i], self.objective[i]))
return (X , X_t, t)
+
+ def os_iterate(self, Xin=None):
+ print ("FISTA Reconstructor: iterate")
+
+ if Xin is None:
+ if self.getParameter('initialize'):
+ X = self.initialize()
+ else:
+ N = vol_geom['GridColCount']
+ X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float)
+ else:
+ # copy by reference
+ X = Xin
+ # store the output volume in the parameters
+ self.setParameter(output_volume=X)
+ X_t = X.copy()
+
+ # some useful constants
+ proj_geom , vol_geom, sino , \
+ SlicesZ, weights , alpha_ring ,
+ lambdaR_L1 , L_const = self.getParameter(
+ ['projector_geometry' , 'output_geometry',
+ 'input_sinogram', 'SlicesZ' , 'weights', 'ring_alpha' ,
+ 'ring_lambda_R_L1', 'Lipschitz_constant'])
diff --git a/src/Python/test_reconstructor-os.py b/src/Python/test_reconstructor-os.py
index f6d7d4b..aee70a4 100644
--- a/src/Python/test_reconstructor-os.py
+++ b/src/Python/test_reconstructor-os.py
@@ -122,10 +122,13 @@ if True:
X_t = X.copy()
print ("initialized")
proj_geom , vol_geom, sino , \
- SlicesZ = fistaRecon.getParameter(['projector_geometry' ,
- 'output_geometry',
- 'input_sinogram',
- 'SlicesZ'])
+ SlicesZ, weights , alpha_ring = fistaRecon.getParameter(
+ ['projector_geometry' , 'output_geometry',
+ 'input_sinogram', 'SlicesZ' , 'weights', 'ring_alpha'])
+ lambdaR_L1 , alpha_ring , weights , L_const= \
+ fistaRecon.getParameter(['ring_lambda_R_L1',
+ 'ring_alpha' , 'weights',
+ 'Lipschitz_constant'])
#fistaRecon.setParameter(number_of_iterations = 3)
iterFISTA = fistaRecon.getParameter('number_of_iterations')
@@ -136,12 +139,13 @@ if True:
t = 1
+
## additional for
proj_geomSUB = proj_geom.copy()
fistaRecon.residual2 = numpy.zeros(numpy.shape(fistaRecon.pars['input_sinogram']))
residual2 = fistaRecon.residual2
- sino_updt_FULL = residual.copy()
+ sino_updt_FULL = fistaRecon.residual.copy()
print ("starting iterations")
## % Outer FISTA iterations loop
@@ -156,7 +160,8 @@ if True:
# hence additional work is required one solution is to work with a full
# sinogram at times
-
+ r_old = fistaRecon.r.copy()
+ t_old = t
SlicesZ, anglesNumb, Detectors = \
numpy.shape(fistaRecon.getParameter('input_sinogram')) ## https://github.com/vais-ral/CCPi-FISTA_Reconstruction/issues/4
if (i > 1 and lambdaR_L1 > 0) :
@@ -167,8 +172,8 @@ if True:
(sino[:,kkk,:]).squeeze() -\
(alpha_ring * r_x)
)
- r_old = fistaRecon.r.copy()
- vec = residual.sum(axis = 1)
+
+ vec = fistaRecon.residual.sum(axis = 1)
#if SlicesZ > 1:
# vec = vec[:,1,:] # 1 or 0?
r_x = fistaRecon.r_x
@@ -227,56 +232,53 @@ if True:
- ## RING REMOVAL
- residual = fistaRecon.residual
-
- lambdaR_L1 , alpha_ring , weights , L_const= \
- fistaRecon.getParameter(['ring_lambda_R_L1',
- 'ring_alpha' , 'weights',
- 'Lipschitz_constant'])
- if lambdaR_L1 > 0 :
- print ("ring removal")
- residualSub = numpy.zeros(shape)
-## for a chosen subset
-## for kkk = 1:numProjSub
-## indC = CurrSubIndeces(kkk);
-## residualSub(:,kkk,:) = squeeze(weights(:,indC,:)).*(squeeze(sino_updt_Sub(:,kkk,:)) - (squeeze(sino(:,indC,:)) - alpha_ring.*r_x));
-## sino_updt_FULL(:,indC,:) = squeeze(sino_updt_Sub(:,kkk,:)); % filling the full sinogram
-## end
- for kkk in range(numProjSub):
- indC = CurrSubIndices[kkk]
- residualSub[:,kkk,:] = weights[:,indC,:].squeeze() * \
- (sino_updt_Sub[:,kkk,:].squeeze() - \
- sino[:,indC,:].squeeze() - alpha_ring * r_x)
- # filling the full sinogram
- sino_updt_FULL[:,indC,:] = sino_updt_Sub[:,kkk,:].squeeze()
+ ## RING REMOVAL
+ residual = fistaRecon.residual
+
+
+ if lambdaR_L1 > 0 :
+ print ("ring removal")
+ residualSub = numpy.zeros(shape)
+ ## for a chosen subset
+ ## for kkk = 1:numProjSub
+ ## indC = CurrSubIndeces(kkk);
+ ## residualSub(:,kkk,:) = squeeze(weights(:,indC,:)).*(squeeze(sino_updt_Sub(:,kkk,:)) - (squeeze(sino(:,indC,:)) - alpha_ring.*r_x));
+ ## sino_updt_FULL(:,indC,:) = squeeze(sino_updt_Sub(:,kkk,:)); % filling the full sinogram
+ ## end
+ for kkk in range(numProjSub):
+ indC = CurrSubIndices[kkk]
+ residualSub[:,kkk,:] = weights[:,indC,:].squeeze() * \
+ (sino_updt_Sub[:,kkk,:].squeeze() - \
+ sino[:,indC,:].squeeze() - alpha_ring * r_x)
+ # filling the full sinogram
+ sino_updt_FULL[:,indC,:] = sino_updt_Sub[:,kkk,:].squeeze()
- else:
- #PWLS model
- residualSub = weights[:,CurrSubIndices,:] * \
- ( sino_updt_Sub - sino[:,CurrSubIndices,:].squeeze() )
- objective[i] = 0.5 * numpy.linalg.norm(residualSub)
+ else:
+ #PWLS model
+ residualSub = weights[:,CurrSubIndices,:] * \
+ ( sino_updt_Sub - sino[:,CurrSubIndices,:].squeeze() )
+ objective[i] = 0.5 * numpy.linalg.norm(residualSub)
- if geometry_type == 'parallel' or \
- geometry_type == 'fanflat' or \
- geometry_type == 'fanflat_vec' :
- # if geometry is 2D use slice-by-slice projection-backprojection
- # routine
- x_temp = numpy.zeros(numpy.shape(X), dtype=numpy.float32)
- for kkk in range(SlicesZ):
-
- x_id, x_temp[kkk] = \
- astra.creators.create_backprojection3d_gpu(
- residualSub[kkk:kkk+1],
- proj_geomSUB, vol_geom)
-
- else:
- x_id, x_temp = \
- astra.creators.create_backprojection3d_gpu(
- residualSub, proj_geomSUB, vol_geom)
+ if geometry_type == 'parallel' or \
+ geometry_type == 'fanflat' or \
+ geometry_type == 'fanflat_vec' :
+ # if geometry is 2D use slice-by-slice projection-backprojection
+ # routine
+ x_temp = numpy.zeros(numpy.shape(X), dtype=numpy.float32)
+ for kkk in range(SlicesZ):
+
+ x_id, x_temp[kkk] = \
+ astra.creators.create_backprojection3d_gpu(
+ residualSub[kkk:kkk+1],
+ proj_geomSUB, vol_geom)
+
+ else:
+ x_id, x_temp = \
+ astra.creators.create_backprojection3d_gpu(
+ residualSub, proj_geomSUB, vol_geom)
- astra.matlab.data3d('delete', x_id)
- X = X_t - (1/L_const) * x_temp
+ astra.matlab.data3d('delete', x_id)
+ X = X_t - (1/L_const) * x_temp