summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2017-10-24 11:26:46 +0100
committerEdoardo Pasca <edo.paskino@gmail.com>2017-10-24 11:26:46 +0100
commita11c59651ec125e24371a2049606df0f80f458d0 (patch)
tree0f16ebde08333da13490eda663444357c621cd49 /src
parent4fd4f187a70c0e4f56d5194b09ab4a528d20ee51 (diff)
downloadregularization-a11c59651ec125e24371a2049606df0f80f458d0.tar.gz
regularization-a11c59651ec125e24371a2049606df0f80f458d0.tar.bz2
regularization-a11c59651ec125e24371a2049606df0f80f458d0.tar.xz
regularization-a11c59651ec125e24371a2049606df0f80f458d0.zip
latest dev
Diffstat (limited to 'src')
-rw-r--r--src/Python/ccpi/reconstruction/FISTAReconstructor.py599
1 files changed, 427 insertions, 172 deletions
diff --git a/src/Python/ccpi/reconstruction/FISTAReconstructor.py b/src/Python/ccpi/reconstruction/FISTAReconstructor.py
index ea96b53..85bfac5 100644
--- a/src/Python/ccpi/reconstruction/FISTAReconstructor.py
+++ b/src/Python/ccpi/reconstruction/FISTAReconstructor.py
@@ -21,10 +21,9 @@
import numpy
-import h5py
#from ccpi.reconstruction.parallelbeam import alg
-from ccpi.imaging.Regularizer import Regularizer
+#from ccpi.imaging.Regularizer import Regularizer
from enum import Enum
import astra
@@ -74,18 +73,34 @@ class FISTAReconstructor():
# 3. "A novel tomographic reconstruction method based on the robust
# Student's t function for suppressing data outliers" D. Kazantsev et.al.
# D. Kazantsev, 2016-17
- def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs):
- self.params = dict()
- self.params['projector_geometry'] = projector_geometry
- self.params['output_geometry'] = output_geometry
- self.params['input_sinogram'] = input_sinogram
- detectors, nangles, sliceZ = numpy.shape(input_sinogram)
- self.params['detectors'] = detectors
- self.params['number_og_angles'] = nangles
- self.params['SlicesZ'] = sliceZ
+ def __init__(self, projector_geometry, output_geometry, input_sinogram,
+ **kwargs):
+ # handle parmeters:
+ # obligatory parameters
+ self.pars = dict()
+ self.pars['projector_geometry'] = projector_geometry # proj_geom
+ self.pars['output_geometry'] = output_geometry # vol_geom
+ self.pars['input_sinogram'] = input_sinogram # sino
+ sliceZ, nangles, detectors = numpy.shape(input_sinogram)
+ self.pars['detectors'] = detectors
+ self.pars['number_of_angles'] = nangles
+ self.pars['SlicesZ'] = sliceZ
+ self.pars['output_volume'] = None
+
+ print (self.pars)
+ # handle optional input parameters (at instantiation)
# Accepted input keywords
- kw = ('number_of_iterations',
+ kw = (
+ # mandatory fields
+ 'projector_geometry',
+ 'output_geometry',
+ 'input_sinogram',
+ 'detectors',
+ 'number_of_angles',
+ 'SlicesZ',
+ # optional fields
+ 'number_of_iterations',
'Lipschitz_constant' ,
'ideal_image' ,
'weights' ,
@@ -93,7 +108,13 @@ class FISTAReconstructor():
'initialize' ,
'regularizer' ,
'ring_lambda_R_L1',
- 'ring_alpha')
+ 'ring_alpha',
+ 'subsets',
+ 'output_volume',
+ 'os_subsets',
+ 'os_indices',
+ 'os_bins')
+ self.acceptedInputKeywords = list(kw)
# handle keyworded parameters
if kwargs is not None:
@@ -110,85 +131,160 @@ class FISTAReconstructor():
if 'weights' in kwargs.keys():
self.pars['weights'] = kwargs['weights']
else:
- self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram']))
+ self.pars['weights'] = \
+ numpy.ones(numpy.shape(
+ self.pars['input_sinogram']))
if 'Lipschitz_constant' in kwargs.keys():
self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant']
else:
- self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod()
+ self.pars['Lipschitz_constant'] = None
- if not self.pars['ideal_image'] in kwargs.keys():
+ if not 'ideal_image' in kwargs.keys():
self.pars['ideal_image'] = None
- if not self.pars['region_of_interest'] :
+ if not 'region_of_interest'in kwargs.keys() :
if self.pars['ideal_image'] == None:
- pass
+ self.pars['region_of_interest'] = None
else:
- self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0)
-
- if not self.pars['regularizer'] :
+ ## nonzero if the image is larger than m
+ fsm = numpy.frompyfunc(lambda x,m: 1 if x>m else 0, 2,1)
+
+ self.pars['region_of_interest'] = fsm(self.pars['ideal_image'], 0)
+
+ # the regularizer must be a correctly instantiated object
+ if not 'regularizer' in kwargs.keys() :
self.pars['regularizer'] = None
+
+ #RING REMOVAL
+ if not 'ring_lambda_R_L1' in kwargs.keys():
+ self.pars['ring_lambda_R_L1'] = 0
+ if not 'ring_alpha' in kwargs.keys():
+ self.pars['ring_alpha'] = 1
+
+ # ORDERED SUBSET
+ if not 'subsets' in kwargs.keys():
+ self.pars['subsets'] = 0
else:
- # the regularizer must be a correctly instantiated object
- if not self.pars['ring_lambda_R_L1']:
- self.pars['ring_lambda_R_L1'] = 0
- if not self.pars['ring_alpha']:
- self.pars['ring_alpha'] = 1
+ self.createOrderedSubsets()
+
+ if not 'initialize' in kwargs.keys():
+ self.pars['initialize'] = False
+
+ def setParameter(self, **kwargs):
+ '''set named parameter for the reconstructor engine
+
+ raises Exception if the named parameter is not recognized
+ '''
+ for key , value in kwargs.items():
+ if key in self.acceptedInputKeywords:
+ self.pars[key] = value
+ else:
+ raise Exception('Wrong parameter {0} for '.format(key) +
+ 'reconstructor')
+ # setParameter
+
+ def getParameter(self, key):
+ if type(key) is str:
+ if key in self.acceptedInputKeywords:
+ return self.pars[key]
+ else:
+ raise Exception('Unrecongnised parameter: {0} '.format(key) )
+ elif type(key) is list:
+ outpars = []
+ for k in key:
+ outpars.append(self.getParameter(k))
+ return outpars
+ else:
+ raise Exception('Unhandled input {0}' .format(str(type(key))))
+
+
def calculateLipschitzConstantWithPowerMethod(self):
''' using Power method (PM) to establish L constant'''
- #N = params.vol_geom.GridColCount
- N = self.pars['output_geometry'].GridColCount
- proj_geom = self.params['projector_geometry']
- vol_geom = self.params['output_geometry']
+ N = self.pars['output_geometry']['GridColCount']
+ proj_geom = self.pars['projector_geometry']
+ vol_geom = self.pars['output_geometry']
weights = self.pars['weights']
SlicesZ = self.pars['SlicesZ']
- if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'):
+
+
+ if (proj_geom['type'] == 'parallel') or \
+ (proj_geom['type'] == 'parallel3d'):
#% for parallel geometry we can do just one slice
- #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...');
- niter = 15;# % number of iteration for the PM
+ #print('Calculating Lipshitz constant for parallel beam geometry...')
+ niter = 5;# % number of iteration for the PM
#N = params.vol_geom.GridColCount;
#x1 = rand(N,N,1);
x1 = numpy.random.rand(1,N,N)
#sqweight = sqrt(weights(:,:,1));
- sqweight = numpy.sqrt(weights.T[0])
+ sqweight = numpy.sqrt(weights[0])
proj_geomT = proj_geom.copy();
- proj_geomT.DetectorRowCount = 1;
+ proj_geomT['DetectorRowCount'] = 1;
vol_geomT = vol_geom.copy();
vol_geomT['GridSliceCount'] = 1;
+ #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
+
for i in range(niter):
- if i == 0:
- #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
- sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT);
- y = sqweight * y # element wise multiplication
- #astra_mex_data3d('delete', sino_id);
- astra.matlab.data3d('delete', sino_id)
+ # [id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geomT, vol_geomT);
+ # s = norm(x1(:));
+ # x1 = x1/s;
+ # [sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
+ # y = sqweight.*y;
+ # astra_mex_data3d('delete', sino_id);
+ # astra_mex_data3d('delete', id);
+ #print ("iteration {0}".format(i))
+
+ sino_id, y = astra.creators.create_sino3d_gpu(x1,
+ proj_geomT,
+ vol_geomT)
+
+ y = (sqweight * y).copy() # element wise multiplication
+
+ #b=fig.add_subplot(2,1,2)
+ #imgplot = plt.imshow(x1[0])
+ #plt.show()
+
+ #astra_mex_data3d('delete', sino_id);
+ astra.matlab.data3d('delete', sino_id)
+ del x1
- idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT);
+ idx,x1 = astra.creators.create_backprojection3d_gpu((sqweight*y).copy(),
+ proj_geomT,
+ vol_geomT)
+ del y
+
+
s = numpy.linalg.norm(x1)
### this line?
- x1 = x1/s;
- ### this line?
- sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT);
- y = sqweight*y;
+ x1 = (x1/s).copy();
+
+ # ### this line?
+ # sino_id, y = astra.creators.create_sino3d_gpu(x1,
+ # proj_geomT,
+ # vol_geomT);
+ # y = sqweight * y;
astra.matlab.data3d('delete', sino_id);
- astra.matlab.data3d('delete', idx);
+ astra.matlab.data3d('delete', idx)
+ print ("iteration {0} s= {1}".format(i,s))
+
#end
del proj_geomT
del vol_geomT
+ #plt.show()
else:
#% divergen beam geometry
- #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...');
+ print('Calculating Lipshitz constant for divergen beam geometry...')
niter = 8; #% number of iteration for PM
x1 = numpy.random.rand(SlicesZ , N , N);
#sqweight = sqrt(weights);
- sqweight = numpy.sqrt(weights.T[0])
+ sqweight = numpy.sqrt(weights[0])
sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom);
y = sqweight*y;
@@ -217,6 +313,7 @@ class FISTAReconstructor():
#end
#clear x1
del x1
+
return s
@@ -225,130 +322,288 @@ class FISTAReconstructor():
if regularizer is not None:
self.pars['regularizer'] = regularizer
+
+ def initialize(self):
+ # convenience variable storage
+ proj_geom = self.pars['projector_geometry']
+ vol_geom = self.pars['output_geometry']
+ sino = self.pars['input_sinogram']
+
+ # a 'warm start' with SIRT method
+ # Create a data object for the reconstruction
+ rec_id = astra.matlab.data3d('create', '-vol',
+ vol_geom);
+
+ #sinogram_id = astra_mex_data3d('create', '-proj3d', proj_geom, sino);
+ sinogram_id = astra.matlab.data3d('create', '-proj3d',
+ proj_geom,
+ sino)
+
+ sirt_config = astra.astra_dict('SIRT3D_CUDA')
+ sirt_config['ReconstructionDataId' ] = rec_id
+ sirt_config['ProjectionDataId'] = sinogram_id
+
+ sirt = astra.algorithm.create(sirt_config)
+ astra.algorithm.run(sirt, iterations=35)
+ X = astra.matlab.data3d('get', rec_id)
+
+ # clean up memory
+ astra.matlab.data3d('delete', rec_id)
+ astra.matlab.data3d('delete', sinogram_id)
+ astra.algorithm.delete(sirt)
+
+
+
+ return X
+
+ def createOrderedSubsets(self, subsets=None):
+ if subsets is None:
+ try:
+ subsets = self.getParameter('subsets')
+ except Exception():
+ subsets = 0
+ #return subsets
+
+ angles = self.getParameter('projector_geometry')['ProjectionAngles']
+
+ #binEdges = numpy.linspace(angles.min(),
+ # angles.max(),
+ # subsets + 1)
+ binsDiscr, binEdges = numpy.histogram(angles, bins=subsets)
+ # get rearranged subset indices
+ IndicesReorg = numpy.zeros((numpy.shape(angles)))
+ counterM = 0
+ for ii in range(binsDiscr.max()):
+ counter = 0
+ for jj in range(subsets):
+ curr_index = ii + jj + counter
+ #print ("{0} {1} {2}".format(binsDiscr[jj] , ii, counterM))
+ if binsDiscr[jj] > ii:
+ if (counterM < numpy.size(IndicesReorg)):
+ IndicesReorg[counterM] = curr_index
+ counterM = counterM + 1
+
+ counter = counter + binsDiscr[jj] - 1
+
+ # store the OS in parameters
+ self.setParameter(os_subsets=subsets,
+ os_bins=binsDiscr,
+ os_indices=IndicesReorg)
+
+
+ def prepareForIteration(self):
+ print ("FISTA Reconstructor: prepare for iteration")
+
+ self.residual_error = numpy.zeros((self.pars['number_of_iterations']))
+ self.objective = numpy.zeros((self.pars['number_of_iterations']))
+
+ #2D array (for 3D data) of sparse "ring"
+ detectors, nangles, sliceZ = numpy.shape(self.pars['input_sinogram'])
+ self.r = numpy.zeros((detectors, sliceZ), dtype=numpy.float)
+ # another ring variable
+ self.r_x = self.r.copy()
+
+ self.residual = numpy.zeros(numpy.shape(self.pars['input_sinogram']))
+
+ if self.getParameter('Lipschitz_constant') is None:
+ self.pars['Lipschitz_constant'] = \
+ self.calculateLipschitzConstantWithPowerMethod()
+ # errors vector (if the ground truth is given)
+ self.Resid_error = numpy.zeros((self.getParameter('number_of_iterations')));
+ # objective function values vector
+ self.objective = numpy.zeros((self.getParameter('number_of_iterations')));
+
+
+ # prepareForIteration
+
+ def iterate(self, Xin=None):
+ print ("FISTA Reconstructor: iterate")
+
+ if Xin is None:
+ if self.getParameter('initialize'):
+ X = self.initialize()
+ else:
+ N = vol_geom['GridColCount']
+ X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float)
+ else:
+ # copy by reference
+ X = Xin
+ # store the output volume in the parameters
+ self.setParameter(output_volume=X)
+ X_t = X.copy()
+ # convenience variable storage
+ proj_geom , vol_geom, sino , \
+ SlicesZ = self.getParameter([ 'projector_geometry' ,
+ 'output_geometry',
+ 'input_sinogram',
+ 'SlicesZ' ])
+
+ t = 1
+
+ for i in range(self.getParameter('number_of_iterations')):
+ X_old = X.copy()
+ t_old = t
+ r_old = self.r.copy()
+ if self.getParameter('projector_geometry')['type'] == 'parallel' or \
+ self.getParameter('projector_geometry')['type'] == 'fanflat' or \
+ self.getParameter('projector_geometry')['type'] == 'fanflat_vec':
+ # if the geometry is parallel use slice-by-slice
+ # projection-backprojection routine
+ #sino_updt = zeros(size(sino),'single');
+ proj_geomT = proj_geom.copy()
+ proj_geomT['DetectorRowCount'] = 1
+ vol_geomT = vol_geom.copy()
+ vol_geomT['GridSliceCount'] = 1;
+ self.sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float)
+ for kkk in range(SlicesZ):
+ sino_id, self.sino_updt[kkk] = \
+ astra.creators.create_sino3d_gpu(
+ X_t[kkk:kkk+1], proj_geomT, vol_geomT)
+ astra.matlab.data3d('delete', sino_id)
+ else:
+ # for divergent 3D geometry (watch the GPU memory overflow in
+ # ASTRA versions < 1.8)
+ #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom);
+ sino_id, self.sino_updt = astra.creators.create_sino3d_gpu(
+ X_t, proj_geom, vol_geom)
+
+
+ ## RING REMOVAL
+ self.ringRemoval(i)
+ ## Projection/Backprojection Routine
+ self.projectionBackprojection(X, X_t)
+ astra.matlab.data3d('delete', sino_id)
+ ## REGULARIZATION
+ X = self.regularize(X)
+ ## Update Loop
+ X , X_t, t = self.updateLoop(i, X, X_old, r_old, t, t_old)
+ self.setParameter(output_volume=X)
+ return X
+ ## iterate
-
+ def ringRemoval(self, i):
+ print ("FISTA Reconstructor: ring removal")
+ residual = self.residual
+ lambdaR_L1 , alpha_ring , weights , L_const , sino= \
+ self.getParameter(['ring_lambda_R_L1',
+ 'ring_alpha' , 'weights',
+ 'Lipschitz_constant',
+ 'input_sinogram'])
+ r_x = self.r_x
+ sino_updt = self.sino_updt
+
+ SlicesZ, anglesNumb, Detectors = \
+ numpy.shape(self.getParameter('input_sinogram'))
+ if lambdaR_L1 > 0 :
+ for kkk in range(anglesNumb):
+
+ residual[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \
+ ((sino_updt[:,kkk,:]).squeeze() - \
+ (sino[:,kkk,:]).squeeze() -\
+ (alpha_ring * r_x)
+ )
+ vec = residual.sum(axis = 1)
+ #if SlicesZ > 1:
+ # vec = vec[:,1,:].squeeze()
+ self.r = (r_x - (1./L_const) * vec).copy()
+ self.objective[i] = (0.5 * (residual ** 2).sum())
+ def projectionBackprojection(self, X, X_t):
+ print ("FISTA Reconstructor: projection-backprojection routine")
+
+ # a few useful variables
+ SlicesZ, anglesNumb, Detectors = \
+ numpy.shape(self.getParameter('input_sinogram'))
+ residual = self.residual
+ proj_geom , vol_geom , L_const = \
+ self.getParameter(['projector_geometry' ,
+ 'output_geometry',
+ 'Lipschitz_constant'])
+
+
+ if self.getParameter('projector_geometry')['type'] == 'parallel' or \
+ self.getParameter('projector_geometry')['type'] == 'fanflat' or \
+ self.getParameter('projector_geometry')['type'] == 'fanflat_vec':
+ # if the geometry is parallel use slice-by-slice
+ # projection-backprojection routine
+ #sino_updt = zeros(size(sino),'single');
+ proj_geomT = proj_geom.copy()
+ proj_geomT['DetectorRowCount'] = 1
+ vol_geomT = vol_geom.copy()
+ vol_geomT['GridSliceCount'] = 1;
+ x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32)
+
+ for kkk in range(SlicesZ):
+
+ x_id, x_temp[kkk] = \
+ astra.creators.create_backprojection3d_gpu(
+ residual[kkk:kkk+1],
+ proj_geomT, vol_geomT)
+ astra.matlab.data3d('delete', x_id)
+ else:
+ x_id, x_temp = \
+ astra.creators.create_backprojection3d_gpu(
+ residual, proj_geom, vol_geom)
+
+ X = X_t - (1/L_const) * x_temp
+ #astra.matlab.data3d('delete', sino_id)
+ astra.matlab.data3d('delete', x_id)
+
+ def regularize(self, X):
+ print ("FISTA Reconstructor: regularize")
+
+ regularizer = self.getParameter('regularizer')
+ if regularizer is not None:
+ return regularizer(input=X)
+ else:
+ return X
+
+ def updateLoop(self, i, X, X_old, r_old, t, t_old):
+ print ("FISTA Reconstructor: update loop")
+ lambdaR_L1 = self.getParameter('ring_lambda_R_L1')
+ if lambdaR_L1 > 0:
+ self.r = numpy.max(
+ numpy.abs(self.r) - lambdaR_L1 , 0) * \
+ numpy.sign(self.r)
+ t = (1 + numpy.sqrt(1 + 4 * t**2))/2
+ X_t = X + (((t_old -1)/t) * (X - X_old))
+
+ if lambdaR_L1 > 0:
+ self.r_x = self.r + \
+ (((t_old-1)/t) * (self.r - r_old))
+
+ if self.getParameter('region_of_interest') is None:
+ string = 'Iteration Number {0} | Objective {1} \n'
+ print (string.format( i, self.objective[i]))
+ else:
+ ROI , X_ideal = fistaRecon.getParameter('region_of_interest',
+ 'ideal_image')
+
+ Resid_error[i] = RMSE(X*ROI, X_ideal*ROI)
+ string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n'
+ print (string.format(i,Resid_error[i], self.objective[i]))
+ return (X , X_t, t)
-def getEntry(location):
- for item in nx[location].keys():
- print (item)
-
-
-print ("Loading Data")
-
-##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif"
-####ind = [i * 1049 for i in range(360)]
-#### use only 360 images
-##images = 200
-##ind = [int(i * 1049 / images) for i in range(images)]
-##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None)
-
-#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs"
-#fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs"
-##fname = "/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/dendr.h5"
-##nx = h5py.File(fname, "r")
-##
-### the data are stored in a particular location in the hdf5
-##for item in nx['entry1/tomo_entry/data'].keys():
-## print (item)
-##
-##data = nx.get('entry1/tomo_entry/data/rotation_angle')
-##angles = numpy.zeros(data.shape)
-##data.read_direct(angles)
-##print (angles)
-### angles should be in degrees
-##
-##data = nx.get('entry1/tomo_entry/data/data')
-##stack = numpy.zeros(data.shape)
-##data.read_direct(stack)
-##print (data.shape)
-##
-##print ("Data Loaded")
-##
-##
-### Normalize
-##data = nx.get('entry1/tomo_entry/instrument/detector/image_key')
-##itype = numpy.zeros(data.shape)
-##data.read_direct(itype)
-### 2 is dark field
-##darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ]
-##dark = darks[0]
-##for i in range(1, len(darks)):
-## dark += darks[i]
-##dark = dark / len(darks)
-###dark[0][0] = dark[0][1]
-##
-### 1 is flat field
-##flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ]
-##flat = flats[0]
-##for i in range(1, len(flats)):
-## flat += flats[i]
-##flat = flat / len(flats)
-###flat[0][0] = dark[0][1]
-##
-##
-### 0 is projection data
-##proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ]
-##angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ]
-##angle_proj = numpy.asarray (angle_proj)
-##angle_proj = angle_proj.astype(numpy.float32)
-##
-### normalized data are
-### norm = (projection - dark)/(flat-dark)
-##
-##def normalize(projection, dark, flat, def_val=0.1):
-## a = (projection - dark)
-## b = (flat-dark)
-## with numpy.errstate(divide='ignore', invalid='ignore'):
-## c = numpy.true_divide( a, b )
-## c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0
-## return c
-##
-##
-##norm = [normalize(projection, dark, flat) for projection in proj]
-##norm = numpy.asarray (norm)
-##norm = norm.astype(numpy.float32)
-
-
-##niterations = 15
-##threads = 3
-##
-##img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
-##img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
-##img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
-##
-##iteration_values = numpy.zeros((niterations,))
-##img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
-## iteration_values, False)
-##print ("iteration values %s" % str(iteration_values))
-##
-##iteration_values = numpy.zeros((niterations,))
-##img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
-## numpy.double(1e-5), iteration_values , False)
-##print ("iteration values %s" % str(iteration_values))
-##iteration_values = numpy.zeros((niterations,))
-##img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
-## numpy.double(1e-5), iteration_values , False)
-##print ("iteration values %s" % str(iteration_values))
-##
-##
-####numpy.save("cgls_recon.npy", img_data)
-##import matplotlib.pyplot as plt
-##fig, ax = plt.subplots(1,6,sharey=True)
-##ax[0].imshow(img_cgls[80])
-##ax[0].axis('off') # clear x- and y-axes
-##ax[1].imshow(img_sirt[80])
-##ax[1].axis('off') # clear x- and y-axes
-##ax[2].imshow(img_mlem[80])
-##ax[2].axis('off') # clear x- and y-axesplt.show()
-##ax[3].imshow(img_cgls_conv[80])
-##ax[3].axis('off') # clear x- and y-axesplt.show()
-##ax[4].imshow(img_cgls_tikhonov[80])
-##ax[4].axis('off') # clear x- and y-axesplt.show()
-##ax[5].imshow(img_cgls_TVreg[80])
-##ax[5].axis('off') # clear x- and y-axesplt.show()
-##
-##
-##plt.show()
-##
+ def os_iterate(self, Xin=None):
+ print ("FISTA Reconstructor: iterate")
+
+ if Xin is None:
+ if self.getParameter('initialize'):
+ X = self.initialize()
+ else:
+ N = vol_geom['GridColCount']
+ X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float)
+ else:
+ # copy by reference
+ X = Xin
+ # store the output volume in the parameters
+ self.setParameter(output_volume=X)
+ X_t = X.copy()
+ # some useful constants
+ proj_geom , vol_geom, sino , \
+ SlicesZ, weights , alpha_ring ,
+ lambdaR_L1 , L_const = self.getParameter(
+ ['projector_geometry' , 'output_geometry',
+ 'input_sinogram', 'SlicesZ' , 'weights', 'ring_alpha' ,
+ 'ring_lambda_R_L1', 'Lipschitz_constant'])