summaryrefslogtreecommitdiffstats
path: root/src/Python/ccpi
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2017-10-24 11:31:36 +0100
committerEdoardo Pasca <edo.paskino@gmail.com>2017-10-24 11:31:36 +0100
commit546104f8dfea5691801137c1be99d09e1e999d82 (patch)
tree7b66e2ec46c49ea4ff7b872cd8ac602fe2a9b8d7 /src/Python/ccpi
parent909a7bb4d71bdb14d4e68f42c2297f6154a77ed0 (diff)
downloadregularization-546104f8dfea5691801137c1be99d09e1e999d82.tar.gz
regularization-546104f8dfea5691801137c1be99d09e1e999d82.tar.bz2
regularization-546104f8dfea5691801137c1be99d09e1e999d82.tar.xz
regularization-546104f8dfea5691801137c1be99d09e1e999d82.zip
removed fista directory
use the standard package reconstruction directory for the fista code
Diffstat (limited to 'src/Python/ccpi')
-rw-r--r--src/Python/ccpi/fista/FISTAReconstructor.py609
-rw-r--r--src/Python/ccpi/fista/Reconstructor.py425
-rw-r--r--src/Python/ccpi/fista/__init__.py0
3 files changed, 0 insertions, 1034 deletions
diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py
deleted file mode 100644
index 85bfac5..0000000
--- a/src/Python/ccpi/fista/FISTAReconstructor.py
+++ /dev/null
@@ -1,609 +0,0 @@
-# -*- coding: utf-8 -*-
-###############################################################################
-#This work is part of the Core Imaging Library developed by
-#Visual Analytics and Imaging System Group of the Science Technology
-#Facilities Council, STFC
-#
-#Copyright 2017 Edoardo Pasca, Srikanth Nagella
-#Copyright 2017 Daniil Kazantsev
-#
-#Licensed under the Apache License, Version 2.0 (the "License");
-#you may not use this file except in compliance with the License.
-#You may obtain a copy of the License at
-#http://www.apache.org/licenses/LICENSE-2.0
-#Unless required by applicable law or agreed to in writing, software
-#distributed under the License is distributed on an "AS IS" BASIS,
-#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-#See the License for the specific language governing permissions and
-#limitations under the License.
-###############################################################################
-
-
-
-import numpy
-#from ccpi.reconstruction.parallelbeam import alg
-
-#from ccpi.imaging.Regularizer import Regularizer
-from enum import Enum
-
-import astra
-
-
-
-class FISTAReconstructor():
- '''FISTA-based reconstruction algorithm using ASTRA-toolbox
-
- '''
- # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>>
- # ___Input___:
- # params.[] file:
- # - .proj_geom (geometry of the projector) [required]
- # - .vol_geom (geometry of the reconstructed object) [required]
- # - .sino (vectorized in 2D or 3D sinogram) [required]
- # - .iterFISTA (iterations for the main loop, default 40)
- # - .L_const (Lipschitz constant, default Power method) )
- # - .X_ideal (ideal image, if given)
- # - .weights (statisitcal weights, size of the sinogram)
- # - .ROI (Region-of-interest, only if X_ideal is given)
- # - .initialize (a 'warm start' using SIRT method from ASTRA)
- #----------------Regularization choices------------------------
- # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter)
- # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter)
- # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter)
- # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04)
- # - .Regul_Iterations (iterations for the selected penalty, default 25)
- # - .Regul_tauLLT (time step parameter for LLT term)
- # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal)
- # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1)
- #----------------Visualization parameters------------------------
- # - .show (visualize reconstruction 1/0, (0 default))
- # - .maxvalplot (maximum value to use for imshow[0 maxvalplot])
- # - .slice (for 3D volumes - slice number to imshow)
- # ___Output___:
- # 1. X - reconstructed image/volume
- # 2. output - a structure with
- # - .Resid_error - residual error (if X_ideal is given)
- # - .objective: value of the objective function
- # - .L_const: Lipshitz constant to avoid recalculations
-
- # References:
- # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse
- # Problems" by A. Beck and M Teboulle
- # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo
- # 3. "A novel tomographic reconstruction method based on the robust
- # Student's t function for suppressing data outliers" D. Kazantsev et.al.
- # D. Kazantsev, 2016-17
- def __init__(self, projector_geometry, output_geometry, input_sinogram,
- **kwargs):
- # handle parmeters:
- # obligatory parameters
- self.pars = dict()
- self.pars['projector_geometry'] = projector_geometry # proj_geom
- self.pars['output_geometry'] = output_geometry # vol_geom
- self.pars['input_sinogram'] = input_sinogram # sino
- sliceZ, nangles, detectors = numpy.shape(input_sinogram)
- self.pars['detectors'] = detectors
- self.pars['number_of_angles'] = nangles
- self.pars['SlicesZ'] = sliceZ
- self.pars['output_volume'] = None
-
- print (self.pars)
- # handle optional input parameters (at instantiation)
-
- # Accepted input keywords
- kw = (
- # mandatory fields
- 'projector_geometry',
- 'output_geometry',
- 'input_sinogram',
- 'detectors',
- 'number_of_angles',
- 'SlicesZ',
- # optional fields
- 'number_of_iterations',
- 'Lipschitz_constant' ,
- 'ideal_image' ,
- 'weights' ,
- 'region_of_interest' ,
- 'initialize' ,
- 'regularizer' ,
- 'ring_lambda_R_L1',
- 'ring_alpha',
- 'subsets',
- 'output_volume',
- 'os_subsets',
- 'os_indices',
- 'os_bins')
- self.acceptedInputKeywords = list(kw)
-
- # handle keyworded parameters
- if kwargs is not None:
- for key, value in kwargs.items():
- if key in kw:
- #print("{0} = {1}".format(key, value))
- self.pars[key] = value
-
- # set the default values for the parameters if not set
- if 'number_of_iterations' in kwargs.keys():
- self.pars['number_of_iterations'] = kwargs['number_of_iterations']
- else:
- self.pars['number_of_iterations'] = 40
- if 'weights' in kwargs.keys():
- self.pars['weights'] = kwargs['weights']
- else:
- self.pars['weights'] = \
- numpy.ones(numpy.shape(
- self.pars['input_sinogram']))
- if 'Lipschitz_constant' in kwargs.keys():
- self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant']
- else:
- self.pars['Lipschitz_constant'] = None
-
- if not 'ideal_image' in kwargs.keys():
- self.pars['ideal_image'] = None
-
- if not 'region_of_interest'in kwargs.keys() :
- if self.pars['ideal_image'] == None:
- self.pars['region_of_interest'] = None
- else:
- ## nonzero if the image is larger than m
- fsm = numpy.frompyfunc(lambda x,m: 1 if x>m else 0, 2,1)
-
- self.pars['region_of_interest'] = fsm(self.pars['ideal_image'], 0)
-
- # the regularizer must be a correctly instantiated object
- if not 'regularizer' in kwargs.keys() :
- self.pars['regularizer'] = None
-
- #RING REMOVAL
- if not 'ring_lambda_R_L1' in kwargs.keys():
- self.pars['ring_lambda_R_L1'] = 0
- if not 'ring_alpha' in kwargs.keys():
- self.pars['ring_alpha'] = 1
-
- # ORDERED SUBSET
- if not 'subsets' in kwargs.keys():
- self.pars['subsets'] = 0
- else:
- self.createOrderedSubsets()
-
- if not 'initialize' in kwargs.keys():
- self.pars['initialize'] = False
-
-
-
-
- def setParameter(self, **kwargs):
- '''set named parameter for the reconstructor engine
-
- raises Exception if the named parameter is not recognized
-
- '''
- for key , value in kwargs.items():
- if key in self.acceptedInputKeywords:
- self.pars[key] = value
- else:
- raise Exception('Wrong parameter {0} for '.format(key) +
- 'reconstructor')
- # setParameter
-
- def getParameter(self, key):
- if type(key) is str:
- if key in self.acceptedInputKeywords:
- return self.pars[key]
- else:
- raise Exception('Unrecongnised parameter: {0} '.format(key) )
- elif type(key) is list:
- outpars = []
- for k in key:
- outpars.append(self.getParameter(k))
- return outpars
- else:
- raise Exception('Unhandled input {0}' .format(str(type(key))))
-
-
- def calculateLipschitzConstantWithPowerMethod(self):
- ''' using Power method (PM) to establish L constant'''
-
- N = self.pars['output_geometry']['GridColCount']
- proj_geom = self.pars['projector_geometry']
- vol_geom = self.pars['output_geometry']
- weights = self.pars['weights']
- SlicesZ = self.pars['SlicesZ']
-
-
-
- if (proj_geom['type'] == 'parallel') or \
- (proj_geom['type'] == 'parallel3d'):
- #% for parallel geometry we can do just one slice
- #print('Calculating Lipshitz constant for parallel beam geometry...')
- niter = 5;# % number of iteration for the PM
- #N = params.vol_geom.GridColCount;
- #x1 = rand(N,N,1);
- x1 = numpy.random.rand(1,N,N)
- #sqweight = sqrt(weights(:,:,1));
- sqweight = numpy.sqrt(weights[0])
- proj_geomT = proj_geom.copy();
- proj_geomT['DetectorRowCount'] = 1;
- vol_geomT = vol_geom.copy();
- vol_geomT['GridSliceCount'] = 1;
-
- #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
-
-
- for i in range(niter):
- # [id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geomT, vol_geomT);
- # s = norm(x1(:));
- # x1 = x1/s;
- # [sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
- # y = sqweight.*y;
- # astra_mex_data3d('delete', sino_id);
- # astra_mex_data3d('delete', id);
- #print ("iteration {0}".format(i))
-
- sino_id, y = astra.creators.create_sino3d_gpu(x1,
- proj_geomT,
- vol_geomT)
-
- y = (sqweight * y).copy() # element wise multiplication
-
- #b=fig.add_subplot(2,1,2)
- #imgplot = plt.imshow(x1[0])
- #plt.show()
-
- #astra_mex_data3d('delete', sino_id);
- astra.matlab.data3d('delete', sino_id)
- del x1
-
- idx,x1 = astra.creators.create_backprojection3d_gpu((sqweight*y).copy(),
- proj_geomT,
- vol_geomT)
- del y
-
-
- s = numpy.linalg.norm(x1)
- ### this line?
- x1 = (x1/s).copy();
-
- # ### this line?
- # sino_id, y = astra.creators.create_sino3d_gpu(x1,
- # proj_geomT,
- # vol_geomT);
- # y = sqweight * y;
- astra.matlab.data3d('delete', sino_id);
- astra.matlab.data3d('delete', idx)
- print ("iteration {0} s= {1}".format(i,s))
-
- #end
- del proj_geomT
- del vol_geomT
- #plt.show()
- else:
- #% divergen beam geometry
- print('Calculating Lipshitz constant for divergen beam geometry...')
- niter = 8; #% number of iteration for PM
- x1 = numpy.random.rand(SlicesZ , N , N);
- #sqweight = sqrt(weights);
- sqweight = numpy.sqrt(weights[0])
-
- sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom);
- y = sqweight*y;
- #astra_mex_data3d('delete', sino_id);
- astra.matlab.data3d('delete', sino_id);
-
- for i in range(niter):
- #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom);
- idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y,
- proj_geom,
- vol_geom)
- s = numpy.linalg.norm(x1)
- ### this line?
- x1 = x1/s;
- ### this line?
- #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom);
- sino_id, y = astra.creators.create_sino3d_gpu(x1,
- proj_geom,
- vol_geom);
-
- y = sqweight*y;
- #astra_mex_data3d('delete', sino_id);
- #astra_mex_data3d('delete', id);
- astra.matlab.data3d('delete', sino_id);
- astra.matlab.data3d('delete', idx);
- #end
- #clear x1
- del x1
-
-
- return s
-
-
- def setRegularizer(self, regularizer):
- if regularizer is not None:
- self.pars['regularizer'] = regularizer
-
-
- def initialize(self):
- # convenience variable storage
- proj_geom = self.pars['projector_geometry']
- vol_geom = self.pars['output_geometry']
- sino = self.pars['input_sinogram']
-
- # a 'warm start' with SIRT method
- # Create a data object for the reconstruction
- rec_id = astra.matlab.data3d('create', '-vol',
- vol_geom);
-
- #sinogram_id = astra_mex_data3d('create', '-proj3d', proj_geom, sino);
- sinogram_id = astra.matlab.data3d('create', '-proj3d',
- proj_geom,
- sino)
-
- sirt_config = astra.astra_dict('SIRT3D_CUDA')
- sirt_config['ReconstructionDataId' ] = rec_id
- sirt_config['ProjectionDataId'] = sinogram_id
-
- sirt = astra.algorithm.create(sirt_config)
- astra.algorithm.run(sirt, iterations=35)
- X = astra.matlab.data3d('get', rec_id)
-
- # clean up memory
- astra.matlab.data3d('delete', rec_id)
- astra.matlab.data3d('delete', sinogram_id)
- astra.algorithm.delete(sirt)
-
-
-
- return X
-
- def createOrderedSubsets(self, subsets=None):
- if subsets is None:
- try:
- subsets = self.getParameter('subsets')
- except Exception():
- subsets = 0
- #return subsets
-
- angles = self.getParameter('projector_geometry')['ProjectionAngles']
-
- #binEdges = numpy.linspace(angles.min(),
- # angles.max(),
- # subsets + 1)
- binsDiscr, binEdges = numpy.histogram(angles, bins=subsets)
- # get rearranged subset indices
- IndicesReorg = numpy.zeros((numpy.shape(angles)))
- counterM = 0
- for ii in range(binsDiscr.max()):
- counter = 0
- for jj in range(subsets):
- curr_index = ii + jj + counter
- #print ("{0} {1} {2}".format(binsDiscr[jj] , ii, counterM))
- if binsDiscr[jj] > ii:
- if (counterM < numpy.size(IndicesReorg)):
- IndicesReorg[counterM] = curr_index
- counterM = counterM + 1
-
- counter = counter + binsDiscr[jj] - 1
-
- # store the OS in parameters
- self.setParameter(os_subsets=subsets,
- os_bins=binsDiscr,
- os_indices=IndicesReorg)
-
-
- def prepareForIteration(self):
- print ("FISTA Reconstructor: prepare for iteration")
-
- self.residual_error = numpy.zeros((self.pars['number_of_iterations']))
- self.objective = numpy.zeros((self.pars['number_of_iterations']))
-
- #2D array (for 3D data) of sparse "ring"
- detectors, nangles, sliceZ = numpy.shape(self.pars['input_sinogram'])
- self.r = numpy.zeros((detectors, sliceZ), dtype=numpy.float)
- # another ring variable
- self.r_x = self.r.copy()
-
- self.residual = numpy.zeros(numpy.shape(self.pars['input_sinogram']))
-
- if self.getParameter('Lipschitz_constant') is None:
- self.pars['Lipschitz_constant'] = \
- self.calculateLipschitzConstantWithPowerMethod()
- # errors vector (if the ground truth is given)
- self.Resid_error = numpy.zeros((self.getParameter('number_of_iterations')));
- # objective function values vector
- self.objective = numpy.zeros((self.getParameter('number_of_iterations')));
-
-
- # prepareForIteration
-
- def iterate(self, Xin=None):
- print ("FISTA Reconstructor: iterate")
-
- if Xin is None:
- if self.getParameter('initialize'):
- X = self.initialize()
- else:
- N = vol_geom['GridColCount']
- X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float)
- else:
- # copy by reference
- X = Xin
- # store the output volume in the parameters
- self.setParameter(output_volume=X)
- X_t = X.copy()
- # convenience variable storage
- proj_geom , vol_geom, sino , \
- SlicesZ = self.getParameter([ 'projector_geometry' ,
- 'output_geometry',
- 'input_sinogram',
- 'SlicesZ' ])
-
- t = 1
-
- for i in range(self.getParameter('number_of_iterations')):
- X_old = X.copy()
- t_old = t
- r_old = self.r.copy()
- if self.getParameter('projector_geometry')['type'] == 'parallel' or \
- self.getParameter('projector_geometry')['type'] == 'fanflat' or \
- self.getParameter('projector_geometry')['type'] == 'fanflat_vec':
- # if the geometry is parallel use slice-by-slice
- # projection-backprojection routine
- #sino_updt = zeros(size(sino),'single');
- proj_geomT = proj_geom.copy()
- proj_geomT['DetectorRowCount'] = 1
- vol_geomT = vol_geom.copy()
- vol_geomT['GridSliceCount'] = 1;
- self.sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float)
- for kkk in range(SlicesZ):
- sino_id, self.sino_updt[kkk] = \
- astra.creators.create_sino3d_gpu(
- X_t[kkk:kkk+1], proj_geomT, vol_geomT)
- astra.matlab.data3d('delete', sino_id)
- else:
- # for divergent 3D geometry (watch the GPU memory overflow in
- # ASTRA versions < 1.8)
- #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom);
- sino_id, self.sino_updt = astra.creators.create_sino3d_gpu(
- X_t, proj_geom, vol_geom)
-
-
- ## RING REMOVAL
- self.ringRemoval(i)
- ## Projection/Backprojection Routine
- self.projectionBackprojection(X, X_t)
- astra.matlab.data3d('delete', sino_id)
- ## REGULARIZATION
- X = self.regularize(X)
- ## Update Loop
- X , X_t, t = self.updateLoop(i, X, X_old, r_old, t, t_old)
- self.setParameter(output_volume=X)
- return X
- ## iterate
-
- def ringRemoval(self, i):
- print ("FISTA Reconstructor: ring removal")
- residual = self.residual
- lambdaR_L1 , alpha_ring , weights , L_const , sino= \
- self.getParameter(['ring_lambda_R_L1',
- 'ring_alpha' , 'weights',
- 'Lipschitz_constant',
- 'input_sinogram'])
- r_x = self.r_x
- sino_updt = self.sino_updt
-
- SlicesZ, anglesNumb, Detectors = \
- numpy.shape(self.getParameter('input_sinogram'))
- if lambdaR_L1 > 0 :
- for kkk in range(anglesNumb):
-
- residual[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \
- ((sino_updt[:,kkk,:]).squeeze() - \
- (sino[:,kkk,:]).squeeze() -\
- (alpha_ring * r_x)
- )
- vec = residual.sum(axis = 1)
- #if SlicesZ > 1:
- # vec = vec[:,1,:].squeeze()
- self.r = (r_x - (1./L_const) * vec).copy()
- self.objective[i] = (0.5 * (residual ** 2).sum())
-
- def projectionBackprojection(self, X, X_t):
- print ("FISTA Reconstructor: projection-backprojection routine")
-
- # a few useful variables
- SlicesZ, anglesNumb, Detectors = \
- numpy.shape(self.getParameter('input_sinogram'))
- residual = self.residual
- proj_geom , vol_geom , L_const = \
- self.getParameter(['projector_geometry' ,
- 'output_geometry',
- 'Lipschitz_constant'])
-
-
- if self.getParameter('projector_geometry')['type'] == 'parallel' or \
- self.getParameter('projector_geometry')['type'] == 'fanflat' or \
- self.getParameter('projector_geometry')['type'] == 'fanflat_vec':
- # if the geometry is parallel use slice-by-slice
- # projection-backprojection routine
- #sino_updt = zeros(size(sino),'single');
- proj_geomT = proj_geom.copy()
- proj_geomT['DetectorRowCount'] = 1
- vol_geomT = vol_geom.copy()
- vol_geomT['GridSliceCount'] = 1;
- x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32)
-
- for kkk in range(SlicesZ):
-
- x_id, x_temp[kkk] = \
- astra.creators.create_backprojection3d_gpu(
- residual[kkk:kkk+1],
- proj_geomT, vol_geomT)
- astra.matlab.data3d('delete', x_id)
- else:
- x_id, x_temp = \
- astra.creators.create_backprojection3d_gpu(
- residual, proj_geom, vol_geom)
-
- X = X_t - (1/L_const) * x_temp
- #astra.matlab.data3d('delete', sino_id)
- astra.matlab.data3d('delete', x_id)
-
- def regularize(self, X):
- print ("FISTA Reconstructor: regularize")
-
- regularizer = self.getParameter('regularizer')
- if regularizer is not None:
- return regularizer(input=X)
- else:
- return X
-
- def updateLoop(self, i, X, X_old, r_old, t, t_old):
- print ("FISTA Reconstructor: update loop")
- lambdaR_L1 = self.getParameter('ring_lambda_R_L1')
- if lambdaR_L1 > 0:
- self.r = numpy.max(
- numpy.abs(self.r) - lambdaR_L1 , 0) * \
- numpy.sign(self.r)
- t = (1 + numpy.sqrt(1 + 4 * t**2))/2
- X_t = X + (((t_old -1)/t) * (X - X_old))
-
- if lambdaR_L1 > 0:
- self.r_x = self.r + \
- (((t_old-1)/t) * (self.r - r_old))
-
- if self.getParameter('region_of_interest') is None:
- string = 'Iteration Number {0} | Objective {1} \n'
- print (string.format( i, self.objective[i]))
- else:
- ROI , X_ideal = fistaRecon.getParameter('region_of_interest',
- 'ideal_image')
-
- Resid_error[i] = RMSE(X*ROI, X_ideal*ROI)
- string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n'
- print (string.format(i,Resid_error[i], self.objective[i]))
- return (X , X_t, t)
-
- def os_iterate(self, Xin=None):
- print ("FISTA Reconstructor: iterate")
-
- if Xin is None:
- if self.getParameter('initialize'):
- X = self.initialize()
- else:
- N = vol_geom['GridColCount']
- X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float)
- else:
- # copy by reference
- X = Xin
- # store the output volume in the parameters
- self.setParameter(output_volume=X)
- X_t = X.copy()
-
- # some useful constants
- proj_geom , vol_geom, sino , \
- SlicesZ, weights , alpha_ring ,
- lambdaR_L1 , L_const = self.getParameter(
- ['projector_geometry' , 'output_geometry',
- 'input_sinogram', 'SlicesZ' , 'weights', 'ring_alpha' ,
- 'ring_lambda_R_L1', 'Lipschitz_constant'])
diff --git a/src/Python/ccpi/fista/Reconstructor.py b/src/Python/ccpi/fista/Reconstructor.py
deleted file mode 100644
index d29ac0d..0000000
--- a/src/Python/ccpi/fista/Reconstructor.py
+++ /dev/null
@@ -1,425 +0,0 @@
-# -*- coding: utf-8 -*-
-###############################################################################
-#This work is part of the Core Imaging Library developed by
-#Visual Analytics and Imaging System Group of the Science Technology
-#Facilities Council, STFC
-#
-#Copyright 2017 Edoardo Pasca, Srikanth Nagella
-#Copyright 2017 Daniil Kazantsev
-#
-#Licensed under the Apache License, Version 2.0 (the "License");
-#you may not use this file except in compliance with the License.
-#You may obtain a copy of the License at
-#http://www.apache.org/licenses/LICENSE-2.0
-#Unless required by applicable law or agreed to in writing, software
-#distributed under the License is distributed on an "AS IS" BASIS,
-#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-#See the License for the specific language governing permissions and
-#limitations under the License.
-###############################################################################
-
-
-
-import numpy
-import h5py
-from ccpi.reconstruction.parallelbeam import alg
-
-from Regularizer import Regularizer
-from enum import Enum
-
-import astra
-
-
-
-class FISTAReconstructor():
- '''FISTA-based reconstruction algorithm using ASTRA-toolbox
-
- '''
- # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>>
- # ___Input___:
- # params.[] file:
- # - .proj_geom (geometry of the projector) [required]
- # - .vol_geom (geometry of the reconstructed object) [required]
- # - .sino (vectorized in 2D or 3D sinogram) [required]
- # - .iterFISTA (iterations for the main loop, default 40)
- # - .L_const (Lipschitz constant, default Power method) )
- # - .X_ideal (ideal image, if given)
- # - .weights (statisitcal weights, size of the sinogram)
- # - .ROI (Region-of-interest, only if X_ideal is given)
- # - .initialize (a 'warm start' using SIRT method from ASTRA)
- #----------------Regularization choices------------------------
- # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter)
- # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter)
- # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter)
- # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04)
- # - .Regul_Iterations (iterations for the selected penalty, default 25)
- # - .Regul_tauLLT (time step parameter for LLT term)
- # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal)
- # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1)
- #----------------Visualization parameters------------------------
- # - .show (visualize reconstruction 1/0, (0 default))
- # - .maxvalplot (maximum value to use for imshow[0 maxvalplot])
- # - .slice (for 3D volumes - slice number to imshow)
- # ___Output___:
- # 1. X - reconstructed image/volume
- # 2. output - a structure with
- # - .Resid_error - residual error (if X_ideal is given)
- # - .objective: value of the objective function
- # - .L_const: Lipshitz constant to avoid recalculations
-
- # References:
- # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse
- # Problems" by A. Beck and M Teboulle
- # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo
- # 3. "A novel tomographic reconstruction method based on the robust
- # Student's t function for suppressing data outliers" D. Kazantsev et.al.
- # D. Kazantsev, 2016-17
- def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs):
- self.params = dict()
- self.params['projector_geometry'] = projector_geometry
- self.params['output_geometry'] = output_geometry
- self.params['input_sinogram'] = input_sinogram
- detectors, nangles, sliceZ = numpy.shape(input_sinogram)
- self.params['detectors'] = detectors
- self.params['number_og_angles'] = nangles
- self.params['SlicesZ'] = sliceZ
-
- # Accepted input keywords
- kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' ,
- 'weights' , 'region_of_interest' , 'initialize' ,
- 'regularizer' ,
- 'ring_lambda_R_L1',
- 'ring_alpha')
-
- # handle keyworded parameters
- if kwargs is not None:
- for key, value in kwargs.items():
- if key in kw:
- #print("{0} = {1}".format(key, value))
- self.pars[key] = value
-
- # set the default values for the parameters if not set
- if 'number_of_iterations' in kwargs.keys():
- self.pars['number_of_iterations'] = kwargs['number_of_iterations']
- else:
- self.pars['number_of_iterations'] = 40
- if 'weights' in kwargs.keys():
- self.pars['weights'] = kwargs['weights']
- else:
- self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram']))
- if 'Lipschitz_constant' in kwargs.keys():
- self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant']
- else:
- self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod()
-
- if not self.pars['ideal_image'] in kwargs.keys():
- self.pars['ideal_image'] = None
-
- if not self.pars['region_of_interest'] :
- if self.pars['ideal_image'] == None:
- pass
- else:
- self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0)
-
- if not self.pars['regularizer'] :
- self.pars['regularizer'] = None
- else:
- # the regularizer must be a correctly instantiated object
- if not self.pars['ring_lambda_R_L1']:
- self.pars['ring_lambda_R_L1'] = 0
- if not self.pars['ring_alpha']:
- self.pars['ring_alpha'] = 1
-
-
-
-
- def calculateLipschitzConstantWithPowerMethod(self):
- ''' using Power method (PM) to establish L constant'''
-
- #N = params.vol_geom.GridColCount
- N = self.pars['output_geometry'].GridColCount
- proj_geom = self.params['projector_geometry']
- vol_geom = self.params['output_geometry']
- weights = self.pars['weights']
- SlicesZ = self.pars['SlicesZ']
-
- if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'):
- #% for parallel geometry we can do just one slice
- #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...');
- niter = 15;# % number of iteration for the PM
- #N = params.vol_geom.GridColCount;
- #x1 = rand(N,N,1);
- x1 = numpy.random.rand(1,N,N)
- #sqweight = sqrt(weights(:,:,1));
- sqweight = numpy.sqrt(weights.T[0])
- proj_geomT = proj_geom.copy();
- proj_geomT.DetectorRowCount = 1;
- vol_geomT = vol_geom.copy();
- vol_geomT['GridSliceCount'] = 1;
-
-
- for i in range(niter):
- if i == 0:
- #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
- sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT);
- y = sqweight * y # element wise multiplication
- #astra_mex_data3d('delete', sino_id);
- astra.matlab.data3d('delete', sino_id)
-
- idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT);
- s = numpy.linalg.norm(x1)
- ### this line?
- x1 = x1/s;
- ### this line?
- sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
- y = sqweight*y;
- astra.matlab.data3d('delete', sino_id);
- astra.matlab.data3d('delete', idx);
- #end
- del proj_geomT
- del vol_geomT
- else
- #% divergen beam geometry
- #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...');
- niter = 8; #% number of iteration for PM
- x1 = numpy.random.rand(SlicesZ , N , N);
- #sqweight = sqrt(weights);
- sqweight = numpy.sqrt(weights.T[0])
-
- sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom);
- y = sqweight*y;
- #astra_mex_data3d('delete', sino_id);
- astra.matlab.data3d('delete', sino_id);
-
- for i in range(niter):
- #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom);
- idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y,
- proj_geom,
- vol_geom)
- s = numpy.linalg.norm(x1)
- ### this line?
- x1 = x1/s;
- ### this line?
- #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom);
- sino_id, y = astra.creators.create_sino3d_gpu(x1,
- proj_geom,
- vol_geom);
-
- y = sqweight*y;
- #astra_mex_data3d('delete', sino_id);
- #astra_mex_data3d('delete', id);
- astra.matlab.data3d('delete', sino_id);
- astra.matlab.data3d('delete', idx);
- #end
- #clear x1
- del x1
-
- return s
-
-
- def setRegularizer(self, regularizer):
- if regularizer
- self.pars['regularizer'] = regularizer
-
-
-
-
-
-def getEntry(location):
- for item in nx[location].keys():
- print (item)
-
-
-print ("Loading Data")
-
-##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif"
-####ind = [i * 1049 for i in range(360)]
-#### use only 360 images
-##images = 200
-##ind = [int(i * 1049 / images) for i in range(images)]
-##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None)
-
-#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs"
-fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs"
-nx = h5py.File(fname, "r")
-
-# the data are stored in a particular location in the hdf5
-for item in nx['entry1/tomo_entry/data'].keys():
- print (item)
-
-data = nx.get('entry1/tomo_entry/data/rotation_angle')
-angles = numpy.zeros(data.shape)
-data.read_direct(angles)
-print (angles)
-# angles should be in degrees
-
-data = nx.get('entry1/tomo_entry/data/data')
-stack = numpy.zeros(data.shape)
-data.read_direct(stack)
-print (data.shape)
-
-print ("Data Loaded")
-
-
-# Normalize
-data = nx.get('entry1/tomo_entry/instrument/detector/image_key')
-itype = numpy.zeros(data.shape)
-data.read_direct(itype)
-# 2 is dark field
-darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ]
-dark = darks[0]
-for i in range(1, len(darks)):
- dark += darks[i]
-dark = dark / len(darks)
-#dark[0][0] = dark[0][1]
-
-# 1 is flat field
-flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ]
-flat = flats[0]
-for i in range(1, len(flats)):
- flat += flats[i]
-flat = flat / len(flats)
-#flat[0][0] = dark[0][1]
-
-
-# 0 is projection data
-proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ]
-angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ]
-angle_proj = numpy.asarray (angle_proj)
-angle_proj = angle_proj.astype(numpy.float32)
-
-# normalized data are
-# norm = (projection - dark)/(flat-dark)
-
-def normalize(projection, dark, flat, def_val=0.1):
- a = (projection - dark)
- b = (flat-dark)
- with numpy.errstate(divide='ignore', invalid='ignore'):
- c = numpy.true_divide( a, b )
- c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0
- return c
-
-
-norm = [normalize(projection, dark, flat) for projection in proj]
-norm = numpy.asarray (norm)
-norm = norm.astype(numpy.float32)
-
-#recon = Reconstructor(algorithm = Algorithm.CGLS, normalized_projection = norm,
-# angles = angle_proj, center_of_rotation = 86.2 ,
-# flat_field = flat, dark_field = dark,
-# iterations = 15, resolution = 1, isLogScale = False, threads = 3)
-
-#recon = Reconstructor(algorithm = Reconstructor.Algorithm.CGLS, projection_data = proj,
-# angles = angle_proj, center_of_rotation = 86.2 ,
-# flat_field = flat, dark_field = dark,
-# iterations = 15, resolution = 1, isLogScale = False, threads = 3)
-#img_cgls = recon.reconstruct()
-#
-#pars = dict()
-#pars['algorithm'] = Reconstructor.Algorithm.SIRT
-#pars['projection_data'] = proj
-#pars['angles'] = angle_proj
-#pars['center_of_rotation'] = numpy.double(86.2)
-#pars['flat_field'] = flat
-#pars['iterations'] = 15
-#pars['dark_field'] = dark
-#pars['resolution'] = 1
-#pars['isLogScale'] = False
-#pars['threads'] = 3
-#
-#img_sirt = recon.reconstruct(pars)
-#
-#recon.pars['algorithm'] = Reconstructor.Algorithm.MLEM
-#img_mlem = recon.reconstruct()
-
-############################################################
-############################################################
-#recon.pars['algorithm'] = Reconstructor.Algorithm.CGLS_CONV
-#recon.pars['regularize'] = numpy.double(0.1)
-#img_cgls_conv = recon.reconstruct()
-
-niterations = 15
-threads = 3
-
-img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
-img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
-img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
-
-iteration_values = numpy.zeros((niterations,))
-img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
- iteration_values, False)
-print ("iteration values %s" % str(iteration_values))
-
-iteration_values = numpy.zeros((niterations,))
-img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
- numpy.double(1e-5), iteration_values , False)
-print ("iteration values %s" % str(iteration_values))
-iteration_values = numpy.zeros((niterations,))
-img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
- numpy.double(1e-5), iteration_values , False)
-print ("iteration values %s" % str(iteration_values))
-
-
-##numpy.save("cgls_recon.npy", img_data)
-import matplotlib.pyplot as plt
-fig, ax = plt.subplots(1,6,sharey=True)
-ax[0].imshow(img_cgls[80])
-ax[0].axis('off') # clear x- and y-axes
-ax[1].imshow(img_sirt[80])
-ax[1].axis('off') # clear x- and y-axes
-ax[2].imshow(img_mlem[80])
-ax[2].axis('off') # clear x- and y-axesplt.show()
-ax[3].imshow(img_cgls_conv[80])
-ax[3].axis('off') # clear x- and y-axesplt.show()
-ax[4].imshow(img_cgls_tikhonov[80])
-ax[4].axis('off') # clear x- and y-axesplt.show()
-ax[5].imshow(img_cgls_TVreg[80])
-ax[5].axis('off') # clear x- and y-axesplt.show()
-
-
-plt.show()
-
-#viewer = edo.CILViewer()
-#viewer.setInputAsNumpy(img_cgls2)
-#viewer.displaySliceActor(0)
-#viewer.startRenderLoop()
-
-import vtk
-
-def NumpyToVTKImageData(numpyarray):
- if (len(numpy.shape(numpyarray)) == 3):
- doubleImg = vtk.vtkImageData()
- shape = numpy.shape(numpyarray)
- doubleImg.SetDimensions(shape[0], shape[1], shape[2])
- doubleImg.SetOrigin(0,0,0)
- doubleImg.SetSpacing(1,1,1)
- doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1)
- #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation())
- doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1)
-
- for i in range(shape[0]):
- for j in range(shape[1]):
- for k in range(shape[2]):
- doubleImg.SetScalarComponentFromDouble(
- i,j,k,0, numpyarray[i][j][k])
- #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) )
- # rescale to appropriate VTK_UNSIGNED_SHORT
- stats = vtk.vtkImageAccumulate()
- stats.SetInputData(doubleImg)
- stats.Update()
- iMin = stats.GetMin()[0]
- iMax = stats.GetMax()[0]
- scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin)
-
- shiftScaler = vtk.vtkImageShiftScale ()
- shiftScaler.SetInputData(doubleImg)
- shiftScaler.SetScale(scale)
- shiftScaler.SetShift(iMin)
- shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT)
- shiftScaler.Update()
- return shiftScaler.GetOutput()
-
-#writer = vtk.vtkMetaImageWriter()
-#writer.SetFileName(alg + "_recon.mha")
-#writer.SetInputData(NumpyToVTKImageData(img_cgls2))
-#writer.Write()
diff --git a/src/Python/ccpi/fista/__init__.py b/src/Python/ccpi/fista/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/src/Python/ccpi/fista/__init__.py
+++ /dev/null