diff options
author | Edoardo Pasca <edo.paskino@gmail.com> | 2017-10-24 11:26:46 +0100 |
---|---|---|
committer | Edoardo Pasca <edo.paskino@gmail.com> | 2017-10-24 11:26:46 +0100 |
commit | a11c59651ec125e24371a2049606df0f80f458d0 (patch) | |
tree | 0f16ebde08333da13490eda663444357c621cd49 | |
parent | 4fd4f187a70c0e4f56d5194b09ab4a528d20ee51 (diff) | |
download | regularization-a11c59651ec125e24371a2049606df0f80f458d0.tar.gz regularization-a11c59651ec125e24371a2049606df0f80f458d0.tar.bz2 regularization-a11c59651ec125e24371a2049606df0f80f458d0.tar.xz regularization-a11c59651ec125e24371a2049606df0f80f458d0.zip |
latest dev
-rw-r--r-- | src/Python/ccpi/reconstruction/FISTAReconstructor.py | 599 |
1 files changed, 427 insertions, 172 deletions
diff --git a/src/Python/ccpi/reconstruction/FISTAReconstructor.py b/src/Python/ccpi/reconstruction/FISTAReconstructor.py index ea96b53..85bfac5 100644 --- a/src/Python/ccpi/reconstruction/FISTAReconstructor.py +++ b/src/Python/ccpi/reconstruction/FISTAReconstructor.py @@ -21,10 +21,9 @@ import numpy -import h5py #from ccpi.reconstruction.parallelbeam import alg -from ccpi.imaging.Regularizer import Regularizer +#from ccpi.imaging.Regularizer import Regularizer from enum import Enum import astra @@ -74,18 +73,34 @@ class FISTAReconstructor(): # 3. "A novel tomographic reconstruction method based on the robust # Student's t function for suppressing data outliers" D. Kazantsev et.al. # D. Kazantsev, 2016-17 - def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): - self.params = dict() - self.params['projector_geometry'] = projector_geometry - self.params['output_geometry'] = output_geometry - self.params['input_sinogram'] = input_sinogram - detectors, nangles, sliceZ = numpy.shape(input_sinogram) - self.params['detectors'] = detectors - self.params['number_og_angles'] = nangles - self.params['SlicesZ'] = sliceZ + def __init__(self, projector_geometry, output_geometry, input_sinogram, + **kwargs): + # handle parmeters: + # obligatory parameters + self.pars = dict() + self.pars['projector_geometry'] = projector_geometry # proj_geom + self.pars['output_geometry'] = output_geometry # vol_geom + self.pars['input_sinogram'] = input_sinogram # sino + sliceZ, nangles, detectors = numpy.shape(input_sinogram) + self.pars['detectors'] = detectors + self.pars['number_of_angles'] = nangles + self.pars['SlicesZ'] = sliceZ + self.pars['output_volume'] = None + + print (self.pars) + # handle optional input parameters (at instantiation) # Accepted input keywords - kw = ('number_of_iterations', + kw = ( + # mandatory fields + 'projector_geometry', + 'output_geometry', + 'input_sinogram', + 'detectors', + 'number_of_angles', + 'SlicesZ', + # optional fields + 'number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , 'weights' , @@ -93,7 +108,13 @@ class FISTAReconstructor(): 'initialize' , 'regularizer' , 'ring_lambda_R_L1', - 'ring_alpha') + 'ring_alpha', + 'subsets', + 'output_volume', + 'os_subsets', + 'os_indices', + 'os_bins') + self.acceptedInputKeywords = list(kw) # handle keyworded parameters if kwargs is not None: @@ -110,85 +131,160 @@ class FISTAReconstructor(): if 'weights' in kwargs.keys(): self.pars['weights'] = kwargs['weights'] else: - self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) + self.pars['weights'] = \ + numpy.ones(numpy.shape( + self.pars['input_sinogram'])) if 'Lipschitz_constant' in kwargs.keys(): self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] else: - self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + self.pars['Lipschitz_constant'] = None - if not self.pars['ideal_image'] in kwargs.keys(): + if not 'ideal_image' in kwargs.keys(): self.pars['ideal_image'] = None - if not self.pars['region_of_interest'] : + if not 'region_of_interest'in kwargs.keys() : if self.pars['ideal_image'] == None: - pass + self.pars['region_of_interest'] = None else: - self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) - - if not self.pars['regularizer'] : + ## nonzero if the image is larger than m + fsm = numpy.frompyfunc(lambda x,m: 1 if x>m else 0, 2,1) + + self.pars['region_of_interest'] = fsm(self.pars['ideal_image'], 0) + + # the regularizer must be a correctly instantiated object + if not 'regularizer' in kwargs.keys() : self.pars['regularizer'] = None + + #RING REMOVAL + if not 'ring_lambda_R_L1' in kwargs.keys(): + self.pars['ring_lambda_R_L1'] = 0 + if not 'ring_alpha' in kwargs.keys(): + self.pars['ring_alpha'] = 1 + + # ORDERED SUBSET + if not 'subsets' in kwargs.keys(): + self.pars['subsets'] = 0 else: - # the regularizer must be a correctly instantiated object - if not self.pars['ring_lambda_R_L1']: - self.pars['ring_lambda_R_L1'] = 0 - if not self.pars['ring_alpha']: - self.pars['ring_alpha'] = 1 + self.createOrderedSubsets() + + if not 'initialize' in kwargs.keys(): + self.pars['initialize'] = False + + def setParameter(self, **kwargs): + '''set named parameter for the reconstructor engine + + raises Exception if the named parameter is not recognized + ''' + for key , value in kwargs.items(): + if key in self.acceptedInputKeywords: + self.pars[key] = value + else: + raise Exception('Wrong parameter {0} for '.format(key) + + 'reconstructor') + # setParameter + + def getParameter(self, key): + if type(key) is str: + if key in self.acceptedInputKeywords: + return self.pars[key] + else: + raise Exception('Unrecongnised parameter: {0} '.format(key) ) + elif type(key) is list: + outpars = [] + for k in key: + outpars.append(self.getParameter(k)) + return outpars + else: + raise Exception('Unhandled input {0}' .format(str(type(key)))) + + def calculateLipschitzConstantWithPowerMethod(self): ''' using Power method (PM) to establish L constant''' - #N = params.vol_geom.GridColCount - N = self.pars['output_geometry'].GridColCount - proj_geom = self.params['projector_geometry'] - vol_geom = self.params['output_geometry'] + N = self.pars['output_geometry']['GridColCount'] + proj_geom = self.pars['projector_geometry'] + vol_geom = self.pars['output_geometry'] weights = self.pars['weights'] SlicesZ = self.pars['SlicesZ'] - if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + + + if (proj_geom['type'] == 'parallel') or \ + (proj_geom['type'] == 'parallel3d'): #% for parallel geometry we can do just one slice - #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); - niter = 15;# % number of iteration for the PM + #print('Calculating Lipshitz constant for parallel beam geometry...') + niter = 5;# % number of iteration for the PM #N = params.vol_geom.GridColCount; #x1 = rand(N,N,1); x1 = numpy.random.rand(1,N,N) #sqweight = sqrt(weights(:,:,1)); - sqweight = numpy.sqrt(weights.T[0]) + sqweight = numpy.sqrt(weights[0]) proj_geomT = proj_geom.copy(); - proj_geomT.DetectorRowCount = 1; + proj_geomT['DetectorRowCount'] = 1; vol_geomT = vol_geom.copy(); vol_geomT['GridSliceCount'] = 1; + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + for i in range(niter): - if i == 0: - #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); - y = sqweight * y # element wise multiplication - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id) + # [id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geomT, vol_geomT); + # s = norm(x1(:)); + # x1 = x1/s; + # [sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + # y = sqweight.*y; + # astra_mex_data3d('delete', sino_id); + # astra_mex_data3d('delete', id); + #print ("iteration {0}".format(i)) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geomT, + vol_geomT) + + y = (sqweight * y).copy() # element wise multiplication + + #b=fig.add_subplot(2,1,2) + #imgplot = plt.imshow(x1[0]) + #plt.show() + + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + del x1 - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); + idx,x1 = astra.creators.create_backprojection3d_gpu((sqweight*y).copy(), + proj_geomT, + vol_geomT) + del y + + s = numpy.linalg.norm(x1) ### this line? - x1 = x1/s; - ### this line? - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); - y = sqweight*y; + x1 = (x1/s).copy(); + + # ### this line? + # sino_id, y = astra.creators.create_sino3d_gpu(x1, + # proj_geomT, + # vol_geomT); + # y = sqweight * y; astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); + astra.matlab.data3d('delete', idx) + print ("iteration {0} s= {1}".format(i,s)) + #end del proj_geomT del vol_geomT + #plt.show() else: #% divergen beam geometry - #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); + print('Calculating Lipshitz constant for divergen beam geometry...') niter = 8; #% number of iteration for PM x1 = numpy.random.rand(SlicesZ , N , N); #sqweight = sqrt(weights); - sqweight = numpy.sqrt(weights.T[0]) + sqweight = numpy.sqrt(weights[0]) sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); y = sqweight*y; @@ -217,6 +313,7 @@ class FISTAReconstructor(): #end #clear x1 del x1 + return s @@ -225,130 +322,288 @@ class FISTAReconstructor(): if regularizer is not None: self.pars['regularizer'] = regularizer + + def initialize(self): + # convenience variable storage + proj_geom = self.pars['projector_geometry'] + vol_geom = self.pars['output_geometry'] + sino = self.pars['input_sinogram'] + + # a 'warm start' with SIRT method + # Create a data object for the reconstruction + rec_id = astra.matlab.data3d('create', '-vol', + vol_geom); + + #sinogram_id = astra_mex_data3d('create', '-proj3d', proj_geom, sino); + sinogram_id = astra.matlab.data3d('create', '-proj3d', + proj_geom, + sino) + + sirt_config = astra.astra_dict('SIRT3D_CUDA') + sirt_config['ReconstructionDataId' ] = rec_id + sirt_config['ProjectionDataId'] = sinogram_id + + sirt = astra.algorithm.create(sirt_config) + astra.algorithm.run(sirt, iterations=35) + X = astra.matlab.data3d('get', rec_id) + + # clean up memory + astra.matlab.data3d('delete', rec_id) + astra.matlab.data3d('delete', sinogram_id) + astra.algorithm.delete(sirt) + + + + return X + + def createOrderedSubsets(self, subsets=None): + if subsets is None: + try: + subsets = self.getParameter('subsets') + except Exception(): + subsets = 0 + #return subsets + + angles = self.getParameter('projector_geometry')['ProjectionAngles'] + + #binEdges = numpy.linspace(angles.min(), + # angles.max(), + # subsets + 1) + binsDiscr, binEdges = numpy.histogram(angles, bins=subsets) + # get rearranged subset indices + IndicesReorg = numpy.zeros((numpy.shape(angles))) + counterM = 0 + for ii in range(binsDiscr.max()): + counter = 0 + for jj in range(subsets): + curr_index = ii + jj + counter + #print ("{0} {1} {2}".format(binsDiscr[jj] , ii, counterM)) + if binsDiscr[jj] > ii: + if (counterM < numpy.size(IndicesReorg)): + IndicesReorg[counterM] = curr_index + counterM = counterM + 1 + + counter = counter + binsDiscr[jj] - 1 + + # store the OS in parameters + self.setParameter(os_subsets=subsets, + os_bins=binsDiscr, + os_indices=IndicesReorg) + + + def prepareForIteration(self): + print ("FISTA Reconstructor: prepare for iteration") + + self.residual_error = numpy.zeros((self.pars['number_of_iterations'])) + self.objective = numpy.zeros((self.pars['number_of_iterations'])) + + #2D array (for 3D data) of sparse "ring" + detectors, nangles, sliceZ = numpy.shape(self.pars['input_sinogram']) + self.r = numpy.zeros((detectors, sliceZ), dtype=numpy.float) + # another ring variable + self.r_x = self.r.copy() + + self.residual = numpy.zeros(numpy.shape(self.pars['input_sinogram'])) + + if self.getParameter('Lipschitz_constant') is None: + self.pars['Lipschitz_constant'] = \ + self.calculateLipschitzConstantWithPowerMethod() + # errors vector (if the ground truth is given) + self.Resid_error = numpy.zeros((self.getParameter('number_of_iterations'))); + # objective function values vector + self.objective = numpy.zeros((self.getParameter('number_of_iterations'))); + + + # prepareForIteration + + def iterate(self, Xin=None): + print ("FISTA Reconstructor: iterate") + + if Xin is None: + if self.getParameter('initialize'): + X = self.initialize() + else: + N = vol_geom['GridColCount'] + X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) + else: + # copy by reference + X = Xin + # store the output volume in the parameters + self.setParameter(output_volume=X) + X_t = X.copy() + # convenience variable storage + proj_geom , vol_geom, sino , \ + SlicesZ = self.getParameter([ 'projector_geometry' , + 'output_geometry', + 'input_sinogram', + 'SlicesZ' ]) + + t = 1 + + for i in range(self.getParameter('number_of_iterations')): + X_old = X.copy() + t_old = t + r_old = self.r.copy() + if self.getParameter('projector_geometry')['type'] == 'parallel' or \ + self.getParameter('projector_geometry')['type'] == 'fanflat' or \ + self.getParameter('projector_geometry')['type'] == 'fanflat_vec': + # if the geometry is parallel use slice-by-slice + # projection-backprojection routine + #sino_updt = zeros(size(sino),'single'); + proj_geomT = proj_geom.copy() + proj_geomT['DetectorRowCount'] = 1 + vol_geomT = vol_geom.copy() + vol_geomT['GridSliceCount'] = 1; + self.sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) + for kkk in range(SlicesZ): + sino_id, self.sino_updt[kkk] = \ + astra.creators.create_sino3d_gpu( + X_t[kkk:kkk+1], proj_geomT, vol_geomT) + astra.matlab.data3d('delete', sino_id) + else: + # for divergent 3D geometry (watch the GPU memory overflow in + # ASTRA versions < 1.8) + #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); + sino_id, self.sino_updt = astra.creators.create_sino3d_gpu( + X_t, proj_geom, vol_geom) + + + ## RING REMOVAL + self.ringRemoval(i) + ## Projection/Backprojection Routine + self.projectionBackprojection(X, X_t) + astra.matlab.data3d('delete', sino_id) + ## REGULARIZATION + X = self.regularize(X) + ## Update Loop + X , X_t, t = self.updateLoop(i, X, X_old, r_old, t, t_old) + self.setParameter(output_volume=X) + return X + ## iterate - + def ringRemoval(self, i): + print ("FISTA Reconstructor: ring removal") + residual = self.residual + lambdaR_L1 , alpha_ring , weights , L_const , sino= \ + self.getParameter(['ring_lambda_R_L1', + 'ring_alpha' , 'weights', + 'Lipschitz_constant', + 'input_sinogram']) + r_x = self.r_x + sino_updt = self.sino_updt + + SlicesZ, anglesNumb, Detectors = \ + numpy.shape(self.getParameter('input_sinogram')) + if lambdaR_L1 > 0 : + for kkk in range(anglesNumb): + + residual[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ + ((sino_updt[:,kkk,:]).squeeze() - \ + (sino[:,kkk,:]).squeeze() -\ + (alpha_ring * r_x) + ) + vec = residual.sum(axis = 1) + #if SlicesZ > 1: + # vec = vec[:,1,:].squeeze() + self.r = (r_x - (1./L_const) * vec).copy() + self.objective[i] = (0.5 * (residual ** 2).sum()) + def projectionBackprojection(self, X, X_t): + print ("FISTA Reconstructor: projection-backprojection routine") + + # a few useful variables + SlicesZ, anglesNumb, Detectors = \ + numpy.shape(self.getParameter('input_sinogram')) + residual = self.residual + proj_geom , vol_geom , L_const = \ + self.getParameter(['projector_geometry' , + 'output_geometry', + 'Lipschitz_constant']) + + + if self.getParameter('projector_geometry')['type'] == 'parallel' or \ + self.getParameter('projector_geometry')['type'] == 'fanflat' or \ + self.getParameter('projector_geometry')['type'] == 'fanflat_vec': + # if the geometry is parallel use slice-by-slice + # projection-backprojection routine + #sino_updt = zeros(size(sino),'single'); + proj_geomT = proj_geom.copy() + proj_geomT['DetectorRowCount'] = 1 + vol_geomT = vol_geom.copy() + vol_geomT['GridSliceCount'] = 1; + x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32) + + for kkk in range(SlicesZ): + + x_id, x_temp[kkk] = \ + astra.creators.create_backprojection3d_gpu( + residual[kkk:kkk+1], + proj_geomT, vol_geomT) + astra.matlab.data3d('delete', x_id) + else: + x_id, x_temp = \ + astra.creators.create_backprojection3d_gpu( + residual, proj_geom, vol_geom) + + X = X_t - (1/L_const) * x_temp + #astra.matlab.data3d('delete', sino_id) + astra.matlab.data3d('delete', x_id) + + def regularize(self, X): + print ("FISTA Reconstructor: regularize") + + regularizer = self.getParameter('regularizer') + if regularizer is not None: + return regularizer(input=X) + else: + return X + + def updateLoop(self, i, X, X_old, r_old, t, t_old): + print ("FISTA Reconstructor: update loop") + lambdaR_L1 = self.getParameter('ring_lambda_R_L1') + if lambdaR_L1 > 0: + self.r = numpy.max( + numpy.abs(self.r) - lambdaR_L1 , 0) * \ + numpy.sign(self.r) + t = (1 + numpy.sqrt(1 + 4 * t**2))/2 + X_t = X + (((t_old -1)/t) * (X - X_old)) + + if lambdaR_L1 > 0: + self.r_x = self.r + \ + (((t_old-1)/t) * (self.r - r_old)) + + if self.getParameter('region_of_interest') is None: + string = 'Iteration Number {0} | Objective {1} \n' + print (string.format( i, self.objective[i])) + else: + ROI , X_ideal = fistaRecon.getParameter('region_of_interest', + 'ideal_image') + + Resid_error[i] = RMSE(X*ROI, X_ideal*ROI) + string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' + print (string.format(i,Resid_error[i], self.objective[i])) + return (X , X_t, t) -def getEntry(location): - for item in nx[location].keys(): - print (item) - - -print ("Loading Data") - -##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" -####ind = [i * 1049 for i in range(360)] -#### use only 360 images -##images = 200 -##ind = [int(i * 1049 / images) for i in range(images)] -##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) - -#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" -#fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" -##fname = "/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/dendr.h5" -##nx = h5py.File(fname, "r") -## -### the data are stored in a particular location in the hdf5 -##for item in nx['entry1/tomo_entry/data'].keys(): -## print (item) -## -##data = nx.get('entry1/tomo_entry/data/rotation_angle') -##angles = numpy.zeros(data.shape) -##data.read_direct(angles) -##print (angles) -### angles should be in degrees -## -##data = nx.get('entry1/tomo_entry/data/data') -##stack = numpy.zeros(data.shape) -##data.read_direct(stack) -##print (data.shape) -## -##print ("Data Loaded") -## -## -### Normalize -##data = nx.get('entry1/tomo_entry/instrument/detector/image_key') -##itype = numpy.zeros(data.shape) -##data.read_direct(itype) -### 2 is dark field -##darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] -##dark = darks[0] -##for i in range(1, len(darks)): -## dark += darks[i] -##dark = dark / len(darks) -###dark[0][0] = dark[0][1] -## -### 1 is flat field -##flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] -##flat = flats[0] -##for i in range(1, len(flats)): -## flat += flats[i] -##flat = flat / len(flats) -###flat[0][0] = dark[0][1] -## -## -### 0 is projection data -##proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] -##angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] -##angle_proj = numpy.asarray (angle_proj) -##angle_proj = angle_proj.astype(numpy.float32) -## -### normalized data are -### norm = (projection - dark)/(flat-dark) -## -##def normalize(projection, dark, flat, def_val=0.1): -## a = (projection - dark) -## b = (flat-dark) -## with numpy.errstate(divide='ignore', invalid='ignore'): -## c = numpy.true_divide( a, b ) -## c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 -## return c -## -## -##norm = [normalize(projection, dark, flat) for projection in proj] -##norm = numpy.asarray (norm) -##norm = norm.astype(numpy.float32) - - -##niterations = 15 -##threads = 3 -## -##img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -##img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -##img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -## -##iteration_values = numpy.zeros((niterations,)) -##img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, -## iteration_values, False) -##print ("iteration values %s" % str(iteration_values)) -## -##iteration_values = numpy.zeros((niterations,)) -##img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, -## numpy.double(1e-5), iteration_values , False) -##print ("iteration values %s" % str(iteration_values)) -##iteration_values = numpy.zeros((niterations,)) -##img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, -## numpy.double(1e-5), iteration_values , False) -##print ("iteration values %s" % str(iteration_values)) -## -## -####numpy.save("cgls_recon.npy", img_data) -##import matplotlib.pyplot as plt -##fig, ax = plt.subplots(1,6,sharey=True) -##ax[0].imshow(img_cgls[80]) -##ax[0].axis('off') # clear x- and y-axes -##ax[1].imshow(img_sirt[80]) -##ax[1].axis('off') # clear x- and y-axes -##ax[2].imshow(img_mlem[80]) -##ax[2].axis('off') # clear x- and y-axesplt.show() -##ax[3].imshow(img_cgls_conv[80]) -##ax[3].axis('off') # clear x- and y-axesplt.show() -##ax[4].imshow(img_cgls_tikhonov[80]) -##ax[4].axis('off') # clear x- and y-axesplt.show() -##ax[5].imshow(img_cgls_TVreg[80]) -##ax[5].axis('off') # clear x- and y-axesplt.show() -## -## -##plt.show() -## + def os_iterate(self, Xin=None): + print ("FISTA Reconstructor: iterate") + + if Xin is None: + if self.getParameter('initialize'): + X = self.initialize() + else: + N = vol_geom['GridColCount'] + X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) + else: + # copy by reference + X = Xin + # store the output volume in the parameters + self.setParameter(output_volume=X) + X_t = X.copy() + # some useful constants + proj_geom , vol_geom, sino , \ + SlicesZ, weights , alpha_ring , + lambdaR_L1 , L_const = self.getParameter( + ['projector_geometry' , 'output_geometry', + 'input_sinogram', 'SlicesZ' , 'weights', 'ring_alpha' , + 'ring_lambda_R_L1', 'Lipschitz_constant']) |