diff options
Diffstat (limited to 'src/Python')
-rw-r--r-- | src/Python/ccpi/fista/FISTAReconstructor.py | 165 | ||||
-rw-r--r-- | src/Python/test_reconstructor.py | 25 |
2 files changed, 165 insertions, 25 deletions
diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py index 33e67a3..fda9cf0 100644 --- a/src/Python/ccpi/fista/FISTAReconstructor.py +++ b/src/Python/ccpi/fista/FISTAReconstructor.py @@ -85,6 +85,7 @@ class FISTAReconstructor(): self.pars['detectors'] = detectors self.pars['number_of_angles'] = nangles self.pars['SlicesZ'] = sliceZ + self.pars['output_volume'] = None print (self.pars) # handle optional input parameters (at instantiation) @@ -108,7 +109,11 @@ class FISTAReconstructor(): 'regularizer' , 'ring_lambda_R_L1', 'ring_alpha', - 'subsets') + 'subsets', + 'output_volume', + 'os_subsets', + 'os_indices', + 'os_bins') self.acceptedInputKeywords = list(kw) # handle keyworded parameters @@ -176,8 +181,6 @@ class FISTAReconstructor(): ''' for key , value in kwargs.items(): if key in self.acceptedInputKeywords: - if key == 'use_studentt_fidelity': - raise Exception('use_studentt_fidelity Not implemented') self.pars[key] = value else: raise Exception('Wrong parameter {0} for '.format(key) + @@ -382,11 +385,15 @@ class FISTAReconstructor(): counter = counter + binsDiscr[jj] - 1 - - return IndicesReorg + # store the OS in parameters + self.setParameter(os_subsets=subsets, + os_bins=binsDiscr, + os_indices=IndicesReorg) def prepareForIteration(self): + print ("FISTA Reconstructor: prepare for iteration") + self.residual_error = numpy.zeros((self.pars['number_of_iterations'])) self.objective = numpy.zeros((self.pars['number_of_iterations'])) @@ -401,19 +408,17 @@ class FISTAReconstructor(): if self.getParameter('Lipschitz_constant') is None: self.pars['Lipschitz_constant'] = \ self.calculateLipschitzConstantWithPowerMethod() + # errors vector (if the ground truth is given) + self.Resid_error = numpy.zeros((self.getParameter('number_of_iterations'))); + # objective function values vector + self.objective = numpy.zeros((self.getParameter('number_of_iterations'))); # prepareForIteration def iterate(self, Xin=None): - # convenience variable storage - proj_geom , vol_geom, sino , \ - SlicesZ = self.getParameter([ 'projector_geometry' , - 'output_geometry', - 'input_sinogram', - 'SlicesZ']) - - t = 1 + print ("FISTA Reconstructor: iterate") + if Xin is None: if self.getParameter('initialize'): X = self.initialize() @@ -423,15 +428,25 @@ class FISTAReconstructor(): else: # copy by reference X = Xin - + # store the output volume in the parameters + self.setParameter(output_volume=X) X_t = X.copy() + # convenience variable storage + proj_geom , vol_geom, sino , \ + SlicesZ = self.getParameter([ 'projector_geometry' , + 'output_geometry', + 'input_sinogram', + 'SlicesZ' ]) + + t = 1 for i in range(self.getParameter('number_of_iterations')): X_old = X.copy() t_old = t r_old = self.r.copy() if self.getParameter('projector_geometry')['type'] == 'parallel' or \ - self.getParameter('projector_geometry')['type'] == 'parallel3d': + self.getParameter('projector_geometry')['type'] == 'fanflat' or \ + self.getParameter('projector_geometry')['type'] == 'fanflat_vec': # if the geometry is parallel use slice-by-slice # projection-backprojection routine #sino_updt = zeros(size(sino),'single'); @@ -439,10 +454,9 @@ class FISTAReconstructor(): proj_geomT['DetectorRowCount'] = 1 vol_geomT = vol_geom.copy() vol_geomT['GridSliceCount'] = 1; - sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) + self.sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) for kkk in range(SlicesZ): - print (kkk) - sino_id, sino_updt[kkk] = \ + sino_id, self.sino_updt[kkk] = \ astra.creators.create_sino3d_gpu( X_t[kkk:kkk+1], proj_geomT, vol_geomT) astra.matlab.data3d('delete', sino_id) @@ -450,11 +464,122 @@ class FISTAReconstructor(): # for divergent 3D geometry (watch the GPU memory overflow in # ASTRA versions < 1.8) #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); - sino_id, sino_updt = astra.matlab.create_sino3d_gpu( + sino_id, self.sino_updt = astra.creators.create_sino3d_gpu( X_t, proj_geom, vol_geom) ## RING REMOVAL - + self.ringRemoval(i) + ## Projection/Backprojection Routine + self.projectionBackprojection(X, X_t) + astra.matlab.data3d('delete', sino_id) ## REGULARIZATION + X = self.regularize(X) + ## Update Loop + X , X_t, t = self.updateLoop(i, X, X_old, r_old, t, t_old) + self.setParameter(output_volume=X) + return X + ## iterate + + def ringRemoval(self, i): + print ("FISTA Reconstructor: ring removal") + residual = self.residual + lambdaR_L1 , alpha_ring , weights , L_const , sino= \ + self.getParameter(['ring_lambda_R_L1', + 'ring_alpha' , 'weights', + 'Lipschitz_constant', + 'input_sinogram']) + r_x = self.r_x + sino_updt = self.sino_updt + + SlicesZ, anglesNumb, Detectors = \ + numpy.shape(self.getParameter('input_sinogram')) + if lambdaR_L1 > 0 : + for kkk in range(anglesNumb): + + residual[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ + ((sino_updt[:,kkk,:]).squeeze() - \ + (sino[:,kkk,:]).squeeze() -\ + (alpha_ring * r_x) + ) + vec = residual.sum(axis = 1) + #if SlicesZ > 1: + # vec = vec[:,1,:].squeeze() + self.r = (r_x - (1./L_const) * vec).copy() + self.objective[i] = (0.5 * (residual ** 2).sum()) + + def projectionBackprojection(self, X, X_t): + print ("FISTA Reconstructor: projection-backprojection routine") + + # a few useful variables + SlicesZ, anglesNumb, Detectors = \ + numpy.shape(self.getParameter('input_sinogram')) + residual = self.residual + proj_geom , vol_geom , L_const = \ + self.getParameter(['projector_geometry' , + 'output_geometry', + 'Lipschitz_constant']) + + + if self.getParameter('projector_geometry')['type'] == 'parallel' or \ + self.getParameter('projector_geometry')['type'] == 'fanflat' or \ + self.getParameter('projector_geometry')['type'] == 'fanflat_vec': + # if the geometry is parallel use slice-by-slice + # projection-backprojection routine + #sino_updt = zeros(size(sino),'single'); + proj_geomT = proj_geom.copy() + proj_geomT['DetectorRowCount'] = 1 + vol_geomT = vol_geom.copy() + vol_geomT['GridSliceCount'] = 1; + x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32) + + for kkk in range(SlicesZ): + + x_id, x_temp[kkk] = \ + astra.creators.create_backprojection3d_gpu( + residual[kkk:kkk+1], + proj_geomT, vol_geomT) + astra.matlab.data3d('delete', x_id) + else: + x_id, x_temp = \ + astra.creators.create_backprojection3d_gpu( + residual, proj_geom, vol_geom) + + X = X_t - (1/L_const) * x_temp + #astra.matlab.data3d('delete', sino_id) + astra.matlab.data3d('delete', x_id) + + def regularize(self, X): + print ("FISTA Reconstructor: regularize") + + regularizer = self.getParameter('regularizer') + if regularizer is not None: + return regularizer(input=X) + else: + return X + + def updateLoop(self, i, X, X_old, r_old, t, t_old): + print ("FISTA Reconstructor: update loop") + lambdaR_L1 = self.getParameter('ring_lambda_R_L1') + if lambdaR_L1 > 0: + self.r = numpy.max( + numpy.abs(self.r) - lambdaR_L1 , 0) * \ + numpy.sign(self.r) + t = (1 + numpy.sqrt(1 + 4 * t**2))/2 + X_t = X + (((t_old -1)/t) * (X - X_old)) + + if lambdaR_L1 > 0: + self.r_x = self.r + \ + (((t_old-1)/t) * (self.r - r_old)) + + if self.getParameter('region_of_interest') is None: + string = 'Iteration Number {0} | Objective {1} \n' + print (string.format( i, self.objective[i])) + else: + ROI , X_ideal = fistaRecon.getParameter('region_of_interest', + 'ideal_image') + Resid_error[i] = RMSE(X*ROI, X_ideal*ROI) + string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' + print (string.format(i,Resid_error[i], self.objective[i])) + return (X , X_t, t) diff --git a/src/Python/test_reconstructor.py b/src/Python/test_reconstructor.py index 2f188b4..07668ba 100644 --- a/src/Python/test_reconstructor.py +++ b/src/Python/test_reconstructor.py @@ -100,7 +100,7 @@ if False: counter = counter + binsDiscr[jj] - 1 -if True: +if False: print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) print ("prepare for iteration") fistaRecon.prepareForIteration() @@ -145,7 +145,8 @@ if True: t_old = t r_old = fistaRecon.r.copy() if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \ - fistaRecon.getParameter('projector_geometry')['type'] == 'parallel3d': + fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat' or \ + fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat_vec' : # if the geometry is parallel use slice-by-slice # projection-backprojection routine #sino_updt = zeros(size(sino),'single'); @@ -157,13 +158,13 @@ if True: for kkk in range(SlicesZ): sino_id, sino_updt[kkk] = \ astra.creators.create_sino3d_gpu( - X_t[kkk:kkk+1], proj_geomT, vol_geomT) + X_t[kkk:kkk+1], proj_geom, vol_geom) astra.matlab.data3d('delete', sino_id) else: # for divergent 3D geometry (watch the GPU memory overflow in # ASTRA versions < 1.8) #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); - sino_id, sino_updt = astra.matlab.create_sino3d_gpu( + sino_id, sino_updt = astra.creators.create_sino3d_gpu( X_t, proj_geom, vol_geom) ## RING REMOVAL @@ -206,7 +207,8 @@ if True: # Projection/Backprojection Routine if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \ - fistaRecon.getParameter('projector_geometry')['type'] == 'parallel3d': + fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat' or\ + fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat_vec': x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32) print ("Projection/Backprojection Routine") for kkk in range(SlicesZ): @@ -284,3 +286,16 @@ if True: ## else ## fprintf('%s %i %s %s %f \n', 'Iteration Number:', i, '|', 'Objective:', objective(i)); ## end +else: + fistaRecon = FISTAReconstructor(proj_geom, + vol_geom, + Sino3D , + weights=Weights3D) + + print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) + fistaRecon.setParameter(number_of_iterations = 12) + fistaRecon.setParameter(Lipschitz_constant = 767893952.0) + fistaRecon.setParameter(ring_alpha = 21) + fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) + fistaRecon.prepareForIteration() + X = fistaRecon.iterate(numpy.load("X.npy")) |