From 73fed4964d81f1f47a0b6ecbe66517f569327b27 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 14:31:16 +0100 Subject: initial commit of Reconstructor.py --- src/Python/ccpi/reconstruction/Reconstructor.py | 598 ++++++++++++++++++++++++ 1 file changed, 598 insertions(+) create mode 100644 src/Python/ccpi/reconstruction/Reconstructor.py (limited to 'src/Python/ccpi') diff --git a/src/Python/ccpi/reconstruction/Reconstructor.py b/src/Python/ccpi/reconstruction/Reconstructor.py new file mode 100644 index 0000000..ba67327 --- /dev/null +++ b/src/Python/ccpi/reconstruction/Reconstructor.py @@ -0,0 +1,598 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +import h5py +from ccpi.reconstruction.parallelbeam import alg + +from Regularizer import Regularizer +from enum import Enum + +import astra + + +class Reconstructor: + + class Algorithm(Enum): + CGLS = alg.cgls + CGLS_CONV = alg.cgls_conv + SIRT = alg.sirt + MLEM = alg.mlem + CGLS_TICHONOV = alg.cgls_tikhonov + CGLS_TVREG = alg.cgls_TVreg + FISTA = 'fista' + + def __init__(self, algorithm = None, projection_data = None, + angles = None, center_of_rotation = None , + flat_field = None, dark_field = None, + iterations = None, resolution = None, isLogScale = False, threads = None, + normalized_projection = None): + + self.pars = dict() + self.pars['algorithm'] = algorithm + self.pars['projection_data'] = projection_data + self.pars['normalized_projection'] = normalized_projection + self.pars['angles'] = angles + self.pars['center_of_rotation'] = numpy.double(center_of_rotation) + self.pars['flat_field'] = flat_field + self.pars['iterations'] = iterations + self.pars['dark_field'] = dark_field + self.pars['resolution'] = resolution + self.pars['isLogScale'] = isLogScale + self.pars['threads'] = threads + if (iterations != None): + self.pars['iterationValues'] = numpy.zeros((iterations)) + + if projection_data != None and dark_field != None and flat_field != None: + norm = self.normalize(projection_data, dark_field, flat_field, 0.1) + self.pars['normalized_projection'] = norm + + + def setPars(self, parameters): + keys = ['algorithm','projection_data' ,'normalized_projection', \ + 'angles' , 'center_of_rotation' , 'flat_field', \ + 'iterations','dark_field' , 'resolution', 'isLogScale' , \ + 'threads' , 'iterationValues', 'regularize'] + + for k in keys: + if k not in parameters.keys(): + self.pars[k] = None + else: + self.pars[k] = parameters[k] + + + def sanityCheck(self): + projection_data = self.pars['projection_data'] + dark_field = self.pars['dark_field'] + flat_field = self.pars['flat_field'] + angles = self.pars['angles'] + + if projection_data != None and dark_field != None and \ + angles != None and flat_field != None: + data_shape = numpy.shape(projection_data) + angle_shape = numpy.shape(angles) + + if angle_shape[0] != data_shape[0]: + #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ + # (angle_shape[0] , data_shape[0]) ) + return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ + (angle_shape[0] , data_shape[0]) ) + + if data_shape[1:] != numpy.shape(flat_field): + #raise Exception('Projection and flat field dimensions do not match') + return (False , 'Projection and flat field dimensions do not match') + if data_shape[1:] != numpy.shape(dark_field): + #raise Exception('Projection and dark field dimensions do not match') + return (False , 'Projection and dark field dimensions do not match') + + return (True , '' ) + elif self.pars['normalized_projection'] != None: + data_shape = numpy.shape(self.pars['normalized_projection']) + angle_shape = numpy.shape(angles) + + if angle_shape[0] != data_shape[0]: + #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ + # (angle_shape[0] , data_shape[0]) ) + return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ + (angle_shape[0] , data_shape[0]) ) + else: + return (True , '' ) + else: + return (False , 'Not enough data') + + def reconstruct(self, parameters = None): + if parameters != None: + self.setPars(parameters) + + go , reason = self.sanityCheck() + if go: + return self._reconstruct() + else: + raise Exception(reason) + + + def _reconstruct(self, parameters=None): + if parameters!=None: + self.setPars(parameters) + parameters = self.pars + + if parameters['algorithm'] != None and \ + parameters['normalized_projection'] != None and \ + parameters['angles'] != None and \ + parameters['center_of_rotation'] != None and \ + parameters['iterations'] != None and \ + parameters['resolution'] != None and\ + parameters['threads'] != None and\ + parameters['isLogScale'] != None: + + + if parameters['algorithm'] in (Reconstructor.Algorithm.CGLS, + Reconstructor.Algorithm.MLEM, Reconstructor.Algorithm.SIRT): + #store parameters + self.pars = parameters + result = parameters['algorithm']( + parameters['normalized_projection'] , + parameters['angles'], + parameters['center_of_rotation'], + parameters['resolution'], + parameters['iterations'], + parameters['threads'] , + parameters['isLogScale'] + ) + return result + elif parameters['algorithm'] in (Reconstructor.Algorithm.CGLS_CONV, + Reconstructor.Algorithm.CGLS_TICHONOV, + Reconstructor.Algorithm.CGLS_TVREG) : + self.pars = parameters + result = parameters['algorithm']( + parameters['normalized_projection'] , + parameters['angles'], + parameters['center_of_rotation'], + parameters['resolution'], + parameters['iterations'], + parameters['threads'] , + parameters['regularize'], + numpy.zeros((parameters['iterations'])), + parameters['isLogScale'] + ) + + elif parameters['algorithm'] == Reconstructor.Algorithm.FISTA: + pass + + else: + if parameters['projection_data'] != None and \ + parameters['dark_field'] != None and \ + parameters['flat_field'] != None: + norm = self.normalize(parameters['projection_data'], + parameters['dark_field'], + parameters['flat_field'], 0.1) + self.pars['normalized_projection'] = norm + return self._reconstruct(parameters) + + + + def _normalize(self, projection, dark, flat, def_val=0): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + def normalize(self, projections, dark, flat, def_val=0): + norm = [self._normalize(projection, dark, flat, def_val) for projection in projections] + return numpy.asarray (norm, dtype=numpy.float32) + + + +class FISTA(): + '''FISTA-based reconstruction algorithm using ASTRA-toolbox + + ''' + # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> + # ___Input___: + # params.[] file: + # - .proj_geom (geometry of the projector) [required] + # - .vol_geom (geometry of the reconstructed object) [required] + # - .sino (vectorized in 2D or 3D sinogram) [required] + # - .iterFISTA (iterations for the main loop, default 40) + # - .L_const (Lipschitz constant, default Power method) ) + # - .X_ideal (ideal image, if given) + # - .weights (statisitcal weights, size of the sinogram) + # - .ROI (Region-of-interest, only if X_ideal is given) + # - .initialize (a 'warm start' using SIRT method from ASTRA) + #----------------Regularization choices------------------------ + # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) + # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) + # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) + # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) + # - .Regul_Iterations (iterations for the selected penalty, default 25) + # - .Regul_tauLLT (time step parameter for LLT term) + # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) + # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) + #----------------Visualization parameters------------------------ + # - .show (visualize reconstruction 1/0, (0 default)) + # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) + # - .slice (for 3D volumes - slice number to imshow) + # ___Output___: + # 1. X - reconstructed image/volume + # 2. output - a structure with + # - .Resid_error - residual error (if X_ideal is given) + # - .objective: value of the objective function + # - .L_const: Lipshitz constant to avoid recalculations + + # References: + # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse + # Problems" by A. Beck and M Teboulle + # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo + # 3. "A novel tomographic reconstruction method based on the robust + # Student's t function for suppressing data outliers" D. Kazantsev et.al. + # D. Kazantsev, 2016-17 + def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + self.params = dict() + self.params['projector_geometry'] = projector_geometry + self.params['output_geometry'] = output_geometry + self.params['input_sinogram'] = input_sinogram + detectors, nangles, sliceZ = numpy.shape(input_sinogram) + self.params['detectors'] = detectors + self.params['number_og_angles'] = nangles + self.params['SlicesZ'] = sliceZ + + # Accepted input keywords + kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , + 'weights' , 'region_of_interest' , 'initialize' , + 'regularizer' , + 'ring_lambda_R_L1', + 'ring_alpha') + + # handle keyworded parameters + if kwargs is not None: + for key, value in kwargs.items(): + if key in kw: + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + # set the default values for the parameters if not set + if 'number_of_iterations' in kwargs.keys(): + self.pars['number_of_iterations'] = kwargs['number_of_iterations'] + else: + self.pars['number_of_iterations'] = 40 + if 'weights' in kwargs.keys(): + self.pars['weights'] = kwargs['weights'] + else: + self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) + if 'Lipschitz_constant' in kwargs.keys(): + self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] + else: + self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + + if not self.pars['ideal_image'] in kwargs.keys(): + self.pars['ideal_image'] = None + + if not self.pars['region_of_interest'] : + if self.pars['ideal_image'] == None: + pass + else: + self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + + if not self.pars['regularizer'] : + self.pars['regularizer'] = None + else: + # the regularizer must be a correctly instantiated object + if not self.pars['ring_lambda_R_L1']: + self.pars['ring_lambda_R_L1'] = 0 + if not self.pars['ring_alpha']: + self.pars['ring_alpha'] = 1 + + + + + def calculateLipschitzConstantWithPowerMethod(self): + ''' using Power method (PM) to establish L constant''' + + #N = params.vol_geom.GridColCount + N = self.pars['output_geometry'].GridColCount + proj_geom = self.params['projector_geometry'] + vol_geom = self.params['output_geometry'] + weights = self.pars['weights'] + SlicesZ = self.pars['SlicesZ'] + + if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); + niter = 15;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights.T[0]) + proj_geomT = proj_geom.copy(); + proj_geomT.DetectorRowCount = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + + for i in range(niter): + if i == 0: + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); + y = sqweight * y # element wise multiplication + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + y = sqweight*y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + del proj_geomT + del vol_geomT + else + #% divergen beam geometry + #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights.T[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 + + return s + + + def setRegularizer(self, regularizer): + if regularizer + self.pars['regularizer'] = regularizer + + + + + +def getEntry(location): + for item in nx[location].keys(): + print (item) + + +print ("Loading Data") + +##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" +####ind = [i * 1049 for i in range(360)] +#### use only 360 images +##images = 200 +##ind = [int(i * 1049 / images) for i in range(images)] +##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) + +#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" +fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" +nx = h5py.File(fname, "r") + +# the data are stored in a particular location in the hdf5 +for item in nx['entry1/tomo_entry/data'].keys(): + print (item) + +data = nx.get('entry1/tomo_entry/data/rotation_angle') +angles = numpy.zeros(data.shape) +data.read_direct(angles) +print (angles) +# angles should be in degrees + +data = nx.get('entry1/tomo_entry/data/data') +stack = numpy.zeros(data.shape) +data.read_direct(stack) +print (data.shape) + +print ("Data Loaded") + + +# Normalize +data = nx.get('entry1/tomo_entry/instrument/detector/image_key') +itype = numpy.zeros(data.shape) +data.read_direct(itype) +# 2 is dark field +darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] +dark = darks[0] +for i in range(1, len(darks)): + dark += darks[i] +dark = dark / len(darks) +#dark[0][0] = dark[0][1] + +# 1 is flat field +flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] +flat = flats[0] +for i in range(1, len(flats)): + flat += flats[i] +flat = flat / len(flats) +#flat[0][0] = dark[0][1] + + +# 0 is projection data +proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = numpy.asarray (angle_proj) +angle_proj = angle_proj.astype(numpy.float32) + +# normalized data are +# norm = (projection - dark)/(flat-dark) + +def normalize(projection, dark, flat, def_val=0.1): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + +norm = [normalize(projection, dark, flat) for projection in proj] +norm = numpy.asarray (norm) +norm = norm.astype(numpy.float32) + +#recon = Reconstructor(algorithm = Algorithm.CGLS, normalized_projection = norm, +# angles = angle_proj, center_of_rotation = 86.2 , +# flat_field = flat, dark_field = dark, +# iterations = 15, resolution = 1, isLogScale = False, threads = 3) + +#recon = Reconstructor(algorithm = Reconstructor.Algorithm.CGLS, projection_data = proj, +# angles = angle_proj, center_of_rotation = 86.2 , +# flat_field = flat, dark_field = dark, +# iterations = 15, resolution = 1, isLogScale = False, threads = 3) +#img_cgls = recon.reconstruct() +# +#pars = dict() +#pars['algorithm'] = Reconstructor.Algorithm.SIRT +#pars['projection_data'] = proj +#pars['angles'] = angle_proj +#pars['center_of_rotation'] = numpy.double(86.2) +#pars['flat_field'] = flat +#pars['iterations'] = 15 +#pars['dark_field'] = dark +#pars['resolution'] = 1 +#pars['isLogScale'] = False +#pars['threads'] = 3 +# +#img_sirt = recon.reconstruct(pars) +# +#recon.pars['algorithm'] = Reconstructor.Algorithm.MLEM +#img_mlem = recon.reconstruct() + +############################################################ +############################################################ +#recon.pars['algorithm'] = Reconstructor.Algorithm.CGLS_CONV +#recon.pars['regularize'] = numpy.double(0.1) +#img_cgls_conv = recon.reconstruct() + +niterations = 15 +threads = 3 + +img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + iteration_values, False) +print ("iteration values %s" % str(iteration_values)) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) +iteration_values = numpy.zeros((niterations,)) +img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) + + +##numpy.save("cgls_recon.npy", img_data) +import matplotlib.pyplot as plt +fig, ax = plt.subplots(1,6,sharey=True) +ax[0].imshow(img_cgls[80]) +ax[0].axis('off') # clear x- and y-axes +ax[1].imshow(img_sirt[80]) +ax[1].axis('off') # clear x- and y-axes +ax[2].imshow(img_mlem[80]) +ax[2].axis('off') # clear x- and y-axesplt.show() +ax[3].imshow(img_cgls_conv[80]) +ax[3].axis('off') # clear x- and y-axesplt.show() +ax[4].imshow(img_cgls_tikhonov[80]) +ax[4].axis('off') # clear x- and y-axesplt.show() +ax[5].imshow(img_cgls_TVreg[80]) +ax[5].axis('off') # clear x- and y-axesplt.show() + + +plt.show() + +#viewer = edo.CILViewer() +#viewer.setInputAsNumpy(img_cgls2) +#viewer.displaySliceActor(0) +#viewer.startRenderLoop() + +import vtk + +def NumpyToVTKImageData(numpyarray): + if (len(numpy.shape(numpyarray)) == 3): + doubleImg = vtk.vtkImageData() + shape = numpy.shape(numpyarray) + doubleImg.SetDimensions(shape[0], shape[1], shape[2]) + doubleImg.SetOrigin(0,0,0) + doubleImg.SetSpacing(1,1,1) + doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) + #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) + doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) + + for i in range(shape[0]): + for j in range(shape[1]): + for k in range(shape[2]): + doubleImg.SetScalarComponentFromDouble( + i,j,k,0, numpyarray[i][j][k]) + #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) + # rescale to appropriate VTK_UNSIGNED_SHORT + stats = vtk.vtkImageAccumulate() + stats.SetInputData(doubleImg) + stats.Update() + iMin = stats.GetMin()[0] + iMax = stats.GetMax()[0] + scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) + + shiftScaler = vtk.vtkImageShiftScale () + shiftScaler.SetInputData(doubleImg) + shiftScaler.SetScale(scale) + shiftScaler.SetShift(iMin) + shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) + shiftScaler.Update() + return shiftScaler.GetOutput() + +#writer = vtk.vtkMetaImageWriter() +#writer.SetFileName(alg + "_recon.mha") +#writer.SetInputData(NumpyToVTKImageData(img_cgls2)) +#writer.Write() -- cgit v1.2.3 From 105b57a1e98c2bb7b3bf94c43b6c669925ebb1b9 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 15:01:28 +0100 Subject: added viewer for testing --- src/Python/ccpi/viewer/CILViewer.py | 361 +++++++ src/Python/ccpi/viewer/CILViewer2D.py | 1126 ++++++++++++++++++++ src/Python/ccpi/viewer/QVTKWidget.py | 340 ++++++ src/Python/ccpi/viewer/QVTKWidget2.py | 84 ++ src/Python/ccpi/viewer/__init__.py | 1 + .../viewer/__pycache__/CILViewer.cpython-35.pyc | Bin 0 -> 10542 bytes .../viewer/__pycache__/CILViewer2D.cpython-35.pyc | Bin 0 -> 35633 bytes .../viewer/__pycache__/QVTKWidget.cpython-35.pyc | Bin 0 -> 10099 bytes .../viewer/__pycache__/QVTKWidget2.cpython-35.pyc | Bin 0 -> 1316 bytes .../viewer/__pycache__/__init__.cpython-35.pyc | Bin 0 -> 210 bytes src/Python/ccpi/viewer/embedvtk.py | 75 ++ 11 files changed, 1987 insertions(+) create mode 100644 src/Python/ccpi/viewer/CILViewer.py create mode 100644 src/Python/ccpi/viewer/CILViewer2D.py create mode 100644 src/Python/ccpi/viewer/QVTKWidget.py create mode 100644 src/Python/ccpi/viewer/QVTKWidget2.py create mode 100644 src/Python/ccpi/viewer/__init__.py create mode 100644 src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc create mode 100644 src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc create mode 100644 src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc create mode 100644 src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc create mode 100644 src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc create mode 100644 src/Python/ccpi/viewer/embedvtk.py (limited to 'src/Python/ccpi') diff --git a/src/Python/ccpi/viewer/CILViewer.py b/src/Python/ccpi/viewer/CILViewer.py new file mode 100644 index 0000000..efcf8be --- /dev/null +++ b/src/Python/ccpi/viewer/CILViewer.py @@ -0,0 +1,361 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Edoardo Pasca +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import vtk +import numpy +import math +from vtk.util import numpy_support + +SLICE_ORIENTATION_XY = 2 # Z +SLICE_ORIENTATION_XZ = 1 # Y +SLICE_ORIENTATION_YZ = 0 # X + + + +class CILViewer(): + '''Simple 3D Viewer based on VTK classes''' + + def __init__(self, dimx=600,dimy=600): + '''creates the rendering pipeline''' + + # create a rendering window and renderer + self.ren = vtk.vtkRenderer() + self.renWin = vtk.vtkRenderWindow() + self.renWin.SetSize(dimx,dimy) + self.renWin.AddRenderer(self.ren) + + # img 3D as slice + self.img3D = None + self.sliceno = 0 + self.sliceOrientation = SLICE_ORIENTATION_XY + self.sliceActor = None + self.voi = None + self.wl = None + self.ia = None + self.sliceActorNo = 0 + # create a renderwindowinteractor + self.iren = vtk.vtkRenderWindowInteractor() + self.iren.SetRenderWindow(self.renWin) + + self.style = vtk.vtkInteractorStyleTrackballCamera() + self.iren.SetInteractorStyle(self.style) + + self.ren.SetBackground(.1, .2, .4) + + self.actors = {} + self.iren.RemoveObservers('MouseWheelForwardEvent') + self.iren.RemoveObservers('MouseWheelBackwardEvent') + + self.iren.AddObserver('MouseWheelForwardEvent', self.mouseInteraction, 1.0) + self.iren.AddObserver('MouseWheelBackwardEvent', self.mouseInteraction, 1.0) + + self.iren.RemoveObservers('KeyPressEvent') + self.iren.AddObserver('KeyPressEvent', self.keyPress, 1.0) + + + self.iren.Initialize() + + + + def getRenderer(self): + '''returns the renderer''' + return self.ren + + def getRenderWindow(self): + '''returns the render window''' + return self.renWin + + def getInteractor(self): + '''returns the render window interactor''' + return self.iren + + def getCamera(self): + '''returns the active camera''' + return self.ren.GetActiveCamera() + + def createPolyDataActor(self, polydata): + '''returns an actor for a given polydata''' + mapper = vtk.vtkPolyDataMapper() + if vtk.VTK_MAJOR_VERSION <= 5: + mapper.SetInput(polydata) + else: + mapper.SetInputData(polydata) + + # actor + actor = vtk.vtkActor() + actor.SetMapper(mapper) + #actor.GetProperty().SetOpacity(0.8) + return actor + + def setPolyDataActor(self, actor): + '''displays the given polydata''' + + self.ren.AddActor(actor) + + self.actors[len(self.actors)+1] = [actor, True] + self.iren.Initialize() + self.renWin.Render() + + def displayPolyData(self, polydata): + self.setPolyDataActor(self.createPolyDataActor(polydata)) + + def hideActor(self, actorno): + '''Hides an actor identified by its number in the list of actors''' + try: + if self.actors[actorno][1]: + self.ren.RemoveActor(self.actors[actorno][0]) + self.actors[actorno][1] = False + except KeyError as ke: + print ("Warning Actor not present") + + def showActor(self, actorno, actor = None): + '''Shows hidden actor identified by its number in the list of actors''' + try: + if not self.actors[actorno][1]: + self.ren.AddActor(self.actors[actorno][0]) + self.actors[actorno][1] = True + return actorno + except KeyError as ke: + # adds it to the actors if not there already + if actor != None: + self.ren.AddActor(actor) + self.actors[len(self.actors)+1] = [actor, True] + return len(self.actors) + + def addActor(self, actor): + '''Adds an actor to the render''' + return self.showActor(0, actor) + + + def saveRender(self, filename, renWin=None): + '''Save the render window to PNG file''' + # screenshot code: + w2if = vtk.vtkWindowToImageFilter() + if renWin == None: + renWin = self.renWin + w2if.SetInput(renWin) + w2if.Update() + + writer = vtk.vtkPNGWriter() + writer.SetFileName("%s.png" % (filename)) + writer.SetInputConnection(w2if.GetOutputPort()) + writer.Write() + + + def startRenderLoop(self): + self.iren.Start() + + + def setupObservers(self, interactor): + interactor.RemoveObservers('LeftButtonPressEvent') + interactor.AddObserver('LeftButtonPressEvent', self.mouseInteraction) + interactor.Initialize() + + + def mouseInteraction(self, interactor, event): + if event == 'MouseWheelForwardEvent': + maxSlice = self.img3D.GetDimensions()[self.sliceOrientation] + if (self.sliceno + 1 < maxSlice): + self.hideActor(self.sliceActorNo) + self.sliceno = self.sliceno + 1 + self.displaySliceActor(self.sliceno) + else: + minSlice = 0 + if (self.sliceno - 1 > minSlice): + self.hideActor(self.sliceActorNo) + self.sliceno = self.sliceno - 1 + self.displaySliceActor(self.sliceno) + + + def keyPress(self, interactor, event): + #print ("Pressed key %s" % interactor.GetKeyCode()) + # Slice Orientation + if interactor.GetKeyCode() == "x": + # slice on the other orientation + self.sliceOrientation = SLICE_ORIENTATION_YZ + self.sliceno = int(self.img3D.GetDimensions()[1] / 2) + self.hideActor(self.sliceActorNo) + self.displaySliceActor(self.sliceno) + elif interactor.GetKeyCode() == "y": + # slice on the other orientation + self.sliceOrientation = SLICE_ORIENTATION_XZ + self.sliceno = int(self.img3D.GetDimensions()[1] / 2) + self.hideActor(self.sliceActorNo) + self.displaySliceActor(self.sliceno) + elif interactor.GetKeyCode() == "z": + # slice on the other orientation + self.sliceOrientation = SLICE_ORIENTATION_XY + self.sliceno = int(self.img3D.GetDimensions()[2] / 2) + self.hideActor(self.sliceActorNo) + self.displaySliceActor(self.sliceno) + if interactor.GetKeyCode() == "X": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.ren.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.ren.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.ren.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_YZ] = math.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,0,-1) + self.ren.SetActiveCamera(camera) + self.ren.ResetCamera() + self.ren.Render() + interactor.SetKeyCode("x") + self.keyPress(interactor, event) + elif interactor.GetKeyCode() == "Y": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.ren.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.ren.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.ren.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_XZ] = math.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_YZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,0,-1) + self.ren.SetActiveCamera(camera) + self.ren.ResetCamera() + self.ren.Render() + interactor.SetKeyCode("y") + self.keyPress(interactor, event) + elif interactor.GetKeyCode() == "Z": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.ren.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.ren.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.ren.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_XY] = math.sqrt(newposition[SLICE_ORIENTATION_YZ] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,0,-1) + self.ren.SetActiveCamera(camera) + self.ren.ResetCamera() + self.ren.Render() + interactor.SetKeyCode("z") + self.keyPress(interactor, event) + else : + print ("Unhandled event %s" % interactor.GetKeyCode()) + + + + def setInput3DData(self, imageData): + self.img3D = imageData + + def setInputAsNumpy(self, numpyarray): + if (len(numpy.shape(numpyarray)) == 3): + doubleImg = vtk.vtkImageData() + shape = numpy.shape(numpyarray) + doubleImg.SetDimensions(shape[0], shape[1], shape[2]) + doubleImg.SetOrigin(0,0,0) + doubleImg.SetSpacing(1,1,1) + doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) + #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) + doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) + + for i in range(shape[0]): + for j in range(shape[1]): + for k in range(shape[2]): + doubleImg.SetScalarComponentFromDouble( + i,j,k,0, numpyarray[i][j][k]) + #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) + # rescale to appropriate VTK_UNSIGNED_SHORT + stats = vtk.vtkImageAccumulate() + stats.SetInputData(doubleImg) + stats.Update() + iMin = stats.GetMin()[0] + iMax = stats.GetMax()[0] + scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) + + shiftScaler = vtk.vtkImageShiftScale () + shiftScaler.SetInputData(doubleImg) + shiftScaler.SetScale(scale) + shiftScaler.SetShift(iMin) + shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) + shiftScaler.Update() + self.img3D = shiftScaler.GetOutput() + + def displaySliceActor(self, sliceno = 0): + self.sliceno = sliceno + first = False + + self.sliceActor , self.voi, self.wl , self.ia = \ + self.getSliceActor(self.img3D, + sliceno, + self.sliceActor, + self.voi, + self.wl, + self.ia) + no = self.showActor(self.sliceActorNo, self.sliceActor) + self.sliceActorNo = no + + self.iren.Initialize() + self.renWin.Render() + + return self.sliceActorNo + + + def getSliceActor(self, + imageData , + sliceno=0, + imageActor=None , + voi=None, + windowLevel=None, + imageAccumulate=None): + '''Slices a 3D volume and then creates an actor to be rendered''' + if (voi==None): + voi = vtk.vtkExtractVOI() + #voi = vtk.vtkImageClip() + voi.SetInputData(imageData) + #select one slice in Z + extent = [ i for i in self.img3D.GetExtent()] + extent[self.sliceOrientation * 2] = sliceno + extent[self.sliceOrientation * 2 + 1] = sliceno + voi.SetVOI(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + + voi.Update() + # set window/level for all slices + if imageAccumulate == None: + imageAccumulate = vtk.vtkImageAccumulate() + + if (windowLevel == None): + windowLevel = vtk.vtkImageMapToWindowLevelColors() + imageAccumulate.SetInputData(imageData) + imageAccumulate.Update() + cmax = imageAccumulate.GetMax()[0] + cmin = imageAccumulate.GetMin()[0] + windowLevel.SetLevel((cmax+cmin)/2) + windowLevel.SetWindow(cmax-cmin) + + windowLevel.SetInputData(voi.GetOutput()) + windowLevel.Update() + + if imageActor == None: + imageActor = vtk.vtkImageActor() + imageActor.SetInputData(windowLevel.GetOutput()) + imageActor.SetDisplayExtent(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + imageActor.Update() + return (imageActor , voi, windowLevel, imageAccumulate) + + + # Set interpolation on + def setInterpolateOn(self): + self.sliceActor.SetInterpolate(True) + self.renWin.Render() + + # Set interpolation off + def setInterpolateOff(self): + self.sliceActor.SetInterpolate(False) + self.renWin.Render() \ No newline at end of file diff --git a/src/Python/ccpi/viewer/CILViewer2D.py b/src/Python/ccpi/viewer/CILViewer2D.py new file mode 100644 index 0000000..c1629af --- /dev/null +++ b/src/Python/ccpi/viewer/CILViewer2D.py @@ -0,0 +1,1126 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Edoardo Pasca +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import vtk +import numpy +from vtk.util import numpy_support , vtkImageImportFromArray +from enum import Enum + +SLICE_ORIENTATION_XY = 2 # Z +SLICE_ORIENTATION_XZ = 1 # Y +SLICE_ORIENTATION_YZ = 0 # X + +CONTROL_KEY = 8 +SHIFT_KEY = 4 +ALT_KEY = -128 + + +# Converter class +class Converter(): + + # Utility functions to transform numpy arrays to vtkImageData and viceversa + @staticmethod + def numpy2vtkImporter(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): + '''Creates a vtkImageImportFromArray object and returns it. + + It handles the different axis order from numpy to VTK''' + importer = vtkImageImportFromArray.vtkImageImportFromArray() + importer.SetArray(numpy.transpose(nparray).copy()) + importer.SetDataSpacing(spacing) + importer.SetDataOrigin(origin) + return importer + + @staticmethod + def numpy2vtk(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): + '''Converts a 3D numpy array to a vtkImageData''' + importer = Converter.numpy2vtkImporter(nparray, spacing, origin) + importer.Update() + return importer.GetOutput() + + @staticmethod + def vtk2numpy(imgdata): + '''Converts the VTK data to 3D numpy array''' + img_data = numpy_support.vtk_to_numpy( + imgdata.GetPointData().GetScalars()) + + dims = imgdata.GetDimensions() + dims = (dims[2],dims[1],dims[0]) + data3d = numpy.reshape(img_data, dims) + + return numpy.transpose(data3d).copy() + + @staticmethod + def tiffStack2numpy(filename, indices, + extent = None , sampleRate = None ,\ + flatField = None, darkField = None): + '''Converts a stack of TIFF files to numpy array. + + filename must contain the whole path. The filename is supposed to be named and + have a suffix with the ordinal file number, i.e. /path/to/projection_%03d.tif + + indices are the suffix, generally an increasing number + + Optionally extracts only a selection of the 2D images and (optionally) + normalizes. + ''' + + stack = vtk.vtkImageData() + reader = vtk.vtkTIFFReader() + voi = vtk.vtkExtractVOI() + + #directory = "C:\\Users\\ofn77899\\Documents\\CCPi\\IMAT\\20170419_crabtomo\\crabtomo\\" + + stack_image = numpy.asarray([]) + nreduced = len(indices) + + for num in range(len(indices)): + fn = filename % indices[num] + print ("resampling %s" % ( fn ) ) + reader.SetFileName(fn) + reader.Update() + print (reader.GetOutput().GetScalarTypeAsString()) + if num == 0: + if (extent == None): + sliced = reader.GetOutput().GetExtent() + stack.SetExtent(sliced[0],sliced[1], sliced[2],sliced[3], 0, nreduced-1) + else: + sliced = extent + voi.SetVOI(extent) + + if sampleRate is not None: + voi.SetSampleRate(sampleRate) + ext = numpy.asarray([(sliced[2*i+1] - sliced[2*i])/sampleRate[i] for i in range(3)], dtype=int) + print ("ext {0}".format(ext)) + stack.SetExtent(0, ext[0] , 0, ext[1], 0, nreduced-1) + else: + stack.SetExtent(0, sliced[1] - sliced[0] , 0, sliced[3]-sliced[2], 0, nreduced-1) + if (flatField != None and darkField != None): + stack.AllocateScalars(vtk.VTK_FLOAT, 1) + else: + stack.AllocateScalars(reader.GetOutput().GetScalarType(), 1) + print ("Image Size: %d" % ((sliced[1]+1)*(sliced[3]+1) )) + stack_image = Converter.vtk2numpy(stack) + print ("Stack shape %s" % str(numpy.shape(stack_image))) + + if extent!=None: + voi.SetInputData(reader.GetOutput()) + voi.Update() + img = voi.GetOutput() + else: + img = reader.GetOutput() + + theSlice = Converter.vtk2numpy(img).T[0] + if darkField != None and flatField != None: + print("Try to normalize") + #if numpy.shape(darkField) == numpy.shape(flatField) and numpy.shape(flatField) == numpy.shape(theSlice): + theSlice = Converter.normalize(theSlice, darkField, flatField, 0.01) + print (theSlice.dtype) + + + print ("Slice shape %s" % str(numpy.shape(theSlice))) + stack_image.T[num] = theSlice.copy() + + return stack_image + + @staticmethod + def normalize(projection, dark, flat, def_val=0): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + + +## Utility functions to transform numpy arrays to vtkImageData and viceversa +#def numpy2vtkImporter(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): +# return Converter.numpy2vtkImporter(nparray, spacing, origin) +# +#def numpy2vtk(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): +# return Converter.numpy2vtk(nparray, spacing, origin) +# +#def vtk2numpy(imgdata): +# return Converter.vtk2numpy(imgdata) +# +#def tiffStack2numpy(filename, indices): +# return Converter.tiffStack2numpy(filename, indices) + +class ViewerEvent(Enum): + # left button + PICK_EVENT = 0 + # alt + right button + move + WINDOW_LEVEL_EVENT = 1 + # shift + right button + ZOOM_EVENT = 2 + # control + right button + PAN_EVENT = 3 + # control + left button + CREATE_ROI_EVENT = 4 + # alt + left button + DELETE_ROI_EVENT = 5 + # release button + NO_EVENT = -1 + + +#class CILInteractorStyle(vtk.vtkInteractorStyleTrackballCamera): +class CILInteractorStyle(vtk.vtkInteractorStyleImage): + + def __init__(self, callback): + vtk.vtkInteractorStyleImage.__init__(self) + self.callback = callback + self._viewer = callback + priority = 1.0 + +# self.AddObserver("MouseWheelForwardEvent" , callback.OnMouseWheelForward , priority) +# self.AddObserver("MouseWheelBackwardEvent" , callback.OnMouseWheelBackward, priority) +# self.AddObserver('KeyPressEvent', callback.OnKeyPress, priority) +# self.AddObserver('LeftButtonPressEvent', callback.OnLeftButtonPressEvent, priority) +# self.AddObserver('RightButtonPressEvent', callback.OnRightButtonPressEvent, priority) +# self.AddObserver('LeftButtonReleaseEvent', callback.OnLeftButtonReleaseEvent, priority) +# self.AddObserver('RightButtonReleaseEvent', callback.OnRightButtonReleaseEvent, priority) +# self.AddObserver('MouseMoveEvent', callback.OnMouseMoveEvent, priority) + + self.AddObserver("MouseWheelForwardEvent" , self.OnMouseWheelForward , priority) + self.AddObserver("MouseWheelBackwardEvent" , self.OnMouseWheelBackward, priority) + self.AddObserver('KeyPressEvent', self.OnKeyPress, priority) + self.AddObserver('LeftButtonPressEvent', self.OnLeftButtonPressEvent, priority) + self.AddObserver('RightButtonPressEvent', self.OnRightButtonPressEvent, priority) + self.AddObserver('LeftButtonReleaseEvent', self.OnLeftButtonReleaseEvent, priority) + self.AddObserver('RightButtonReleaseEvent', self.OnRightButtonReleaseEvent, priority) + self.AddObserver('MouseMoveEvent', self.OnMouseMoveEvent, priority) + + self.InitialEventPosition = (0,0) + + + def SetInitialEventPosition(self, xy): + self.InitialEventPosition = xy + + def GetInitialEventPosition(self): + return self.InitialEventPosition + + def GetKeyCode(self): + return self.GetInteractor().GetKeyCode() + + def SetKeyCode(self, keycode): + self.GetInteractor().SetKeyCode(keycode) + + def GetControlKey(self): + return self.GetInteractor().GetControlKey() == CONTROL_KEY + + def GetShiftKey(self): + return self.GetInteractor().GetShiftKey() == SHIFT_KEY + + def GetAltKey(self): + return self.GetInteractor().GetAltKey() == ALT_KEY + + def GetEventPosition(self): + return self.GetInteractor().GetEventPosition() + + def GetEventPositionInWorldCoordinates(self): + pass + + def GetDeltaEventPosition(self): + x,y = self.GetInteractor().GetEventPosition() + return (x - self.InitialEventPosition[0] , y - self.InitialEventPosition[1]) + + def Dolly(self, factor): + self.callback.camera.Dolly(factor) + self.callback.ren.ResetCameraClippingRange() + + def GetDimensions(self): + return self._viewer.img3D.GetDimensions() + + def GetInputData(self): + return self._viewer.img3D + + def GetSliceOrientation(self): + return self._viewer.sliceOrientation + + def SetSliceOrientation(self, orientation): + self._viewer.sliceOrientation = orientation + + def GetActiveSlice(self): + return self._viewer.sliceno + + def SetActiveSlice(self, sliceno): + self._viewer.sliceno = sliceno + + def UpdatePipeline(self, reset = False): + self._viewer.updatePipeline(reset) + + def GetActiveCamera(self): + return self._viewer.ren.GetActiveCamera() + + def SetActiveCamera(self, camera): + self._viewer.ren.SetActiveCamera(camera) + + def ResetCamera(self): + self._viewer.ren.ResetCamera() + + def Render(self): + self._viewer.renWin.Render() + + def UpdateSliceActor(self): + self._viewer.sliceActor.Update() + + def AdjustCamera(self): + self._viewer.AdjustCamera() + + def SaveRender(self, filename): + self._viewer.SaveRender(filename) + + def GetRenderWindow(self): + return self._viewer.renWin + + def GetRenderer(self): + return self._viewer.ren + + def GetROIWidget(self): + return self._viewer.ROIWidget + + def SetViewerEvent(self, event): + self._viewer.event = event + + def GetViewerEvent(self): + return self._viewer.event + + def SetInitialCameraPosition(self, position): + self._viewer.InitialCameraPosition = position + + def GetInitialCameraPosition(self): + return self._viewer.InitialCameraPosition + + def SetInitialLevel(self, level): + self._viewer.InitialLevel = level + + def GetInitialLevel(self): + return self._viewer.InitialLevel + + def SetInitialWindow(self, window): + self._viewer.InitialWindow = window + + def GetInitialWindow(self): + return self._viewer.InitialWindow + + def GetWindowLevel(self): + return self._viewer.wl + + def SetROI(self, roi): + self._viewer.ROI = roi + + def GetROI(self): + return self._viewer.ROI + + def UpdateCornerAnnotation(self, text, corner): + self._viewer.updateCornerAnnotation(text, corner) + + def GetPicker(self): + return self._viewer.picker + + def GetCornerAnnotation(self): + return self._viewer.cornerAnnotation + + def UpdateROIHistogram(self): + self._viewer.updateROIHistogram() + + + ############### Handle events + def OnMouseWheelForward(self, interactor, event): + maxSlice = self.GetDimensions()[self.GetSliceOrientation()] + shift = interactor.GetShiftKey() + advance = 1 + if shift: + advance = 10 + + if (self.GetActiveSlice() + advance < maxSlice): + self.SetActiveSlice(self.GetActiveSlice() + advance) + + self.UpdatePipeline() + else: + print ("maxSlice %d request %d" % (maxSlice, self.GetActiveSlice() + 1 )) + + def OnMouseWheelBackward(self, interactor, event): + minSlice = 0 + shift = interactor.GetShiftKey() + advance = 1 + if shift: + advance = 10 + if (self.GetActiveSlice() - advance >= minSlice): + self.SetActiveSlice( self.GetActiveSlice() - advance) + self.UpdatePipeline() + else: + print ("minSlice %d request %d" % (minSlice, self.GetActiveSlice() + 1 )) + + def OnKeyPress(self, interactor, event): + #print ("Pressed key %s" % interactor.GetKeyCode()) + # Slice Orientation + if interactor.GetKeyCode() == "X": + # slice on the other orientation + self.SetSliceOrientation ( SLICE_ORIENTATION_YZ ) + self.SetActiveSlice( int(self.GetDimensions()[1] / 2) ) + self.UpdatePipeline(True) + elif interactor.GetKeyCode() == "Y": + # slice on the other orientation + self.SetSliceOrientation ( SLICE_ORIENTATION_XZ ) + self.SetActiveSlice ( int(self.GetInputData().GetDimensions()[1] / 2) ) + self.UpdatePipeline(True) + elif interactor.GetKeyCode() == "Z": + # slice on the other orientation + self.SetSliceOrientation ( SLICE_ORIENTATION_XY ) + self.SetActiveSlice ( int(self.GetInputData().GetDimensions()[2] / 2) ) + self.UpdatePipeline(True) + if interactor.GetKeyCode() == "x": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_YZ] = numpy.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,0,-1) + self.SetActiveCamera(camera) + self.Render() + interactor.SetKeyCode("X") + self.OnKeyPress(interactor, event) + elif interactor.GetKeyCode() == "y": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_XZ] = numpy.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_YZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,0,-1) + self.SetActiveCamera(camera) + self.Render() + interactor.SetKeyCode("Y") + self.OnKeyPress(interactor, event) + elif interactor.GetKeyCode() == "z": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_XY] = numpy.sqrt(newposition[SLICE_ORIENTATION_YZ] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,1,0) + self.SetActiveCamera(camera) + self.ResetCamera() + self.Render() + interactor.SetKeyCode("Z") + self.OnKeyPress(interactor, event) + elif interactor.GetKeyCode() == "a": + # reset color/window + cmax = self._viewer.ia.GetMax()[0] + cmin = self._viewer.ia.GetMin()[0] + + self.SetInitialLevel( (cmax+cmin)/2 ) + self.SetInitialWindow( cmax-cmin ) + + self.GetWindowLevel().SetLevel(self.GetInitialLevel()) + self.GetWindowLevel().SetWindow(self.GetInitialWindow()) + + self.GetWindowLevel().Update() + + self.UpdateSliceActor() + self.AdjustCamera() + self.Render() + + elif interactor.GetKeyCode() == "s": + filename = "current_render" + self.SaveRender(filename) + elif interactor.GetKeyCode() == "q": + print ("Terminating by pressing q %s" % (interactor.GetKeyCode(), )) + interactor.SetKeyCode("e") + self.OnKeyPress(interactor, event) + else : + #print ("Unhandled event %s" % (interactor.GetKeyCode(), ))) + pass + + def OnLeftButtonPressEvent(self, interactor, event): + alt = interactor.GetAltKey() + shift = interactor.GetShiftKey() + ctrl = interactor.GetControlKey() +# print ("alt pressed " + (lambda x : "Yes" if x else "No")(alt)) +# print ("shift pressed " + (lambda x : "Yes" if x else "No")(shift)) +# print ("ctrl pressed " + (lambda x : "Yes" if x else "No")(ctrl)) + + interactor.SetInitialEventPosition(interactor.GetEventPosition()) + + if ctrl and not (alt and shift): + self.SetViewerEvent( ViewerEvent.CREATE_ROI_EVENT ) + wsize = self.GetRenderWindow().GetSize() + position = interactor.GetEventPosition() + self.GetROIWidget().GetBorderRepresentation().SetPosition((position[0]/wsize[0] - 0.05) , (position[1]/wsize[1] - 0.05)) + self.GetROIWidget().GetBorderRepresentation().SetPosition2( (0.1) , (0.1)) + + self.GetROIWidget().On() + self.SetDisplayHistogram(True) + self.Render() + print ("Event %s is CREATE_ROI_EVENT" % (event)) + elif alt and not (shift and ctrl): + self.SetViewerEvent( ViewerEvent.DELETE_ROI_EVENT ) + self.GetROIWidget().Off() + self._viewer.updateCornerAnnotation("", 1, False) + self.SetDisplayHistogram(False) + self.Render() + print ("Event %s is DELETE_ROI_EVENT" % (event)) + elif not (ctrl and alt and shift): + self.SetViewerEvent ( ViewerEvent.PICK_EVENT ) + self.HandlePickEvent(interactor, event) + print ("Event %s is PICK_EVENT" % (event)) + + + def SetDisplayHistogram(self, display): + if display: + if (self._viewer.displayHistogram == 0): + self.GetRenderer().AddActor(self._viewer.histogramPlotActor) + self.firstHistogram = 1 + self.Render() + + self._viewer.histogramPlotActor.VisibilityOn() + self._viewer.displayHistogram = True + else: + self._viewer.histogramPlotActor.VisibilityOff() + self._viewer.displayHistogram = False + + + def OnLeftButtonReleaseEvent(self, interactor, event): + if self.GetViewerEvent() == ViewerEvent.CREATE_ROI_EVENT: + #bc = self.ROIWidget.GetBorderRepresentation().GetPositionCoordinate() + #print (bc.GetValue()) + self.OnROIModifiedEvent(interactor, event) + + elif self.GetViewerEvent() == ViewerEvent.PICK_EVENT: + self.HandlePickEvent(interactor, event) + + self.SetViewerEvent( ViewerEvent.NO_EVENT ) + + def OnRightButtonPressEvent(self, interactor, event): + alt = interactor.GetAltKey() + shift = interactor.GetShiftKey() + ctrl = interactor.GetControlKey() +# print ("alt pressed " + (lambda x : "Yes" if x else "No")(alt)) +# print ("shift pressed " + (lambda x : "Yes" if x else "No")(shift)) +# print ("ctrl pressed " + (lambda x : "Yes" if x else "No")(ctrl)) + + interactor.SetInitialEventPosition(interactor.GetEventPosition()) + + + if alt and not (ctrl and shift): + self.SetViewerEvent( ViewerEvent.WINDOW_LEVEL_EVENT ) + print ("Event %s is WINDOW_LEVEL_EVENT" % (event)) + self.HandleWindowLevel(interactor, event) + elif shift and not (ctrl and alt): + self.SetViewerEvent( ViewerEvent.ZOOM_EVENT ) + self.SetInitialCameraPosition( self.GetActiveCamera().GetPosition()) + print ("Event %s is ZOOM_EVENT" % (event)) + elif ctrl and not (shift and alt): + self.SetViewerEvent (ViewerEvent.PAN_EVENT ) + self.SetInitialCameraPosition ( self.GetActiveCamera().GetPosition() ) + print ("Event %s is PAN_EVENT" % (event)) + + def OnRightButtonReleaseEvent(self, interactor, event): + print (event) + if self.GetViewerEvent() == ViewerEvent.WINDOW_LEVEL_EVENT: + self.SetInitialLevel( self.GetWindowLevel().GetLevel() ) + self.SetInitialWindow ( self.GetWindowLevel().GetWindow() ) + elif self.GetViewerEvent() == ViewerEvent.ZOOM_EVENT or \ + self.GetViewerEvent() == ViewerEvent.PAN_EVENT: + self.SetInitialCameraPosition( () ) + + self.SetViewerEvent( ViewerEvent.NO_EVENT ) + + + def OnROIModifiedEvent(self, interactor, event): + + #print ("ROI EVENT " + event) + p1 = self.GetROIWidget().GetBorderRepresentation().GetPositionCoordinate() + p2 = self.GetROIWidget().GetBorderRepresentation().GetPosition2Coordinate() + wsize = self.GetRenderWindow().GetSize() + + #print (p1.GetValue()) + #print (p2.GetValue()) + pp1 = [p1.GetValue()[0] * wsize[0] , p1.GetValue()[1] * wsize[1] , 0.0] + pp2 = [p2.GetValue()[0] * wsize[0] + pp1[0] , p2.GetValue()[1] * wsize[1] + pp1[1] , 0.0] + vox1 = self.viewport2imageCoordinate(pp1) + vox2 = self.viewport2imageCoordinate(pp2) + + self.SetROI( (vox1 , vox2) ) + roi = self.GetROI() + print ("Pixel1 %d,%d,%d Value %f" % vox1 ) + print ("Pixel2 %d,%d,%d Value %f" % vox2 ) + if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: + print ("slice orientation : XY") + x = abs(roi[1][0] - roi[0][0]) + y = abs(roi[1][1] - roi[0][1]) + elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: + print ("slice orientation : XY") + x = abs(roi[1][0] - roi[0][0]) + y = abs(roi[1][2] - roi[0][2]) + elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: + print ("slice orientation : XY") + x = abs(roi[1][1] - roi[0][1]) + y = abs(roi[1][2] - roi[0][2]) + + text = "ROI: %d x %d, %.2f kp" % (x,y,float(x*y)/1024.) + print (text) + self.UpdateCornerAnnotation(text, 1) + self.UpdateROIHistogram() + self.SetViewerEvent( ViewerEvent.NO_EVENT ) + + def viewport2imageCoordinate(self, viewerposition): + #Determine point index + + self.GetPicker().Pick(viewerposition[0], viewerposition[1], 0.0, self.GetRenderer()) + pickPosition = list(self.GetPicker().GetPickPosition()) + pickPosition[self.GetSliceOrientation()] = \ + self.GetInputData().GetSpacing()[self.GetSliceOrientation()] * self.GetActiveSlice() + \ + self.GetInputData().GetOrigin()[self.GetSliceOrientation()] + print ("Pick Position " + str (pickPosition)) + + if (pickPosition != [0,0,0]): + dims = self.GetInputData().GetDimensions() + print (dims) + spac = self.GetInputData().GetSpacing() + orig = self.GetInputData().GetOrigin() + imagePosition = [int(pickPosition[i] / spac[i] + orig[i]) for i in range(3) ] + + pixelValue = self.GetInputData().GetScalarComponentAsDouble(imagePosition[0], imagePosition[1], imagePosition[2], 0) + return (imagePosition[0], imagePosition[1], imagePosition[2] , pixelValue) + else: + return (0,0,0,0) + + + + + def OnMouseMoveEvent(self, interactor, event): + if self.GetViewerEvent() == ViewerEvent.WINDOW_LEVEL_EVENT: + print ("Event %s is WINDOW_LEVEL_EVENT" % (event)) + self.HandleWindowLevel(interactor, event) + elif self.GetViewerEvent() == ViewerEvent.PICK_EVENT: + self.HandlePickEvent(interactor, event) + elif self.GetViewerEvent() == ViewerEvent.ZOOM_EVENT: + self.HandleZoomEvent(interactor, event) + elif self.GetViewerEvent() == ViewerEvent.PAN_EVENT: + self.HandlePanEvent(interactor, event) + + + def HandleZoomEvent(self, interactor, event): + dx,dy = interactor.GetDeltaEventPosition() + size = self.GetRenderWindow().GetSize() + dy = - 4 * dy / size[1] + + print ("distance: " + str(self.GetActiveCamera().GetDistance())) + + print ("\ndy: %f\ncamera dolly %f\n" % (dy, 1 + dy)) + + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) + #print ("current position " + str(self.InitialCameraPosition)) + camera.SetViewUp(self.GetActiveCamera().GetViewUp()) + camera.SetPosition(self.GetInitialCameraPosition()) + newposition = [i for i in self.GetInitialCameraPosition()] + if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: + dist = newposition[SLICE_ORIENTATION_XY] * ( 1 + dy ) + newposition[SLICE_ORIENTATION_XY] *= ( 1 + dy ) + elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: + newposition[SLICE_ORIENTATION_XZ] *= ( 1 + dy ) + elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: + newposition[SLICE_ORIENTATION_YZ] *= ( 1 + dy ) + #print ("new position " + str(newposition)) + camera.SetPosition(newposition) + self.SetActiveCamera(camera) + + self.Render() + + print ("distance after: " + str(self.GetActiveCamera().GetDistance())) + + def HandlePanEvent(self, interactor, event): + x,y = interactor.GetEventPosition() + x0,y0 = interactor.GetInitialEventPosition() + + ic = self.viewport2imageCoordinate((x,y)) + ic0 = self.viewport2imageCoordinate((x0,y0)) + + dx = 4 *( ic[0] - ic0[0]) + dy = 4* (ic[1] - ic0[1]) + + camera = vtk.vtkCamera() + #print ("current position " + str(self.InitialCameraPosition)) + camera.SetViewUp(self.GetActiveCamera().GetViewUp()) + camera.SetPosition(self.GetInitialCameraPosition()) + newposition = [i for i in self.GetInitialCameraPosition()] + newfocalpoint = [i for i in self.GetActiveCamera().GetFocalPoint()] + if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: + newposition[0] -= dx + newposition[1] -= dy + newfocalpoint[0] = newposition[0] + newfocalpoint[1] = newposition[1] + elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: + newposition[0] -= dx + newposition[2] -= dy + newfocalpoint[0] = newposition[0] + newfocalpoint[2] = newposition[2] + elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: + newposition[1] -= dx + newposition[2] -= dy + newfocalpoint[2] = newposition[2] + newfocalpoint[1] = newposition[1] + #print ("new position " + str(newposition)) + camera.SetFocalPoint(newfocalpoint) + camera.SetPosition(newposition) + self.SetActiveCamera(camera) + + self.Render() + + def HandleWindowLevel(self, interactor, event): + dx,dy = interactor.GetDeltaEventPosition() + print ("Event delta %d %d" % (dx,dy)) + size = self.GetRenderWindow().GetSize() + + dx = 4 * dx / size[0] + dy = 4 * dy / size[1] + window = self.GetInitialWindow() + level = self.GetInitialLevel() + + if abs(window) > 0.01: + dx = dx * window + else: + dx = dx * (lambda x: -0.01 if x <0 else 0.01)(window); + + if abs(level) > 0.01: + dy = dy * level + else: + dy = dy * (lambda x: -0.01 if x <0 else 0.01)(level) + + + # Abs so that direction does not flip + + if window < 0.0: + dx = -1*dx + if level < 0.0: + dy = -1*dy + + # Compute new window level + + newWindow = dx + window + newLevel = level - dy + + # Stay away from zero and really + + if abs(newWindow) < 0.01: + newWindow = 0.01 * (lambda x: -1 if x <0 else 1)(newWindow) + + if abs(newLevel) < 0.01: + newLevel = 0.01 * (lambda x: -1 if x <0 else 1)(newLevel) + + self.GetWindowLevel().SetWindow(newWindow) + self.GetWindowLevel().SetLevel(newLevel) + + self.GetWindowLevel().Update() + self.UpdateSliceActor() + self.AdjustCamera() + + self.Render() + + def HandlePickEvent(self, interactor, event): + position = interactor.GetEventPosition() + #print ("PICK " + str(position)) + vox = self.viewport2imageCoordinate(position) + #print ("Pixel %d,%d,%d Value %f" % vox ) + self._viewer.cornerAnnotation.VisibilityOn() + self.UpdateCornerAnnotation("[%d,%d,%d] : %.2f" % vox , 0) + self.Render() + +############################################################################### + + + +class CILViewer2D(): + '''Simple Interactive Viewer based on VTK classes''' + + def __init__(self, dimx=600,dimy=600, ren=None, renWin=None,iren=None): + '''creates the rendering pipeline''' + # create a rendering window and renderer + if ren == None: + self.ren = vtk.vtkRenderer() + else: + self.ren = ren + if renWin == None: + self.renWin = vtk.vtkRenderWindow() + else: + self.renWin = renWin + if iren == None: + self.iren = vtk.vtkRenderWindowInteractor() + else: + self.iren = iren + + self.renWin.SetSize(dimx,dimy) + self.renWin.AddRenderer(self.ren) + + self.style = CILInteractorStyle(self) + + self.iren.SetInteractorStyle(self.style) + self.iren.SetRenderWindow(self.renWin) + self.iren.Initialize() + self.ren.SetBackground(.1, .2, .4) + + self.camera = vtk.vtkCamera() + self.camera.ParallelProjectionOn() + self.ren.SetActiveCamera(self.camera) + + # data + self.img3D = None + self.sliceno = 0 + self.sliceOrientation = SLICE_ORIENTATION_XY + + #Actors + self.sliceActor = vtk.vtkImageActor() + self.voi = vtk.vtkExtractVOI() + self.wl = vtk.vtkImageMapToWindowLevelColors() + self.ia = vtk.vtkImageAccumulate() + self.sliceActorNo = 0 + + #initial Window/Level + self.InitialLevel = 0 + self.InitialWindow = 0 + + #ViewerEvent + self.event = ViewerEvent.NO_EVENT + + # ROI Widget + self.ROIWidget = vtk.vtkBorderWidget() + self.ROIWidget.SetInteractor(self.iren) + self.ROIWidget.CreateDefaultRepresentation() + self.ROIWidget.GetBorderRepresentation().GetBorderProperty().SetColor(0,1,0) + self.ROIWidget.AddObserver(vtk.vtkWidgetEvent.Select, self.style.OnROIModifiedEvent, 1.0) + + # edge points of the ROI + self.ROI = () + + #picker + self.picker = vtk.vtkPropPicker() + self.picker.PickFromListOn() + self.picker.AddPickList(self.sliceActor) + + self.iren.SetPicker(self.picker) + + # corner annotation + self.cornerAnnotation = vtk.vtkCornerAnnotation() + self.cornerAnnotation.SetMaximumFontSize(12); + self.cornerAnnotation.PickableOff(); + self.cornerAnnotation.VisibilityOff(); + self.cornerAnnotation.GetTextProperty().ShadowOn(); + self.cornerAnnotation.SetLayerNumber(1); + + + + # cursor doesn't show up + self.cursor = vtk.vtkCursor2D() + self.cursorMapper = vtk.vtkPolyDataMapper2D() + self.cursorActor = vtk.vtkActor2D() + self.cursor.SetModelBounds(-10, 10, -10, 10, 0, 0) + self.cursor.SetFocalPoint(0, 0, 0) + self.cursor.AllOff() + self.cursor.AxesOn() + self.cursorActor.PickableOff() + self.cursorActor.VisibilityOn() + self.cursorActor.GetProperty().SetColor(1, 1, 1) + self.cursorActor.SetLayerNumber(1) + self.cursorMapper.SetInputData(self.cursor.GetOutput()) + self.cursorActor.SetMapper(self.cursorMapper) + + # Zoom + self.InitialCameraPosition = () + + # XY Plot actor for histogram + self.displayHistogram = False + self.firstHistogram = 0 + self.roiIA = vtk.vtkImageAccumulate() + self.roiVOI = vtk.vtkExtractVOI() + self.histogramPlotActor = vtk.vtkXYPlotActor() + self.histogramPlotActor.ExchangeAxesOff(); + self.histogramPlotActor.SetXLabelFormat( "%g" ) + self.histogramPlotActor.SetXLabelFormat( "%g" ) + self.histogramPlotActor.SetAdjustXLabels(3) + self.histogramPlotActor.SetXTitle( "Level" ) + self.histogramPlotActor.SetYTitle( "N" ) + self.histogramPlotActor.SetXValuesToValue() + self.histogramPlotActor.SetPlotColor(0, (0,1,1) ) + self.histogramPlotActor.SetPosition(0.6,0.6) + self.histogramPlotActor.SetPosition2(0.4,0.4) + + + + def GetInteractor(self): + return self.iren + + def GetRenderer(self): + return self.ren + + def setInput3DData(self, imageData): + self.img3D = imageData + self.installPipeline() + + def setInputAsNumpy(self, numpyarray, origin=(0,0,0), spacing=(1.,1.,1.), + rescale=True, dtype=vtk.VTK_UNSIGNED_SHORT): + importer = Converter.numpy2vtkImporter(numpyarray, spacing, origin) + importer.Update() + + if rescale: + # rescale to appropriate VTK_UNSIGNED_SHORT + stats = vtk.vtkImageAccumulate() + stats.SetInputData(importer.GetOutput()) + stats.Update() + iMin = stats.GetMin()[0] + iMax = stats.GetMax()[0] + if (iMax - iMin == 0): + scale = 1 + else: + if dtype == vtk.VTK_UNSIGNED_SHORT: + scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) + elif dtype == vtk.VTK_UNSIGNED_INT: + scale = vtk.VTK_UNSIGNED_INT_MAX / (iMax - iMin) + + shiftScaler = vtk.vtkImageShiftScale () + shiftScaler.SetInputData(importer.GetOutput()) + shiftScaler.SetScale(scale) + shiftScaler.SetShift(-iMin) + shiftScaler.SetOutputScalarType(dtype) + shiftScaler.Update() + self.img3D = shiftScaler.GetOutput() + else: + self.img3D = importer.GetOutput() + + self.installPipeline() + + def displaySlice(self, sliceno = 0): + self.sliceno = sliceno + + self.updatePipeline() + + self.renWin.Render() + + return self.sliceActorNo + + def updatePipeline(self, resetcamera = False): + extent = [ i for i in self.img3D.GetExtent()] + extent[self.sliceOrientation * 2] = self.sliceno + extent[self.sliceOrientation * 2 + 1] = self.sliceno + self.voi.SetVOI(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + + self.voi.Update() + self.ia.Update() + self.wl.Update() + self.sliceActor.SetDisplayExtent(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + self.sliceActor.Update() + + self.updateCornerAnnotation("Slice %d/%d" % (self.sliceno + 1 , self.img3D.GetDimensions()[self.sliceOrientation])) + + if self.displayHistogram: + self.updateROIHistogram() + + self.AdjustCamera(resetcamera) + + self.renWin.Render() + + + def installPipeline(self): + '''Slices a 3D volume and then creates an actor to be rendered''' + + self.ren.AddViewProp(self.cornerAnnotation) + + self.voi.SetInputData(self.img3D) + #select one slice in Z + extent = [ i for i in self.img3D.GetExtent()] + extent[self.sliceOrientation * 2] = self.sliceno + extent[self.sliceOrientation * 2 + 1] = self.sliceno + self.voi.SetVOI(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + + self.voi.Update() + # set window/level for current slices + + + self.wl = vtk.vtkImageMapToWindowLevelColors() + self.ia.SetInputData(self.voi.GetOutput()) + self.ia.Update() + cmax = self.ia.GetMax()[0] + cmin = self.ia.GetMin()[0] + + self.InitialLevel = (cmax+cmin)/2 + self.InitialWindow = cmax-cmin + + + self.wl.SetLevel(self.InitialLevel) + self.wl.SetWindow(self.InitialWindow) + + self.wl.SetInputData(self.voi.GetOutput()) + self.wl.Update() + + self.sliceActor.SetInputData(self.wl.GetOutput()) + self.sliceActor.SetDisplayExtent(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + self.sliceActor.Update() + self.sliceActor.SetInterpolate(False) + self.ren.AddActor(self.sliceActor) + self.ren.ResetCamera() + self.ren.Render() + + self.AdjustCamera() + + self.ren.AddViewProp(self.cursorActor) + self.cursorActor.VisibilityOn() + + self.iren.Initialize() + self.renWin.Render() + #self.iren.Start() + + def AdjustCamera(self, resetcamera = False): + self.ren.ResetCameraClippingRange() + if resetcamera: + self.ren.ResetCamera() + + + def getROI(self): + return self.ROI + + def getROIExtent(self): + p0 = self.ROI[0] + p1 = self.ROI[1] + return (p0[0], p1[0],p0[1],p1[1],p0[2],p1[2]) + + ############### Handle events are moved to the interactor style + + + def viewport2imageCoordinate(self, viewerposition): + #Determine point index + + self.picker.Pick(viewerposition[0], viewerposition[1], 0.0, self.GetRenderer()) + pickPosition = list(self.picker.GetPickPosition()) + pickPosition[self.sliceOrientation] = \ + self.img3D.GetSpacing()[self.sliceOrientation] * self.sliceno + \ + self.img3D.GetOrigin()[self.sliceOrientation] + print ("Pick Position " + str (pickPosition)) + + if (pickPosition != [0,0,0]): + dims = self.img3D.GetDimensions() + print (dims) + spac = self.img3D.GetSpacing() + orig = self.img3D.GetOrigin() + imagePosition = [int(pickPosition[i] / spac[i] + orig[i]) for i in range(3) ] + + pixelValue = self.img3D.GetScalarComponentAsDouble(imagePosition[0], imagePosition[1], imagePosition[2], 0) + return (imagePosition[0], imagePosition[1], imagePosition[2] , pixelValue) + else: + return (0,0,0,0) + + + + def GetRenderWindow(self): + return self.renWin + + + def startRenderLoop(self): + self.iren.Start() + + def GetSliceOrientation(self): + return self.sliceOrientation + + def GetActiveSlice(self): + return self.sliceno + + def updateCornerAnnotation(self, text , idx=0, visibility=True): + if visibility: + self.cornerAnnotation.VisibilityOn() + else: + self.cornerAnnotation.VisibilityOff() + + self.cornerAnnotation.SetText(idx, text) + self.iren.Render() + + def saveRender(self, filename, renWin=None): + '''Save the render window to PNG file''' + # screenshot code: + w2if = vtk.vtkWindowToImageFilter() + if renWin == None: + renWin = self.renWin + w2if.SetInput(renWin) + w2if.Update() + + writer = vtk.vtkPNGWriter() + writer.SetFileName("%s.png" % (filename)) + writer.SetInputConnection(w2if.GetOutputPort()) + writer.Write() + + def updateROIHistogram(self): + + extent = [0 for i in range(6)] + if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: + print ("slice orientation : XY") + extent[0] = self.ROI[0][0] + extent[1] = self.ROI[1][0] + extent[2] = self.ROI[0][1] + extent[3] = self.ROI[1][1] + extent[4] = self.GetActiveSlice() + extent[5] = self.GetActiveSlice()+1 + #y = abs(roi[1][1] - roi[0][1]) + elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: + print ("slice orientation : XY") + extent[0] = self.ROI[0][0] + extent[1] = self.ROI[1][0] + #x = abs(roi[1][0] - roi[0][0]) + extent[4] = self.ROI[0][2] + extent[5] = self.ROI[1][2] + #y = abs(roi[1][2] - roi[0][2]) + extent[2] = self.GetActiveSlice() + extent[3] = self.GetActiveSlice()+1 + elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: + print ("slice orientation : XY") + extent[2] = self.ROI[0][1] + extent[3] = self.ROI[1][1] + #x = abs(roi[1][1] - roi[0][1]) + extent[4] = self.ROI[0][2] + extent[5] = self.ROI[1][2] + #y = abs(roi[1][2] - roi[0][2]) + extent[0] = self.GetActiveSlice() + extent[1] = self.GetActiveSlice()+1 + + self.roiVOI.SetVOI(extent) + self.roiVOI.SetInputData(self.img3D) + self.roiVOI.Update() + irange = self.roiVOI.GetOutput().GetScalarRange() + + self.roiIA.SetInputData(self.roiVOI.GetOutput()) + self.roiIA.IgnoreZeroOff() + self.roiIA.SetComponentExtent(0,int(irange[1]-irange[0]-1),0,0,0,0 ) + self.roiIA.SetComponentOrigin( int(irange[0]),0,0 ); + self.roiIA.SetComponentSpacing( 1,0,0 ); + self.roiIA.Update() + + self.histogramPlotActor.AddDataSetInputConnection(self.roiIA.GetOutputPort()) + self.histogramPlotActor.SetXRange(irange[0],irange[1]) + + self.histogramPlotActor.SetYRange( self.roiIA.GetOutput().GetScalarRange() ) + + \ No newline at end of file diff --git a/src/Python/ccpi/viewer/QVTKWidget.py b/src/Python/ccpi/viewer/QVTKWidget.py new file mode 100644 index 0000000..906786b --- /dev/null +++ b/src/Python/ccpi/viewer/QVTKWidget.py @@ -0,0 +1,340 @@ +################################################################################ +# File: QVTKWidget.py +# Author: Edoardo Pasca +# Description: PyVE Viewer Qt widget +# +# License: +# This file is part of PyVE. PyVE is an open-source image +# analysis and visualization environment focused on medical +# imaging. More info at http://pyve.sourceforge.net +# +# Copyright (c) 2011-2012 Edoardo Pasca, Lukas Batteau +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or +# without modification, are permitted provided that the following +# conditions are met: +# +# Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. Neither name of Edoardo Pasca or Lukas +# Batteau nor the names of any contributors may be used to endorse +# or promote products derived from this software without specific +# prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, +# INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR +# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT +# OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY +# OF SUCH DAMAGE. +# +# CHANGE HISTORY +# +# 20120118 Edoardo Pasca Initial version +# +############################################################################### + +import os +from PyQt5 import QtCore, QtGui, QtWidgets +#import itk +import vtk +#from viewer import PyveViewer +from ccpi.viewer.CILViewer2D import CILViewer2D , Converter + +class QVTKWidget(QtWidgets.QWidget): + + """ A QVTKWidget for Python and Qt.""" + + # Map between VTK and Qt cursors. + _CURSOR_MAP = { + 0: QtCore.Qt.ArrowCursor, # VTK_CURSOR_DEFAULT + 1: QtCore.Qt.ArrowCursor, # VTK_CURSOR_ARROW + 2: QtCore.Qt.SizeBDiagCursor, # VTK_CURSOR_SIZENE + 3: QtCore.Qt.SizeFDiagCursor, # VTK_CURSOR_SIZENWSE + 4: QtCore.Qt.SizeBDiagCursor, # VTK_CURSOR_SIZESW + 5: QtCore.Qt.SizeFDiagCursor, # VTK_CURSOR_SIZESE + 6: QtCore.Qt.SizeVerCursor, # VTK_CURSOR_SIZENS + 7: QtCore.Qt.SizeHorCursor, # VTK_CURSOR_SIZEWE + 8: QtCore.Qt.SizeAllCursor, # VTK_CURSOR_SIZEALL + 9: QtCore.Qt.PointingHandCursor, # VTK_CURSOR_HAND + 10: QtCore.Qt.CrossCursor, # VTK_CURSOR_CROSSHAIR + } + + def __init__(self, parent=None, wflags=QtCore.Qt.WindowFlags(), **kw): + # the current button + self._ActiveButton = QtCore.Qt.NoButton + + # private attributes + self.__oldFocus = None + self.__saveX = 0 + self.__saveY = 0 + self.__saveModifiers = QtCore.Qt.NoModifier + self.__saveButtons = QtCore.Qt.NoButton + self.__timeframe = 0 + + # create qt-level widget + QtWidgets.QWidget.__init__(self, parent, wflags|QtCore.Qt.MSWindowsOwnDC) + + # Link to PyVE Viewer + self._PyveViewer = CILViewer2D() + #self._Viewer = self._PyveViewer._vtkPyveViewer + + self._Iren = self._PyveViewer.GetInteractor() + #self._Iren = self._Viewer.GetRenderWindow().GetInteractor() + self._RenderWindow = self._PyveViewer.GetRenderWindow() + #self._RenderWindow = self._Viewer.GetRenderWindow() + + self._Iren.Register(self._RenderWindow) + self._Iren.SetRenderWindow(self._RenderWindow) + self._RenderWindow.SetWindowInfo(str(int(self.winId()))) + + # do all the necessary qt setup + self.setAttribute(QtCore.Qt.WA_OpaquePaintEvent) + self.setAttribute(QtCore.Qt.WA_PaintOnScreen) + self.setMouseTracking(True) # get all mouse events + self.setFocusPolicy(QtCore.Qt.WheelFocus) + self.setSizePolicy(QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)) + + self._Timer = QtCore.QTimer(self) + #self.connect(self._Timer, QtCore.pyqtSignal('timeout()'), self.TimerEvent) + + self._Iren.AddObserver('CreateTimerEvent', self.CreateTimer) + self._Iren.AddObserver('DestroyTimerEvent', self.DestroyTimer) + self._Iren.GetRenderWindow().AddObserver('CursorChangedEvent', + self.CursorChangedEvent) + + # Destructor + def __del__(self): + self._Iren.UnRegister(self._RenderWindow) + #QtWidgets.QWidget.__del__(self) + + # Display image data + def SetInput(self, imageData): + self._PyveViewer.setInput3DData(imageData) + + # GetInteractor + def GetInteractor(self): + return self._Iren + + # Display image data + def GetPyveViewer(self): + return self._PyveViewer + + def __getattr__(self, attr): + """Makes the object behave like a vtkGenericRenderWindowInteractor""" + print (attr) + if attr == '__vtk__': + return lambda t=self._Iren: t + elif hasattr(self._Iren, attr): + return getattr(self._Iren, attr) +# else: +# raise AttributeError( self.__class__.__name__ + \ +# " has no attribute named " + attr ) + + def CreateTimer(self, obj, evt): + self._Timer.start(10) + + def DestroyTimer(self, obj, evt): + self._Timer.stop() + return 1 + + def TimerEvent(self): + self._Iren.InvokeEvent("TimerEvent") + + def CursorChangedEvent(self, obj, evt): + """Called when the CursorChangedEvent fires on the render window.""" + # This indirection is needed since when the event fires, the current + # cursor is not yet set so we defer this by which time the current + # cursor should have been set. + QtCore.QTimer.singleShot(0, self.ShowCursor) + + def HideCursor(self): + """Hides the cursor.""" + self.setCursor(QtCore.Qt.BlankCursor) + + def ShowCursor(self): + """Shows the cursor.""" + vtk_cursor = self._Iren.GetRenderWindow().GetCurrentCursor() + qt_cursor = self._CURSOR_MAP.get(vtk_cursor, QtCore.Qt.ArrowCursor) + self.setCursor(qt_cursor) + + def sizeHint(self): + return QtCore.QSize(400, 400) + + def paintEngine(self): + return None + + def paintEvent(self, ev): + self._RenderWindow.Render() + + def resizeEvent(self, ev): + self._RenderWindow.Render() + w = self.width() + h = self.height() + + self._RenderWindow.SetSize(w, h) + self._Iren.SetSize(w, h) + + def _GetCtrlShiftAlt(self, ev): + ctrl = shift = alt = False + + if hasattr(ev, 'modifiers'): + if ev.modifiers() & QtCore.Qt.ShiftModifier: + shift = True + if ev.modifiers() & QtCore.Qt.ControlModifier: + ctrl = True + if ev.modifiers() & QtCore.Qt.AltModifier: + alt = True + else: + if self.__saveModifiers & QtCore.Qt.ShiftModifier: + shift = True + if self.__saveModifiers & QtCore.Qt.ControlModifier: + ctrl = True + if self.__saveModifiers & QtCore.Qt.AltModifier: + alt = True + + return ctrl, shift, alt + + def enterEvent(self, ev): + if not self.hasFocus(): + self.__oldFocus = self.focusWidget() + self.setFocus() + + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, + ctrl, shift, chr(0), 0, None) + self._Iren.SetAltKey(alt) + self._Iren.InvokeEvent("EnterEvent") + + def leaveEvent(self, ev): + if self.__saveButtons == QtCore.Qt.NoButton and self.__oldFocus: + self.__oldFocus.setFocus() + self.__oldFocus = None + + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, + ctrl, shift, chr(0), 0, None) + self._Iren.SetAltKey(alt) + self._Iren.InvokeEvent("LeaveEvent") + + def mousePressEvent(self, ev): + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + repeat = 0 + if ev.type() == QtCore.QEvent.MouseButtonDblClick: + repeat = 1 + self._Iren.SetEventInformationFlipY(ev.x(), ev.y(), + ctrl, shift, chr(0), repeat, None) + + self._Iren.SetAltKey(alt) + self._ActiveButton = ev.button() + + if self._ActiveButton == QtCore.Qt.LeftButton: + self._Iren.InvokeEvent("LeftButtonPressEvent") + elif self._ActiveButton == QtCore.Qt.RightButton: + self._Iren.InvokeEvent("RightButtonPressEvent") + elif self._ActiveButton == QtCore.Qt.MidButton: + self._Iren.InvokeEvent("MiddleButtonPressEvent") + + def mouseReleaseEvent(self, ev): + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + self._Iren.SetEventInformationFlipY(ev.x(), ev.y(), + ctrl, shift, chr(0), 0, None) + self._Iren.SetAltKey(alt) + + if self._ActiveButton == QtCore.Qt.LeftButton: + self._Iren.InvokeEvent("LeftButtonReleaseEvent") + elif self._ActiveButton == QtCore.Qt.RightButton: + self._Iren.InvokeEvent("RightButtonReleaseEvent") + elif self._ActiveButton == QtCore.Qt.MidButton: + self._Iren.InvokeEvent("MiddleButtonReleaseEvent") + + def mouseMoveEvent(self, ev): + self.__saveModifiers = ev.modifiers() + self.__saveButtons = ev.buttons() + self.__saveX = ev.x() + self.__saveY = ev.y() + + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + self._Iren.SetEventInformationFlipY(ev.x(), ev.y(), + ctrl, shift, chr(0), 0, None) + self._Iren.SetAltKey(alt) + self._Iren.InvokeEvent("MouseMoveEvent") + + def keyPressEvent(self, ev): + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + if ev.key() < 256: + key = str(ev.text()) + else: + key = chr(0) + + self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, + ctrl, shift, key, 0, None) + self._Iren.SetAltKey(alt) + self._Iren.InvokeEvent("KeyPressEvent") + self._Iren.InvokeEvent("CharEvent") + + def keyReleaseEvent(self, ev): + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + if ev.key() < 256: + key = chr(ev.key()) + else: + key = chr(0) + + self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, + ctrl, shift, key, 0, None) + self._Iren.SetAltKey(alt) + self._Iren.InvokeEvent("KeyReleaseEvent") + + def wheelEvent(self, ev): + print ("angleDeltaX %d" % ev.angleDelta().x()) + print ("angleDeltaY %d" % ev.angleDelta().y()) + if ev.angleDelta().y() >= 0: + self._Iren.InvokeEvent("MouseWheelForwardEvent") + else: + self._Iren.InvokeEvent("MouseWheelBackwardEvent") + + def GetRenderWindow(self): + return self._RenderWindow + + def Render(self): + self.update() + + +def QVTKExample(): + """A simple example that uses the QVTKWidget class.""" + + # every QT app needs an app + app = QtWidgets.QApplication(['PyVE QVTKWidget Example']) + page_VTK = QtWidgets.QWidget() + page_VTK.resize(500,500) + layout = QtWidgets.QVBoxLayout(page_VTK) + # create the widget + widget = QVTKWidget(parent=None) + layout.addWidget(widget) + + #reader = vtk.vtkPNGReader() + #reader.SetFileName("F:\Diagnostics\Images\PyVE\VTKData\Data\camscene.png") + reader = vtk.vtkMetaImageReader() + reader.SetFileName("C:\\Users\\ofn77899\\Documents\\GitHub\\CCPi-Simpleflex\\data\\head.mha") + reader.Update() + + widget.SetInput(reader.GetOutput()) + + # show the widget + page_VTK.show() + # start event processing + app.exec_() + +if __name__ == "__main__": + QVTKExample() diff --git a/src/Python/ccpi/viewer/QVTKWidget2.py b/src/Python/ccpi/viewer/QVTKWidget2.py new file mode 100644 index 0000000..e32e1c2 --- /dev/null +++ b/src/Python/ccpi/viewer/QVTKWidget2.py @@ -0,0 +1,84 @@ +################################################################################ +# File: QVTKWidget.py +# Author: Edoardo Pasca +# Description: PyVE Viewer Qt widget +# +# License: +# This file is part of PyVE. PyVE is an open-source image +# analysis and visualization environment focused on medical +# imaging. More info at http://pyve.sourceforge.net +# +# Copyright (c) 2011-2012 Edoardo Pasca, Lukas Batteau +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or +# without modification, are permitted provided that the following +# conditions are met: +# +# Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. Neither name of Edoardo Pasca or Lukas +# Batteau nor the names of any contributors may be used to endorse +# or promote products derived from this software without specific +# prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, +# INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR +# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT +# OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY +# OF SUCH DAMAGE. +# +# CHANGE HISTORY +# +# 20120118 Edoardo Pasca Initial version +# +############################################################################### + +import os +from PyQt5 import QtCore, QtGui, QtWidgets +#import itk +import vtk +#from viewer import PyveViewer +from ccpi.viewer.CILViewer2D import CILViewer2D , Converter +from vtk.qt.QVTKRenderWindowInteractor import QVTKRenderWindowInteractor + +class QVTKWidget(QVTKRenderWindowInteractor): + + + def __init__(self, parent=None, wflags=QtCore.Qt.WindowFlags(), **kw): + kw = dict() + super().__init__(parent, **kw) + + + # Link to PyVE Viewer + self._PyveViewer = CILViewer2D(400,400) + #self._Viewer = self._PyveViewer._vtkPyveViewer + + self._Iren = self._PyveViewer.GetInteractor() + kw['iren'] = self._Iren + #self._Iren = self._Viewer.GetRenderWindow().GetInteractor() + self._RenderWindow = self._PyveViewer.GetRenderWindow() + #self._RenderWindow = self._Viewer.GetRenderWindow() + kw['rw'] = self._RenderWindow + + + + + def GetInteractor(self): + return self._Iren + + # Display image data + def SetInput(self, imageData): + self._PyveViewer.setInput3DData(imageData) + \ No newline at end of file diff --git a/src/Python/ccpi/viewer/__init__.py b/src/Python/ccpi/viewer/__init__.py new file mode 100644 index 0000000..946188b --- /dev/null +++ b/src/Python/ccpi/viewer/__init__.py @@ -0,0 +1 @@ +from ccpi.viewer.CILViewer import CILViewer \ No newline at end of file diff --git a/src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc new file mode 100644 index 0000000..711f77a Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc differ diff --git a/src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc new file mode 100644 index 0000000..77c2ca8 Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc differ diff --git a/src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc new file mode 100644 index 0000000..3d11b87 Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc differ diff --git a/src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc new file mode 100644 index 0000000..2fa2eaf Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc differ diff --git a/src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc new file mode 100644 index 0000000..fcea537 Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc differ diff --git a/src/Python/ccpi/viewer/embedvtk.py b/src/Python/ccpi/viewer/embedvtk.py new file mode 100644 index 0000000..b5eb0a7 --- /dev/null +++ b/src/Python/ccpi/viewer/embedvtk.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +""" +Created on Thu Jul 27 12:18:58 2017 + +@author: ofn77899 +""" + +#!/usr/bin/env python + +import sys +import vtk +from PyQt5 import QtCore, QtWidgets +from vtk.qt.QVTKRenderWindowInteractor import QVTKRenderWindowInteractor +import QVTKWidget2 + +class MainWindow(QtWidgets.QMainWindow): + + def __init__(self, parent = None): + QtWidgets.QMainWindow.__init__(self, parent) + + self.frame = QtWidgets.QFrame() + + self.vl = QtWidgets.QVBoxLayout() +# self.vtkWidget = QVTKRenderWindowInteractor(self.frame) + + self.vtkWidget = QVTKWidget2.QVTKWidget(self.frame) + self.iren = self.vtkWidget.GetInteractor() + self.vl.addWidget(self.vtkWidget) + + + + + self.ren = vtk.vtkRenderer() + self.vtkWidget.GetRenderWindow().AddRenderer(self.ren) +# self.iren = self.vtkWidget.GetRenderWindow().GetInteractor() +# +# # Create source +# source = vtk.vtkSphereSource() +# source.SetCenter(0, 0, 0) +# source.SetRadius(5.0) +# +# # Create a mapper +# mapper = vtk.vtkPolyDataMapper() +# mapper.SetInputConnection(source.GetOutputPort()) +# +# # Create an actor +# actor = vtk.vtkActor() +# actor.SetMapper(mapper) +# +# self.ren.AddActor(actor) +# +# self.ren.ResetCamera() +# + self.frame.setLayout(self.vl) + self.setCentralWidget(self.frame) + reader = vtk.vtkMetaImageReader() + reader.SetFileName("C:\\Users\\ofn77899\\Documents\\GitHub\\CCPi-Simpleflex\\data\\head.mha") + reader.Update() + + self.vtkWidget.SetInput(reader.GetOutput()) + + #self.vktWidget.Initialize() + #self.vktWidget.Start() + + self.show() + #self.iren.Initialize() + + +if __name__ == "__main__": + + app = QtWidgets.QApplication(sys.argv) + + window = MainWindow() + + sys.exit(app.exec_()) \ No newline at end of file -- cgit v1.2.3 From ad62962697509d977087c25d24a3ff083d9c4308 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 15:08:32 +0100 Subject: initial revision --- src/Python/ccpi/imaging/Regularizer.py | 322 +++++++++++++++++++++++++++++++++ 1 file changed, 322 insertions(+) create mode 100644 src/Python/ccpi/imaging/Regularizer.py (limited to 'src/Python/ccpi') diff --git a/src/Python/ccpi/imaging/Regularizer.py b/src/Python/ccpi/imaging/Regularizer.py new file mode 100644 index 0000000..fb9ae08 --- /dev/null +++ b/src/Python/ccpi/imaging/Regularizer.py @@ -0,0 +1,322 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Aug 8 14:26:00 2017 + +@author: ofn77899 +""" + +from ccpi.imaging import cpu_regularizers +import numpy as np +from enum import Enum +import timeit + +class Regularizer(): + '''Class to handle regularizer algorithms to be used during reconstruction + + Currently 5 CPU (OMP) regularization algorithms are available: + + 1) SplitBregman_TV + 2) FGP_TV + 3) LLT_model + 4) PatchBased_Regul + 5) TGV_PD + + Usage: + the regularizer can be invoked as object or as static method + Depending on the actual regularizer the input parameter may vary, and + a different default setting is defined. + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + + out = reg(input=u0, regularization_parameter=10., number_of_iterations=30, + tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., + number_of_iterations=30, tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + + A number of optional parameters can be passed or skipped + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) + + ''' + class Algorithm(Enum): + SplitBregman_TV = cpu_regularizers.SplitBregman_TV + FGP_TV = cpu_regularizers.FGP_TV + LLT_model = cpu_regularizers.LLT_model + PatchBased_Regul = cpu_regularizers.PatchBased_Regul + TGV_PD = cpu_regularizers.TGV_PD + # Algorithm + + class TotalVariationPenalty(Enum): + isotropic = 0 + l1 = 1 + # TotalVariationPenalty + + def __init__(self , algorithm, debug = True): + self.setAlgorithm ( algorithm ) + self.debug = debug + # __init__ + + def setAlgorithm(self, algorithm): + self.algorithm = algorithm + self.pars = self.getDefaultParsForAlgorithm(algorithm) + # setAlgorithm + + def getDefaultParsForAlgorithm(self, algorithm): + pars = dict() + + if algorithm == Regularizer.Algorithm.SplitBregman_TV : + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['number_of_iterations'] = 35 + pars['tolerance_constant'] = 0.0001 + pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + + elif algorithm == Regularizer.Algorithm.FGP_TV : + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['number_of_iterations'] = 50 + pars['tolerance_constant'] = 0.001 + pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + + elif algorithm == Regularizer.Algorithm.LLT_model: + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['time_step'] = None + pars['number_of_iterations'] = None + pars['tolerance_constant'] = None + pars['restrictive_Z_smoothing'] = 0 + + elif algorithm == Regularizer.Algorithm.PatchBased_Regul: + pars['algorithm'] = algorithm + pars['input'] = None + pars['searching_window_ratio'] = None + pars['similarity_window_ratio'] = None + pars['PB_filtering_parameter'] = None + pars['regularization_parameter'] = None + + elif algorithm == Regularizer.Algorithm.TGV_PD: + pars['algorithm'] = algorithm + pars['input'] = None + pars['first_order_term'] = None + pars['second_order_term'] = None + pars['number_of_iterations'] = None + pars['regularization_parameter'] = None + + else: + raise Exception('Unknown regularizer algorithm') + + return pars + # parsForAlgorithm + + def setParameter(self, **kwargs): + '''set named parameter for the regularization engine + + raises Exception if the named parameter is not recognized + Typical usage is: + + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + reg.setParameter(input=u0) + reg.setParameter(regularization_parameter=10.) + + it can be also used as + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + reg.setParameter(input=u0 , regularization_parameter=10.) + ''' + + for key , value in kwargs.items(): + if key in self.pars.keys(): + self.pars[key] = value + else: + raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) + # setParameter + + def getParameter(self, **kwargs): + ret = {} + for key , value in kwargs.items(): + if key in self.pars.keys(): + ret[key] = self.pars[key] + else: + raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) + # setParameter + + + def __call__(self, input = None, regularization_parameter = None, **kwargs): + '''Actual call for the regularizer. + + One can either set the regularization parameters first and then call the + algorithm or set the regularization parameter during the call (as + is done in the static methods). + ''' + + if kwargs is not None: + for key, value in kwargs.items(): + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + if input is not None: + self.pars['input'] = input + if regularization_parameter is not None: + self.pars['regularization_parameter'] = regularization_parameter + + if self.debug: + print ("--------------------------------------------------") + for key, value in self.pars.items(): + if key== 'algorithm' : + print("{0} = {1}".format(key, value.__name__)) + elif key == 'input': + print("{0} = {1}".format(key, np.shape(value))) + else: + print("{0} = {1}".format(key, value)) + + + if None in self.pars: + raise Exception("Not all parameters have been provided") + + input = self.pars['input'] + regularization_parameter = self.pars['regularization_parameter'] + if self.algorithm == Regularizer.Algorithm.SplitBregman_TV : + return self.algorithm(input, regularization_parameter, + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['TV_penalty'].value ) + elif self.algorithm == Regularizer.Algorithm.FGP_TV : + return self.algorithm(input, regularization_parameter, + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['TV_penalty'].value ) + elif self.algorithm == Regularizer.Algorithm.LLT_model : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + return self.algorithm(input, + regularization_parameter, + self.pars['time_step'] , + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['restrictive_Z_smoothing'] ) + elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + return self.algorithm(input, regularization_parameter, + self.pars['searching_window_ratio'] , + self.pars['similarity_window_ratio'] , + self.pars['PB_filtering_parameter']) + elif self.algorithm == Regularizer.Algorithm.TGV_PD : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + if len(np.shape(input)) == 2: + return self.algorithm(input, regularization_parameter, + self.pars['first_order_term'] , + self.pars['second_order_term'] , + self.pars['number_of_iterations']) + elif len(np.shape(input)) == 3: + #assuming it's 3D + # run independent calls on each slice + out3d = input.copy() + for i in range(np.shape(input)[2]): + out = self.algorithm(input, regularization_parameter, + self.pars['first_order_term'] , + self.pars['second_order_term'] , + self.pars['number_of_iterations']) + # copy the result in the 3D image + out3d.T[i] = out[0].copy() + # append the rest of the info that the algorithm returns + output = [out3d] + for i in range(1,len(out)): + output.append(out[i]) + return output + + + + + + # __call__ + + @staticmethod + def SplitBregman_TV(input, regularization_parameter , **kwargs): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + out = list( reg(input, regularization_parameter, **kwargs) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def FGP_TV(input, regularization_parameter , **kwargs): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.FGP_TV) + out = list( reg(input, regularization_parameter, **kwargs) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def LLT_model(input, regularization_parameter , time_step, number_of_iterations, + tolerance_constant, restrictive_Z_smoothing=0): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.LLT_model) + out = list( reg(input, regularization_parameter, time_step=time_step, + number_of_iterations=number_of_iterations, + tolerance_constant=tolerance_constant, + restrictive_Z_smoothing=restrictive_Z_smoothing) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def PatchBased_Regul(input, regularization_parameter, + searching_window_ratio, + similarity_window_ratio, + PB_filtering_parameter): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.PatchBased_Regul) + out = list( reg(input, + regularization_parameter, + searching_window_ratio=searching_window_ratio, + similarity_window_ratio=similarity_window_ratio, + PB_filtering_parameter=PB_filtering_parameter ) + ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def TGV_PD(input, regularization_parameter , first_order_term, + second_order_term, number_of_iterations): + start_time = timeit.default_timer() + + reg = Regularizer(Regularizer.Algorithm.TGV_PD) + out = list( reg(input, regularization_parameter, + first_order_term=first_order_term, + second_order_term=second_order_term, + number_of_iterations=number_of_iterations) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + + return out + + def printParametersToString(self): + txt = r'' + for key, value in self.pars.items(): + if key== 'algorithm' : + txt += "{0} = {1}".format(key, value.__name__) + elif key == 'input': + txt += "{0} = {1}".format(key, np.shape(value)) + else: + txt += "{0} = {1}".format(key, value) + txt += '\n' + return txt + -- cgit v1.2.3 From 56915cc00ded38d24c23b9ab1a0717d52d430ddd Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 16:54:59 +0100 Subject: initial revision for testing --- .../ccpi/reconstruction/FISTAReconstructor.py | 354 +++++++++++++++++++++ 1 file changed, 354 insertions(+) create mode 100644 src/Python/ccpi/reconstruction/FISTAReconstructor.py (limited to 'src/Python/ccpi') diff --git a/src/Python/ccpi/reconstruction/FISTAReconstructor.py b/src/Python/ccpi/reconstruction/FISTAReconstructor.py new file mode 100644 index 0000000..ea96b53 --- /dev/null +++ b/src/Python/ccpi/reconstruction/FISTAReconstructor.py @@ -0,0 +1,354 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +import h5py +#from ccpi.reconstruction.parallelbeam import alg + +from ccpi.imaging.Regularizer import Regularizer +from enum import Enum + +import astra + + + +class FISTAReconstructor(): + '''FISTA-based reconstruction algorithm using ASTRA-toolbox + + ''' + # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> + # ___Input___: + # params.[] file: + # - .proj_geom (geometry of the projector) [required] + # - .vol_geom (geometry of the reconstructed object) [required] + # - .sino (vectorized in 2D or 3D sinogram) [required] + # - .iterFISTA (iterations for the main loop, default 40) + # - .L_const (Lipschitz constant, default Power method) ) + # - .X_ideal (ideal image, if given) + # - .weights (statisitcal weights, size of the sinogram) + # - .ROI (Region-of-interest, only if X_ideal is given) + # - .initialize (a 'warm start' using SIRT method from ASTRA) + #----------------Regularization choices------------------------ + # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) + # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) + # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) + # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) + # - .Regul_Iterations (iterations for the selected penalty, default 25) + # - .Regul_tauLLT (time step parameter for LLT term) + # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) + # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) + #----------------Visualization parameters------------------------ + # - .show (visualize reconstruction 1/0, (0 default)) + # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) + # - .slice (for 3D volumes - slice number to imshow) + # ___Output___: + # 1. X - reconstructed image/volume + # 2. output - a structure with + # - .Resid_error - residual error (if X_ideal is given) + # - .objective: value of the objective function + # - .L_const: Lipshitz constant to avoid recalculations + + # References: + # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse + # Problems" by A. Beck and M Teboulle + # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo + # 3. "A novel tomographic reconstruction method based on the robust + # Student's t function for suppressing data outliers" D. Kazantsev et.al. + # D. Kazantsev, 2016-17 + def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + self.params = dict() + self.params['projector_geometry'] = projector_geometry + self.params['output_geometry'] = output_geometry + self.params['input_sinogram'] = input_sinogram + detectors, nangles, sliceZ = numpy.shape(input_sinogram) + self.params['detectors'] = detectors + self.params['number_og_angles'] = nangles + self.params['SlicesZ'] = sliceZ + + # Accepted input keywords + kw = ('number_of_iterations', + 'Lipschitz_constant' , + 'ideal_image' , + 'weights' , + 'region_of_interest' , + 'initialize' , + 'regularizer' , + 'ring_lambda_R_L1', + 'ring_alpha') + + # handle keyworded parameters + if kwargs is not None: + for key, value in kwargs.items(): + if key in kw: + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + # set the default values for the parameters if not set + if 'number_of_iterations' in kwargs.keys(): + self.pars['number_of_iterations'] = kwargs['number_of_iterations'] + else: + self.pars['number_of_iterations'] = 40 + if 'weights' in kwargs.keys(): + self.pars['weights'] = kwargs['weights'] + else: + self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) + if 'Lipschitz_constant' in kwargs.keys(): + self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] + else: + self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + + if not self.pars['ideal_image'] in kwargs.keys(): + self.pars['ideal_image'] = None + + if not self.pars['region_of_interest'] : + if self.pars['ideal_image'] == None: + pass + else: + self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + + if not self.pars['regularizer'] : + self.pars['regularizer'] = None + else: + # the regularizer must be a correctly instantiated object + if not self.pars['ring_lambda_R_L1']: + self.pars['ring_lambda_R_L1'] = 0 + if not self.pars['ring_alpha']: + self.pars['ring_alpha'] = 1 + + + + + def calculateLipschitzConstantWithPowerMethod(self): + ''' using Power method (PM) to establish L constant''' + + #N = params.vol_geom.GridColCount + N = self.pars['output_geometry'].GridColCount + proj_geom = self.params['projector_geometry'] + vol_geom = self.params['output_geometry'] + weights = self.pars['weights'] + SlicesZ = self.pars['SlicesZ'] + + if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); + niter = 15;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights.T[0]) + proj_geomT = proj_geom.copy(); + proj_geomT.DetectorRowCount = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + + for i in range(niter): + if i == 0: + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); + y = sqweight * y # element wise multiplication + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); + y = sqweight*y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + del proj_geomT + del vol_geomT + else: + #% divergen beam geometry + #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights.T[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 + + return s + + + def setRegularizer(self, regularizer): + if regularizer is not None: + self.pars['regularizer'] = regularizer + + + + + +def getEntry(location): + for item in nx[location].keys(): + print (item) + + +print ("Loading Data") + +##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" +####ind = [i * 1049 for i in range(360)] +#### use only 360 images +##images = 200 +##ind = [int(i * 1049 / images) for i in range(images)] +##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) + +#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" +#fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" +##fname = "/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/dendr.h5" +##nx = h5py.File(fname, "r") +## +### the data are stored in a particular location in the hdf5 +##for item in nx['entry1/tomo_entry/data'].keys(): +## print (item) +## +##data = nx.get('entry1/tomo_entry/data/rotation_angle') +##angles = numpy.zeros(data.shape) +##data.read_direct(angles) +##print (angles) +### angles should be in degrees +## +##data = nx.get('entry1/tomo_entry/data/data') +##stack = numpy.zeros(data.shape) +##data.read_direct(stack) +##print (data.shape) +## +##print ("Data Loaded") +## +## +### Normalize +##data = nx.get('entry1/tomo_entry/instrument/detector/image_key') +##itype = numpy.zeros(data.shape) +##data.read_direct(itype) +### 2 is dark field +##darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] +##dark = darks[0] +##for i in range(1, len(darks)): +## dark += darks[i] +##dark = dark / len(darks) +###dark[0][0] = dark[0][1] +## +### 1 is flat field +##flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] +##flat = flats[0] +##for i in range(1, len(flats)): +## flat += flats[i] +##flat = flat / len(flats) +###flat[0][0] = dark[0][1] +## +## +### 0 is projection data +##proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] +##angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] +##angle_proj = numpy.asarray (angle_proj) +##angle_proj = angle_proj.astype(numpy.float32) +## +### normalized data are +### norm = (projection - dark)/(flat-dark) +## +##def normalize(projection, dark, flat, def_val=0.1): +## a = (projection - dark) +## b = (flat-dark) +## with numpy.errstate(divide='ignore', invalid='ignore'): +## c = numpy.true_divide( a, b ) +## c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 +## return c +## +## +##norm = [normalize(projection, dark, flat) for projection in proj] +##norm = numpy.asarray (norm) +##norm = norm.astype(numpy.float32) + + +##niterations = 15 +##threads = 3 +## +##img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +##img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +##img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +## +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## iteration_values, False) +##print ("iteration values %s" % str(iteration_values)) +## +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## numpy.double(1e-5), iteration_values , False) +##print ("iteration values %s" % str(iteration_values)) +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## numpy.double(1e-5), iteration_values , False) +##print ("iteration values %s" % str(iteration_values)) +## +## +####numpy.save("cgls_recon.npy", img_data) +##import matplotlib.pyplot as plt +##fig, ax = plt.subplots(1,6,sharey=True) +##ax[0].imshow(img_cgls[80]) +##ax[0].axis('off') # clear x- and y-axes +##ax[1].imshow(img_sirt[80]) +##ax[1].axis('off') # clear x- and y-axes +##ax[2].imshow(img_mlem[80]) +##ax[2].axis('off') # clear x- and y-axesplt.show() +##ax[3].imshow(img_cgls_conv[80]) +##ax[3].axis('off') # clear x- and y-axesplt.show() +##ax[4].imshow(img_cgls_tikhonov[80]) +##ax[4].axis('off') # clear x- and y-axesplt.show() +##ax[5].imshow(img_cgls_TVreg[80]) +##ax[5].axis('off') # clear x- and y-axesplt.show() +## +## +##plt.show() +## + -- cgit v1.2.3 From 7111d98258becca09e4c93e3c66edb7d524d6463 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 24 Aug 2017 16:42:05 +0100 Subject: initial revision --- src/Python/ccpi/imaging/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/Python/ccpi/imaging/__init__.py (limited to 'src/Python/ccpi') diff --git a/src/Python/ccpi/imaging/__init__.py b/src/Python/ccpi/imaging/__init__.py new file mode 100644 index 0000000..e69de29 -- cgit v1.2.3 From 3f26b1d8ab3a632ceca97bdf04225008f9163684 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 24 Aug 2017 16:42:27 +0100 Subject: initial revision --- src/Python/ccpi/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/Python/ccpi/__init__.py (limited to 'src/Python/ccpi') diff --git a/src/Python/ccpi/__init__.py b/src/Python/ccpi/__init__.py new file mode 100644 index 0000000..e69de29 -- cgit v1.2.3 From 391473269674bc98697eabac0b4fb2bd89f5d85e Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 25 Aug 2017 16:55:48 +0100 Subject: Reorganized code with new fista package name --- src/Python/ccpi/fista/FISTAReconstructor.py | 389 ++++++++++++++ src/Python/ccpi/fista/FISTAReconstructor.pyc | Bin 0 -> 3804 bytes src/Python/ccpi/fista/FISTAReconstructor.py~ | 349 ++++++++++++ src/Python/ccpi/fista/Reconstructor.py | 425 +++++++++++++++ src/Python/ccpi/fista/Reconstructor.py~ | 598 +++++++++++++++++++++ src/Python/ccpi/fista/__init__.py | 0 src/Python/ccpi/fista/__init__.pyc | Bin 0 -> 189 bytes .../__pycache__/FISTAReconstructor.cpython-35.pyc | Bin 0 -> 3641 bytes .../ccpi/fista/__pycache__/__init__.cpython-35.pyc | Bin 0 -> 185 bytes 9 files changed, 1761 insertions(+) create mode 100644 src/Python/ccpi/fista/FISTAReconstructor.py create mode 100644 src/Python/ccpi/fista/FISTAReconstructor.pyc create mode 100644 src/Python/ccpi/fista/FISTAReconstructor.py~ create mode 100644 src/Python/ccpi/fista/Reconstructor.py create mode 100644 src/Python/ccpi/fista/Reconstructor.py~ create mode 100644 src/Python/ccpi/fista/__init__.py create mode 100644 src/Python/ccpi/fista/__init__.pyc create mode 100644 src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc create mode 100644 src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc (limited to 'src/Python/ccpi') diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py new file mode 100644 index 0000000..1e76815 --- /dev/null +++ b/src/Python/ccpi/fista/FISTAReconstructor.py @@ -0,0 +1,389 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +#from ccpi.reconstruction.parallelbeam import alg + +#from ccpi.imaging.Regularizer import Regularizer +from enum import Enum + +import astra + + + +class FISTAReconstructor(): + '''FISTA-based reconstruction algorithm using ASTRA-toolbox + + ''' + # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> + # ___Input___: + # params.[] file: + # - .proj_geom (geometry of the projector) [required] + # - .vol_geom (geometry of the reconstructed object) [required] + # - .sino (vectorized in 2D or 3D sinogram) [required] + # - .iterFISTA (iterations for the main loop, default 40) + # - .L_const (Lipschitz constant, default Power method) ) + # - .X_ideal (ideal image, if given) + # - .weights (statisitcal weights, size of the sinogram) + # - .ROI (Region-of-interest, only if X_ideal is given) + # - .initialize (a 'warm start' using SIRT method from ASTRA) + #----------------Regularization choices------------------------ + # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) + # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) + # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) + # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) + # - .Regul_Iterations (iterations for the selected penalty, default 25) + # - .Regul_tauLLT (time step parameter for LLT term) + # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) + # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) + #----------------Visualization parameters------------------------ + # - .show (visualize reconstruction 1/0, (0 default)) + # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) + # - .slice (for 3D volumes - slice number to imshow) + # ___Output___: + # 1. X - reconstructed image/volume + # 2. output - a structure with + # - .Resid_error - residual error (if X_ideal is given) + # - .objective: value of the objective function + # - .L_const: Lipshitz constant to avoid recalculations + + # References: + # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse + # Problems" by A. Beck and M Teboulle + # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo + # 3. "A novel tomographic reconstruction method based on the robust + # Student's t function for suppressing data outliers" D. Kazantsev et.al. + # D. Kazantsev, 2016-17 + def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + # handle parmeters: + # obligatory parameters + self.pars = dict() + self.pars['projector_geometry'] = projector_geometry + self.pars['output_geometry'] = output_geometry + self.pars['input_sinogram'] = input_sinogram + detectors, nangles, sliceZ = numpy.shape(input_sinogram) + self.pars['detectors'] = detectors + self.pars['number_og_angles'] = nangles + self.pars['SlicesZ'] = sliceZ + + print (self.pars) + # handle optional input parameters (at instantiation) + + # Accepted input keywords + kw = ('number_of_iterations', + 'Lipschitz_constant' , + 'ideal_image' , + 'weights' , + 'region_of_interest' , + 'initialize' , + 'regularizer' , + 'ring_lambda_R_L1', + 'ring_alpha') + + # handle keyworded parameters + if kwargs is not None: + for key, value in kwargs.items(): + if key in kw: + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + # set the default values for the parameters if not set + if 'number_of_iterations' in kwargs.keys(): + self.pars['number_of_iterations'] = kwargs['number_of_iterations'] + else: + self.pars['number_of_iterations'] = 40 + if 'weights' in kwargs.keys(): + self.pars['weights'] = kwargs['weights'] + else: + self.pars['weights'] = numpy.ones(numpy.shape(self.pars['input_sinogram'])) + if 'Lipschitz_constant' in kwargs.keys(): + self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] + else: + self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + + if not 'ideal_image' in kwargs.keys(): + self.pars['ideal_image'] = None + + if not 'region_of_interest'in kwargs.keys() : + if self.pars['ideal_image'] == None: + pass + else: + self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + + if not 'regularizer' in kwargs.keys() : + self.pars['regularizer'] = None + else: + # the regularizer must be a correctly instantiated object + if not 'ring_lambda_R_L1' in kwargs.keys(): + self.pars['ring_lambda_R_L1'] = 0 + if not 'ring_alpha' in kwargs.keys(): + self.pars['ring_alpha'] = 1 + + + + + def calculateLipschitzConstantWithPowerMethod(self): + ''' using Power method (PM) to establish L constant''' + + N = self.pars['output_geometry']['GridColCount'] + proj_geom = self.pars['projector_geometry'] + vol_geom = self.pars['output_geometry'] + weights = self.pars['weights'] + SlicesZ = self.pars['SlicesZ'] + + + + if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + #print('Calculating Lipshitz constant for parallel beam geometry...') + niter = 5;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights[0]) + proj_geomT = proj_geom.copy(); + proj_geomT['DetectorRowCount'] = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + + + for i in range(niter): + # [id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geomT, vol_geomT); + # s = norm(x1(:)); + # x1 = x1/s; + # [sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + # y = sqweight.*y; + # astra_mex_data3d('delete', sino_id); + # astra_mex_data3d('delete', id); + #print ("iteration {0}".format(i)) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geomT, + vol_geomT) + + y = (sqweight * y).copy() # element wise multiplication + + #b=fig.add_subplot(2,1,2) + #imgplot = plt.imshow(x1[0]) + #plt.show() + + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + del x1 + + idx,x1 = astra.creators.create_backprojection3d_gpu((sqweight*y).copy(), + proj_geomT, + vol_geomT) + del y + + + s = numpy.linalg.norm(x1) + ### this line? + x1 = (x1/s).copy(); + + # ### this line? + # sino_id, y = astra.creators.create_sino3d_gpu(x1, + # proj_geomT, + # vol_geomT); + # y = sqweight * y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx) + print ("iteration {0} s= {1}".format(i,s)) + + #end + del proj_geomT + del vol_geomT + #plt.show() + else: + #% divergen beam geometry + print('Calculating Lipshitz constant for divergen beam geometry...') + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 + + + return s + + + def setRegularizer(self, regularizer): + if regularizer is not None: + self.pars['regularizer'] = regularizer + + + + + +def getEntry(location, nx): + for item in nx[location].keys(): + print (item) + + +print ("Loading Data") + +##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" +####ind = [i * 1049 for i in range(360)] +#### use only 360 images +##images = 200 +##ind = [int(i * 1049 / images) for i in range(images)] +##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) + +#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" +#fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" +##fname = "/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/dendr.h5" +##nx = h5py.File(fname, "r") +## +### the data are stored in a particular location in the hdf5 +##for item in nx['entry1/tomo_entry/data'].keys(): +## print (item) +## +##data = nx.get('entry1/tomo_entry/data/rotation_angle') +##angles = numpy.zeros(data.shape) +##data.read_direct(angles) +##print (angles) +### angles should be in degrees +## +##data = nx.get('entry1/tomo_entry/data/data') +##stack = numpy.zeros(data.shape) +##data.read_direct(stack) +##print (data.shape) +## +##print ("Data Loaded") +## +## +### Normalize +##data = nx.get('entry1/tomo_entry/instrument/detector/image_key') +##itype = numpy.zeros(data.shape) +##data.read_direct(itype) +### 2 is dark field +##darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] +##dark = darks[0] +##for i in range(1, len(darks)): +## dark += darks[i] +##dark = dark / len(darks) +###dark[0][0] = dark[0][1] +## +### 1 is flat field +##flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] +##flat = flats[0] +##for i in range(1, len(flats)): +## flat += flats[i] +##flat = flat / len(flats) +###flat[0][0] = dark[0][1] +## +## +### 0 is projection data +##proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] +##angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] +##angle_proj = numpy.asarray (angle_proj) +##angle_proj = angle_proj.astype(numpy.float32) +## +### normalized data are +### norm = (projection - dark)/(flat-dark) +## +##def normalize(projection, dark, flat, def_val=0.1): +## a = (projection - dark) +## b = (flat-dark) +## with numpy.errstate(divide='ignore', invalid='ignore'): +## c = numpy.true_divide( a, b ) +## c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 +## return c +## +## +##norm = [normalize(projection, dark, flat) for projection in proj] +##norm = numpy.asarray (norm) +##norm = norm.astype(numpy.float32) + + +##niterations = 15 +##threads = 3 +## +##img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +##img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +##img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +## +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## iteration_values, False) +##print ("iteration values %s" % str(iteration_values)) +## +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## numpy.double(1e-5), iteration_values , False) +##print ("iteration values %s" % str(iteration_values)) +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## numpy.double(1e-5), iteration_values , False) +##print ("iteration values %s" % str(iteration_values)) +## +## +####numpy.save("cgls_recon.npy", img_data) +##import matplotlib.pyplot as plt +##fig, ax = plt.subplots(1,6,sharey=True) +##ax[0].imshow(img_cgls[80]) +##ax[0].axis('off') # clear x- and y-axes +##ax[1].imshow(img_sirt[80]) +##ax[1].axis('off') # clear x- and y-axes +##ax[2].imshow(img_mlem[80]) +##ax[2].axis('off') # clear x- and y-axesplt.show() +##ax[3].imshow(img_cgls_conv[80]) +##ax[3].axis('off') # clear x- and y-axesplt.show() +##ax[4].imshow(img_cgls_tikhonov[80]) +##ax[4].axis('off') # clear x- and y-axesplt.show() +##ax[5].imshow(img_cgls_TVreg[80]) +##ax[5].axis('off') # clear x- and y-axesplt.show() +## +## +##plt.show() +## + diff --git a/src/Python/ccpi/fista/FISTAReconstructor.pyc b/src/Python/ccpi/fista/FISTAReconstructor.pyc new file mode 100644 index 0000000..ecc4d7d Binary files /dev/null and b/src/Python/ccpi/fista/FISTAReconstructor.pyc differ diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py~ b/src/Python/ccpi/fista/FISTAReconstructor.py~ new file mode 100644 index 0000000..6c7024d --- /dev/null +++ b/src/Python/ccpi/fista/FISTAReconstructor.py~ @@ -0,0 +1,349 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +import h5py +#from ccpi.reconstruction.parallelbeam import alg + +from Regularizer import Regularizer +from enum import Enum + +import astra + + + +class FISTAReconstructor(): + '''FISTA-based reconstruction algorithm using ASTRA-toolbox + + ''' + # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> + # ___Input___: + # params.[] file: + # - .proj_geom (geometry of the projector) [required] + # - .vol_geom (geometry of the reconstructed object) [required] + # - .sino (vectorized in 2D or 3D sinogram) [required] + # - .iterFISTA (iterations for the main loop, default 40) + # - .L_const (Lipschitz constant, default Power method) ) + # - .X_ideal (ideal image, if given) + # - .weights (statisitcal weights, size of the sinogram) + # - .ROI (Region-of-interest, only if X_ideal is given) + # - .initialize (a 'warm start' using SIRT method from ASTRA) + #----------------Regularization choices------------------------ + # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) + # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) + # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) + # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) + # - .Regul_Iterations (iterations for the selected penalty, default 25) + # - .Regul_tauLLT (time step parameter for LLT term) + # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) + # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) + #----------------Visualization parameters------------------------ + # - .show (visualize reconstruction 1/0, (0 default)) + # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) + # - .slice (for 3D volumes - slice number to imshow) + # ___Output___: + # 1. X - reconstructed image/volume + # 2. output - a structure with + # - .Resid_error - residual error (if X_ideal is given) + # - .objective: value of the objective function + # - .L_const: Lipshitz constant to avoid recalculations + + # References: + # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse + # Problems" by A. Beck and M Teboulle + # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo + # 3. "A novel tomographic reconstruction method based on the robust + # Student's t function for suppressing data outliers" D. Kazantsev et.al. + # D. Kazantsev, 2016-17 + def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + self.params = dict() + self.params['projector_geometry'] = projector_geometry + self.params['output_geometry'] = output_geometry + self.params['input_sinogram'] = input_sinogram + detectors, nangles, sliceZ = numpy.shape(input_sinogram) + self.params['detectors'] = detectors + self.params['number_og_angles'] = nangles + self.params['SlicesZ'] = sliceZ + + # Accepted input keywords + kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , + 'weights' , 'region_of_interest' , 'initialize' , + 'regularizer' , + 'ring_lambda_R_L1', + 'ring_alpha') + + # handle keyworded parameters + if kwargs is not None: + for key, value in kwargs.items(): + if key in kw: + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + # set the default values for the parameters if not set + if 'number_of_iterations' in kwargs.keys(): + self.pars['number_of_iterations'] = kwargs['number_of_iterations'] + else: + self.pars['number_of_iterations'] = 40 + if 'weights' in kwargs.keys(): + self.pars['weights'] = kwargs['weights'] + else: + self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) + if 'Lipschitz_constant' in kwargs.keys(): + self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] + else: + self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + + if not self.pars['ideal_image'] in kwargs.keys(): + self.pars['ideal_image'] = None + + if not self.pars['region_of_interest'] : + if self.pars['ideal_image'] == None: + pass + else: + self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + + if not self.pars['regularizer'] : + self.pars['regularizer'] = None + else: + # the regularizer must be a correctly instantiated object + if not self.pars['ring_lambda_R_L1']: + self.pars['ring_lambda_R_L1'] = 0 + if not self.pars['ring_alpha']: + self.pars['ring_alpha'] = 1 + + + + + def calculateLipschitzConstantWithPowerMethod(self): + ''' using Power method (PM) to establish L constant''' + + #N = params.vol_geom.GridColCount + N = self.pars['output_geometry'].GridColCount + proj_geom = self.params['projector_geometry'] + vol_geom = self.params['output_geometry'] + weights = self.pars['weights'] + SlicesZ = self.pars['SlicesZ'] + + if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); + niter = 15;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights.T[0]) + proj_geomT = proj_geom.copy(); + proj_geomT.DetectorRowCount = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + + for i in range(niter): + if i == 0: + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); + y = sqweight * y # element wise multiplication + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + y = sqweight*y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + del proj_geomT + del vol_geomT + else + #% divergen beam geometry + #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights.T[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 + + return s + + + def setRegularizer(self, regularizer): + if regularizer + self.pars['regularizer'] = regularizer + + + + + +def getEntry(location): + for item in nx[location].keys(): + print (item) + + +print ("Loading Data") + +##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" +####ind = [i * 1049 for i in range(360)] +#### use only 360 images +##images = 200 +##ind = [int(i * 1049 / images) for i in range(images)] +##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) + +#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" +fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" +nx = h5py.File(fname, "r") + +# the data are stored in a particular location in the hdf5 +for item in nx['entry1/tomo_entry/data'].keys(): + print (item) + +data = nx.get('entry1/tomo_entry/data/rotation_angle') +angles = numpy.zeros(data.shape) +data.read_direct(angles) +print (angles) +# angles should be in degrees + +data = nx.get('entry1/tomo_entry/data/data') +stack = numpy.zeros(data.shape) +data.read_direct(stack) +print (data.shape) + +print ("Data Loaded") + + +# Normalize +data = nx.get('entry1/tomo_entry/instrument/detector/image_key') +itype = numpy.zeros(data.shape) +data.read_direct(itype) +# 2 is dark field +darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] +dark = darks[0] +for i in range(1, len(darks)): + dark += darks[i] +dark = dark / len(darks) +#dark[0][0] = dark[0][1] + +# 1 is flat field +flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] +flat = flats[0] +for i in range(1, len(flats)): + flat += flats[i] +flat = flat / len(flats) +#flat[0][0] = dark[0][1] + + +# 0 is projection data +proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = numpy.asarray (angle_proj) +angle_proj = angle_proj.astype(numpy.float32) + +# normalized data are +# norm = (projection - dark)/(flat-dark) + +def normalize(projection, dark, flat, def_val=0.1): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + +norm = [normalize(projection, dark, flat) for projection in proj] +norm = numpy.asarray (norm) +norm = norm.astype(numpy.float32) + + +niterations = 15 +threads = 3 + +img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + iteration_values, False) +print ("iteration values %s" % str(iteration_values)) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) +iteration_values = numpy.zeros((niterations,)) +img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) + + +##numpy.save("cgls_recon.npy", img_data) +import matplotlib.pyplot as plt +fig, ax = plt.subplots(1,6,sharey=True) +ax[0].imshow(img_cgls[80]) +ax[0].axis('off') # clear x- and y-axes +ax[1].imshow(img_sirt[80]) +ax[1].axis('off') # clear x- and y-axes +ax[2].imshow(img_mlem[80]) +ax[2].axis('off') # clear x- and y-axesplt.show() +ax[3].imshow(img_cgls_conv[80]) +ax[3].axis('off') # clear x- and y-axesplt.show() +ax[4].imshow(img_cgls_tikhonov[80]) +ax[4].axis('off') # clear x- and y-axesplt.show() +ax[5].imshow(img_cgls_TVreg[80]) +ax[5].axis('off') # clear x- and y-axesplt.show() + + +plt.show() + + diff --git a/src/Python/ccpi/fista/Reconstructor.py b/src/Python/ccpi/fista/Reconstructor.py new file mode 100644 index 0000000..d29ac0d --- /dev/null +++ b/src/Python/ccpi/fista/Reconstructor.py @@ -0,0 +1,425 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +import h5py +from ccpi.reconstruction.parallelbeam import alg + +from Regularizer import Regularizer +from enum import Enum + +import astra + + + +class FISTAReconstructor(): + '''FISTA-based reconstruction algorithm using ASTRA-toolbox + + ''' + # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> + # ___Input___: + # params.[] file: + # - .proj_geom (geometry of the projector) [required] + # - .vol_geom (geometry of the reconstructed object) [required] + # - .sino (vectorized in 2D or 3D sinogram) [required] + # - .iterFISTA (iterations for the main loop, default 40) + # - .L_const (Lipschitz constant, default Power method) ) + # - .X_ideal (ideal image, if given) + # - .weights (statisitcal weights, size of the sinogram) + # - .ROI (Region-of-interest, only if X_ideal is given) + # - .initialize (a 'warm start' using SIRT method from ASTRA) + #----------------Regularization choices------------------------ + # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) + # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) + # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) + # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) + # - .Regul_Iterations (iterations for the selected penalty, default 25) + # - .Regul_tauLLT (time step parameter for LLT term) + # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) + # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) + #----------------Visualization parameters------------------------ + # - .show (visualize reconstruction 1/0, (0 default)) + # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) + # - .slice (for 3D volumes - slice number to imshow) + # ___Output___: + # 1. X - reconstructed image/volume + # 2. output - a structure with + # - .Resid_error - residual error (if X_ideal is given) + # - .objective: value of the objective function + # - .L_const: Lipshitz constant to avoid recalculations + + # References: + # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse + # Problems" by A. Beck and M Teboulle + # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo + # 3. "A novel tomographic reconstruction method based on the robust + # Student's t function for suppressing data outliers" D. Kazantsev et.al. + # D. Kazantsev, 2016-17 + def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + self.params = dict() + self.params['projector_geometry'] = projector_geometry + self.params['output_geometry'] = output_geometry + self.params['input_sinogram'] = input_sinogram + detectors, nangles, sliceZ = numpy.shape(input_sinogram) + self.params['detectors'] = detectors + self.params['number_og_angles'] = nangles + self.params['SlicesZ'] = sliceZ + + # Accepted input keywords + kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , + 'weights' , 'region_of_interest' , 'initialize' , + 'regularizer' , + 'ring_lambda_R_L1', + 'ring_alpha') + + # handle keyworded parameters + if kwargs is not None: + for key, value in kwargs.items(): + if key in kw: + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + # set the default values for the parameters if not set + if 'number_of_iterations' in kwargs.keys(): + self.pars['number_of_iterations'] = kwargs['number_of_iterations'] + else: + self.pars['number_of_iterations'] = 40 + if 'weights' in kwargs.keys(): + self.pars['weights'] = kwargs['weights'] + else: + self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) + if 'Lipschitz_constant' in kwargs.keys(): + self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] + else: + self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + + if not self.pars['ideal_image'] in kwargs.keys(): + self.pars['ideal_image'] = None + + if not self.pars['region_of_interest'] : + if self.pars['ideal_image'] == None: + pass + else: + self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + + if not self.pars['regularizer'] : + self.pars['regularizer'] = None + else: + # the regularizer must be a correctly instantiated object + if not self.pars['ring_lambda_R_L1']: + self.pars['ring_lambda_R_L1'] = 0 + if not self.pars['ring_alpha']: + self.pars['ring_alpha'] = 1 + + + + + def calculateLipschitzConstantWithPowerMethod(self): + ''' using Power method (PM) to establish L constant''' + + #N = params.vol_geom.GridColCount + N = self.pars['output_geometry'].GridColCount + proj_geom = self.params['projector_geometry'] + vol_geom = self.params['output_geometry'] + weights = self.pars['weights'] + SlicesZ = self.pars['SlicesZ'] + + if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); + niter = 15;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights.T[0]) + proj_geomT = proj_geom.copy(); + proj_geomT.DetectorRowCount = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + + for i in range(niter): + if i == 0: + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); + y = sqweight * y # element wise multiplication + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + y = sqweight*y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + del proj_geomT + del vol_geomT + else + #% divergen beam geometry + #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights.T[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 + + return s + + + def setRegularizer(self, regularizer): + if regularizer + self.pars['regularizer'] = regularizer + + + + + +def getEntry(location): + for item in nx[location].keys(): + print (item) + + +print ("Loading Data") + +##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" +####ind = [i * 1049 for i in range(360)] +#### use only 360 images +##images = 200 +##ind = [int(i * 1049 / images) for i in range(images)] +##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) + +#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" +fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" +nx = h5py.File(fname, "r") + +# the data are stored in a particular location in the hdf5 +for item in nx['entry1/tomo_entry/data'].keys(): + print (item) + +data = nx.get('entry1/tomo_entry/data/rotation_angle') +angles = numpy.zeros(data.shape) +data.read_direct(angles) +print (angles) +# angles should be in degrees + +data = nx.get('entry1/tomo_entry/data/data') +stack = numpy.zeros(data.shape) +data.read_direct(stack) +print (data.shape) + +print ("Data Loaded") + + +# Normalize +data = nx.get('entry1/tomo_entry/instrument/detector/image_key') +itype = numpy.zeros(data.shape) +data.read_direct(itype) +# 2 is dark field +darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] +dark = darks[0] +for i in range(1, len(darks)): + dark += darks[i] +dark = dark / len(darks) +#dark[0][0] = dark[0][1] + +# 1 is flat field +flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] +flat = flats[0] +for i in range(1, len(flats)): + flat += flats[i] +flat = flat / len(flats) +#flat[0][0] = dark[0][1] + + +# 0 is projection data +proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = numpy.asarray (angle_proj) +angle_proj = angle_proj.astype(numpy.float32) + +# normalized data are +# norm = (projection - dark)/(flat-dark) + +def normalize(projection, dark, flat, def_val=0.1): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + +norm = [normalize(projection, dark, flat) for projection in proj] +norm = numpy.asarray (norm) +norm = norm.astype(numpy.float32) + +#recon = Reconstructor(algorithm = Algorithm.CGLS, normalized_projection = norm, +# angles = angle_proj, center_of_rotation = 86.2 , +# flat_field = flat, dark_field = dark, +# iterations = 15, resolution = 1, isLogScale = False, threads = 3) + +#recon = Reconstructor(algorithm = Reconstructor.Algorithm.CGLS, projection_data = proj, +# angles = angle_proj, center_of_rotation = 86.2 , +# flat_field = flat, dark_field = dark, +# iterations = 15, resolution = 1, isLogScale = False, threads = 3) +#img_cgls = recon.reconstruct() +# +#pars = dict() +#pars['algorithm'] = Reconstructor.Algorithm.SIRT +#pars['projection_data'] = proj +#pars['angles'] = angle_proj +#pars['center_of_rotation'] = numpy.double(86.2) +#pars['flat_field'] = flat +#pars['iterations'] = 15 +#pars['dark_field'] = dark +#pars['resolution'] = 1 +#pars['isLogScale'] = False +#pars['threads'] = 3 +# +#img_sirt = recon.reconstruct(pars) +# +#recon.pars['algorithm'] = Reconstructor.Algorithm.MLEM +#img_mlem = recon.reconstruct() + +############################################################ +############################################################ +#recon.pars['algorithm'] = Reconstructor.Algorithm.CGLS_CONV +#recon.pars['regularize'] = numpy.double(0.1) +#img_cgls_conv = recon.reconstruct() + +niterations = 15 +threads = 3 + +img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + iteration_values, False) +print ("iteration values %s" % str(iteration_values)) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) +iteration_values = numpy.zeros((niterations,)) +img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) + + +##numpy.save("cgls_recon.npy", img_data) +import matplotlib.pyplot as plt +fig, ax = plt.subplots(1,6,sharey=True) +ax[0].imshow(img_cgls[80]) +ax[0].axis('off') # clear x- and y-axes +ax[1].imshow(img_sirt[80]) +ax[1].axis('off') # clear x- and y-axes +ax[2].imshow(img_mlem[80]) +ax[2].axis('off') # clear x- and y-axesplt.show() +ax[3].imshow(img_cgls_conv[80]) +ax[3].axis('off') # clear x- and y-axesplt.show() +ax[4].imshow(img_cgls_tikhonov[80]) +ax[4].axis('off') # clear x- and y-axesplt.show() +ax[5].imshow(img_cgls_TVreg[80]) +ax[5].axis('off') # clear x- and y-axesplt.show() + + +plt.show() + +#viewer = edo.CILViewer() +#viewer.setInputAsNumpy(img_cgls2) +#viewer.displaySliceActor(0) +#viewer.startRenderLoop() + +import vtk + +def NumpyToVTKImageData(numpyarray): + if (len(numpy.shape(numpyarray)) == 3): + doubleImg = vtk.vtkImageData() + shape = numpy.shape(numpyarray) + doubleImg.SetDimensions(shape[0], shape[1], shape[2]) + doubleImg.SetOrigin(0,0,0) + doubleImg.SetSpacing(1,1,1) + doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) + #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) + doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) + + for i in range(shape[0]): + for j in range(shape[1]): + for k in range(shape[2]): + doubleImg.SetScalarComponentFromDouble( + i,j,k,0, numpyarray[i][j][k]) + #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) + # rescale to appropriate VTK_UNSIGNED_SHORT + stats = vtk.vtkImageAccumulate() + stats.SetInputData(doubleImg) + stats.Update() + iMin = stats.GetMin()[0] + iMax = stats.GetMax()[0] + scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) + + shiftScaler = vtk.vtkImageShiftScale () + shiftScaler.SetInputData(doubleImg) + shiftScaler.SetScale(scale) + shiftScaler.SetShift(iMin) + shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) + shiftScaler.Update() + return shiftScaler.GetOutput() + +#writer = vtk.vtkMetaImageWriter() +#writer.SetFileName(alg + "_recon.mha") +#writer.SetInputData(NumpyToVTKImageData(img_cgls2)) +#writer.Write() diff --git a/src/Python/ccpi/fista/Reconstructor.py~ b/src/Python/ccpi/fista/Reconstructor.py~ new file mode 100644 index 0000000..ba67327 --- /dev/null +++ b/src/Python/ccpi/fista/Reconstructor.py~ @@ -0,0 +1,598 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +import h5py +from ccpi.reconstruction.parallelbeam import alg + +from Regularizer import Regularizer +from enum import Enum + +import astra + + +class Reconstructor: + + class Algorithm(Enum): + CGLS = alg.cgls + CGLS_CONV = alg.cgls_conv + SIRT = alg.sirt + MLEM = alg.mlem + CGLS_TICHONOV = alg.cgls_tikhonov + CGLS_TVREG = alg.cgls_TVreg + FISTA = 'fista' + + def __init__(self, algorithm = None, projection_data = None, + angles = None, center_of_rotation = None , + flat_field = None, dark_field = None, + iterations = None, resolution = None, isLogScale = False, threads = None, + normalized_projection = None): + + self.pars = dict() + self.pars['algorithm'] = algorithm + self.pars['projection_data'] = projection_data + self.pars['normalized_projection'] = normalized_projection + self.pars['angles'] = angles + self.pars['center_of_rotation'] = numpy.double(center_of_rotation) + self.pars['flat_field'] = flat_field + self.pars['iterations'] = iterations + self.pars['dark_field'] = dark_field + self.pars['resolution'] = resolution + self.pars['isLogScale'] = isLogScale + self.pars['threads'] = threads + if (iterations != None): + self.pars['iterationValues'] = numpy.zeros((iterations)) + + if projection_data != None and dark_field != None and flat_field != None: + norm = self.normalize(projection_data, dark_field, flat_field, 0.1) + self.pars['normalized_projection'] = norm + + + def setPars(self, parameters): + keys = ['algorithm','projection_data' ,'normalized_projection', \ + 'angles' , 'center_of_rotation' , 'flat_field', \ + 'iterations','dark_field' , 'resolution', 'isLogScale' , \ + 'threads' , 'iterationValues', 'regularize'] + + for k in keys: + if k not in parameters.keys(): + self.pars[k] = None + else: + self.pars[k] = parameters[k] + + + def sanityCheck(self): + projection_data = self.pars['projection_data'] + dark_field = self.pars['dark_field'] + flat_field = self.pars['flat_field'] + angles = self.pars['angles'] + + if projection_data != None and dark_field != None and \ + angles != None and flat_field != None: + data_shape = numpy.shape(projection_data) + angle_shape = numpy.shape(angles) + + if angle_shape[0] != data_shape[0]: + #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ + # (angle_shape[0] , data_shape[0]) ) + return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ + (angle_shape[0] , data_shape[0]) ) + + if data_shape[1:] != numpy.shape(flat_field): + #raise Exception('Projection and flat field dimensions do not match') + return (False , 'Projection and flat field dimensions do not match') + if data_shape[1:] != numpy.shape(dark_field): + #raise Exception('Projection and dark field dimensions do not match') + return (False , 'Projection and dark field dimensions do not match') + + return (True , '' ) + elif self.pars['normalized_projection'] != None: + data_shape = numpy.shape(self.pars['normalized_projection']) + angle_shape = numpy.shape(angles) + + if angle_shape[0] != data_shape[0]: + #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ + # (angle_shape[0] , data_shape[0]) ) + return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ + (angle_shape[0] , data_shape[0]) ) + else: + return (True , '' ) + else: + return (False , 'Not enough data') + + def reconstruct(self, parameters = None): + if parameters != None: + self.setPars(parameters) + + go , reason = self.sanityCheck() + if go: + return self._reconstruct() + else: + raise Exception(reason) + + + def _reconstruct(self, parameters=None): + if parameters!=None: + self.setPars(parameters) + parameters = self.pars + + if parameters['algorithm'] != None and \ + parameters['normalized_projection'] != None and \ + parameters['angles'] != None and \ + parameters['center_of_rotation'] != None and \ + parameters['iterations'] != None and \ + parameters['resolution'] != None and\ + parameters['threads'] != None and\ + parameters['isLogScale'] != None: + + + if parameters['algorithm'] in (Reconstructor.Algorithm.CGLS, + Reconstructor.Algorithm.MLEM, Reconstructor.Algorithm.SIRT): + #store parameters + self.pars = parameters + result = parameters['algorithm']( + parameters['normalized_projection'] , + parameters['angles'], + parameters['center_of_rotation'], + parameters['resolution'], + parameters['iterations'], + parameters['threads'] , + parameters['isLogScale'] + ) + return result + elif parameters['algorithm'] in (Reconstructor.Algorithm.CGLS_CONV, + Reconstructor.Algorithm.CGLS_TICHONOV, + Reconstructor.Algorithm.CGLS_TVREG) : + self.pars = parameters + result = parameters['algorithm']( + parameters['normalized_projection'] , + parameters['angles'], + parameters['center_of_rotation'], + parameters['resolution'], + parameters['iterations'], + parameters['threads'] , + parameters['regularize'], + numpy.zeros((parameters['iterations'])), + parameters['isLogScale'] + ) + + elif parameters['algorithm'] == Reconstructor.Algorithm.FISTA: + pass + + else: + if parameters['projection_data'] != None and \ + parameters['dark_field'] != None and \ + parameters['flat_field'] != None: + norm = self.normalize(parameters['projection_data'], + parameters['dark_field'], + parameters['flat_field'], 0.1) + self.pars['normalized_projection'] = norm + return self._reconstruct(parameters) + + + + def _normalize(self, projection, dark, flat, def_val=0): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + def normalize(self, projections, dark, flat, def_val=0): + norm = [self._normalize(projection, dark, flat, def_val) for projection in projections] + return numpy.asarray (norm, dtype=numpy.float32) + + + +class FISTA(): + '''FISTA-based reconstruction algorithm using ASTRA-toolbox + + ''' + # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> + # ___Input___: + # params.[] file: + # - .proj_geom (geometry of the projector) [required] + # - .vol_geom (geometry of the reconstructed object) [required] + # - .sino (vectorized in 2D or 3D sinogram) [required] + # - .iterFISTA (iterations for the main loop, default 40) + # - .L_const (Lipschitz constant, default Power method) ) + # - .X_ideal (ideal image, if given) + # - .weights (statisitcal weights, size of the sinogram) + # - .ROI (Region-of-interest, only if X_ideal is given) + # - .initialize (a 'warm start' using SIRT method from ASTRA) + #----------------Regularization choices------------------------ + # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) + # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) + # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) + # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) + # - .Regul_Iterations (iterations for the selected penalty, default 25) + # - .Regul_tauLLT (time step parameter for LLT term) + # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) + # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) + #----------------Visualization parameters------------------------ + # - .show (visualize reconstruction 1/0, (0 default)) + # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) + # - .slice (for 3D volumes - slice number to imshow) + # ___Output___: + # 1. X - reconstructed image/volume + # 2. output - a structure with + # - .Resid_error - residual error (if X_ideal is given) + # - .objective: value of the objective function + # - .L_const: Lipshitz constant to avoid recalculations + + # References: + # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse + # Problems" by A. Beck and M Teboulle + # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo + # 3. "A novel tomographic reconstruction method based on the robust + # Student's t function for suppressing data outliers" D. Kazantsev et.al. + # D. Kazantsev, 2016-17 + def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + self.params = dict() + self.params['projector_geometry'] = projector_geometry + self.params['output_geometry'] = output_geometry + self.params['input_sinogram'] = input_sinogram + detectors, nangles, sliceZ = numpy.shape(input_sinogram) + self.params['detectors'] = detectors + self.params['number_og_angles'] = nangles + self.params['SlicesZ'] = sliceZ + + # Accepted input keywords + kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , + 'weights' , 'region_of_interest' , 'initialize' , + 'regularizer' , + 'ring_lambda_R_L1', + 'ring_alpha') + + # handle keyworded parameters + if kwargs is not None: + for key, value in kwargs.items(): + if key in kw: + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + # set the default values for the parameters if not set + if 'number_of_iterations' in kwargs.keys(): + self.pars['number_of_iterations'] = kwargs['number_of_iterations'] + else: + self.pars['number_of_iterations'] = 40 + if 'weights' in kwargs.keys(): + self.pars['weights'] = kwargs['weights'] + else: + self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) + if 'Lipschitz_constant' in kwargs.keys(): + self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] + else: + self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + + if not self.pars['ideal_image'] in kwargs.keys(): + self.pars['ideal_image'] = None + + if not self.pars['region_of_interest'] : + if self.pars['ideal_image'] == None: + pass + else: + self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + + if not self.pars['regularizer'] : + self.pars['regularizer'] = None + else: + # the regularizer must be a correctly instantiated object + if not self.pars['ring_lambda_R_L1']: + self.pars['ring_lambda_R_L1'] = 0 + if not self.pars['ring_alpha']: + self.pars['ring_alpha'] = 1 + + + + + def calculateLipschitzConstantWithPowerMethod(self): + ''' using Power method (PM) to establish L constant''' + + #N = params.vol_geom.GridColCount + N = self.pars['output_geometry'].GridColCount + proj_geom = self.params['projector_geometry'] + vol_geom = self.params['output_geometry'] + weights = self.pars['weights'] + SlicesZ = self.pars['SlicesZ'] + + if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); + niter = 15;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights.T[0]) + proj_geomT = proj_geom.copy(); + proj_geomT.DetectorRowCount = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + + for i in range(niter): + if i == 0: + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); + y = sqweight * y # element wise multiplication + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + y = sqweight*y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + del proj_geomT + del vol_geomT + else + #% divergen beam geometry + #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights.T[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 + + return s + + + def setRegularizer(self, regularizer): + if regularizer + self.pars['regularizer'] = regularizer + + + + + +def getEntry(location): + for item in nx[location].keys(): + print (item) + + +print ("Loading Data") + +##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" +####ind = [i * 1049 for i in range(360)] +#### use only 360 images +##images = 200 +##ind = [int(i * 1049 / images) for i in range(images)] +##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) + +#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" +fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" +nx = h5py.File(fname, "r") + +# the data are stored in a particular location in the hdf5 +for item in nx['entry1/tomo_entry/data'].keys(): + print (item) + +data = nx.get('entry1/tomo_entry/data/rotation_angle') +angles = numpy.zeros(data.shape) +data.read_direct(angles) +print (angles) +# angles should be in degrees + +data = nx.get('entry1/tomo_entry/data/data') +stack = numpy.zeros(data.shape) +data.read_direct(stack) +print (data.shape) + +print ("Data Loaded") + + +# Normalize +data = nx.get('entry1/tomo_entry/instrument/detector/image_key') +itype = numpy.zeros(data.shape) +data.read_direct(itype) +# 2 is dark field +darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] +dark = darks[0] +for i in range(1, len(darks)): + dark += darks[i] +dark = dark / len(darks) +#dark[0][0] = dark[0][1] + +# 1 is flat field +flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] +flat = flats[0] +for i in range(1, len(flats)): + flat += flats[i] +flat = flat / len(flats) +#flat[0][0] = dark[0][1] + + +# 0 is projection data +proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = numpy.asarray (angle_proj) +angle_proj = angle_proj.astype(numpy.float32) + +# normalized data are +# norm = (projection - dark)/(flat-dark) + +def normalize(projection, dark, flat, def_val=0.1): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + +norm = [normalize(projection, dark, flat) for projection in proj] +norm = numpy.asarray (norm) +norm = norm.astype(numpy.float32) + +#recon = Reconstructor(algorithm = Algorithm.CGLS, normalized_projection = norm, +# angles = angle_proj, center_of_rotation = 86.2 , +# flat_field = flat, dark_field = dark, +# iterations = 15, resolution = 1, isLogScale = False, threads = 3) + +#recon = Reconstructor(algorithm = Reconstructor.Algorithm.CGLS, projection_data = proj, +# angles = angle_proj, center_of_rotation = 86.2 , +# flat_field = flat, dark_field = dark, +# iterations = 15, resolution = 1, isLogScale = False, threads = 3) +#img_cgls = recon.reconstruct() +# +#pars = dict() +#pars['algorithm'] = Reconstructor.Algorithm.SIRT +#pars['projection_data'] = proj +#pars['angles'] = angle_proj +#pars['center_of_rotation'] = numpy.double(86.2) +#pars['flat_field'] = flat +#pars['iterations'] = 15 +#pars['dark_field'] = dark +#pars['resolution'] = 1 +#pars['isLogScale'] = False +#pars['threads'] = 3 +# +#img_sirt = recon.reconstruct(pars) +# +#recon.pars['algorithm'] = Reconstructor.Algorithm.MLEM +#img_mlem = recon.reconstruct() + +############################################################ +############################################################ +#recon.pars['algorithm'] = Reconstructor.Algorithm.CGLS_CONV +#recon.pars['regularize'] = numpy.double(0.1) +#img_cgls_conv = recon.reconstruct() + +niterations = 15 +threads = 3 + +img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + iteration_values, False) +print ("iteration values %s" % str(iteration_values)) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) +iteration_values = numpy.zeros((niterations,)) +img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) + + +##numpy.save("cgls_recon.npy", img_data) +import matplotlib.pyplot as plt +fig, ax = plt.subplots(1,6,sharey=True) +ax[0].imshow(img_cgls[80]) +ax[0].axis('off') # clear x- and y-axes +ax[1].imshow(img_sirt[80]) +ax[1].axis('off') # clear x- and y-axes +ax[2].imshow(img_mlem[80]) +ax[2].axis('off') # clear x- and y-axesplt.show() +ax[3].imshow(img_cgls_conv[80]) +ax[3].axis('off') # clear x- and y-axesplt.show() +ax[4].imshow(img_cgls_tikhonov[80]) +ax[4].axis('off') # clear x- and y-axesplt.show() +ax[5].imshow(img_cgls_TVreg[80]) +ax[5].axis('off') # clear x- and y-axesplt.show() + + +plt.show() + +#viewer = edo.CILViewer() +#viewer.setInputAsNumpy(img_cgls2) +#viewer.displaySliceActor(0) +#viewer.startRenderLoop() + +import vtk + +def NumpyToVTKImageData(numpyarray): + if (len(numpy.shape(numpyarray)) == 3): + doubleImg = vtk.vtkImageData() + shape = numpy.shape(numpyarray) + doubleImg.SetDimensions(shape[0], shape[1], shape[2]) + doubleImg.SetOrigin(0,0,0) + doubleImg.SetSpacing(1,1,1) + doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) + #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) + doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) + + for i in range(shape[0]): + for j in range(shape[1]): + for k in range(shape[2]): + doubleImg.SetScalarComponentFromDouble( + i,j,k,0, numpyarray[i][j][k]) + #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) + # rescale to appropriate VTK_UNSIGNED_SHORT + stats = vtk.vtkImageAccumulate() + stats.SetInputData(doubleImg) + stats.Update() + iMin = stats.GetMin()[0] + iMax = stats.GetMax()[0] + scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) + + shiftScaler = vtk.vtkImageShiftScale () + shiftScaler.SetInputData(doubleImg) + shiftScaler.SetScale(scale) + shiftScaler.SetShift(iMin) + shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) + shiftScaler.Update() + return shiftScaler.GetOutput() + +#writer = vtk.vtkMetaImageWriter() +#writer.SetFileName(alg + "_recon.mha") +#writer.SetInputData(NumpyToVTKImageData(img_cgls2)) +#writer.Write() diff --git a/src/Python/ccpi/fista/__init__.py b/src/Python/ccpi/fista/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/Python/ccpi/fista/__init__.pyc b/src/Python/ccpi/fista/__init__.pyc new file mode 100644 index 0000000..719e264 Binary files /dev/null and b/src/Python/ccpi/fista/__init__.pyc differ diff --git a/src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc b/src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc new file mode 100644 index 0000000..84f16e2 Binary files /dev/null and b/src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc differ diff --git a/src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc b/src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc new file mode 100644 index 0000000..90c23ff Binary files /dev/null and b/src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc differ -- cgit v1.2.3 From 64c0b9e7a1bcfc54e6ed8b57274d53c3ed9bb950 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 25 Aug 2017 17:03:17 +0100 Subject: use refactore code --- src/Python/ccpi/fista/FISTAReconstructor.pyc | Bin 3804 -> 0 bytes src/Python/ccpi/fista/FISTAReconstructor.py~ | 349 ------------ src/Python/ccpi/fista/Reconstructor.py~ | 598 --------------------- src/Python/ccpi/fista/__init__.pyc | Bin 189 -> 0 bytes .../__pycache__/FISTAReconstructor.cpython-35.pyc | Bin 3641 -> 0 bytes .../ccpi/fista/__pycache__/__init__.cpython-35.pyc | Bin 185 -> 0 bytes 6 files changed, 947 deletions(-) delete mode 100644 src/Python/ccpi/fista/FISTAReconstructor.pyc delete mode 100644 src/Python/ccpi/fista/FISTAReconstructor.py~ delete mode 100644 src/Python/ccpi/fista/Reconstructor.py~ delete mode 100644 src/Python/ccpi/fista/__init__.pyc delete mode 100644 src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc delete mode 100644 src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc (limited to 'src/Python/ccpi') diff --git a/src/Python/ccpi/fista/FISTAReconstructor.pyc b/src/Python/ccpi/fista/FISTAReconstructor.pyc deleted file mode 100644 index ecc4d7d..0000000 Binary files a/src/Python/ccpi/fista/FISTAReconstructor.pyc and /dev/null differ diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py~ b/src/Python/ccpi/fista/FISTAReconstructor.py~ deleted file mode 100644 index 6c7024d..0000000 --- a/src/Python/ccpi/fista/FISTAReconstructor.py~ +++ /dev/null @@ -1,349 +0,0 @@ -# -*- coding: utf-8 -*- -############################################################################### -#This work is part of the Core Imaging Library developed by -#Visual Analytics and Imaging System Group of the Science Technology -#Facilities Council, STFC -# -#Copyright 2017 Edoardo Pasca, Srikanth Nagella -#Copyright 2017 Daniil Kazantsev -# -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -#http://www.apache.org/licenses/LICENSE-2.0 -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. -############################################################################### - - - -import numpy -import h5py -#from ccpi.reconstruction.parallelbeam import alg - -from Regularizer import Regularizer -from enum import Enum - -import astra - - - -class FISTAReconstructor(): - '''FISTA-based reconstruction algorithm using ASTRA-toolbox - - ''' - # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> - # ___Input___: - # params.[] file: - # - .proj_geom (geometry of the projector) [required] - # - .vol_geom (geometry of the reconstructed object) [required] - # - .sino (vectorized in 2D or 3D sinogram) [required] - # - .iterFISTA (iterations for the main loop, default 40) - # - .L_const (Lipschitz constant, default Power method) ) - # - .X_ideal (ideal image, if given) - # - .weights (statisitcal weights, size of the sinogram) - # - .ROI (Region-of-interest, only if X_ideal is given) - # - .initialize (a 'warm start' using SIRT method from ASTRA) - #----------------Regularization choices------------------------ - # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) - # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) - # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) - # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) - # - .Regul_Iterations (iterations for the selected penalty, default 25) - # - .Regul_tauLLT (time step parameter for LLT term) - # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) - # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) - #----------------Visualization parameters------------------------ - # - .show (visualize reconstruction 1/0, (0 default)) - # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) - # - .slice (for 3D volumes - slice number to imshow) - # ___Output___: - # 1. X - reconstructed image/volume - # 2. output - a structure with - # - .Resid_error - residual error (if X_ideal is given) - # - .objective: value of the objective function - # - .L_const: Lipshitz constant to avoid recalculations - - # References: - # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse - # Problems" by A. Beck and M Teboulle - # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo - # 3. "A novel tomographic reconstruction method based on the robust - # Student's t function for suppressing data outliers" D. Kazantsev et.al. - # D. Kazantsev, 2016-17 - def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): - self.params = dict() - self.params['projector_geometry'] = projector_geometry - self.params['output_geometry'] = output_geometry - self.params['input_sinogram'] = input_sinogram - detectors, nangles, sliceZ = numpy.shape(input_sinogram) - self.params['detectors'] = detectors - self.params['number_og_angles'] = nangles - self.params['SlicesZ'] = sliceZ - - # Accepted input keywords - kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , - 'weights' , 'region_of_interest' , 'initialize' , - 'regularizer' , - 'ring_lambda_R_L1', - 'ring_alpha') - - # handle keyworded parameters - if kwargs is not None: - for key, value in kwargs.items(): - if key in kw: - #print("{0} = {1}".format(key, value)) - self.pars[key] = value - - # set the default values for the parameters if not set - if 'number_of_iterations' in kwargs.keys(): - self.pars['number_of_iterations'] = kwargs['number_of_iterations'] - else: - self.pars['number_of_iterations'] = 40 - if 'weights' in kwargs.keys(): - self.pars['weights'] = kwargs['weights'] - else: - self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) - if 'Lipschitz_constant' in kwargs.keys(): - self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] - else: - self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() - - if not self.pars['ideal_image'] in kwargs.keys(): - self.pars['ideal_image'] = None - - if not self.pars['region_of_interest'] : - if self.pars['ideal_image'] == None: - pass - else: - self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) - - if not self.pars['regularizer'] : - self.pars['regularizer'] = None - else: - # the regularizer must be a correctly instantiated object - if not self.pars['ring_lambda_R_L1']: - self.pars['ring_lambda_R_L1'] = 0 - if not self.pars['ring_alpha']: - self.pars['ring_alpha'] = 1 - - - - - def calculateLipschitzConstantWithPowerMethod(self): - ''' using Power method (PM) to establish L constant''' - - #N = params.vol_geom.GridColCount - N = self.pars['output_geometry'].GridColCount - proj_geom = self.params['projector_geometry'] - vol_geom = self.params['output_geometry'] - weights = self.pars['weights'] - SlicesZ = self.pars['SlicesZ'] - - if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): - #% for parallel geometry we can do just one slice - #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); - niter = 15;# % number of iteration for the PM - #N = params.vol_geom.GridColCount; - #x1 = rand(N,N,1); - x1 = numpy.random.rand(1,N,N) - #sqweight = sqrt(weights(:,:,1)); - sqweight = numpy.sqrt(weights.T[0]) - proj_geomT = proj_geom.copy(); - proj_geomT.DetectorRowCount = 1; - vol_geomT = vol_geom.copy(); - vol_geomT['GridSliceCount'] = 1; - - - for i in range(niter): - if i == 0: - #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); - y = sqweight * y # element wise multiplication - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id) - - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); - s = numpy.linalg.norm(x1) - ### this line? - x1 = x1/s; - ### this line? - sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - y = sqweight*y; - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); - #end - del proj_geomT - del vol_geomT - else - #% divergen beam geometry - #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); - niter = 8; #% number of iteration for PM - x1 = numpy.random.rand(SlicesZ , N , N); - #sqweight = sqrt(weights); - sqweight = numpy.sqrt(weights.T[0]) - - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id); - - for i in range(niter): - #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, - proj_geom, - vol_geom) - s = numpy.linalg.norm(x1) - ### this line? - x1 = x1/s; - ### this line? - #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); - sino_id, y = astra.creators.create_sino3d_gpu(x1, - proj_geom, - vol_geom); - - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - #astra_mex_data3d('delete', id); - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); - #end - #clear x1 - del x1 - - return s - - - def setRegularizer(self, regularizer): - if regularizer - self.pars['regularizer'] = regularizer - - - - - -def getEntry(location): - for item in nx[location].keys(): - print (item) - - -print ("Loading Data") - -##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" -####ind = [i * 1049 for i in range(360)] -#### use only 360 images -##images = 200 -##ind = [int(i * 1049 / images) for i in range(images)] -##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) - -#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" -fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" -nx = h5py.File(fname, "r") - -# the data are stored in a particular location in the hdf5 -for item in nx['entry1/tomo_entry/data'].keys(): - print (item) - -data = nx.get('entry1/tomo_entry/data/rotation_angle') -angles = numpy.zeros(data.shape) -data.read_direct(angles) -print (angles) -# angles should be in degrees - -data = nx.get('entry1/tomo_entry/data/data') -stack = numpy.zeros(data.shape) -data.read_direct(stack) -print (data.shape) - -print ("Data Loaded") - - -# Normalize -data = nx.get('entry1/tomo_entry/instrument/detector/image_key') -itype = numpy.zeros(data.shape) -data.read_direct(itype) -# 2 is dark field -darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] -dark = darks[0] -for i in range(1, len(darks)): - dark += darks[i] -dark = dark / len(darks) -#dark[0][0] = dark[0][1] - -# 1 is flat field -flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] -flat = flats[0] -for i in range(1, len(flats)): - flat += flats[i] -flat = flat / len(flats) -#flat[0][0] = dark[0][1] - - -# 0 is projection data -proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] -angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] -angle_proj = numpy.asarray (angle_proj) -angle_proj = angle_proj.astype(numpy.float32) - -# normalized data are -# norm = (projection - dark)/(flat-dark) - -def normalize(projection, dark, flat, def_val=0.1): - a = (projection - dark) - b = (flat-dark) - with numpy.errstate(divide='ignore', invalid='ignore'): - c = numpy.true_divide( a, b ) - c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 - return c - - -norm = [normalize(projection, dark, flat) for projection in proj] -norm = numpy.asarray (norm) -norm = norm.astype(numpy.float32) - - -niterations = 15 -threads = 3 - -img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) - -iteration_values = numpy.zeros((niterations,)) -img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - iteration_values, False) -print ("iteration values %s" % str(iteration_values)) - -iteration_values = numpy.zeros((niterations,)) -img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - numpy.double(1e-5), iteration_values , False) -print ("iteration values %s" % str(iteration_values)) -iteration_values = numpy.zeros((niterations,)) -img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - numpy.double(1e-5), iteration_values , False) -print ("iteration values %s" % str(iteration_values)) - - -##numpy.save("cgls_recon.npy", img_data) -import matplotlib.pyplot as plt -fig, ax = plt.subplots(1,6,sharey=True) -ax[0].imshow(img_cgls[80]) -ax[0].axis('off') # clear x- and y-axes -ax[1].imshow(img_sirt[80]) -ax[1].axis('off') # clear x- and y-axes -ax[2].imshow(img_mlem[80]) -ax[2].axis('off') # clear x- and y-axesplt.show() -ax[3].imshow(img_cgls_conv[80]) -ax[3].axis('off') # clear x- and y-axesplt.show() -ax[4].imshow(img_cgls_tikhonov[80]) -ax[4].axis('off') # clear x- and y-axesplt.show() -ax[5].imshow(img_cgls_TVreg[80]) -ax[5].axis('off') # clear x- and y-axesplt.show() - - -plt.show() - - diff --git a/src/Python/ccpi/fista/Reconstructor.py~ b/src/Python/ccpi/fista/Reconstructor.py~ deleted file mode 100644 index ba67327..0000000 --- a/src/Python/ccpi/fista/Reconstructor.py~ +++ /dev/null @@ -1,598 +0,0 @@ -# -*- coding: utf-8 -*- -############################################################################### -#This work is part of the Core Imaging Library developed by -#Visual Analytics and Imaging System Group of the Science Technology -#Facilities Council, STFC -# -#Copyright 2017 Edoardo Pasca, Srikanth Nagella -#Copyright 2017 Daniil Kazantsev -# -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -#http://www.apache.org/licenses/LICENSE-2.0 -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. -############################################################################### - - - -import numpy -import h5py -from ccpi.reconstruction.parallelbeam import alg - -from Regularizer import Regularizer -from enum import Enum - -import astra - - -class Reconstructor: - - class Algorithm(Enum): - CGLS = alg.cgls - CGLS_CONV = alg.cgls_conv - SIRT = alg.sirt - MLEM = alg.mlem - CGLS_TICHONOV = alg.cgls_tikhonov - CGLS_TVREG = alg.cgls_TVreg - FISTA = 'fista' - - def __init__(self, algorithm = None, projection_data = None, - angles = None, center_of_rotation = None , - flat_field = None, dark_field = None, - iterations = None, resolution = None, isLogScale = False, threads = None, - normalized_projection = None): - - self.pars = dict() - self.pars['algorithm'] = algorithm - self.pars['projection_data'] = projection_data - self.pars['normalized_projection'] = normalized_projection - self.pars['angles'] = angles - self.pars['center_of_rotation'] = numpy.double(center_of_rotation) - self.pars['flat_field'] = flat_field - self.pars['iterations'] = iterations - self.pars['dark_field'] = dark_field - self.pars['resolution'] = resolution - self.pars['isLogScale'] = isLogScale - self.pars['threads'] = threads - if (iterations != None): - self.pars['iterationValues'] = numpy.zeros((iterations)) - - if projection_data != None and dark_field != None and flat_field != None: - norm = self.normalize(projection_data, dark_field, flat_field, 0.1) - self.pars['normalized_projection'] = norm - - - def setPars(self, parameters): - keys = ['algorithm','projection_data' ,'normalized_projection', \ - 'angles' , 'center_of_rotation' , 'flat_field', \ - 'iterations','dark_field' , 'resolution', 'isLogScale' , \ - 'threads' , 'iterationValues', 'regularize'] - - for k in keys: - if k not in parameters.keys(): - self.pars[k] = None - else: - self.pars[k] = parameters[k] - - - def sanityCheck(self): - projection_data = self.pars['projection_data'] - dark_field = self.pars['dark_field'] - flat_field = self.pars['flat_field'] - angles = self.pars['angles'] - - if projection_data != None and dark_field != None and \ - angles != None and flat_field != None: - data_shape = numpy.shape(projection_data) - angle_shape = numpy.shape(angles) - - if angle_shape[0] != data_shape[0]: - #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ - # (angle_shape[0] , data_shape[0]) ) - return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ - (angle_shape[0] , data_shape[0]) ) - - if data_shape[1:] != numpy.shape(flat_field): - #raise Exception('Projection and flat field dimensions do not match') - return (False , 'Projection and flat field dimensions do not match') - if data_shape[1:] != numpy.shape(dark_field): - #raise Exception('Projection and dark field dimensions do not match') - return (False , 'Projection and dark field dimensions do not match') - - return (True , '' ) - elif self.pars['normalized_projection'] != None: - data_shape = numpy.shape(self.pars['normalized_projection']) - angle_shape = numpy.shape(angles) - - if angle_shape[0] != data_shape[0]: - #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ - # (angle_shape[0] , data_shape[0]) ) - return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ - (angle_shape[0] , data_shape[0]) ) - else: - return (True , '' ) - else: - return (False , 'Not enough data') - - def reconstruct(self, parameters = None): - if parameters != None: - self.setPars(parameters) - - go , reason = self.sanityCheck() - if go: - return self._reconstruct() - else: - raise Exception(reason) - - - def _reconstruct(self, parameters=None): - if parameters!=None: - self.setPars(parameters) - parameters = self.pars - - if parameters['algorithm'] != None and \ - parameters['normalized_projection'] != None and \ - parameters['angles'] != None and \ - parameters['center_of_rotation'] != None and \ - parameters['iterations'] != None and \ - parameters['resolution'] != None and\ - parameters['threads'] != None and\ - parameters['isLogScale'] != None: - - - if parameters['algorithm'] in (Reconstructor.Algorithm.CGLS, - Reconstructor.Algorithm.MLEM, Reconstructor.Algorithm.SIRT): - #store parameters - self.pars = parameters - result = parameters['algorithm']( - parameters['normalized_projection'] , - parameters['angles'], - parameters['center_of_rotation'], - parameters['resolution'], - parameters['iterations'], - parameters['threads'] , - parameters['isLogScale'] - ) - return result - elif parameters['algorithm'] in (Reconstructor.Algorithm.CGLS_CONV, - Reconstructor.Algorithm.CGLS_TICHONOV, - Reconstructor.Algorithm.CGLS_TVREG) : - self.pars = parameters - result = parameters['algorithm']( - parameters['normalized_projection'] , - parameters['angles'], - parameters['center_of_rotation'], - parameters['resolution'], - parameters['iterations'], - parameters['threads'] , - parameters['regularize'], - numpy.zeros((parameters['iterations'])), - parameters['isLogScale'] - ) - - elif parameters['algorithm'] == Reconstructor.Algorithm.FISTA: - pass - - else: - if parameters['projection_data'] != None and \ - parameters['dark_field'] != None and \ - parameters['flat_field'] != None: - norm = self.normalize(parameters['projection_data'], - parameters['dark_field'], - parameters['flat_field'], 0.1) - self.pars['normalized_projection'] = norm - return self._reconstruct(parameters) - - - - def _normalize(self, projection, dark, flat, def_val=0): - a = (projection - dark) - b = (flat-dark) - with numpy.errstate(divide='ignore', invalid='ignore'): - c = numpy.true_divide( a, b ) - c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 - return c - - def normalize(self, projections, dark, flat, def_val=0): - norm = [self._normalize(projection, dark, flat, def_val) for projection in projections] - return numpy.asarray (norm, dtype=numpy.float32) - - - -class FISTA(): - '''FISTA-based reconstruction algorithm using ASTRA-toolbox - - ''' - # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> - # ___Input___: - # params.[] file: - # - .proj_geom (geometry of the projector) [required] - # - .vol_geom (geometry of the reconstructed object) [required] - # - .sino (vectorized in 2D or 3D sinogram) [required] - # - .iterFISTA (iterations for the main loop, default 40) - # - .L_const (Lipschitz constant, default Power method) ) - # - .X_ideal (ideal image, if given) - # - .weights (statisitcal weights, size of the sinogram) - # - .ROI (Region-of-interest, only if X_ideal is given) - # - .initialize (a 'warm start' using SIRT method from ASTRA) - #----------------Regularization choices------------------------ - # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) - # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) - # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) - # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) - # - .Regul_Iterations (iterations for the selected penalty, default 25) - # - .Regul_tauLLT (time step parameter for LLT term) - # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) - # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) - #----------------Visualization parameters------------------------ - # - .show (visualize reconstruction 1/0, (0 default)) - # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) - # - .slice (for 3D volumes - slice number to imshow) - # ___Output___: - # 1. X - reconstructed image/volume - # 2. output - a structure with - # - .Resid_error - residual error (if X_ideal is given) - # - .objective: value of the objective function - # - .L_const: Lipshitz constant to avoid recalculations - - # References: - # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse - # Problems" by A. Beck and M Teboulle - # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo - # 3. "A novel tomographic reconstruction method based on the robust - # Student's t function for suppressing data outliers" D. Kazantsev et.al. - # D. Kazantsev, 2016-17 - def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): - self.params = dict() - self.params['projector_geometry'] = projector_geometry - self.params['output_geometry'] = output_geometry - self.params['input_sinogram'] = input_sinogram - detectors, nangles, sliceZ = numpy.shape(input_sinogram) - self.params['detectors'] = detectors - self.params['number_og_angles'] = nangles - self.params['SlicesZ'] = sliceZ - - # Accepted input keywords - kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , - 'weights' , 'region_of_interest' , 'initialize' , - 'regularizer' , - 'ring_lambda_R_L1', - 'ring_alpha') - - # handle keyworded parameters - if kwargs is not None: - for key, value in kwargs.items(): - if key in kw: - #print("{0} = {1}".format(key, value)) - self.pars[key] = value - - # set the default values for the parameters if not set - if 'number_of_iterations' in kwargs.keys(): - self.pars['number_of_iterations'] = kwargs['number_of_iterations'] - else: - self.pars['number_of_iterations'] = 40 - if 'weights' in kwargs.keys(): - self.pars['weights'] = kwargs['weights'] - else: - self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) - if 'Lipschitz_constant' in kwargs.keys(): - self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] - else: - self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() - - if not self.pars['ideal_image'] in kwargs.keys(): - self.pars['ideal_image'] = None - - if not self.pars['region_of_interest'] : - if self.pars['ideal_image'] == None: - pass - else: - self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) - - if not self.pars['regularizer'] : - self.pars['regularizer'] = None - else: - # the regularizer must be a correctly instantiated object - if not self.pars['ring_lambda_R_L1']: - self.pars['ring_lambda_R_L1'] = 0 - if not self.pars['ring_alpha']: - self.pars['ring_alpha'] = 1 - - - - - def calculateLipschitzConstantWithPowerMethod(self): - ''' using Power method (PM) to establish L constant''' - - #N = params.vol_geom.GridColCount - N = self.pars['output_geometry'].GridColCount - proj_geom = self.params['projector_geometry'] - vol_geom = self.params['output_geometry'] - weights = self.pars['weights'] - SlicesZ = self.pars['SlicesZ'] - - if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): - #% for parallel geometry we can do just one slice - #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); - niter = 15;# % number of iteration for the PM - #N = params.vol_geom.GridColCount; - #x1 = rand(N,N,1); - x1 = numpy.random.rand(1,N,N) - #sqweight = sqrt(weights(:,:,1)); - sqweight = numpy.sqrt(weights.T[0]) - proj_geomT = proj_geom.copy(); - proj_geomT.DetectorRowCount = 1; - vol_geomT = vol_geom.copy(); - vol_geomT['GridSliceCount'] = 1; - - - for i in range(niter): - if i == 0: - #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); - y = sqweight * y # element wise multiplication - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id) - - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); - s = numpy.linalg.norm(x1) - ### this line? - x1 = x1/s; - ### this line? - sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - y = sqweight*y; - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); - #end - del proj_geomT - del vol_geomT - else - #% divergen beam geometry - #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); - niter = 8; #% number of iteration for PM - x1 = numpy.random.rand(SlicesZ , N , N); - #sqweight = sqrt(weights); - sqweight = numpy.sqrt(weights.T[0]) - - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id); - - for i in range(niter): - #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, - proj_geom, - vol_geom) - s = numpy.linalg.norm(x1) - ### this line? - x1 = x1/s; - ### this line? - #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); - sino_id, y = astra.creators.create_sino3d_gpu(x1, - proj_geom, - vol_geom); - - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - #astra_mex_data3d('delete', id); - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); - #end - #clear x1 - del x1 - - return s - - - def setRegularizer(self, regularizer): - if regularizer - self.pars['regularizer'] = regularizer - - - - - -def getEntry(location): - for item in nx[location].keys(): - print (item) - - -print ("Loading Data") - -##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" -####ind = [i * 1049 for i in range(360)] -#### use only 360 images -##images = 200 -##ind = [int(i * 1049 / images) for i in range(images)] -##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) - -#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" -fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" -nx = h5py.File(fname, "r") - -# the data are stored in a particular location in the hdf5 -for item in nx['entry1/tomo_entry/data'].keys(): - print (item) - -data = nx.get('entry1/tomo_entry/data/rotation_angle') -angles = numpy.zeros(data.shape) -data.read_direct(angles) -print (angles) -# angles should be in degrees - -data = nx.get('entry1/tomo_entry/data/data') -stack = numpy.zeros(data.shape) -data.read_direct(stack) -print (data.shape) - -print ("Data Loaded") - - -# Normalize -data = nx.get('entry1/tomo_entry/instrument/detector/image_key') -itype = numpy.zeros(data.shape) -data.read_direct(itype) -# 2 is dark field -darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] -dark = darks[0] -for i in range(1, len(darks)): - dark += darks[i] -dark = dark / len(darks) -#dark[0][0] = dark[0][1] - -# 1 is flat field -flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] -flat = flats[0] -for i in range(1, len(flats)): - flat += flats[i] -flat = flat / len(flats) -#flat[0][0] = dark[0][1] - - -# 0 is projection data -proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] -angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] -angle_proj = numpy.asarray (angle_proj) -angle_proj = angle_proj.astype(numpy.float32) - -# normalized data are -# norm = (projection - dark)/(flat-dark) - -def normalize(projection, dark, flat, def_val=0.1): - a = (projection - dark) - b = (flat-dark) - with numpy.errstate(divide='ignore', invalid='ignore'): - c = numpy.true_divide( a, b ) - c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 - return c - - -norm = [normalize(projection, dark, flat) for projection in proj] -norm = numpy.asarray (norm) -norm = norm.astype(numpy.float32) - -#recon = Reconstructor(algorithm = Algorithm.CGLS, normalized_projection = norm, -# angles = angle_proj, center_of_rotation = 86.2 , -# flat_field = flat, dark_field = dark, -# iterations = 15, resolution = 1, isLogScale = False, threads = 3) - -#recon = Reconstructor(algorithm = Reconstructor.Algorithm.CGLS, projection_data = proj, -# angles = angle_proj, center_of_rotation = 86.2 , -# flat_field = flat, dark_field = dark, -# iterations = 15, resolution = 1, isLogScale = False, threads = 3) -#img_cgls = recon.reconstruct() -# -#pars = dict() -#pars['algorithm'] = Reconstructor.Algorithm.SIRT -#pars['projection_data'] = proj -#pars['angles'] = angle_proj -#pars['center_of_rotation'] = numpy.double(86.2) -#pars['flat_field'] = flat -#pars['iterations'] = 15 -#pars['dark_field'] = dark -#pars['resolution'] = 1 -#pars['isLogScale'] = False -#pars['threads'] = 3 -# -#img_sirt = recon.reconstruct(pars) -# -#recon.pars['algorithm'] = Reconstructor.Algorithm.MLEM -#img_mlem = recon.reconstruct() - -############################################################ -############################################################ -#recon.pars['algorithm'] = Reconstructor.Algorithm.CGLS_CONV -#recon.pars['regularize'] = numpy.double(0.1) -#img_cgls_conv = recon.reconstruct() - -niterations = 15 -threads = 3 - -img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) - -iteration_values = numpy.zeros((niterations,)) -img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - iteration_values, False) -print ("iteration values %s" % str(iteration_values)) - -iteration_values = numpy.zeros((niterations,)) -img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - numpy.double(1e-5), iteration_values , False) -print ("iteration values %s" % str(iteration_values)) -iteration_values = numpy.zeros((niterations,)) -img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - numpy.double(1e-5), iteration_values , False) -print ("iteration values %s" % str(iteration_values)) - - -##numpy.save("cgls_recon.npy", img_data) -import matplotlib.pyplot as plt -fig, ax = plt.subplots(1,6,sharey=True) -ax[0].imshow(img_cgls[80]) -ax[0].axis('off') # clear x- and y-axes -ax[1].imshow(img_sirt[80]) -ax[1].axis('off') # clear x- and y-axes -ax[2].imshow(img_mlem[80]) -ax[2].axis('off') # clear x- and y-axesplt.show() -ax[3].imshow(img_cgls_conv[80]) -ax[3].axis('off') # clear x- and y-axesplt.show() -ax[4].imshow(img_cgls_tikhonov[80]) -ax[4].axis('off') # clear x- and y-axesplt.show() -ax[5].imshow(img_cgls_TVreg[80]) -ax[5].axis('off') # clear x- and y-axesplt.show() - - -plt.show() - -#viewer = edo.CILViewer() -#viewer.setInputAsNumpy(img_cgls2) -#viewer.displaySliceActor(0) -#viewer.startRenderLoop() - -import vtk - -def NumpyToVTKImageData(numpyarray): - if (len(numpy.shape(numpyarray)) == 3): - doubleImg = vtk.vtkImageData() - shape = numpy.shape(numpyarray) - doubleImg.SetDimensions(shape[0], shape[1], shape[2]) - doubleImg.SetOrigin(0,0,0) - doubleImg.SetSpacing(1,1,1) - doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) - #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) - doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) - - for i in range(shape[0]): - for j in range(shape[1]): - for k in range(shape[2]): - doubleImg.SetScalarComponentFromDouble( - i,j,k,0, numpyarray[i][j][k]) - #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) - # rescale to appropriate VTK_UNSIGNED_SHORT - stats = vtk.vtkImageAccumulate() - stats.SetInputData(doubleImg) - stats.Update() - iMin = stats.GetMin()[0] - iMax = stats.GetMax()[0] - scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) - - shiftScaler = vtk.vtkImageShiftScale () - shiftScaler.SetInputData(doubleImg) - shiftScaler.SetScale(scale) - shiftScaler.SetShift(iMin) - shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) - shiftScaler.Update() - return shiftScaler.GetOutput() - -#writer = vtk.vtkMetaImageWriter() -#writer.SetFileName(alg + "_recon.mha") -#writer.SetInputData(NumpyToVTKImageData(img_cgls2)) -#writer.Write() diff --git a/src/Python/ccpi/fista/__init__.pyc b/src/Python/ccpi/fista/__init__.pyc deleted file mode 100644 index 719e264..0000000 Binary files a/src/Python/ccpi/fista/__init__.pyc and /dev/null differ diff --git a/src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc b/src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc deleted file mode 100644 index 84f16e2..0000000 Binary files a/src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc and /dev/null differ diff --git a/src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc b/src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc deleted file mode 100644 index 90c23ff..0000000 Binary files a/src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc and /dev/null differ -- cgit v1.2.3 From 49c4a595c58d296c3a4b2f7fd480e9c64f638897 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 13 Oct 2017 16:48:24 +0100 Subject: Added setParameter minor beautification of code --- src/Python/ccpi/fista/FISTAReconstructor.py | 164 ++++++---------------------- 1 file changed, 34 insertions(+), 130 deletions(-) (limited to 'src/Python/ccpi') diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py index 1e76815..cbd27da 100644 --- a/src/Python/ccpi/fista/FISTAReconstructor.py +++ b/src/Python/ccpi/fista/FISTAReconstructor.py @@ -73,7 +73,8 @@ class FISTAReconstructor(): # 3. "A novel tomographic reconstruction method based on the robust # Student's t function for suppressing data outliers" D. Kazantsev et.al. # D. Kazantsev, 2016-17 - def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + def __init__(self, projector_geometry, output_geometry, input_sinogram, + **kwargs): # handle parmeters: # obligatory parameters self.pars = dict() @@ -98,6 +99,7 @@ class FISTAReconstructor(): 'regularizer' , 'ring_lambda_R_L1', 'ring_alpha') + self.acceptedInputKeywords = kw # handle keyworded parameters if kwargs is not None: @@ -114,11 +116,14 @@ class FISTAReconstructor(): if 'weights' in kwargs.keys(): self.pars['weights'] = kwargs['weights'] else: - self.pars['weights'] = numpy.ones(numpy.shape(self.pars['input_sinogram'])) + self.pars['weights'] = \ + numpy.ones(numpy.shape( + self.pars['input_sinogram'])) if 'Lipschitz_constant' in kwargs.keys(): self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] else: - self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + self.pars['Lipschitz_constant'] = \ + self.calculateLipschitzConstantWithPowerMethod() if not 'ideal_image' in kwargs.keys(): self.pars['ideal_image'] = None @@ -127,7 +132,8 @@ class FISTAReconstructor(): if self.pars['ideal_image'] == None: pass else: - self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + self.pars['region_of_interest'] = numpy.nonzero( + self.pars['ideal_image']>0.0) if not 'regularizer' in kwargs.keys() : self.pars['regularizer'] = None @@ -140,7 +146,29 @@ class FISTAReconstructor(): + def setParameter(self, **kwargs): + '''set named parameter for the regularization engine + raises Exception if the named parameter is not recognized + Typical usage is: + + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + reg.setParameter(input=u0) + reg.setParameter(regularization_parameter=10.) + + it can be also used as + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + reg.setParameter(input=u0 , regularization_parameter=10.) + ''' + + for key , value in kwargs.items(): + if key in self.acceptedInputKeywords.keys(): + self.pars[key] = value + else: + raise Exception('Wrong parameter {0} for '.format(key) + + 'Reconstruction algorithm') + # setParameter + def calculateLipschitzConstantWithPowerMethod(self): ''' using Power method (PM) to establish L constant''' @@ -152,7 +180,8 @@ class FISTAReconstructor(): - if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + if (proj_geom['type'] == 'parallel') or \ + (proj_geom['type'] == 'parallel3d'): #% for parallel geometry we can do just one slice #print('Calculating Lipshitz constant for parallel beam geometry...') niter = 5;# % number of iteration for the PM @@ -262,128 +291,3 @@ class FISTAReconstructor(): - - -def getEntry(location, nx): - for item in nx[location].keys(): - print (item) - - -print ("Loading Data") - -##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" -####ind = [i * 1049 for i in range(360)] -#### use only 360 images -##images = 200 -##ind = [int(i * 1049 / images) for i in range(images)] -##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) - -#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" -#fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" -##fname = "/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/dendr.h5" -##nx = h5py.File(fname, "r") -## -### the data are stored in a particular location in the hdf5 -##for item in nx['entry1/tomo_entry/data'].keys(): -## print (item) -## -##data = nx.get('entry1/tomo_entry/data/rotation_angle') -##angles = numpy.zeros(data.shape) -##data.read_direct(angles) -##print (angles) -### angles should be in degrees -## -##data = nx.get('entry1/tomo_entry/data/data') -##stack = numpy.zeros(data.shape) -##data.read_direct(stack) -##print (data.shape) -## -##print ("Data Loaded") -## -## -### Normalize -##data = nx.get('entry1/tomo_entry/instrument/detector/image_key') -##itype = numpy.zeros(data.shape) -##data.read_direct(itype) -### 2 is dark field -##darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] -##dark = darks[0] -##for i in range(1, len(darks)): -## dark += darks[i] -##dark = dark / len(darks) -###dark[0][0] = dark[0][1] -## -### 1 is flat field -##flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] -##flat = flats[0] -##for i in range(1, len(flats)): -## flat += flats[i] -##flat = flat / len(flats) -###flat[0][0] = dark[0][1] -## -## -### 0 is projection data -##proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] -##angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] -##angle_proj = numpy.asarray (angle_proj) -##angle_proj = angle_proj.astype(numpy.float32) -## -### normalized data are -### norm = (projection - dark)/(flat-dark) -## -##def normalize(projection, dark, flat, def_val=0.1): -## a = (projection - dark) -## b = (flat-dark) -## with numpy.errstate(divide='ignore', invalid='ignore'): -## c = numpy.true_divide( a, b ) -## c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 -## return c -## -## -##norm = [normalize(projection, dark, flat) for projection in proj] -##norm = numpy.asarray (norm) -##norm = norm.astype(numpy.float32) - - -##niterations = 15 -##threads = 3 -## -##img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -##img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -##img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -## -##iteration_values = numpy.zeros((niterations,)) -##img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, -## iteration_values, False) -##print ("iteration values %s" % str(iteration_values)) -## -##iteration_values = numpy.zeros((niterations,)) -##img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, -## numpy.double(1e-5), iteration_values , False) -##print ("iteration values %s" % str(iteration_values)) -##iteration_values = numpy.zeros((niterations,)) -##img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, -## numpy.double(1e-5), iteration_values , False) -##print ("iteration values %s" % str(iteration_values)) -## -## -####numpy.save("cgls_recon.npy", img_data) -##import matplotlib.pyplot as plt -##fig, ax = plt.subplots(1,6,sharey=True) -##ax[0].imshow(img_cgls[80]) -##ax[0].axis('off') # clear x- and y-axes -##ax[1].imshow(img_sirt[80]) -##ax[1].axis('off') # clear x- and y-axes -##ax[2].imshow(img_mlem[80]) -##ax[2].axis('off') # clear x- and y-axesplt.show() -##ax[3].imshow(img_cgls_conv[80]) -##ax[3].axis('off') # clear x- and y-axesplt.show() -##ax[4].imshow(img_cgls_tikhonov[80]) -##ax[4].axis('off') # clear x- and y-axesplt.show() -##ax[5].imshow(img_cgls_TVreg[80]) -##ax[5].axis('off') # clear x- and y-axesplt.show() -## -## -##plt.show() -## - -- cgit v1.2.3 From 3c2815ec1d0ddd9d00a5c1f454fcecc060126623 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Tue, 17 Oct 2017 09:35:11 +0100 Subject: Added many methods --- src/Python/ccpi/fista/FISTAReconstructor.py | 184 ++++++++++++++++++++++++---- 1 file changed, 160 insertions(+), 24 deletions(-) (limited to 'src/Python/ccpi') diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py index cbd27da..8318ea6 100644 --- a/src/Python/ccpi/fista/FISTAReconstructor.py +++ b/src/Python/ccpi/fista/FISTAReconstructor.py @@ -78,19 +78,28 @@ class FISTAReconstructor(): # handle parmeters: # obligatory parameters self.pars = dict() - self.pars['projector_geometry'] = projector_geometry - self.pars['output_geometry'] = output_geometry - self.pars['input_sinogram'] = input_sinogram + self.pars['projector_geometry'] = projector_geometry # proj_geom + self.pars['output_geometry'] = output_geometry # vol_geom + self.pars['input_sinogram'] = input_sinogram # sino detectors, nangles, sliceZ = numpy.shape(input_sinogram) self.pars['detectors'] = detectors - self.pars['number_og_angles'] = nangles + self.pars['number_of_angles'] = nangles self.pars['SlicesZ'] = sliceZ print (self.pars) # handle optional input parameters (at instantiation) # Accepted input keywords - kw = ('number_of_iterations', + kw = ( + # mandatory fields + 'projector_geometry', + 'output_geometry', + 'input_sinogram', + 'detectors', + 'number_of_angles', + 'SlicesZ', + # optional fields + 'number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , 'weights' , @@ -98,8 +107,9 @@ class FISTAReconstructor(): 'initialize' , 'regularizer' , 'ring_lambda_R_L1', - 'ring_alpha') - self.acceptedInputKeywords = kw + 'ring_alpha', + 'subsets') + self.acceptedInputKeywords = list(kw) # handle keyworded parameters if kwargs is not None: @@ -122,8 +132,7 @@ class FISTAReconstructor(): if 'Lipschitz_constant' in kwargs.keys(): self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] else: - self.pars['Lipschitz_constant'] = \ - self.calculateLipschitzConstantWithPowerMethod() + self.pars['Lipschitz_constant'] = None if not 'ideal_image' in kwargs.keys(): self.pars['ideal_image'] = None @@ -143,31 +152,44 @@ class FISTAReconstructor(): self.pars['ring_lambda_R_L1'] = 0 if not 'ring_alpha' in kwargs.keys(): self.pars['ring_alpha'] = 1 - + + if not 'subsets' in kwargs.keys(): + self.pars['subsets'] = 0 + else: + self.createOrderedSubsets() + + if not 'initialize' in kwargs.keys(): + self.pars['initialize'] = False def setParameter(self, **kwargs): - '''set named parameter for the regularization engine + '''set named parameter for the reconstructor engine raises Exception if the named parameter is not recognized - Typical usage is: - - reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) - reg.setParameter(input=u0) - reg.setParameter(regularization_parameter=10.) - it can be also used as - reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) - reg.setParameter(input=u0 , regularization_parameter=10.) ''' - for key , value in kwargs.items(): - if key in self.acceptedInputKeywords.keys(): + if key in self.acceptedInputKeywords: self.pars[key] = value else: - raise Exception('Wrong parameter {0} for '.format(key) + - 'Reconstruction algorithm') + raise Exception('Wrong parameter {0} for '.format(key) + + 'reconstructor') # setParameter + + def getParameter(self, key): + if type(key) is str: + if key in self.acceptedInputKeywords: + return self.pars[key] + else: + raise Exception('Unrecongnised parameter: {0} '.format(key) ) + elif type(key) is list: + outpars = [] + for k in key: + outpars.append(self.getParameter(k)) + return outpars + else: + raise Exception('Unhandled input {0}' .format(str(type(key)))) + def calculateLipschitzConstantWithPowerMethod(self): ''' using Power method (PM) to establish L constant''' @@ -289,5 +311,119 @@ class FISTAReconstructor(): if regularizer is not None: self.pars['regularizer'] = regularizer + + def initialize(self): + # convenience variable storage + proj_geom = self.pars['projector_geometry'] + vol_geom = self.pars['output_geometry'] + sino = self.pars['input_sinogram'] + + # a 'warm start' with SIRT method + # Create a data object for the reconstruction + rec_id = astra.matlab.data3d('create', '-vol', + vol_geom); + + #sinogram_id = astra_mex_data3d('create', '-proj3d', proj_geom, sino); + sinogram_id = astra.matlab.data3d('create', '-proj3d', + proj_geom, + sino) + + sirt_config = astra.astra_dict('SIRT3D_CUDA') + sirt_config['ReconstructionDataId' ] = rec_id + sirt_config['ProjectionDataId'] = sinogram_id + + sirt = astra.algorithm.create(sirt_config) + astra.algorithm.run(sirt, iterations=35) + X = astra.matlab.data3d('get', rec_id) + + # clean up memory + astra.matlab.data3d('delete', rec_id) + astra.matlab.data3d('delete', sinogram_id) + astra.algorithm.delete(sirt) + + + + return X + + def createOrderedSubsets(self, subsets=None): + if subsets is None: + try: + subsets = self.getParameter('subsets') + except Exception(): + subsets = 0 + #return subsets + + angles = self.getParameter('projector_geometry')['ProjectionAngles'] + + + + + + + def prepareForIteration(self): + self.residual_error = numpy.zeros((self.pars['number_of_iterations'])) + self.objective = numpy.zeros((self.pars['number_of_iterations'])) + + #2D array (for 3D data) of sparse "ring" + detectors, nangles, sliceZ = numpy.shape(self.pars['input_sinogram']) + self.r = numpy.zeros((detectors, sliceZ), dtype=numpy.float) + # another ring variable + self.rx = self.r.copy() + + self.residual = numpy.zeros(numpy.shape(self.pars['input_sinogram'])) + + if self.getParameter('Lipschitz_constant') is None: + self.pars['Lipschitz_constant'] = \ + self.calculateLipschitzConstantWithPowerMethod() + + # prepareForIteration + + def iterate(self, Xin=None): + # convenience variable storage + proj_geom , vol_geom, sino , \ + SlicesZ = self.getParameter(['projector_geometry' , + 'output_geometry', + 'input_sinogram', + 'SlicesZ']) + + t = 1 + if Xin is None: + if self.getParameter('initialize'): + X = self.initialize() + else: + N = vol_geom['GridColCount'] + X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) + else: + X = Xin.copy() + + X_t = X.copy() + + for i in range(self.getParameter('number_of_iterations')): + X_old = X.copy() + t_old = t + r_old = self.r.copy() + if self.pars['projector_geometry']['type'] == 'parallel' or \ + self.pars['projector_geometry']['type'] == 'parallel3d': + # if the geometry is parallel use slice-by-slice + # projection-backprojection routine + #sino_updt = zeros(size(sino),'single'); + sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) + + #for kkk = 1:SlicesZ + # [sino_id, sino_updt(:,:,kkk)] = + # astra_create_sino3d_cuda(X_t(:,:,kkk), proj_geomT, vol_geomT); + # astra_mex_data3d('delete', sino_id); + for kkk in range(SlicesZ): + sino_id, sino_updt[kkk] = \ + astra.creators.create_sino3d_gpu( + X_t[kkk], proj_geomT, vol_geomT) + + else: + # for divergent 3D geometry (watch GPU memory overflow in + # Astra < 1.8 + sino_id, y = astra.creators.create_sino3d_gpu(X_t, + proj_geom, + vol_geom) - + + -- cgit v1.2.3 From dd30175d2a198a44c92cdbdb40c3512f15a637e8 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Tue, 17 Oct 2017 09:37:25 +0100 Subject: Squashing 2 commits: Added and removed hdf5 (too big) Added data in hdf5 format removed hdf5 data --- src/Python/ccpi/reconstruction/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/Python/ccpi/reconstruction/__init__.py (limited to 'src/Python/ccpi') diff --git a/src/Python/ccpi/reconstruction/__init__.py b/src/Python/ccpi/reconstruction/__init__.py new file mode 100644 index 0000000..e69de29 -- cgit v1.2.3 From 2014650ab9fbf5a7d1c7334fa54ac0b1c5908915 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Tue, 17 Oct 2017 17:01:12 +0100 Subject: Progress in pythonization --- src/Python/ccpi/fista/FISTAReconstructor.py | 104 +++++++++++++++++++--------- 1 file changed, 71 insertions(+), 33 deletions(-) (limited to 'src/Python/ccpi') diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py index 8318ea6..87dd2c0 100644 --- a/src/Python/ccpi/fista/FISTAReconstructor.py +++ b/src/Python/ccpi/fista/FISTAReconstructor.py @@ -81,7 +81,7 @@ class FISTAReconstructor(): self.pars['projector_geometry'] = projector_geometry # proj_geom self.pars['output_geometry'] = output_geometry # vol_geom self.pars['input_sinogram'] = input_sinogram # sino - detectors, nangles, sliceZ = numpy.shape(input_sinogram) + sliceZ, nangles, detectors = numpy.shape(input_sinogram) self.pars['detectors'] = detectors self.pars['number_of_angles'] = nangles self.pars['SlicesZ'] = sliceZ @@ -108,7 +108,9 @@ class FISTAReconstructor(): 'regularizer' , 'ring_lambda_R_L1', 'ring_alpha', - 'subsets') + 'subsets', + 'use_studentt_fidelity', + 'studentt') self.acceptedInputKeywords = list(kw) # handle keyworded parameters @@ -143,16 +145,18 @@ class FISTAReconstructor(): else: self.pars['region_of_interest'] = numpy.nonzero( self.pars['ideal_image']>0.0) - + + # the regularizer must be a correctly instantiated object if not 'regularizer' in kwargs.keys() : self.pars['regularizer'] = None - else: - # the regularizer must be a correctly instantiated object - if not 'ring_lambda_R_L1' in kwargs.keys(): - self.pars['ring_lambda_R_L1'] = 0 - if not 'ring_alpha' in kwargs.keys(): - self.pars['ring_alpha'] = 1 + #RING REMOVAL + if not 'ring_lambda_R_L1' in kwargs.keys(): + self.pars['ring_lambda_R_L1'] = 0 + if not 'ring_alpha' in kwargs.keys(): + self.pars['ring_alpha'] = 1 + + # ORDERED SUBSET if not 'subsets' in kwargs.keys(): self.pars['subsets'] = 0 else: @@ -160,6 +164,15 @@ class FISTAReconstructor(): if not 'initialize' in kwargs.keys(): self.pars['initialize'] = False + + if not 'use_studentt_fidelity' in kwargs.keys(): + self.setParameter(studentt=False) + else: + print ("studentt {0}".format(kwargs['use_studentt_fidelity'])) + if kwargs['use_studentt_fidelity']: + raise Exception('Not implemented') + + self.setParameter(studentt=kwargs['use_studentt_fidelity']) def setParameter(self, **kwargs): @@ -170,6 +183,8 @@ class FISTAReconstructor(): ''' for key , value in kwargs.items(): if key in self.acceptedInputKeywords: + if key == 'use_studentt_fidelity': + raise Exception('use_studentt_fidelity Not implemented') self.pars[key] = value else: raise Exception('Wrong parameter {0} for '.format(key) + @@ -354,10 +369,28 @@ class FISTAReconstructor(): #return subsets angles = self.getParameter('projector_geometry')['ProjectionAngles'] - - - + #binEdges = numpy.linspace(angles.min(), + # angles.max(), + # subsets + 1) + binsDiscr, binEdges = numpy.histogram(angles, bins=subsets) + # get rearranged subset indices + IndicesReorg = numpy.zeros((numpy.shape(angles))) + counterM = 0 + for ii in range(binsDiscr.max()): + counter = 0 + for jj in range(subsets): + curr_index = ii + jj + counter + #print ("{0} {1} {2}".format(binsDiscr[jj] , ii, counterM)) + if binsDiscr[jj] > ii: + if (counterM < numpy.size(IndicesReorg)): + IndicesReorg[counterM] = curr_index + counterM = counterM + 1 + + counter = counter + binsDiscr[jj] - 1 + + + return IndicesReorg def prepareForIteration(self): @@ -368,23 +401,24 @@ class FISTAReconstructor(): detectors, nangles, sliceZ = numpy.shape(self.pars['input_sinogram']) self.r = numpy.zeros((detectors, sliceZ), dtype=numpy.float) # another ring variable - self.rx = self.r.copy() + self.r_x = self.r.copy() self.residual = numpy.zeros(numpy.shape(self.pars['input_sinogram'])) if self.getParameter('Lipschitz_constant') is None: self.pars['Lipschitz_constant'] = \ self.calculateLipschitzConstantWithPowerMethod() + # prepareForIteration def iterate(self, Xin=None): # convenience variable storage proj_geom , vol_geom, sino , \ - SlicesZ = self.getParameter(['projector_geometry' , - 'output_geometry', - 'input_sinogram', - 'SlicesZ']) + SlicesZ = self.getParameter([ 'projector_geometry' , + 'output_geometry', + 'input_sinogram', + 'SlicesZ']) t = 1 if Xin is None: @@ -394,7 +428,8 @@ class FISTAReconstructor(): N = vol_geom['GridColCount'] X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) else: - X = Xin.copy() + # copy by reference + X = Xin X_t = X.copy() @@ -402,28 +437,31 @@ class FISTAReconstructor(): X_old = X.copy() t_old = t r_old = self.r.copy() - if self.pars['projector_geometry']['type'] == 'parallel' or \ - self.pars['projector_geometry']['type'] == 'parallel3d': + if self.getParameter('projector_geometry')['type'] == 'parallel' or \ + self.getParameter('projector_geometry')['type'] == 'parallel3d': # if the geometry is parallel use slice-by-slice # projection-backprojection routine #sino_updt = zeros(size(sino),'single'); + proj_geomT = proj_geom.copy() + proj_geomT['DetectorRowCount'] = 1 + vol_geomT = vol_geom.copy() + vol_geomT['GridSliceCount'] = 1; sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) - - #for kkk = 1:SlicesZ - # [sino_id, sino_updt(:,:,kkk)] = - # astra_create_sino3d_cuda(X_t(:,:,kkk), proj_geomT, vol_geomT); - # astra_mex_data3d('delete', sino_id); for kkk in range(SlicesZ): + print (kkk) sino_id, sino_updt[kkk] = \ astra.creators.create_sino3d_gpu( - X_t[kkk], proj_geomT, vol_geomT) - + X_t[kkk:kkk+1], proj_geomT, vol_geomT) + astra.matlab.data3d('delete', sino_id) else: - # for divergent 3D geometry (watch GPU memory overflow in - # Astra < 1.8 - sino_id, y = astra.creators.create_sino3d_gpu(X_t, - proj_geom, - vol_geom) - + # for divergent 3D geometry (watch the GPU memory overflow in + # ASTRA versions < 1.8) + #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); + sino_id, sino_updt = astra.matlab.create_sino3d_gpu( + X_t, proj_geom, vol_geom) + + + ## RING REMOVAL + ## REGULARIZATION -- cgit v1.2.3 From 1af73a75ccab1147a8d2387b7056f91f0642549f Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 18 Oct 2017 10:04:32 +0100 Subject: removed viewer package from tree --- src/Python/ccpi/viewer/CILViewer.py | 361 ------- src/Python/ccpi/viewer/CILViewer2D.py | 1126 -------------------- src/Python/ccpi/viewer/QVTKWidget.py | 340 ------ src/Python/ccpi/viewer/QVTKWidget2.py | 84 -- src/Python/ccpi/viewer/__init__.py | 1 - .../viewer/__pycache__/CILViewer.cpython-35.pyc | Bin 10542 -> 0 bytes .../viewer/__pycache__/CILViewer2D.cpython-35.pyc | Bin 35633 -> 0 bytes .../viewer/__pycache__/QVTKWidget.cpython-35.pyc | Bin 10099 -> 0 bytes .../viewer/__pycache__/QVTKWidget2.cpython-35.pyc | Bin 1316 -> 0 bytes .../viewer/__pycache__/__init__.cpython-35.pyc | Bin 210 -> 0 bytes src/Python/ccpi/viewer/embedvtk.py | 75 -- 11 files changed, 1987 deletions(-) delete mode 100644 src/Python/ccpi/viewer/CILViewer.py delete mode 100644 src/Python/ccpi/viewer/CILViewer2D.py delete mode 100644 src/Python/ccpi/viewer/QVTKWidget.py delete mode 100644 src/Python/ccpi/viewer/QVTKWidget2.py delete mode 100644 src/Python/ccpi/viewer/__init__.py delete mode 100644 src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc delete mode 100644 src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc delete mode 100644 src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc delete mode 100644 src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc delete mode 100644 src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc delete mode 100644 src/Python/ccpi/viewer/embedvtk.py (limited to 'src/Python/ccpi') diff --git a/src/Python/ccpi/viewer/CILViewer.py b/src/Python/ccpi/viewer/CILViewer.py deleted file mode 100644 index efcf8be..0000000 --- a/src/Python/ccpi/viewer/CILViewer.py +++ /dev/null @@ -1,361 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2017 Edoardo Pasca -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import vtk -import numpy -import math -from vtk.util import numpy_support - -SLICE_ORIENTATION_XY = 2 # Z -SLICE_ORIENTATION_XZ = 1 # Y -SLICE_ORIENTATION_YZ = 0 # X - - - -class CILViewer(): - '''Simple 3D Viewer based on VTK classes''' - - def __init__(self, dimx=600,dimy=600): - '''creates the rendering pipeline''' - - # create a rendering window and renderer - self.ren = vtk.vtkRenderer() - self.renWin = vtk.vtkRenderWindow() - self.renWin.SetSize(dimx,dimy) - self.renWin.AddRenderer(self.ren) - - # img 3D as slice - self.img3D = None - self.sliceno = 0 - self.sliceOrientation = SLICE_ORIENTATION_XY - self.sliceActor = None - self.voi = None - self.wl = None - self.ia = None - self.sliceActorNo = 0 - # create a renderwindowinteractor - self.iren = vtk.vtkRenderWindowInteractor() - self.iren.SetRenderWindow(self.renWin) - - self.style = vtk.vtkInteractorStyleTrackballCamera() - self.iren.SetInteractorStyle(self.style) - - self.ren.SetBackground(.1, .2, .4) - - self.actors = {} - self.iren.RemoveObservers('MouseWheelForwardEvent') - self.iren.RemoveObservers('MouseWheelBackwardEvent') - - self.iren.AddObserver('MouseWheelForwardEvent', self.mouseInteraction, 1.0) - self.iren.AddObserver('MouseWheelBackwardEvent', self.mouseInteraction, 1.0) - - self.iren.RemoveObservers('KeyPressEvent') - self.iren.AddObserver('KeyPressEvent', self.keyPress, 1.0) - - - self.iren.Initialize() - - - - def getRenderer(self): - '''returns the renderer''' - return self.ren - - def getRenderWindow(self): - '''returns the render window''' - return self.renWin - - def getInteractor(self): - '''returns the render window interactor''' - return self.iren - - def getCamera(self): - '''returns the active camera''' - return self.ren.GetActiveCamera() - - def createPolyDataActor(self, polydata): - '''returns an actor for a given polydata''' - mapper = vtk.vtkPolyDataMapper() - if vtk.VTK_MAJOR_VERSION <= 5: - mapper.SetInput(polydata) - else: - mapper.SetInputData(polydata) - - # actor - actor = vtk.vtkActor() - actor.SetMapper(mapper) - #actor.GetProperty().SetOpacity(0.8) - return actor - - def setPolyDataActor(self, actor): - '''displays the given polydata''' - - self.ren.AddActor(actor) - - self.actors[len(self.actors)+1] = [actor, True] - self.iren.Initialize() - self.renWin.Render() - - def displayPolyData(self, polydata): - self.setPolyDataActor(self.createPolyDataActor(polydata)) - - def hideActor(self, actorno): - '''Hides an actor identified by its number in the list of actors''' - try: - if self.actors[actorno][1]: - self.ren.RemoveActor(self.actors[actorno][0]) - self.actors[actorno][1] = False - except KeyError as ke: - print ("Warning Actor not present") - - def showActor(self, actorno, actor = None): - '''Shows hidden actor identified by its number in the list of actors''' - try: - if not self.actors[actorno][1]: - self.ren.AddActor(self.actors[actorno][0]) - self.actors[actorno][1] = True - return actorno - except KeyError as ke: - # adds it to the actors if not there already - if actor != None: - self.ren.AddActor(actor) - self.actors[len(self.actors)+1] = [actor, True] - return len(self.actors) - - def addActor(self, actor): - '''Adds an actor to the render''' - return self.showActor(0, actor) - - - def saveRender(self, filename, renWin=None): - '''Save the render window to PNG file''' - # screenshot code: - w2if = vtk.vtkWindowToImageFilter() - if renWin == None: - renWin = self.renWin - w2if.SetInput(renWin) - w2if.Update() - - writer = vtk.vtkPNGWriter() - writer.SetFileName("%s.png" % (filename)) - writer.SetInputConnection(w2if.GetOutputPort()) - writer.Write() - - - def startRenderLoop(self): - self.iren.Start() - - - def setupObservers(self, interactor): - interactor.RemoveObservers('LeftButtonPressEvent') - interactor.AddObserver('LeftButtonPressEvent', self.mouseInteraction) - interactor.Initialize() - - - def mouseInteraction(self, interactor, event): - if event == 'MouseWheelForwardEvent': - maxSlice = self.img3D.GetDimensions()[self.sliceOrientation] - if (self.sliceno + 1 < maxSlice): - self.hideActor(self.sliceActorNo) - self.sliceno = self.sliceno + 1 - self.displaySliceActor(self.sliceno) - else: - minSlice = 0 - if (self.sliceno - 1 > minSlice): - self.hideActor(self.sliceActorNo) - self.sliceno = self.sliceno - 1 - self.displaySliceActor(self.sliceno) - - - def keyPress(self, interactor, event): - #print ("Pressed key %s" % interactor.GetKeyCode()) - # Slice Orientation - if interactor.GetKeyCode() == "x": - # slice on the other orientation - self.sliceOrientation = SLICE_ORIENTATION_YZ - self.sliceno = int(self.img3D.GetDimensions()[1] / 2) - self.hideActor(self.sliceActorNo) - self.displaySliceActor(self.sliceno) - elif interactor.GetKeyCode() == "y": - # slice on the other orientation - self.sliceOrientation = SLICE_ORIENTATION_XZ - self.sliceno = int(self.img3D.GetDimensions()[1] / 2) - self.hideActor(self.sliceActorNo) - self.displaySliceActor(self.sliceno) - elif interactor.GetKeyCode() == "z": - # slice on the other orientation - self.sliceOrientation = SLICE_ORIENTATION_XY - self.sliceno = int(self.img3D.GetDimensions()[2] / 2) - self.hideActor(self.sliceActorNo) - self.displaySliceActor(self.sliceno) - if interactor.GetKeyCode() == "X": - # Change the camera view point - camera = vtk.vtkCamera() - camera.SetFocalPoint(self.ren.GetActiveCamera().GetFocalPoint()) - camera.SetViewUp(self.ren.GetActiveCamera().GetViewUp()) - newposition = [i for i in self.ren.GetActiveCamera().GetFocalPoint()] - newposition[SLICE_ORIENTATION_YZ] = math.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) - camera.SetPosition(newposition) - camera.SetViewUp(0,0,-1) - self.ren.SetActiveCamera(camera) - self.ren.ResetCamera() - self.ren.Render() - interactor.SetKeyCode("x") - self.keyPress(interactor, event) - elif interactor.GetKeyCode() == "Y": - # Change the camera view point - camera = vtk.vtkCamera() - camera.SetFocalPoint(self.ren.GetActiveCamera().GetFocalPoint()) - camera.SetViewUp(self.ren.GetActiveCamera().GetViewUp()) - newposition = [i for i in self.ren.GetActiveCamera().GetFocalPoint()] - newposition[SLICE_ORIENTATION_XZ] = math.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_YZ] ** 2) - camera.SetPosition(newposition) - camera.SetViewUp(0,0,-1) - self.ren.SetActiveCamera(camera) - self.ren.ResetCamera() - self.ren.Render() - interactor.SetKeyCode("y") - self.keyPress(interactor, event) - elif interactor.GetKeyCode() == "Z": - # Change the camera view point - camera = vtk.vtkCamera() - camera.SetFocalPoint(self.ren.GetActiveCamera().GetFocalPoint()) - camera.SetViewUp(self.ren.GetActiveCamera().GetViewUp()) - newposition = [i for i in self.ren.GetActiveCamera().GetFocalPoint()] - newposition[SLICE_ORIENTATION_XY] = math.sqrt(newposition[SLICE_ORIENTATION_YZ] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) - camera.SetPosition(newposition) - camera.SetViewUp(0,0,-1) - self.ren.SetActiveCamera(camera) - self.ren.ResetCamera() - self.ren.Render() - interactor.SetKeyCode("z") - self.keyPress(interactor, event) - else : - print ("Unhandled event %s" % interactor.GetKeyCode()) - - - - def setInput3DData(self, imageData): - self.img3D = imageData - - def setInputAsNumpy(self, numpyarray): - if (len(numpy.shape(numpyarray)) == 3): - doubleImg = vtk.vtkImageData() - shape = numpy.shape(numpyarray) - doubleImg.SetDimensions(shape[0], shape[1], shape[2]) - doubleImg.SetOrigin(0,0,0) - doubleImg.SetSpacing(1,1,1) - doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) - #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) - doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) - - for i in range(shape[0]): - for j in range(shape[1]): - for k in range(shape[2]): - doubleImg.SetScalarComponentFromDouble( - i,j,k,0, numpyarray[i][j][k]) - #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) - # rescale to appropriate VTK_UNSIGNED_SHORT - stats = vtk.vtkImageAccumulate() - stats.SetInputData(doubleImg) - stats.Update() - iMin = stats.GetMin()[0] - iMax = stats.GetMax()[0] - scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) - - shiftScaler = vtk.vtkImageShiftScale () - shiftScaler.SetInputData(doubleImg) - shiftScaler.SetScale(scale) - shiftScaler.SetShift(iMin) - shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) - shiftScaler.Update() - self.img3D = shiftScaler.GetOutput() - - def displaySliceActor(self, sliceno = 0): - self.sliceno = sliceno - first = False - - self.sliceActor , self.voi, self.wl , self.ia = \ - self.getSliceActor(self.img3D, - sliceno, - self.sliceActor, - self.voi, - self.wl, - self.ia) - no = self.showActor(self.sliceActorNo, self.sliceActor) - self.sliceActorNo = no - - self.iren.Initialize() - self.renWin.Render() - - return self.sliceActorNo - - - def getSliceActor(self, - imageData , - sliceno=0, - imageActor=None , - voi=None, - windowLevel=None, - imageAccumulate=None): - '''Slices a 3D volume and then creates an actor to be rendered''' - if (voi==None): - voi = vtk.vtkExtractVOI() - #voi = vtk.vtkImageClip() - voi.SetInputData(imageData) - #select one slice in Z - extent = [ i for i in self.img3D.GetExtent()] - extent[self.sliceOrientation * 2] = sliceno - extent[self.sliceOrientation * 2 + 1] = sliceno - voi.SetVOI(extent[0], extent[1], - extent[2], extent[3], - extent[4], extent[5]) - - voi.Update() - # set window/level for all slices - if imageAccumulate == None: - imageAccumulate = vtk.vtkImageAccumulate() - - if (windowLevel == None): - windowLevel = vtk.vtkImageMapToWindowLevelColors() - imageAccumulate.SetInputData(imageData) - imageAccumulate.Update() - cmax = imageAccumulate.GetMax()[0] - cmin = imageAccumulate.GetMin()[0] - windowLevel.SetLevel((cmax+cmin)/2) - windowLevel.SetWindow(cmax-cmin) - - windowLevel.SetInputData(voi.GetOutput()) - windowLevel.Update() - - if imageActor == None: - imageActor = vtk.vtkImageActor() - imageActor.SetInputData(windowLevel.GetOutput()) - imageActor.SetDisplayExtent(extent[0], extent[1], - extent[2], extent[3], - extent[4], extent[5]) - imageActor.Update() - return (imageActor , voi, windowLevel, imageAccumulate) - - - # Set interpolation on - def setInterpolateOn(self): - self.sliceActor.SetInterpolate(True) - self.renWin.Render() - - # Set interpolation off - def setInterpolateOff(self): - self.sliceActor.SetInterpolate(False) - self.renWin.Render() \ No newline at end of file diff --git a/src/Python/ccpi/viewer/CILViewer2D.py b/src/Python/ccpi/viewer/CILViewer2D.py deleted file mode 100644 index c1629af..0000000 --- a/src/Python/ccpi/viewer/CILViewer2D.py +++ /dev/null @@ -1,1126 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2017 Edoardo Pasca -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import vtk -import numpy -from vtk.util import numpy_support , vtkImageImportFromArray -from enum import Enum - -SLICE_ORIENTATION_XY = 2 # Z -SLICE_ORIENTATION_XZ = 1 # Y -SLICE_ORIENTATION_YZ = 0 # X - -CONTROL_KEY = 8 -SHIFT_KEY = 4 -ALT_KEY = -128 - - -# Converter class -class Converter(): - - # Utility functions to transform numpy arrays to vtkImageData and viceversa - @staticmethod - def numpy2vtkImporter(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): - '''Creates a vtkImageImportFromArray object and returns it. - - It handles the different axis order from numpy to VTK''' - importer = vtkImageImportFromArray.vtkImageImportFromArray() - importer.SetArray(numpy.transpose(nparray).copy()) - importer.SetDataSpacing(spacing) - importer.SetDataOrigin(origin) - return importer - - @staticmethod - def numpy2vtk(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): - '''Converts a 3D numpy array to a vtkImageData''' - importer = Converter.numpy2vtkImporter(nparray, spacing, origin) - importer.Update() - return importer.GetOutput() - - @staticmethod - def vtk2numpy(imgdata): - '''Converts the VTK data to 3D numpy array''' - img_data = numpy_support.vtk_to_numpy( - imgdata.GetPointData().GetScalars()) - - dims = imgdata.GetDimensions() - dims = (dims[2],dims[1],dims[0]) - data3d = numpy.reshape(img_data, dims) - - return numpy.transpose(data3d).copy() - - @staticmethod - def tiffStack2numpy(filename, indices, - extent = None , sampleRate = None ,\ - flatField = None, darkField = None): - '''Converts a stack of TIFF files to numpy array. - - filename must contain the whole path. The filename is supposed to be named and - have a suffix with the ordinal file number, i.e. /path/to/projection_%03d.tif - - indices are the suffix, generally an increasing number - - Optionally extracts only a selection of the 2D images and (optionally) - normalizes. - ''' - - stack = vtk.vtkImageData() - reader = vtk.vtkTIFFReader() - voi = vtk.vtkExtractVOI() - - #directory = "C:\\Users\\ofn77899\\Documents\\CCPi\\IMAT\\20170419_crabtomo\\crabtomo\\" - - stack_image = numpy.asarray([]) - nreduced = len(indices) - - for num in range(len(indices)): - fn = filename % indices[num] - print ("resampling %s" % ( fn ) ) - reader.SetFileName(fn) - reader.Update() - print (reader.GetOutput().GetScalarTypeAsString()) - if num == 0: - if (extent == None): - sliced = reader.GetOutput().GetExtent() - stack.SetExtent(sliced[0],sliced[1], sliced[2],sliced[3], 0, nreduced-1) - else: - sliced = extent - voi.SetVOI(extent) - - if sampleRate is not None: - voi.SetSampleRate(sampleRate) - ext = numpy.asarray([(sliced[2*i+1] - sliced[2*i])/sampleRate[i] for i in range(3)], dtype=int) - print ("ext {0}".format(ext)) - stack.SetExtent(0, ext[0] , 0, ext[1], 0, nreduced-1) - else: - stack.SetExtent(0, sliced[1] - sliced[0] , 0, sliced[3]-sliced[2], 0, nreduced-1) - if (flatField != None and darkField != None): - stack.AllocateScalars(vtk.VTK_FLOAT, 1) - else: - stack.AllocateScalars(reader.GetOutput().GetScalarType(), 1) - print ("Image Size: %d" % ((sliced[1]+1)*(sliced[3]+1) )) - stack_image = Converter.vtk2numpy(stack) - print ("Stack shape %s" % str(numpy.shape(stack_image))) - - if extent!=None: - voi.SetInputData(reader.GetOutput()) - voi.Update() - img = voi.GetOutput() - else: - img = reader.GetOutput() - - theSlice = Converter.vtk2numpy(img).T[0] - if darkField != None and flatField != None: - print("Try to normalize") - #if numpy.shape(darkField) == numpy.shape(flatField) and numpy.shape(flatField) == numpy.shape(theSlice): - theSlice = Converter.normalize(theSlice, darkField, flatField, 0.01) - print (theSlice.dtype) - - - print ("Slice shape %s" % str(numpy.shape(theSlice))) - stack_image.T[num] = theSlice.copy() - - return stack_image - - @staticmethod - def normalize(projection, dark, flat, def_val=0): - a = (projection - dark) - b = (flat-dark) - with numpy.errstate(divide='ignore', invalid='ignore'): - c = numpy.true_divide( a, b ) - c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 - return c - - - -## Utility functions to transform numpy arrays to vtkImageData and viceversa -#def numpy2vtkImporter(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): -# return Converter.numpy2vtkImporter(nparray, spacing, origin) -# -#def numpy2vtk(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): -# return Converter.numpy2vtk(nparray, spacing, origin) -# -#def vtk2numpy(imgdata): -# return Converter.vtk2numpy(imgdata) -# -#def tiffStack2numpy(filename, indices): -# return Converter.tiffStack2numpy(filename, indices) - -class ViewerEvent(Enum): - # left button - PICK_EVENT = 0 - # alt + right button + move - WINDOW_LEVEL_EVENT = 1 - # shift + right button - ZOOM_EVENT = 2 - # control + right button - PAN_EVENT = 3 - # control + left button - CREATE_ROI_EVENT = 4 - # alt + left button - DELETE_ROI_EVENT = 5 - # release button - NO_EVENT = -1 - - -#class CILInteractorStyle(vtk.vtkInteractorStyleTrackballCamera): -class CILInteractorStyle(vtk.vtkInteractorStyleImage): - - def __init__(self, callback): - vtk.vtkInteractorStyleImage.__init__(self) - self.callback = callback - self._viewer = callback - priority = 1.0 - -# self.AddObserver("MouseWheelForwardEvent" , callback.OnMouseWheelForward , priority) -# self.AddObserver("MouseWheelBackwardEvent" , callback.OnMouseWheelBackward, priority) -# self.AddObserver('KeyPressEvent', callback.OnKeyPress, priority) -# self.AddObserver('LeftButtonPressEvent', callback.OnLeftButtonPressEvent, priority) -# self.AddObserver('RightButtonPressEvent', callback.OnRightButtonPressEvent, priority) -# self.AddObserver('LeftButtonReleaseEvent', callback.OnLeftButtonReleaseEvent, priority) -# self.AddObserver('RightButtonReleaseEvent', callback.OnRightButtonReleaseEvent, priority) -# self.AddObserver('MouseMoveEvent', callback.OnMouseMoveEvent, priority) - - self.AddObserver("MouseWheelForwardEvent" , self.OnMouseWheelForward , priority) - self.AddObserver("MouseWheelBackwardEvent" , self.OnMouseWheelBackward, priority) - self.AddObserver('KeyPressEvent', self.OnKeyPress, priority) - self.AddObserver('LeftButtonPressEvent', self.OnLeftButtonPressEvent, priority) - self.AddObserver('RightButtonPressEvent', self.OnRightButtonPressEvent, priority) - self.AddObserver('LeftButtonReleaseEvent', self.OnLeftButtonReleaseEvent, priority) - self.AddObserver('RightButtonReleaseEvent', self.OnRightButtonReleaseEvent, priority) - self.AddObserver('MouseMoveEvent', self.OnMouseMoveEvent, priority) - - self.InitialEventPosition = (0,0) - - - def SetInitialEventPosition(self, xy): - self.InitialEventPosition = xy - - def GetInitialEventPosition(self): - return self.InitialEventPosition - - def GetKeyCode(self): - return self.GetInteractor().GetKeyCode() - - def SetKeyCode(self, keycode): - self.GetInteractor().SetKeyCode(keycode) - - def GetControlKey(self): - return self.GetInteractor().GetControlKey() == CONTROL_KEY - - def GetShiftKey(self): - return self.GetInteractor().GetShiftKey() == SHIFT_KEY - - def GetAltKey(self): - return self.GetInteractor().GetAltKey() == ALT_KEY - - def GetEventPosition(self): - return self.GetInteractor().GetEventPosition() - - def GetEventPositionInWorldCoordinates(self): - pass - - def GetDeltaEventPosition(self): - x,y = self.GetInteractor().GetEventPosition() - return (x - self.InitialEventPosition[0] , y - self.InitialEventPosition[1]) - - def Dolly(self, factor): - self.callback.camera.Dolly(factor) - self.callback.ren.ResetCameraClippingRange() - - def GetDimensions(self): - return self._viewer.img3D.GetDimensions() - - def GetInputData(self): - return self._viewer.img3D - - def GetSliceOrientation(self): - return self._viewer.sliceOrientation - - def SetSliceOrientation(self, orientation): - self._viewer.sliceOrientation = orientation - - def GetActiveSlice(self): - return self._viewer.sliceno - - def SetActiveSlice(self, sliceno): - self._viewer.sliceno = sliceno - - def UpdatePipeline(self, reset = False): - self._viewer.updatePipeline(reset) - - def GetActiveCamera(self): - return self._viewer.ren.GetActiveCamera() - - def SetActiveCamera(self, camera): - self._viewer.ren.SetActiveCamera(camera) - - def ResetCamera(self): - self._viewer.ren.ResetCamera() - - def Render(self): - self._viewer.renWin.Render() - - def UpdateSliceActor(self): - self._viewer.sliceActor.Update() - - def AdjustCamera(self): - self._viewer.AdjustCamera() - - def SaveRender(self, filename): - self._viewer.SaveRender(filename) - - def GetRenderWindow(self): - return self._viewer.renWin - - def GetRenderer(self): - return self._viewer.ren - - def GetROIWidget(self): - return self._viewer.ROIWidget - - def SetViewerEvent(self, event): - self._viewer.event = event - - def GetViewerEvent(self): - return self._viewer.event - - def SetInitialCameraPosition(self, position): - self._viewer.InitialCameraPosition = position - - def GetInitialCameraPosition(self): - return self._viewer.InitialCameraPosition - - def SetInitialLevel(self, level): - self._viewer.InitialLevel = level - - def GetInitialLevel(self): - return self._viewer.InitialLevel - - def SetInitialWindow(self, window): - self._viewer.InitialWindow = window - - def GetInitialWindow(self): - return self._viewer.InitialWindow - - def GetWindowLevel(self): - return self._viewer.wl - - def SetROI(self, roi): - self._viewer.ROI = roi - - def GetROI(self): - return self._viewer.ROI - - def UpdateCornerAnnotation(self, text, corner): - self._viewer.updateCornerAnnotation(text, corner) - - def GetPicker(self): - return self._viewer.picker - - def GetCornerAnnotation(self): - return self._viewer.cornerAnnotation - - def UpdateROIHistogram(self): - self._viewer.updateROIHistogram() - - - ############### Handle events - def OnMouseWheelForward(self, interactor, event): - maxSlice = self.GetDimensions()[self.GetSliceOrientation()] - shift = interactor.GetShiftKey() - advance = 1 - if shift: - advance = 10 - - if (self.GetActiveSlice() + advance < maxSlice): - self.SetActiveSlice(self.GetActiveSlice() + advance) - - self.UpdatePipeline() - else: - print ("maxSlice %d request %d" % (maxSlice, self.GetActiveSlice() + 1 )) - - def OnMouseWheelBackward(self, interactor, event): - minSlice = 0 - shift = interactor.GetShiftKey() - advance = 1 - if shift: - advance = 10 - if (self.GetActiveSlice() - advance >= minSlice): - self.SetActiveSlice( self.GetActiveSlice() - advance) - self.UpdatePipeline() - else: - print ("minSlice %d request %d" % (minSlice, self.GetActiveSlice() + 1 )) - - def OnKeyPress(self, interactor, event): - #print ("Pressed key %s" % interactor.GetKeyCode()) - # Slice Orientation - if interactor.GetKeyCode() == "X": - # slice on the other orientation - self.SetSliceOrientation ( SLICE_ORIENTATION_YZ ) - self.SetActiveSlice( int(self.GetDimensions()[1] / 2) ) - self.UpdatePipeline(True) - elif interactor.GetKeyCode() == "Y": - # slice on the other orientation - self.SetSliceOrientation ( SLICE_ORIENTATION_XZ ) - self.SetActiveSlice ( int(self.GetInputData().GetDimensions()[1] / 2) ) - self.UpdatePipeline(True) - elif interactor.GetKeyCode() == "Z": - # slice on the other orientation - self.SetSliceOrientation ( SLICE_ORIENTATION_XY ) - self.SetActiveSlice ( int(self.GetInputData().GetDimensions()[2] / 2) ) - self.UpdatePipeline(True) - if interactor.GetKeyCode() == "x": - # Change the camera view point - camera = vtk.vtkCamera() - camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) - camera.SetViewUp(self.GetActiveCamera().GetViewUp()) - newposition = [i for i in self.GetActiveCamera().GetFocalPoint()] - newposition[SLICE_ORIENTATION_YZ] = numpy.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) - camera.SetPosition(newposition) - camera.SetViewUp(0,0,-1) - self.SetActiveCamera(camera) - self.Render() - interactor.SetKeyCode("X") - self.OnKeyPress(interactor, event) - elif interactor.GetKeyCode() == "y": - # Change the camera view point - camera = vtk.vtkCamera() - camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) - camera.SetViewUp(self.GetActiveCamera().GetViewUp()) - newposition = [i for i in self.GetActiveCamera().GetFocalPoint()] - newposition[SLICE_ORIENTATION_XZ] = numpy.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_YZ] ** 2) - camera.SetPosition(newposition) - camera.SetViewUp(0,0,-1) - self.SetActiveCamera(camera) - self.Render() - interactor.SetKeyCode("Y") - self.OnKeyPress(interactor, event) - elif interactor.GetKeyCode() == "z": - # Change the camera view point - camera = vtk.vtkCamera() - camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) - camera.SetViewUp(self.GetActiveCamera().GetViewUp()) - newposition = [i for i in self.GetActiveCamera().GetFocalPoint()] - newposition[SLICE_ORIENTATION_XY] = numpy.sqrt(newposition[SLICE_ORIENTATION_YZ] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) - camera.SetPosition(newposition) - camera.SetViewUp(0,1,0) - self.SetActiveCamera(camera) - self.ResetCamera() - self.Render() - interactor.SetKeyCode("Z") - self.OnKeyPress(interactor, event) - elif interactor.GetKeyCode() == "a": - # reset color/window - cmax = self._viewer.ia.GetMax()[0] - cmin = self._viewer.ia.GetMin()[0] - - self.SetInitialLevel( (cmax+cmin)/2 ) - self.SetInitialWindow( cmax-cmin ) - - self.GetWindowLevel().SetLevel(self.GetInitialLevel()) - self.GetWindowLevel().SetWindow(self.GetInitialWindow()) - - self.GetWindowLevel().Update() - - self.UpdateSliceActor() - self.AdjustCamera() - self.Render() - - elif interactor.GetKeyCode() == "s": - filename = "current_render" - self.SaveRender(filename) - elif interactor.GetKeyCode() == "q": - print ("Terminating by pressing q %s" % (interactor.GetKeyCode(), )) - interactor.SetKeyCode("e") - self.OnKeyPress(interactor, event) - else : - #print ("Unhandled event %s" % (interactor.GetKeyCode(), ))) - pass - - def OnLeftButtonPressEvent(self, interactor, event): - alt = interactor.GetAltKey() - shift = interactor.GetShiftKey() - ctrl = interactor.GetControlKey() -# print ("alt pressed " + (lambda x : "Yes" if x else "No")(alt)) -# print ("shift pressed " + (lambda x : "Yes" if x else "No")(shift)) -# print ("ctrl pressed " + (lambda x : "Yes" if x else "No")(ctrl)) - - interactor.SetInitialEventPosition(interactor.GetEventPosition()) - - if ctrl and not (alt and shift): - self.SetViewerEvent( ViewerEvent.CREATE_ROI_EVENT ) - wsize = self.GetRenderWindow().GetSize() - position = interactor.GetEventPosition() - self.GetROIWidget().GetBorderRepresentation().SetPosition((position[0]/wsize[0] - 0.05) , (position[1]/wsize[1] - 0.05)) - self.GetROIWidget().GetBorderRepresentation().SetPosition2( (0.1) , (0.1)) - - self.GetROIWidget().On() - self.SetDisplayHistogram(True) - self.Render() - print ("Event %s is CREATE_ROI_EVENT" % (event)) - elif alt and not (shift and ctrl): - self.SetViewerEvent( ViewerEvent.DELETE_ROI_EVENT ) - self.GetROIWidget().Off() - self._viewer.updateCornerAnnotation("", 1, False) - self.SetDisplayHistogram(False) - self.Render() - print ("Event %s is DELETE_ROI_EVENT" % (event)) - elif not (ctrl and alt and shift): - self.SetViewerEvent ( ViewerEvent.PICK_EVENT ) - self.HandlePickEvent(interactor, event) - print ("Event %s is PICK_EVENT" % (event)) - - - def SetDisplayHistogram(self, display): - if display: - if (self._viewer.displayHistogram == 0): - self.GetRenderer().AddActor(self._viewer.histogramPlotActor) - self.firstHistogram = 1 - self.Render() - - self._viewer.histogramPlotActor.VisibilityOn() - self._viewer.displayHistogram = True - else: - self._viewer.histogramPlotActor.VisibilityOff() - self._viewer.displayHistogram = False - - - def OnLeftButtonReleaseEvent(self, interactor, event): - if self.GetViewerEvent() == ViewerEvent.CREATE_ROI_EVENT: - #bc = self.ROIWidget.GetBorderRepresentation().GetPositionCoordinate() - #print (bc.GetValue()) - self.OnROIModifiedEvent(interactor, event) - - elif self.GetViewerEvent() == ViewerEvent.PICK_EVENT: - self.HandlePickEvent(interactor, event) - - self.SetViewerEvent( ViewerEvent.NO_EVENT ) - - def OnRightButtonPressEvent(self, interactor, event): - alt = interactor.GetAltKey() - shift = interactor.GetShiftKey() - ctrl = interactor.GetControlKey() -# print ("alt pressed " + (lambda x : "Yes" if x else "No")(alt)) -# print ("shift pressed " + (lambda x : "Yes" if x else "No")(shift)) -# print ("ctrl pressed " + (lambda x : "Yes" if x else "No")(ctrl)) - - interactor.SetInitialEventPosition(interactor.GetEventPosition()) - - - if alt and not (ctrl and shift): - self.SetViewerEvent( ViewerEvent.WINDOW_LEVEL_EVENT ) - print ("Event %s is WINDOW_LEVEL_EVENT" % (event)) - self.HandleWindowLevel(interactor, event) - elif shift and not (ctrl and alt): - self.SetViewerEvent( ViewerEvent.ZOOM_EVENT ) - self.SetInitialCameraPosition( self.GetActiveCamera().GetPosition()) - print ("Event %s is ZOOM_EVENT" % (event)) - elif ctrl and not (shift and alt): - self.SetViewerEvent (ViewerEvent.PAN_EVENT ) - self.SetInitialCameraPosition ( self.GetActiveCamera().GetPosition() ) - print ("Event %s is PAN_EVENT" % (event)) - - def OnRightButtonReleaseEvent(self, interactor, event): - print (event) - if self.GetViewerEvent() == ViewerEvent.WINDOW_LEVEL_EVENT: - self.SetInitialLevel( self.GetWindowLevel().GetLevel() ) - self.SetInitialWindow ( self.GetWindowLevel().GetWindow() ) - elif self.GetViewerEvent() == ViewerEvent.ZOOM_EVENT or \ - self.GetViewerEvent() == ViewerEvent.PAN_EVENT: - self.SetInitialCameraPosition( () ) - - self.SetViewerEvent( ViewerEvent.NO_EVENT ) - - - def OnROIModifiedEvent(self, interactor, event): - - #print ("ROI EVENT " + event) - p1 = self.GetROIWidget().GetBorderRepresentation().GetPositionCoordinate() - p2 = self.GetROIWidget().GetBorderRepresentation().GetPosition2Coordinate() - wsize = self.GetRenderWindow().GetSize() - - #print (p1.GetValue()) - #print (p2.GetValue()) - pp1 = [p1.GetValue()[0] * wsize[0] , p1.GetValue()[1] * wsize[1] , 0.0] - pp2 = [p2.GetValue()[0] * wsize[0] + pp1[0] , p2.GetValue()[1] * wsize[1] + pp1[1] , 0.0] - vox1 = self.viewport2imageCoordinate(pp1) - vox2 = self.viewport2imageCoordinate(pp2) - - self.SetROI( (vox1 , vox2) ) - roi = self.GetROI() - print ("Pixel1 %d,%d,%d Value %f" % vox1 ) - print ("Pixel2 %d,%d,%d Value %f" % vox2 ) - if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: - print ("slice orientation : XY") - x = abs(roi[1][0] - roi[0][0]) - y = abs(roi[1][1] - roi[0][1]) - elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: - print ("slice orientation : XY") - x = abs(roi[1][0] - roi[0][0]) - y = abs(roi[1][2] - roi[0][2]) - elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: - print ("slice orientation : XY") - x = abs(roi[1][1] - roi[0][1]) - y = abs(roi[1][2] - roi[0][2]) - - text = "ROI: %d x %d, %.2f kp" % (x,y,float(x*y)/1024.) - print (text) - self.UpdateCornerAnnotation(text, 1) - self.UpdateROIHistogram() - self.SetViewerEvent( ViewerEvent.NO_EVENT ) - - def viewport2imageCoordinate(self, viewerposition): - #Determine point index - - self.GetPicker().Pick(viewerposition[0], viewerposition[1], 0.0, self.GetRenderer()) - pickPosition = list(self.GetPicker().GetPickPosition()) - pickPosition[self.GetSliceOrientation()] = \ - self.GetInputData().GetSpacing()[self.GetSliceOrientation()] * self.GetActiveSlice() + \ - self.GetInputData().GetOrigin()[self.GetSliceOrientation()] - print ("Pick Position " + str (pickPosition)) - - if (pickPosition != [0,0,0]): - dims = self.GetInputData().GetDimensions() - print (dims) - spac = self.GetInputData().GetSpacing() - orig = self.GetInputData().GetOrigin() - imagePosition = [int(pickPosition[i] / spac[i] + orig[i]) for i in range(3) ] - - pixelValue = self.GetInputData().GetScalarComponentAsDouble(imagePosition[0], imagePosition[1], imagePosition[2], 0) - return (imagePosition[0], imagePosition[1], imagePosition[2] , pixelValue) - else: - return (0,0,0,0) - - - - - def OnMouseMoveEvent(self, interactor, event): - if self.GetViewerEvent() == ViewerEvent.WINDOW_LEVEL_EVENT: - print ("Event %s is WINDOW_LEVEL_EVENT" % (event)) - self.HandleWindowLevel(interactor, event) - elif self.GetViewerEvent() == ViewerEvent.PICK_EVENT: - self.HandlePickEvent(interactor, event) - elif self.GetViewerEvent() == ViewerEvent.ZOOM_EVENT: - self.HandleZoomEvent(interactor, event) - elif self.GetViewerEvent() == ViewerEvent.PAN_EVENT: - self.HandlePanEvent(interactor, event) - - - def HandleZoomEvent(self, interactor, event): - dx,dy = interactor.GetDeltaEventPosition() - size = self.GetRenderWindow().GetSize() - dy = - 4 * dy / size[1] - - print ("distance: " + str(self.GetActiveCamera().GetDistance())) - - print ("\ndy: %f\ncamera dolly %f\n" % (dy, 1 + dy)) - - camera = vtk.vtkCamera() - camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) - #print ("current position " + str(self.InitialCameraPosition)) - camera.SetViewUp(self.GetActiveCamera().GetViewUp()) - camera.SetPosition(self.GetInitialCameraPosition()) - newposition = [i for i in self.GetInitialCameraPosition()] - if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: - dist = newposition[SLICE_ORIENTATION_XY] * ( 1 + dy ) - newposition[SLICE_ORIENTATION_XY] *= ( 1 + dy ) - elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: - newposition[SLICE_ORIENTATION_XZ] *= ( 1 + dy ) - elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: - newposition[SLICE_ORIENTATION_YZ] *= ( 1 + dy ) - #print ("new position " + str(newposition)) - camera.SetPosition(newposition) - self.SetActiveCamera(camera) - - self.Render() - - print ("distance after: " + str(self.GetActiveCamera().GetDistance())) - - def HandlePanEvent(self, interactor, event): - x,y = interactor.GetEventPosition() - x0,y0 = interactor.GetInitialEventPosition() - - ic = self.viewport2imageCoordinate((x,y)) - ic0 = self.viewport2imageCoordinate((x0,y0)) - - dx = 4 *( ic[0] - ic0[0]) - dy = 4* (ic[1] - ic0[1]) - - camera = vtk.vtkCamera() - #print ("current position " + str(self.InitialCameraPosition)) - camera.SetViewUp(self.GetActiveCamera().GetViewUp()) - camera.SetPosition(self.GetInitialCameraPosition()) - newposition = [i for i in self.GetInitialCameraPosition()] - newfocalpoint = [i for i in self.GetActiveCamera().GetFocalPoint()] - if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: - newposition[0] -= dx - newposition[1] -= dy - newfocalpoint[0] = newposition[0] - newfocalpoint[1] = newposition[1] - elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: - newposition[0] -= dx - newposition[2] -= dy - newfocalpoint[0] = newposition[0] - newfocalpoint[2] = newposition[2] - elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: - newposition[1] -= dx - newposition[2] -= dy - newfocalpoint[2] = newposition[2] - newfocalpoint[1] = newposition[1] - #print ("new position " + str(newposition)) - camera.SetFocalPoint(newfocalpoint) - camera.SetPosition(newposition) - self.SetActiveCamera(camera) - - self.Render() - - def HandleWindowLevel(self, interactor, event): - dx,dy = interactor.GetDeltaEventPosition() - print ("Event delta %d %d" % (dx,dy)) - size = self.GetRenderWindow().GetSize() - - dx = 4 * dx / size[0] - dy = 4 * dy / size[1] - window = self.GetInitialWindow() - level = self.GetInitialLevel() - - if abs(window) > 0.01: - dx = dx * window - else: - dx = dx * (lambda x: -0.01 if x <0 else 0.01)(window); - - if abs(level) > 0.01: - dy = dy * level - else: - dy = dy * (lambda x: -0.01 if x <0 else 0.01)(level) - - - # Abs so that direction does not flip - - if window < 0.0: - dx = -1*dx - if level < 0.0: - dy = -1*dy - - # Compute new window level - - newWindow = dx + window - newLevel = level - dy - - # Stay away from zero and really - - if abs(newWindow) < 0.01: - newWindow = 0.01 * (lambda x: -1 if x <0 else 1)(newWindow) - - if abs(newLevel) < 0.01: - newLevel = 0.01 * (lambda x: -1 if x <0 else 1)(newLevel) - - self.GetWindowLevel().SetWindow(newWindow) - self.GetWindowLevel().SetLevel(newLevel) - - self.GetWindowLevel().Update() - self.UpdateSliceActor() - self.AdjustCamera() - - self.Render() - - def HandlePickEvent(self, interactor, event): - position = interactor.GetEventPosition() - #print ("PICK " + str(position)) - vox = self.viewport2imageCoordinate(position) - #print ("Pixel %d,%d,%d Value %f" % vox ) - self._viewer.cornerAnnotation.VisibilityOn() - self.UpdateCornerAnnotation("[%d,%d,%d] : %.2f" % vox , 0) - self.Render() - -############################################################################### - - - -class CILViewer2D(): - '''Simple Interactive Viewer based on VTK classes''' - - def __init__(self, dimx=600,dimy=600, ren=None, renWin=None,iren=None): - '''creates the rendering pipeline''' - # create a rendering window and renderer - if ren == None: - self.ren = vtk.vtkRenderer() - else: - self.ren = ren - if renWin == None: - self.renWin = vtk.vtkRenderWindow() - else: - self.renWin = renWin - if iren == None: - self.iren = vtk.vtkRenderWindowInteractor() - else: - self.iren = iren - - self.renWin.SetSize(dimx,dimy) - self.renWin.AddRenderer(self.ren) - - self.style = CILInteractorStyle(self) - - self.iren.SetInteractorStyle(self.style) - self.iren.SetRenderWindow(self.renWin) - self.iren.Initialize() - self.ren.SetBackground(.1, .2, .4) - - self.camera = vtk.vtkCamera() - self.camera.ParallelProjectionOn() - self.ren.SetActiveCamera(self.camera) - - # data - self.img3D = None - self.sliceno = 0 - self.sliceOrientation = SLICE_ORIENTATION_XY - - #Actors - self.sliceActor = vtk.vtkImageActor() - self.voi = vtk.vtkExtractVOI() - self.wl = vtk.vtkImageMapToWindowLevelColors() - self.ia = vtk.vtkImageAccumulate() - self.sliceActorNo = 0 - - #initial Window/Level - self.InitialLevel = 0 - self.InitialWindow = 0 - - #ViewerEvent - self.event = ViewerEvent.NO_EVENT - - # ROI Widget - self.ROIWidget = vtk.vtkBorderWidget() - self.ROIWidget.SetInteractor(self.iren) - self.ROIWidget.CreateDefaultRepresentation() - self.ROIWidget.GetBorderRepresentation().GetBorderProperty().SetColor(0,1,0) - self.ROIWidget.AddObserver(vtk.vtkWidgetEvent.Select, self.style.OnROIModifiedEvent, 1.0) - - # edge points of the ROI - self.ROI = () - - #picker - self.picker = vtk.vtkPropPicker() - self.picker.PickFromListOn() - self.picker.AddPickList(self.sliceActor) - - self.iren.SetPicker(self.picker) - - # corner annotation - self.cornerAnnotation = vtk.vtkCornerAnnotation() - self.cornerAnnotation.SetMaximumFontSize(12); - self.cornerAnnotation.PickableOff(); - self.cornerAnnotation.VisibilityOff(); - self.cornerAnnotation.GetTextProperty().ShadowOn(); - self.cornerAnnotation.SetLayerNumber(1); - - - - # cursor doesn't show up - self.cursor = vtk.vtkCursor2D() - self.cursorMapper = vtk.vtkPolyDataMapper2D() - self.cursorActor = vtk.vtkActor2D() - self.cursor.SetModelBounds(-10, 10, -10, 10, 0, 0) - self.cursor.SetFocalPoint(0, 0, 0) - self.cursor.AllOff() - self.cursor.AxesOn() - self.cursorActor.PickableOff() - self.cursorActor.VisibilityOn() - self.cursorActor.GetProperty().SetColor(1, 1, 1) - self.cursorActor.SetLayerNumber(1) - self.cursorMapper.SetInputData(self.cursor.GetOutput()) - self.cursorActor.SetMapper(self.cursorMapper) - - # Zoom - self.InitialCameraPosition = () - - # XY Plot actor for histogram - self.displayHistogram = False - self.firstHistogram = 0 - self.roiIA = vtk.vtkImageAccumulate() - self.roiVOI = vtk.vtkExtractVOI() - self.histogramPlotActor = vtk.vtkXYPlotActor() - self.histogramPlotActor.ExchangeAxesOff(); - self.histogramPlotActor.SetXLabelFormat( "%g" ) - self.histogramPlotActor.SetXLabelFormat( "%g" ) - self.histogramPlotActor.SetAdjustXLabels(3) - self.histogramPlotActor.SetXTitle( "Level" ) - self.histogramPlotActor.SetYTitle( "N" ) - self.histogramPlotActor.SetXValuesToValue() - self.histogramPlotActor.SetPlotColor(0, (0,1,1) ) - self.histogramPlotActor.SetPosition(0.6,0.6) - self.histogramPlotActor.SetPosition2(0.4,0.4) - - - - def GetInteractor(self): - return self.iren - - def GetRenderer(self): - return self.ren - - def setInput3DData(self, imageData): - self.img3D = imageData - self.installPipeline() - - def setInputAsNumpy(self, numpyarray, origin=(0,0,0), spacing=(1.,1.,1.), - rescale=True, dtype=vtk.VTK_UNSIGNED_SHORT): - importer = Converter.numpy2vtkImporter(numpyarray, spacing, origin) - importer.Update() - - if rescale: - # rescale to appropriate VTK_UNSIGNED_SHORT - stats = vtk.vtkImageAccumulate() - stats.SetInputData(importer.GetOutput()) - stats.Update() - iMin = stats.GetMin()[0] - iMax = stats.GetMax()[0] - if (iMax - iMin == 0): - scale = 1 - else: - if dtype == vtk.VTK_UNSIGNED_SHORT: - scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) - elif dtype == vtk.VTK_UNSIGNED_INT: - scale = vtk.VTK_UNSIGNED_INT_MAX / (iMax - iMin) - - shiftScaler = vtk.vtkImageShiftScale () - shiftScaler.SetInputData(importer.GetOutput()) - shiftScaler.SetScale(scale) - shiftScaler.SetShift(-iMin) - shiftScaler.SetOutputScalarType(dtype) - shiftScaler.Update() - self.img3D = shiftScaler.GetOutput() - else: - self.img3D = importer.GetOutput() - - self.installPipeline() - - def displaySlice(self, sliceno = 0): - self.sliceno = sliceno - - self.updatePipeline() - - self.renWin.Render() - - return self.sliceActorNo - - def updatePipeline(self, resetcamera = False): - extent = [ i for i in self.img3D.GetExtent()] - extent[self.sliceOrientation * 2] = self.sliceno - extent[self.sliceOrientation * 2 + 1] = self.sliceno - self.voi.SetVOI(extent[0], extent[1], - extent[2], extent[3], - extent[4], extent[5]) - - self.voi.Update() - self.ia.Update() - self.wl.Update() - self.sliceActor.SetDisplayExtent(extent[0], extent[1], - extent[2], extent[3], - extent[4], extent[5]) - self.sliceActor.Update() - - self.updateCornerAnnotation("Slice %d/%d" % (self.sliceno + 1 , self.img3D.GetDimensions()[self.sliceOrientation])) - - if self.displayHistogram: - self.updateROIHistogram() - - self.AdjustCamera(resetcamera) - - self.renWin.Render() - - - def installPipeline(self): - '''Slices a 3D volume and then creates an actor to be rendered''' - - self.ren.AddViewProp(self.cornerAnnotation) - - self.voi.SetInputData(self.img3D) - #select one slice in Z - extent = [ i for i in self.img3D.GetExtent()] - extent[self.sliceOrientation * 2] = self.sliceno - extent[self.sliceOrientation * 2 + 1] = self.sliceno - self.voi.SetVOI(extent[0], extent[1], - extent[2], extent[3], - extent[4], extent[5]) - - self.voi.Update() - # set window/level for current slices - - - self.wl = vtk.vtkImageMapToWindowLevelColors() - self.ia.SetInputData(self.voi.GetOutput()) - self.ia.Update() - cmax = self.ia.GetMax()[0] - cmin = self.ia.GetMin()[0] - - self.InitialLevel = (cmax+cmin)/2 - self.InitialWindow = cmax-cmin - - - self.wl.SetLevel(self.InitialLevel) - self.wl.SetWindow(self.InitialWindow) - - self.wl.SetInputData(self.voi.GetOutput()) - self.wl.Update() - - self.sliceActor.SetInputData(self.wl.GetOutput()) - self.sliceActor.SetDisplayExtent(extent[0], extent[1], - extent[2], extent[3], - extent[4], extent[5]) - self.sliceActor.Update() - self.sliceActor.SetInterpolate(False) - self.ren.AddActor(self.sliceActor) - self.ren.ResetCamera() - self.ren.Render() - - self.AdjustCamera() - - self.ren.AddViewProp(self.cursorActor) - self.cursorActor.VisibilityOn() - - self.iren.Initialize() - self.renWin.Render() - #self.iren.Start() - - def AdjustCamera(self, resetcamera = False): - self.ren.ResetCameraClippingRange() - if resetcamera: - self.ren.ResetCamera() - - - def getROI(self): - return self.ROI - - def getROIExtent(self): - p0 = self.ROI[0] - p1 = self.ROI[1] - return (p0[0], p1[0],p0[1],p1[1],p0[2],p1[2]) - - ############### Handle events are moved to the interactor style - - - def viewport2imageCoordinate(self, viewerposition): - #Determine point index - - self.picker.Pick(viewerposition[0], viewerposition[1], 0.0, self.GetRenderer()) - pickPosition = list(self.picker.GetPickPosition()) - pickPosition[self.sliceOrientation] = \ - self.img3D.GetSpacing()[self.sliceOrientation] * self.sliceno + \ - self.img3D.GetOrigin()[self.sliceOrientation] - print ("Pick Position " + str (pickPosition)) - - if (pickPosition != [0,0,0]): - dims = self.img3D.GetDimensions() - print (dims) - spac = self.img3D.GetSpacing() - orig = self.img3D.GetOrigin() - imagePosition = [int(pickPosition[i] / spac[i] + orig[i]) for i in range(3) ] - - pixelValue = self.img3D.GetScalarComponentAsDouble(imagePosition[0], imagePosition[1], imagePosition[2], 0) - return (imagePosition[0], imagePosition[1], imagePosition[2] , pixelValue) - else: - return (0,0,0,0) - - - - def GetRenderWindow(self): - return self.renWin - - - def startRenderLoop(self): - self.iren.Start() - - def GetSliceOrientation(self): - return self.sliceOrientation - - def GetActiveSlice(self): - return self.sliceno - - def updateCornerAnnotation(self, text , idx=0, visibility=True): - if visibility: - self.cornerAnnotation.VisibilityOn() - else: - self.cornerAnnotation.VisibilityOff() - - self.cornerAnnotation.SetText(idx, text) - self.iren.Render() - - def saveRender(self, filename, renWin=None): - '''Save the render window to PNG file''' - # screenshot code: - w2if = vtk.vtkWindowToImageFilter() - if renWin == None: - renWin = self.renWin - w2if.SetInput(renWin) - w2if.Update() - - writer = vtk.vtkPNGWriter() - writer.SetFileName("%s.png" % (filename)) - writer.SetInputConnection(w2if.GetOutputPort()) - writer.Write() - - def updateROIHistogram(self): - - extent = [0 for i in range(6)] - if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: - print ("slice orientation : XY") - extent[0] = self.ROI[0][0] - extent[1] = self.ROI[1][0] - extent[2] = self.ROI[0][1] - extent[3] = self.ROI[1][1] - extent[4] = self.GetActiveSlice() - extent[5] = self.GetActiveSlice()+1 - #y = abs(roi[1][1] - roi[0][1]) - elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: - print ("slice orientation : XY") - extent[0] = self.ROI[0][0] - extent[1] = self.ROI[1][0] - #x = abs(roi[1][0] - roi[0][0]) - extent[4] = self.ROI[0][2] - extent[5] = self.ROI[1][2] - #y = abs(roi[1][2] - roi[0][2]) - extent[2] = self.GetActiveSlice() - extent[3] = self.GetActiveSlice()+1 - elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: - print ("slice orientation : XY") - extent[2] = self.ROI[0][1] - extent[3] = self.ROI[1][1] - #x = abs(roi[1][1] - roi[0][1]) - extent[4] = self.ROI[0][2] - extent[5] = self.ROI[1][2] - #y = abs(roi[1][2] - roi[0][2]) - extent[0] = self.GetActiveSlice() - extent[1] = self.GetActiveSlice()+1 - - self.roiVOI.SetVOI(extent) - self.roiVOI.SetInputData(self.img3D) - self.roiVOI.Update() - irange = self.roiVOI.GetOutput().GetScalarRange() - - self.roiIA.SetInputData(self.roiVOI.GetOutput()) - self.roiIA.IgnoreZeroOff() - self.roiIA.SetComponentExtent(0,int(irange[1]-irange[0]-1),0,0,0,0 ) - self.roiIA.SetComponentOrigin( int(irange[0]),0,0 ); - self.roiIA.SetComponentSpacing( 1,0,0 ); - self.roiIA.Update() - - self.histogramPlotActor.AddDataSetInputConnection(self.roiIA.GetOutputPort()) - self.histogramPlotActor.SetXRange(irange[0],irange[1]) - - self.histogramPlotActor.SetYRange( self.roiIA.GetOutput().GetScalarRange() ) - - \ No newline at end of file diff --git a/src/Python/ccpi/viewer/QVTKWidget.py b/src/Python/ccpi/viewer/QVTKWidget.py deleted file mode 100644 index 906786b..0000000 --- a/src/Python/ccpi/viewer/QVTKWidget.py +++ /dev/null @@ -1,340 +0,0 @@ -################################################################################ -# File: QVTKWidget.py -# Author: Edoardo Pasca -# Description: PyVE Viewer Qt widget -# -# License: -# This file is part of PyVE. PyVE is an open-source image -# analysis and visualization environment focused on medical -# imaging. More info at http://pyve.sourceforge.net -# -# Copyright (c) 2011-2012 Edoardo Pasca, Lukas Batteau -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or -# without modification, are permitted provided that the following -# conditions are met: -# -# Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# Redistributions in binary form must reproduce the above -# copyright notice, this list of conditions and the following -# disclaimer in the documentation and/or other materials provided -# with the distribution. Neither name of Edoardo Pasca or Lukas -# Batteau nor the names of any contributors may be used to endorse -# or promote products derived from this software without specific -# prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND -# CONTRIBUTORS ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, -# INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF -# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE -# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, -# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, -# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR -# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT -# OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY -# OF SUCH DAMAGE. -# -# CHANGE HISTORY -# -# 20120118 Edoardo Pasca Initial version -# -############################################################################### - -import os -from PyQt5 import QtCore, QtGui, QtWidgets -#import itk -import vtk -#from viewer import PyveViewer -from ccpi.viewer.CILViewer2D import CILViewer2D , Converter - -class QVTKWidget(QtWidgets.QWidget): - - """ A QVTKWidget for Python and Qt.""" - - # Map between VTK and Qt cursors. - _CURSOR_MAP = { - 0: QtCore.Qt.ArrowCursor, # VTK_CURSOR_DEFAULT - 1: QtCore.Qt.ArrowCursor, # VTK_CURSOR_ARROW - 2: QtCore.Qt.SizeBDiagCursor, # VTK_CURSOR_SIZENE - 3: QtCore.Qt.SizeFDiagCursor, # VTK_CURSOR_SIZENWSE - 4: QtCore.Qt.SizeBDiagCursor, # VTK_CURSOR_SIZESW - 5: QtCore.Qt.SizeFDiagCursor, # VTK_CURSOR_SIZESE - 6: QtCore.Qt.SizeVerCursor, # VTK_CURSOR_SIZENS - 7: QtCore.Qt.SizeHorCursor, # VTK_CURSOR_SIZEWE - 8: QtCore.Qt.SizeAllCursor, # VTK_CURSOR_SIZEALL - 9: QtCore.Qt.PointingHandCursor, # VTK_CURSOR_HAND - 10: QtCore.Qt.CrossCursor, # VTK_CURSOR_CROSSHAIR - } - - def __init__(self, parent=None, wflags=QtCore.Qt.WindowFlags(), **kw): - # the current button - self._ActiveButton = QtCore.Qt.NoButton - - # private attributes - self.__oldFocus = None - self.__saveX = 0 - self.__saveY = 0 - self.__saveModifiers = QtCore.Qt.NoModifier - self.__saveButtons = QtCore.Qt.NoButton - self.__timeframe = 0 - - # create qt-level widget - QtWidgets.QWidget.__init__(self, parent, wflags|QtCore.Qt.MSWindowsOwnDC) - - # Link to PyVE Viewer - self._PyveViewer = CILViewer2D() - #self._Viewer = self._PyveViewer._vtkPyveViewer - - self._Iren = self._PyveViewer.GetInteractor() - #self._Iren = self._Viewer.GetRenderWindow().GetInteractor() - self._RenderWindow = self._PyveViewer.GetRenderWindow() - #self._RenderWindow = self._Viewer.GetRenderWindow() - - self._Iren.Register(self._RenderWindow) - self._Iren.SetRenderWindow(self._RenderWindow) - self._RenderWindow.SetWindowInfo(str(int(self.winId()))) - - # do all the necessary qt setup - self.setAttribute(QtCore.Qt.WA_OpaquePaintEvent) - self.setAttribute(QtCore.Qt.WA_PaintOnScreen) - self.setMouseTracking(True) # get all mouse events - self.setFocusPolicy(QtCore.Qt.WheelFocus) - self.setSizePolicy(QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)) - - self._Timer = QtCore.QTimer(self) - #self.connect(self._Timer, QtCore.pyqtSignal('timeout()'), self.TimerEvent) - - self._Iren.AddObserver('CreateTimerEvent', self.CreateTimer) - self._Iren.AddObserver('DestroyTimerEvent', self.DestroyTimer) - self._Iren.GetRenderWindow().AddObserver('CursorChangedEvent', - self.CursorChangedEvent) - - # Destructor - def __del__(self): - self._Iren.UnRegister(self._RenderWindow) - #QtWidgets.QWidget.__del__(self) - - # Display image data - def SetInput(self, imageData): - self._PyveViewer.setInput3DData(imageData) - - # GetInteractor - def GetInteractor(self): - return self._Iren - - # Display image data - def GetPyveViewer(self): - return self._PyveViewer - - def __getattr__(self, attr): - """Makes the object behave like a vtkGenericRenderWindowInteractor""" - print (attr) - if attr == '__vtk__': - return lambda t=self._Iren: t - elif hasattr(self._Iren, attr): - return getattr(self._Iren, attr) -# else: -# raise AttributeError( self.__class__.__name__ + \ -# " has no attribute named " + attr ) - - def CreateTimer(self, obj, evt): - self._Timer.start(10) - - def DestroyTimer(self, obj, evt): - self._Timer.stop() - return 1 - - def TimerEvent(self): - self._Iren.InvokeEvent("TimerEvent") - - def CursorChangedEvent(self, obj, evt): - """Called when the CursorChangedEvent fires on the render window.""" - # This indirection is needed since when the event fires, the current - # cursor is not yet set so we defer this by which time the current - # cursor should have been set. - QtCore.QTimer.singleShot(0, self.ShowCursor) - - def HideCursor(self): - """Hides the cursor.""" - self.setCursor(QtCore.Qt.BlankCursor) - - def ShowCursor(self): - """Shows the cursor.""" - vtk_cursor = self._Iren.GetRenderWindow().GetCurrentCursor() - qt_cursor = self._CURSOR_MAP.get(vtk_cursor, QtCore.Qt.ArrowCursor) - self.setCursor(qt_cursor) - - def sizeHint(self): - return QtCore.QSize(400, 400) - - def paintEngine(self): - return None - - def paintEvent(self, ev): - self._RenderWindow.Render() - - def resizeEvent(self, ev): - self._RenderWindow.Render() - w = self.width() - h = self.height() - - self._RenderWindow.SetSize(w, h) - self._Iren.SetSize(w, h) - - def _GetCtrlShiftAlt(self, ev): - ctrl = shift = alt = False - - if hasattr(ev, 'modifiers'): - if ev.modifiers() & QtCore.Qt.ShiftModifier: - shift = True - if ev.modifiers() & QtCore.Qt.ControlModifier: - ctrl = True - if ev.modifiers() & QtCore.Qt.AltModifier: - alt = True - else: - if self.__saveModifiers & QtCore.Qt.ShiftModifier: - shift = True - if self.__saveModifiers & QtCore.Qt.ControlModifier: - ctrl = True - if self.__saveModifiers & QtCore.Qt.AltModifier: - alt = True - - return ctrl, shift, alt - - def enterEvent(self, ev): - if not self.hasFocus(): - self.__oldFocus = self.focusWidget() - self.setFocus() - - ctrl, shift, alt = self._GetCtrlShiftAlt(ev) - self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, - ctrl, shift, chr(0), 0, None) - self._Iren.SetAltKey(alt) - self._Iren.InvokeEvent("EnterEvent") - - def leaveEvent(self, ev): - if self.__saveButtons == QtCore.Qt.NoButton and self.__oldFocus: - self.__oldFocus.setFocus() - self.__oldFocus = None - - ctrl, shift, alt = self._GetCtrlShiftAlt(ev) - self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, - ctrl, shift, chr(0), 0, None) - self._Iren.SetAltKey(alt) - self._Iren.InvokeEvent("LeaveEvent") - - def mousePressEvent(self, ev): - ctrl, shift, alt = self._GetCtrlShiftAlt(ev) - repeat = 0 - if ev.type() == QtCore.QEvent.MouseButtonDblClick: - repeat = 1 - self._Iren.SetEventInformationFlipY(ev.x(), ev.y(), - ctrl, shift, chr(0), repeat, None) - - self._Iren.SetAltKey(alt) - self._ActiveButton = ev.button() - - if self._ActiveButton == QtCore.Qt.LeftButton: - self._Iren.InvokeEvent("LeftButtonPressEvent") - elif self._ActiveButton == QtCore.Qt.RightButton: - self._Iren.InvokeEvent("RightButtonPressEvent") - elif self._ActiveButton == QtCore.Qt.MidButton: - self._Iren.InvokeEvent("MiddleButtonPressEvent") - - def mouseReleaseEvent(self, ev): - ctrl, shift, alt = self._GetCtrlShiftAlt(ev) - self._Iren.SetEventInformationFlipY(ev.x(), ev.y(), - ctrl, shift, chr(0), 0, None) - self._Iren.SetAltKey(alt) - - if self._ActiveButton == QtCore.Qt.LeftButton: - self._Iren.InvokeEvent("LeftButtonReleaseEvent") - elif self._ActiveButton == QtCore.Qt.RightButton: - self._Iren.InvokeEvent("RightButtonReleaseEvent") - elif self._ActiveButton == QtCore.Qt.MidButton: - self._Iren.InvokeEvent("MiddleButtonReleaseEvent") - - def mouseMoveEvent(self, ev): - self.__saveModifiers = ev.modifiers() - self.__saveButtons = ev.buttons() - self.__saveX = ev.x() - self.__saveY = ev.y() - - ctrl, shift, alt = self._GetCtrlShiftAlt(ev) - self._Iren.SetEventInformationFlipY(ev.x(), ev.y(), - ctrl, shift, chr(0), 0, None) - self._Iren.SetAltKey(alt) - self._Iren.InvokeEvent("MouseMoveEvent") - - def keyPressEvent(self, ev): - ctrl, shift, alt = self._GetCtrlShiftAlt(ev) - if ev.key() < 256: - key = str(ev.text()) - else: - key = chr(0) - - self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, - ctrl, shift, key, 0, None) - self._Iren.SetAltKey(alt) - self._Iren.InvokeEvent("KeyPressEvent") - self._Iren.InvokeEvent("CharEvent") - - def keyReleaseEvent(self, ev): - ctrl, shift, alt = self._GetCtrlShiftAlt(ev) - if ev.key() < 256: - key = chr(ev.key()) - else: - key = chr(0) - - self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, - ctrl, shift, key, 0, None) - self._Iren.SetAltKey(alt) - self._Iren.InvokeEvent("KeyReleaseEvent") - - def wheelEvent(self, ev): - print ("angleDeltaX %d" % ev.angleDelta().x()) - print ("angleDeltaY %d" % ev.angleDelta().y()) - if ev.angleDelta().y() >= 0: - self._Iren.InvokeEvent("MouseWheelForwardEvent") - else: - self._Iren.InvokeEvent("MouseWheelBackwardEvent") - - def GetRenderWindow(self): - return self._RenderWindow - - def Render(self): - self.update() - - -def QVTKExample(): - """A simple example that uses the QVTKWidget class.""" - - # every QT app needs an app - app = QtWidgets.QApplication(['PyVE QVTKWidget Example']) - page_VTK = QtWidgets.QWidget() - page_VTK.resize(500,500) - layout = QtWidgets.QVBoxLayout(page_VTK) - # create the widget - widget = QVTKWidget(parent=None) - layout.addWidget(widget) - - #reader = vtk.vtkPNGReader() - #reader.SetFileName("F:\Diagnostics\Images\PyVE\VTKData\Data\camscene.png") - reader = vtk.vtkMetaImageReader() - reader.SetFileName("C:\\Users\\ofn77899\\Documents\\GitHub\\CCPi-Simpleflex\\data\\head.mha") - reader.Update() - - widget.SetInput(reader.GetOutput()) - - # show the widget - page_VTK.show() - # start event processing - app.exec_() - -if __name__ == "__main__": - QVTKExample() diff --git a/src/Python/ccpi/viewer/QVTKWidget2.py b/src/Python/ccpi/viewer/QVTKWidget2.py deleted file mode 100644 index e32e1c2..0000000 --- a/src/Python/ccpi/viewer/QVTKWidget2.py +++ /dev/null @@ -1,84 +0,0 @@ -################################################################################ -# File: QVTKWidget.py -# Author: Edoardo Pasca -# Description: PyVE Viewer Qt widget -# -# License: -# This file is part of PyVE. PyVE is an open-source image -# analysis and visualization environment focused on medical -# imaging. More info at http://pyve.sourceforge.net -# -# Copyright (c) 2011-2012 Edoardo Pasca, Lukas Batteau -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or -# without modification, are permitted provided that the following -# conditions are met: -# -# Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# Redistributions in binary form must reproduce the above -# copyright notice, this list of conditions and the following -# disclaimer in the documentation and/or other materials provided -# with the distribution. Neither name of Edoardo Pasca or Lukas -# Batteau nor the names of any contributors may be used to endorse -# or promote products derived from this software without specific -# prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND -# CONTRIBUTORS ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, -# INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF -# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE -# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, -# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, -# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR -# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT -# OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY -# OF SUCH DAMAGE. -# -# CHANGE HISTORY -# -# 20120118 Edoardo Pasca Initial version -# -############################################################################### - -import os -from PyQt5 import QtCore, QtGui, QtWidgets -#import itk -import vtk -#from viewer import PyveViewer -from ccpi.viewer.CILViewer2D import CILViewer2D , Converter -from vtk.qt.QVTKRenderWindowInteractor import QVTKRenderWindowInteractor - -class QVTKWidget(QVTKRenderWindowInteractor): - - - def __init__(self, parent=None, wflags=QtCore.Qt.WindowFlags(), **kw): - kw = dict() - super().__init__(parent, **kw) - - - # Link to PyVE Viewer - self._PyveViewer = CILViewer2D(400,400) - #self._Viewer = self._PyveViewer._vtkPyveViewer - - self._Iren = self._PyveViewer.GetInteractor() - kw['iren'] = self._Iren - #self._Iren = self._Viewer.GetRenderWindow().GetInteractor() - self._RenderWindow = self._PyveViewer.GetRenderWindow() - #self._RenderWindow = self._Viewer.GetRenderWindow() - kw['rw'] = self._RenderWindow - - - - - def GetInteractor(self): - return self._Iren - - # Display image data - def SetInput(self, imageData): - self._PyveViewer.setInput3DData(imageData) - \ No newline at end of file diff --git a/src/Python/ccpi/viewer/__init__.py b/src/Python/ccpi/viewer/__init__.py deleted file mode 100644 index 946188b..0000000 --- a/src/Python/ccpi/viewer/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from ccpi.viewer.CILViewer import CILViewer \ No newline at end of file diff --git a/src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc deleted file mode 100644 index 711f77a..0000000 Binary files a/src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc and /dev/null differ diff --git a/src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc deleted file mode 100644 index 77c2ca8..0000000 Binary files a/src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc and /dev/null differ diff --git a/src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc deleted file mode 100644 index 3d11b87..0000000 Binary files a/src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc and /dev/null differ diff --git a/src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc deleted file mode 100644 index 2fa2eaf..0000000 Binary files a/src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc and /dev/null differ diff --git a/src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc deleted file mode 100644 index fcea537..0000000 Binary files a/src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc and /dev/null differ diff --git a/src/Python/ccpi/viewer/embedvtk.py b/src/Python/ccpi/viewer/embedvtk.py deleted file mode 100644 index b5eb0a7..0000000 --- a/src/Python/ccpi/viewer/embedvtk.py +++ /dev/null @@ -1,75 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Created on Thu Jul 27 12:18:58 2017 - -@author: ofn77899 -""" - -#!/usr/bin/env python - -import sys -import vtk -from PyQt5 import QtCore, QtWidgets -from vtk.qt.QVTKRenderWindowInteractor import QVTKRenderWindowInteractor -import QVTKWidget2 - -class MainWindow(QtWidgets.QMainWindow): - - def __init__(self, parent = None): - QtWidgets.QMainWindow.__init__(self, parent) - - self.frame = QtWidgets.QFrame() - - self.vl = QtWidgets.QVBoxLayout() -# self.vtkWidget = QVTKRenderWindowInteractor(self.frame) - - self.vtkWidget = QVTKWidget2.QVTKWidget(self.frame) - self.iren = self.vtkWidget.GetInteractor() - self.vl.addWidget(self.vtkWidget) - - - - - self.ren = vtk.vtkRenderer() - self.vtkWidget.GetRenderWindow().AddRenderer(self.ren) -# self.iren = self.vtkWidget.GetRenderWindow().GetInteractor() -# -# # Create source -# source = vtk.vtkSphereSource() -# source.SetCenter(0, 0, 0) -# source.SetRadius(5.0) -# -# # Create a mapper -# mapper = vtk.vtkPolyDataMapper() -# mapper.SetInputConnection(source.GetOutputPort()) -# -# # Create an actor -# actor = vtk.vtkActor() -# actor.SetMapper(mapper) -# -# self.ren.AddActor(actor) -# -# self.ren.ResetCamera() -# - self.frame.setLayout(self.vl) - self.setCentralWidget(self.frame) - reader = vtk.vtkMetaImageReader() - reader.SetFileName("C:\\Users\\ofn77899\\Documents\\GitHub\\CCPi-Simpleflex\\data\\head.mha") - reader.Update() - - self.vtkWidget.SetInput(reader.GetOutput()) - - #self.vktWidget.Initialize() - #self.vktWidget.Start() - - self.show() - #self.iren.Initialize() - - -if __name__ == "__main__": - - app = QtWidgets.QApplication(sys.argv) - - window = MainWindow() - - sys.exit(app.exec_()) \ No newline at end of file -- cgit v1.2.3 From 99e8a3130d6ee161fc8e73faf526d7e0a7a9db44 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 18 Oct 2017 11:46:54 +0100 Subject: Modified Region of interest; removed studentt https://github.com/vais-ral/CCPi-FISTA_Reconstruction/commit/6fb8f5d188ed31d7a7077cba8ab7aea17b25b8bf --- src/Python/ccpi/fista/FISTAReconstructor.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) (limited to 'src/Python/ccpi') diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py index 87dd2c0..33e67a3 100644 --- a/src/Python/ccpi/fista/FISTAReconstructor.py +++ b/src/Python/ccpi/fista/FISTAReconstructor.py @@ -108,9 +108,7 @@ class FISTAReconstructor(): 'regularizer' , 'ring_lambda_R_L1', 'ring_alpha', - 'subsets', - 'use_studentt_fidelity', - 'studentt') + 'subsets') self.acceptedInputKeywords = list(kw) # handle keyworded parameters @@ -141,10 +139,12 @@ class FISTAReconstructor(): if not 'region_of_interest'in kwargs.keys() : if self.pars['ideal_image'] == None: - pass + self.pars['region_of_interest'] = None else: - self.pars['region_of_interest'] = numpy.nonzero( - self.pars['ideal_image']>0.0) + ## nonzero if the image is larger than m + fsm = numpy.frompyfunc(lambda x,m: 1 if x>m else 0, 2,1) + + self.pars['region_of_interest'] = fsm(self.pars['ideal_image'], 0) # the regularizer must be a correctly instantiated object if not 'regularizer' in kwargs.keys() : @@ -165,14 +165,7 @@ class FISTAReconstructor(): if not 'initialize' in kwargs.keys(): self.pars['initialize'] = False - if not 'use_studentt_fidelity' in kwargs.keys(): - self.setParameter(studentt=False) - else: - print ("studentt {0}".format(kwargs['use_studentt_fidelity'])) - if kwargs['use_studentt_fidelity']: - raise Exception('Not implemented') - - self.setParameter(studentt=kwargs['use_studentt_fidelity']) + def setParameter(self, **kwargs): -- cgit v1.2.3 From c097c34a59f80a6d4475a1f783b772fa42a44862 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 18 Oct 2017 16:54:06 +0100 Subject: implemented non ordered subset FISTA in reconstructor --- src/Python/ccpi/fista/FISTAReconstructor.py | 165 ++++++++++++++++++++++++---- 1 file changed, 145 insertions(+), 20 deletions(-) (limited to 'src/Python/ccpi') diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py index 33e67a3..fda9cf0 100644 --- a/src/Python/ccpi/fista/FISTAReconstructor.py +++ b/src/Python/ccpi/fista/FISTAReconstructor.py @@ -85,6 +85,7 @@ class FISTAReconstructor(): self.pars['detectors'] = detectors self.pars['number_of_angles'] = nangles self.pars['SlicesZ'] = sliceZ + self.pars['output_volume'] = None print (self.pars) # handle optional input parameters (at instantiation) @@ -108,7 +109,11 @@ class FISTAReconstructor(): 'regularizer' , 'ring_lambda_R_L1', 'ring_alpha', - 'subsets') + 'subsets', + 'output_volume', + 'os_subsets', + 'os_indices', + 'os_bins') self.acceptedInputKeywords = list(kw) # handle keyworded parameters @@ -176,8 +181,6 @@ class FISTAReconstructor(): ''' for key , value in kwargs.items(): if key in self.acceptedInputKeywords: - if key == 'use_studentt_fidelity': - raise Exception('use_studentt_fidelity Not implemented') self.pars[key] = value else: raise Exception('Wrong parameter {0} for '.format(key) + @@ -382,11 +385,15 @@ class FISTAReconstructor(): counter = counter + binsDiscr[jj] - 1 - - return IndicesReorg + # store the OS in parameters + self.setParameter(os_subsets=subsets, + os_bins=binsDiscr, + os_indices=IndicesReorg) def prepareForIteration(self): + print ("FISTA Reconstructor: prepare for iteration") + self.residual_error = numpy.zeros((self.pars['number_of_iterations'])) self.objective = numpy.zeros((self.pars['number_of_iterations'])) @@ -401,19 +408,17 @@ class FISTAReconstructor(): if self.getParameter('Lipschitz_constant') is None: self.pars['Lipschitz_constant'] = \ self.calculateLipschitzConstantWithPowerMethod() + # errors vector (if the ground truth is given) + self.Resid_error = numpy.zeros((self.getParameter('number_of_iterations'))); + # objective function values vector + self.objective = numpy.zeros((self.getParameter('number_of_iterations'))); # prepareForIteration def iterate(self, Xin=None): - # convenience variable storage - proj_geom , vol_geom, sino , \ - SlicesZ = self.getParameter([ 'projector_geometry' , - 'output_geometry', - 'input_sinogram', - 'SlicesZ']) - - t = 1 + print ("FISTA Reconstructor: iterate") + if Xin is None: if self.getParameter('initialize'): X = self.initialize() @@ -423,15 +428,25 @@ class FISTAReconstructor(): else: # copy by reference X = Xin - + # store the output volume in the parameters + self.setParameter(output_volume=X) X_t = X.copy() + # convenience variable storage + proj_geom , vol_geom, sino , \ + SlicesZ = self.getParameter([ 'projector_geometry' , + 'output_geometry', + 'input_sinogram', + 'SlicesZ' ]) + + t = 1 for i in range(self.getParameter('number_of_iterations')): X_old = X.copy() t_old = t r_old = self.r.copy() if self.getParameter('projector_geometry')['type'] == 'parallel' or \ - self.getParameter('projector_geometry')['type'] == 'parallel3d': + self.getParameter('projector_geometry')['type'] == 'fanflat' or \ + self.getParameter('projector_geometry')['type'] == 'fanflat_vec': # if the geometry is parallel use slice-by-slice # projection-backprojection routine #sino_updt = zeros(size(sino),'single'); @@ -439,10 +454,9 @@ class FISTAReconstructor(): proj_geomT['DetectorRowCount'] = 1 vol_geomT = vol_geom.copy() vol_geomT['GridSliceCount'] = 1; - sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) + self.sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) for kkk in range(SlicesZ): - print (kkk) - sino_id, sino_updt[kkk] = \ + sino_id, self.sino_updt[kkk] = \ astra.creators.create_sino3d_gpu( X_t[kkk:kkk+1], proj_geomT, vol_geomT) astra.matlab.data3d('delete', sino_id) @@ -450,11 +464,122 @@ class FISTAReconstructor(): # for divergent 3D geometry (watch the GPU memory overflow in # ASTRA versions < 1.8) #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); - sino_id, sino_updt = astra.matlab.create_sino3d_gpu( + sino_id, self.sino_updt = astra.creators.create_sino3d_gpu( X_t, proj_geom, vol_geom) ## RING REMOVAL - + self.ringRemoval(i) + ## Projection/Backprojection Routine + self.projectionBackprojection(X, X_t) + astra.matlab.data3d('delete', sino_id) ## REGULARIZATION + X = self.regularize(X) + ## Update Loop + X , X_t, t = self.updateLoop(i, X, X_old, r_old, t, t_old) + self.setParameter(output_volume=X) + return X + ## iterate + + def ringRemoval(self, i): + print ("FISTA Reconstructor: ring removal") + residual = self.residual + lambdaR_L1 , alpha_ring , weights , L_const , sino= \ + self.getParameter(['ring_lambda_R_L1', + 'ring_alpha' , 'weights', + 'Lipschitz_constant', + 'input_sinogram']) + r_x = self.r_x + sino_updt = self.sino_updt + + SlicesZ, anglesNumb, Detectors = \ + numpy.shape(self.getParameter('input_sinogram')) + if lambdaR_L1 > 0 : + for kkk in range(anglesNumb): + + residual[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ + ((sino_updt[:,kkk,:]).squeeze() - \ + (sino[:,kkk,:]).squeeze() -\ + (alpha_ring * r_x) + ) + vec = residual.sum(axis = 1) + #if SlicesZ > 1: + # vec = vec[:,1,:].squeeze() + self.r = (r_x - (1./L_const) * vec).copy() + self.objective[i] = (0.5 * (residual ** 2).sum()) + + def projectionBackprojection(self, X, X_t): + print ("FISTA Reconstructor: projection-backprojection routine") + + # a few useful variables + SlicesZ, anglesNumb, Detectors = \ + numpy.shape(self.getParameter('input_sinogram')) + residual = self.residual + proj_geom , vol_geom , L_const = \ + self.getParameter(['projector_geometry' , + 'output_geometry', + 'Lipschitz_constant']) + + + if self.getParameter('projector_geometry')['type'] == 'parallel' or \ + self.getParameter('projector_geometry')['type'] == 'fanflat' or \ + self.getParameter('projector_geometry')['type'] == 'fanflat_vec': + # if the geometry is parallel use slice-by-slice + # projection-backprojection routine + #sino_updt = zeros(size(sino),'single'); + proj_geomT = proj_geom.copy() + proj_geomT['DetectorRowCount'] = 1 + vol_geomT = vol_geom.copy() + vol_geomT['GridSliceCount'] = 1; + x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32) + + for kkk in range(SlicesZ): + + x_id, x_temp[kkk] = \ + astra.creators.create_backprojection3d_gpu( + residual[kkk:kkk+1], + proj_geomT, vol_geomT) + astra.matlab.data3d('delete', x_id) + else: + x_id, x_temp = \ + astra.creators.create_backprojection3d_gpu( + residual, proj_geom, vol_geom) + + X = X_t - (1/L_const) * x_temp + #astra.matlab.data3d('delete', sino_id) + astra.matlab.data3d('delete', x_id) + + def regularize(self, X): + print ("FISTA Reconstructor: regularize") + + regularizer = self.getParameter('regularizer') + if regularizer is not None: + return regularizer(input=X) + else: + return X + + def updateLoop(self, i, X, X_old, r_old, t, t_old): + print ("FISTA Reconstructor: update loop") + lambdaR_L1 = self.getParameter('ring_lambda_R_L1') + if lambdaR_L1 > 0: + self.r = numpy.max( + numpy.abs(self.r) - lambdaR_L1 , 0) * \ + numpy.sign(self.r) + t = (1 + numpy.sqrt(1 + 4 * t**2))/2 + X_t = X + (((t_old -1)/t) * (X - X_old)) + + if lambdaR_L1 > 0: + self.r_x = self.r + \ + (((t_old-1)/t) * (self.r - r_old)) + + if self.getParameter('region_of_interest') is None: + string = 'Iteration Number {0} | Objective {1} \n' + print (string.format( i, self.objective[i])) + else: + ROI , X_ideal = fistaRecon.getParameter('region_of_interest', + 'ideal_image') + Resid_error[i] = RMSE(X*ROI, X_ideal*ROI) + string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' + print (string.format(i,Resid_error[i], self.objective[i])) + return (X , X_t, t) -- cgit v1.2.3 From 903175ed67f7645fa35edf4623b27999d6cb990f Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 20 Oct 2017 17:04:26 +0100 Subject: Further development --- src/Python/ccpi/fista/FISTAReconstructor.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) (limited to 'src/Python/ccpi') diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py index fda9cf0..85bfac5 100644 --- a/src/Python/ccpi/fista/FISTAReconstructor.py +++ b/src/Python/ccpi/fista/FISTAReconstructor.py @@ -583,3 +583,27 @@ class FISTAReconstructor(): string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' print (string.format(i,Resid_error[i], self.objective[i])) return (X , X_t, t) + + def os_iterate(self, Xin=None): + print ("FISTA Reconstructor: iterate") + + if Xin is None: + if self.getParameter('initialize'): + X = self.initialize() + else: + N = vol_geom['GridColCount'] + X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) + else: + # copy by reference + X = Xin + # store the output volume in the parameters + self.setParameter(output_volume=X) + X_t = X.copy() + + # some useful constants + proj_geom , vol_geom, sino , \ + SlicesZ, weights , alpha_ring , + lambdaR_L1 , L_const = self.getParameter( + ['projector_geometry' , 'output_geometry', + 'input_sinogram', 'SlicesZ' , 'weights', 'ring_alpha' , + 'ring_lambda_R_L1', 'Lipschitz_constant']) -- cgit v1.2.3 From a11c59651ec125e24371a2049606df0f80f458d0 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Tue, 24 Oct 2017 11:26:46 +0100 Subject: latest dev --- .../ccpi/reconstruction/FISTAReconstructor.py | 599 +++++++++++++++------ 1 file changed, 427 insertions(+), 172 deletions(-) (limited to 'src/Python/ccpi') diff --git a/src/Python/ccpi/reconstruction/FISTAReconstructor.py b/src/Python/ccpi/reconstruction/FISTAReconstructor.py index ea96b53..85bfac5 100644 --- a/src/Python/ccpi/reconstruction/FISTAReconstructor.py +++ b/src/Python/ccpi/reconstruction/FISTAReconstructor.py @@ -21,10 +21,9 @@ import numpy -import h5py #from ccpi.reconstruction.parallelbeam import alg -from ccpi.imaging.Regularizer import Regularizer +#from ccpi.imaging.Regularizer import Regularizer from enum import Enum import astra @@ -74,18 +73,34 @@ class FISTAReconstructor(): # 3. "A novel tomographic reconstruction method based on the robust # Student's t function for suppressing data outliers" D. Kazantsev et.al. # D. Kazantsev, 2016-17 - def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): - self.params = dict() - self.params['projector_geometry'] = projector_geometry - self.params['output_geometry'] = output_geometry - self.params['input_sinogram'] = input_sinogram - detectors, nangles, sliceZ = numpy.shape(input_sinogram) - self.params['detectors'] = detectors - self.params['number_og_angles'] = nangles - self.params['SlicesZ'] = sliceZ + def __init__(self, projector_geometry, output_geometry, input_sinogram, + **kwargs): + # handle parmeters: + # obligatory parameters + self.pars = dict() + self.pars['projector_geometry'] = projector_geometry # proj_geom + self.pars['output_geometry'] = output_geometry # vol_geom + self.pars['input_sinogram'] = input_sinogram # sino + sliceZ, nangles, detectors = numpy.shape(input_sinogram) + self.pars['detectors'] = detectors + self.pars['number_of_angles'] = nangles + self.pars['SlicesZ'] = sliceZ + self.pars['output_volume'] = None + + print (self.pars) + # handle optional input parameters (at instantiation) # Accepted input keywords - kw = ('number_of_iterations', + kw = ( + # mandatory fields + 'projector_geometry', + 'output_geometry', + 'input_sinogram', + 'detectors', + 'number_of_angles', + 'SlicesZ', + # optional fields + 'number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , 'weights' , @@ -93,7 +108,13 @@ class FISTAReconstructor(): 'initialize' , 'regularizer' , 'ring_lambda_R_L1', - 'ring_alpha') + 'ring_alpha', + 'subsets', + 'output_volume', + 'os_subsets', + 'os_indices', + 'os_bins') + self.acceptedInputKeywords = list(kw) # handle keyworded parameters if kwargs is not None: @@ -110,85 +131,160 @@ class FISTAReconstructor(): if 'weights' in kwargs.keys(): self.pars['weights'] = kwargs['weights'] else: - self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) + self.pars['weights'] = \ + numpy.ones(numpy.shape( + self.pars['input_sinogram'])) if 'Lipschitz_constant' in kwargs.keys(): self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] else: - self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + self.pars['Lipschitz_constant'] = None - if not self.pars['ideal_image'] in kwargs.keys(): + if not 'ideal_image' in kwargs.keys(): self.pars['ideal_image'] = None - if not self.pars['region_of_interest'] : + if not 'region_of_interest'in kwargs.keys() : if self.pars['ideal_image'] == None: - pass + self.pars['region_of_interest'] = None else: - self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) - - if not self.pars['regularizer'] : + ## nonzero if the image is larger than m + fsm = numpy.frompyfunc(lambda x,m: 1 if x>m else 0, 2,1) + + self.pars['region_of_interest'] = fsm(self.pars['ideal_image'], 0) + + # the regularizer must be a correctly instantiated object + if not 'regularizer' in kwargs.keys() : self.pars['regularizer'] = None + + #RING REMOVAL + if not 'ring_lambda_R_L1' in kwargs.keys(): + self.pars['ring_lambda_R_L1'] = 0 + if not 'ring_alpha' in kwargs.keys(): + self.pars['ring_alpha'] = 1 + + # ORDERED SUBSET + if not 'subsets' in kwargs.keys(): + self.pars['subsets'] = 0 else: - # the regularizer must be a correctly instantiated object - if not self.pars['ring_lambda_R_L1']: - self.pars['ring_lambda_R_L1'] = 0 - if not self.pars['ring_alpha']: - self.pars['ring_alpha'] = 1 + self.createOrderedSubsets() + + if not 'initialize' in kwargs.keys(): + self.pars['initialize'] = False + + def setParameter(self, **kwargs): + '''set named parameter for the reconstructor engine + + raises Exception if the named parameter is not recognized + ''' + for key , value in kwargs.items(): + if key in self.acceptedInputKeywords: + self.pars[key] = value + else: + raise Exception('Wrong parameter {0} for '.format(key) + + 'reconstructor') + # setParameter + + def getParameter(self, key): + if type(key) is str: + if key in self.acceptedInputKeywords: + return self.pars[key] + else: + raise Exception('Unrecongnised parameter: {0} '.format(key) ) + elif type(key) is list: + outpars = [] + for k in key: + outpars.append(self.getParameter(k)) + return outpars + else: + raise Exception('Unhandled input {0}' .format(str(type(key)))) + + def calculateLipschitzConstantWithPowerMethod(self): ''' using Power method (PM) to establish L constant''' - #N = params.vol_geom.GridColCount - N = self.pars['output_geometry'].GridColCount - proj_geom = self.params['projector_geometry'] - vol_geom = self.params['output_geometry'] + N = self.pars['output_geometry']['GridColCount'] + proj_geom = self.pars['projector_geometry'] + vol_geom = self.pars['output_geometry'] weights = self.pars['weights'] SlicesZ = self.pars['SlicesZ'] - if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + + + if (proj_geom['type'] == 'parallel') or \ + (proj_geom['type'] == 'parallel3d'): #% for parallel geometry we can do just one slice - #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); - niter = 15;# % number of iteration for the PM + #print('Calculating Lipshitz constant for parallel beam geometry...') + niter = 5;# % number of iteration for the PM #N = params.vol_geom.GridColCount; #x1 = rand(N,N,1); x1 = numpy.random.rand(1,N,N) #sqweight = sqrt(weights(:,:,1)); - sqweight = numpy.sqrt(weights.T[0]) + sqweight = numpy.sqrt(weights[0]) proj_geomT = proj_geom.copy(); - proj_geomT.DetectorRowCount = 1; + proj_geomT['DetectorRowCount'] = 1; vol_geomT = vol_geom.copy(); vol_geomT['GridSliceCount'] = 1; + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + for i in range(niter): - if i == 0: - #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); - y = sqweight * y # element wise multiplication - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id) + # [id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geomT, vol_geomT); + # s = norm(x1(:)); + # x1 = x1/s; + # [sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + # y = sqweight.*y; + # astra_mex_data3d('delete', sino_id); + # astra_mex_data3d('delete', id); + #print ("iteration {0}".format(i)) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geomT, + vol_geomT) + + y = (sqweight * y).copy() # element wise multiplication + + #b=fig.add_subplot(2,1,2) + #imgplot = plt.imshow(x1[0]) + #plt.show() + + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + del x1 - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); + idx,x1 = astra.creators.create_backprojection3d_gpu((sqweight*y).copy(), + proj_geomT, + vol_geomT) + del y + + s = numpy.linalg.norm(x1) ### this line? - x1 = x1/s; - ### this line? - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); - y = sqweight*y; + x1 = (x1/s).copy(); + + # ### this line? + # sino_id, y = astra.creators.create_sino3d_gpu(x1, + # proj_geomT, + # vol_geomT); + # y = sqweight * y; astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); + astra.matlab.data3d('delete', idx) + print ("iteration {0} s= {1}".format(i,s)) + #end del proj_geomT del vol_geomT + #plt.show() else: #% divergen beam geometry - #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); + print('Calculating Lipshitz constant for divergen beam geometry...') niter = 8; #% number of iteration for PM x1 = numpy.random.rand(SlicesZ , N , N); #sqweight = sqrt(weights); - sqweight = numpy.sqrt(weights.T[0]) + sqweight = numpy.sqrt(weights[0]) sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); y = sqweight*y; @@ -217,6 +313,7 @@ class FISTAReconstructor(): #end #clear x1 del x1 + return s @@ -225,130 +322,288 @@ class FISTAReconstructor(): if regularizer is not None: self.pars['regularizer'] = regularizer + + def initialize(self): + # convenience variable storage + proj_geom = self.pars['projector_geometry'] + vol_geom = self.pars['output_geometry'] + sino = self.pars['input_sinogram'] + + # a 'warm start' with SIRT method + # Create a data object for the reconstruction + rec_id = astra.matlab.data3d('create', '-vol', + vol_geom); + + #sinogram_id = astra_mex_data3d('create', '-proj3d', proj_geom, sino); + sinogram_id = astra.matlab.data3d('create', '-proj3d', + proj_geom, + sino) + + sirt_config = astra.astra_dict('SIRT3D_CUDA') + sirt_config['ReconstructionDataId' ] = rec_id + sirt_config['ProjectionDataId'] = sinogram_id + + sirt = astra.algorithm.create(sirt_config) + astra.algorithm.run(sirt, iterations=35) + X = astra.matlab.data3d('get', rec_id) + + # clean up memory + astra.matlab.data3d('delete', rec_id) + astra.matlab.data3d('delete', sinogram_id) + astra.algorithm.delete(sirt) + + + + return X + + def createOrderedSubsets(self, subsets=None): + if subsets is None: + try: + subsets = self.getParameter('subsets') + except Exception(): + subsets = 0 + #return subsets + + angles = self.getParameter('projector_geometry')['ProjectionAngles'] + + #binEdges = numpy.linspace(angles.min(), + # angles.max(), + # subsets + 1) + binsDiscr, binEdges = numpy.histogram(angles, bins=subsets) + # get rearranged subset indices + IndicesReorg = numpy.zeros((numpy.shape(angles))) + counterM = 0 + for ii in range(binsDiscr.max()): + counter = 0 + for jj in range(subsets): + curr_index = ii + jj + counter + #print ("{0} {1} {2}".format(binsDiscr[jj] , ii, counterM)) + if binsDiscr[jj] > ii: + if (counterM < numpy.size(IndicesReorg)): + IndicesReorg[counterM] = curr_index + counterM = counterM + 1 + + counter = counter + binsDiscr[jj] - 1 + + # store the OS in parameters + self.setParameter(os_subsets=subsets, + os_bins=binsDiscr, + os_indices=IndicesReorg) + + + def prepareForIteration(self): + print ("FISTA Reconstructor: prepare for iteration") + + self.residual_error = numpy.zeros((self.pars['number_of_iterations'])) + self.objective = numpy.zeros((self.pars['number_of_iterations'])) + + #2D array (for 3D data) of sparse "ring" + detectors, nangles, sliceZ = numpy.shape(self.pars['input_sinogram']) + self.r = numpy.zeros((detectors, sliceZ), dtype=numpy.float) + # another ring variable + self.r_x = self.r.copy() + + self.residual = numpy.zeros(numpy.shape(self.pars['input_sinogram'])) + + if self.getParameter('Lipschitz_constant') is None: + self.pars['Lipschitz_constant'] = \ + self.calculateLipschitzConstantWithPowerMethod() + # errors vector (if the ground truth is given) + self.Resid_error = numpy.zeros((self.getParameter('number_of_iterations'))); + # objective function values vector + self.objective = numpy.zeros((self.getParameter('number_of_iterations'))); + + + # prepareForIteration + + def iterate(self, Xin=None): + print ("FISTA Reconstructor: iterate") + + if Xin is None: + if self.getParameter('initialize'): + X = self.initialize() + else: + N = vol_geom['GridColCount'] + X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) + else: + # copy by reference + X = Xin + # store the output volume in the parameters + self.setParameter(output_volume=X) + X_t = X.copy() + # convenience variable storage + proj_geom , vol_geom, sino , \ + SlicesZ = self.getParameter([ 'projector_geometry' , + 'output_geometry', + 'input_sinogram', + 'SlicesZ' ]) + + t = 1 + + for i in range(self.getParameter('number_of_iterations')): + X_old = X.copy() + t_old = t + r_old = self.r.copy() + if self.getParameter('projector_geometry')['type'] == 'parallel' or \ + self.getParameter('projector_geometry')['type'] == 'fanflat' or \ + self.getParameter('projector_geometry')['type'] == 'fanflat_vec': + # if the geometry is parallel use slice-by-slice + # projection-backprojection routine + #sino_updt = zeros(size(sino),'single'); + proj_geomT = proj_geom.copy() + proj_geomT['DetectorRowCount'] = 1 + vol_geomT = vol_geom.copy() + vol_geomT['GridSliceCount'] = 1; + self.sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) + for kkk in range(SlicesZ): + sino_id, self.sino_updt[kkk] = \ + astra.creators.create_sino3d_gpu( + X_t[kkk:kkk+1], proj_geomT, vol_geomT) + astra.matlab.data3d('delete', sino_id) + else: + # for divergent 3D geometry (watch the GPU memory overflow in + # ASTRA versions < 1.8) + #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); + sino_id, self.sino_updt = astra.creators.create_sino3d_gpu( + X_t, proj_geom, vol_geom) + + + ## RING REMOVAL + self.ringRemoval(i) + ## Projection/Backprojection Routine + self.projectionBackprojection(X, X_t) + astra.matlab.data3d('delete', sino_id) + ## REGULARIZATION + X = self.regularize(X) + ## Update Loop + X , X_t, t = self.updateLoop(i, X, X_old, r_old, t, t_old) + self.setParameter(output_volume=X) + return X + ## iterate - + def ringRemoval(self, i): + print ("FISTA Reconstructor: ring removal") + residual = self.residual + lambdaR_L1 , alpha_ring , weights , L_const , sino= \ + self.getParameter(['ring_lambda_R_L1', + 'ring_alpha' , 'weights', + 'Lipschitz_constant', + 'input_sinogram']) + r_x = self.r_x + sino_updt = self.sino_updt + + SlicesZ, anglesNumb, Detectors = \ + numpy.shape(self.getParameter('input_sinogram')) + if lambdaR_L1 > 0 : + for kkk in range(anglesNumb): + + residual[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ + ((sino_updt[:,kkk,:]).squeeze() - \ + (sino[:,kkk,:]).squeeze() -\ + (alpha_ring * r_x) + ) + vec = residual.sum(axis = 1) + #if SlicesZ > 1: + # vec = vec[:,1,:].squeeze() + self.r = (r_x - (1./L_const) * vec).copy() + self.objective[i] = (0.5 * (residual ** 2).sum()) + def projectionBackprojection(self, X, X_t): + print ("FISTA Reconstructor: projection-backprojection routine") + + # a few useful variables + SlicesZ, anglesNumb, Detectors = \ + numpy.shape(self.getParameter('input_sinogram')) + residual = self.residual + proj_geom , vol_geom , L_const = \ + self.getParameter(['projector_geometry' , + 'output_geometry', + 'Lipschitz_constant']) + + + if self.getParameter('projector_geometry')['type'] == 'parallel' or \ + self.getParameter('projector_geometry')['type'] == 'fanflat' or \ + self.getParameter('projector_geometry')['type'] == 'fanflat_vec': + # if the geometry is parallel use slice-by-slice + # projection-backprojection routine + #sino_updt = zeros(size(sino),'single'); + proj_geomT = proj_geom.copy() + proj_geomT['DetectorRowCount'] = 1 + vol_geomT = vol_geom.copy() + vol_geomT['GridSliceCount'] = 1; + x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32) + + for kkk in range(SlicesZ): + + x_id, x_temp[kkk] = \ + astra.creators.create_backprojection3d_gpu( + residual[kkk:kkk+1], + proj_geomT, vol_geomT) + astra.matlab.data3d('delete', x_id) + else: + x_id, x_temp = \ + astra.creators.create_backprojection3d_gpu( + residual, proj_geom, vol_geom) + + X = X_t - (1/L_const) * x_temp + #astra.matlab.data3d('delete', sino_id) + astra.matlab.data3d('delete', x_id) + + def regularize(self, X): + print ("FISTA Reconstructor: regularize") + + regularizer = self.getParameter('regularizer') + if regularizer is not None: + return regularizer(input=X) + else: + return X + + def updateLoop(self, i, X, X_old, r_old, t, t_old): + print ("FISTA Reconstructor: update loop") + lambdaR_L1 = self.getParameter('ring_lambda_R_L1') + if lambdaR_L1 > 0: + self.r = numpy.max( + numpy.abs(self.r) - lambdaR_L1 , 0) * \ + numpy.sign(self.r) + t = (1 + numpy.sqrt(1 + 4 * t**2))/2 + X_t = X + (((t_old -1)/t) * (X - X_old)) + + if lambdaR_L1 > 0: + self.r_x = self.r + \ + (((t_old-1)/t) * (self.r - r_old)) + + if self.getParameter('region_of_interest') is None: + string = 'Iteration Number {0} | Objective {1} \n' + print (string.format( i, self.objective[i])) + else: + ROI , X_ideal = fistaRecon.getParameter('region_of_interest', + 'ideal_image') + + Resid_error[i] = RMSE(X*ROI, X_ideal*ROI) + string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' + print (string.format(i,Resid_error[i], self.objective[i])) + return (X , X_t, t) -def getEntry(location): - for item in nx[location].keys(): - print (item) - - -print ("Loading Data") - -##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" -####ind = [i * 1049 for i in range(360)] -#### use only 360 images -##images = 200 -##ind = [int(i * 1049 / images) for i in range(images)] -##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) - -#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" -#fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" -##fname = "/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/dendr.h5" -##nx = h5py.File(fname, "r") -## -### the data are stored in a particular location in the hdf5 -##for item in nx['entry1/tomo_entry/data'].keys(): -## print (item) -## -##data = nx.get('entry1/tomo_entry/data/rotation_angle') -##angles = numpy.zeros(data.shape) -##data.read_direct(angles) -##print (angles) -### angles should be in degrees -## -##data = nx.get('entry1/tomo_entry/data/data') -##stack = numpy.zeros(data.shape) -##data.read_direct(stack) -##print (data.shape) -## -##print ("Data Loaded") -## -## -### Normalize -##data = nx.get('entry1/tomo_entry/instrument/detector/image_key') -##itype = numpy.zeros(data.shape) -##data.read_direct(itype) -### 2 is dark field -##darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] -##dark = darks[0] -##for i in range(1, len(darks)): -## dark += darks[i] -##dark = dark / len(darks) -###dark[0][0] = dark[0][1] -## -### 1 is flat field -##flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] -##flat = flats[0] -##for i in range(1, len(flats)): -## flat += flats[i] -##flat = flat / len(flats) -###flat[0][0] = dark[0][1] -## -## -### 0 is projection data -##proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] -##angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] -##angle_proj = numpy.asarray (angle_proj) -##angle_proj = angle_proj.astype(numpy.float32) -## -### normalized data are -### norm = (projection - dark)/(flat-dark) -## -##def normalize(projection, dark, flat, def_val=0.1): -## a = (projection - dark) -## b = (flat-dark) -## with numpy.errstate(divide='ignore', invalid='ignore'): -## c = numpy.true_divide( a, b ) -## c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 -## return c -## -## -##norm = [normalize(projection, dark, flat) for projection in proj] -##norm = numpy.asarray (norm) -##norm = norm.astype(numpy.float32) - - -##niterations = 15 -##threads = 3 -## -##img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -##img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -##img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -## -##iteration_values = numpy.zeros((niterations,)) -##img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, -## iteration_values, False) -##print ("iteration values %s" % str(iteration_values)) -## -##iteration_values = numpy.zeros((niterations,)) -##img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, -## numpy.double(1e-5), iteration_values , False) -##print ("iteration values %s" % str(iteration_values)) -##iteration_values = numpy.zeros((niterations,)) -##img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, -## numpy.double(1e-5), iteration_values , False) -##print ("iteration values %s" % str(iteration_values)) -## -## -####numpy.save("cgls_recon.npy", img_data) -##import matplotlib.pyplot as plt -##fig, ax = plt.subplots(1,6,sharey=True) -##ax[0].imshow(img_cgls[80]) -##ax[0].axis('off') # clear x- and y-axes -##ax[1].imshow(img_sirt[80]) -##ax[1].axis('off') # clear x- and y-axes -##ax[2].imshow(img_mlem[80]) -##ax[2].axis('off') # clear x- and y-axesplt.show() -##ax[3].imshow(img_cgls_conv[80]) -##ax[3].axis('off') # clear x- and y-axesplt.show() -##ax[4].imshow(img_cgls_tikhonov[80]) -##ax[4].axis('off') # clear x- and y-axesplt.show() -##ax[5].imshow(img_cgls_TVreg[80]) -##ax[5].axis('off') # clear x- and y-axesplt.show() -## -## -##plt.show() -## + def os_iterate(self, Xin=None): + print ("FISTA Reconstructor: iterate") + + if Xin is None: + if self.getParameter('initialize'): + X = self.initialize() + else: + N = vol_geom['GridColCount'] + X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) + else: + # copy by reference + X = Xin + # store the output volume in the parameters + self.setParameter(output_volume=X) + X_t = X.copy() + # some useful constants + proj_geom , vol_geom, sino , \ + SlicesZ, weights , alpha_ring , + lambdaR_L1 , L_const = self.getParameter( + ['projector_geometry' , 'output_geometry', + 'input_sinogram', 'SlicesZ' , 'weights', 'ring_alpha' , + 'ring_lambda_R_L1', 'Lipschitz_constant']) -- cgit v1.2.3 From 546104f8dfea5691801137c1be99d09e1e999d82 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Tue, 24 Oct 2017 11:31:36 +0100 Subject: removed fista directory use the standard package reconstruction directory for the fista code --- src/Python/ccpi/fista/FISTAReconstructor.py | 609 ---------------------------- src/Python/ccpi/fista/Reconstructor.py | 425 ------------------- src/Python/ccpi/fista/__init__.py | 0 3 files changed, 1034 deletions(-) delete mode 100644 src/Python/ccpi/fista/FISTAReconstructor.py delete mode 100644 src/Python/ccpi/fista/Reconstructor.py delete mode 100644 src/Python/ccpi/fista/__init__.py (limited to 'src/Python/ccpi') diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py deleted file mode 100644 index 85bfac5..0000000 --- a/src/Python/ccpi/fista/FISTAReconstructor.py +++ /dev/null @@ -1,609 +0,0 @@ -# -*- coding: utf-8 -*- -############################################################################### -#This work is part of the Core Imaging Library developed by -#Visual Analytics and Imaging System Group of the Science Technology -#Facilities Council, STFC -# -#Copyright 2017 Edoardo Pasca, Srikanth Nagella -#Copyright 2017 Daniil Kazantsev -# -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -#http://www.apache.org/licenses/LICENSE-2.0 -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. -############################################################################### - - - -import numpy -#from ccpi.reconstruction.parallelbeam import alg - -#from ccpi.imaging.Regularizer import Regularizer -from enum import Enum - -import astra - - - -class FISTAReconstructor(): - '''FISTA-based reconstruction algorithm using ASTRA-toolbox - - ''' - # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> - # ___Input___: - # params.[] file: - # - .proj_geom (geometry of the projector) [required] - # - .vol_geom (geometry of the reconstructed object) [required] - # - .sino (vectorized in 2D or 3D sinogram) [required] - # - .iterFISTA (iterations for the main loop, default 40) - # - .L_const (Lipschitz constant, default Power method) ) - # - .X_ideal (ideal image, if given) - # - .weights (statisitcal weights, size of the sinogram) - # - .ROI (Region-of-interest, only if X_ideal is given) - # - .initialize (a 'warm start' using SIRT method from ASTRA) - #----------------Regularization choices------------------------ - # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) - # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) - # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) - # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) - # - .Regul_Iterations (iterations for the selected penalty, default 25) - # - .Regul_tauLLT (time step parameter for LLT term) - # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) - # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) - #----------------Visualization parameters------------------------ - # - .show (visualize reconstruction 1/0, (0 default)) - # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) - # - .slice (for 3D volumes - slice number to imshow) - # ___Output___: - # 1. X - reconstructed image/volume - # 2. output - a structure with - # - .Resid_error - residual error (if X_ideal is given) - # - .objective: value of the objective function - # - .L_const: Lipshitz constant to avoid recalculations - - # References: - # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse - # Problems" by A. Beck and M Teboulle - # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo - # 3. "A novel tomographic reconstruction method based on the robust - # Student's t function for suppressing data outliers" D. Kazantsev et.al. - # D. Kazantsev, 2016-17 - def __init__(self, projector_geometry, output_geometry, input_sinogram, - **kwargs): - # handle parmeters: - # obligatory parameters - self.pars = dict() - self.pars['projector_geometry'] = projector_geometry # proj_geom - self.pars['output_geometry'] = output_geometry # vol_geom - self.pars['input_sinogram'] = input_sinogram # sino - sliceZ, nangles, detectors = numpy.shape(input_sinogram) - self.pars['detectors'] = detectors - self.pars['number_of_angles'] = nangles - self.pars['SlicesZ'] = sliceZ - self.pars['output_volume'] = None - - print (self.pars) - # handle optional input parameters (at instantiation) - - # Accepted input keywords - kw = ( - # mandatory fields - 'projector_geometry', - 'output_geometry', - 'input_sinogram', - 'detectors', - 'number_of_angles', - 'SlicesZ', - # optional fields - 'number_of_iterations', - 'Lipschitz_constant' , - 'ideal_image' , - 'weights' , - 'region_of_interest' , - 'initialize' , - 'regularizer' , - 'ring_lambda_R_L1', - 'ring_alpha', - 'subsets', - 'output_volume', - 'os_subsets', - 'os_indices', - 'os_bins') - self.acceptedInputKeywords = list(kw) - - # handle keyworded parameters - if kwargs is not None: - for key, value in kwargs.items(): - if key in kw: - #print("{0} = {1}".format(key, value)) - self.pars[key] = value - - # set the default values for the parameters if not set - if 'number_of_iterations' in kwargs.keys(): - self.pars['number_of_iterations'] = kwargs['number_of_iterations'] - else: - self.pars['number_of_iterations'] = 40 - if 'weights' in kwargs.keys(): - self.pars['weights'] = kwargs['weights'] - else: - self.pars['weights'] = \ - numpy.ones(numpy.shape( - self.pars['input_sinogram'])) - if 'Lipschitz_constant' in kwargs.keys(): - self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] - else: - self.pars['Lipschitz_constant'] = None - - if not 'ideal_image' in kwargs.keys(): - self.pars['ideal_image'] = None - - if not 'region_of_interest'in kwargs.keys() : - if self.pars['ideal_image'] == None: - self.pars['region_of_interest'] = None - else: - ## nonzero if the image is larger than m - fsm = numpy.frompyfunc(lambda x,m: 1 if x>m else 0, 2,1) - - self.pars['region_of_interest'] = fsm(self.pars['ideal_image'], 0) - - # the regularizer must be a correctly instantiated object - if not 'regularizer' in kwargs.keys() : - self.pars['regularizer'] = None - - #RING REMOVAL - if not 'ring_lambda_R_L1' in kwargs.keys(): - self.pars['ring_lambda_R_L1'] = 0 - if not 'ring_alpha' in kwargs.keys(): - self.pars['ring_alpha'] = 1 - - # ORDERED SUBSET - if not 'subsets' in kwargs.keys(): - self.pars['subsets'] = 0 - else: - self.createOrderedSubsets() - - if not 'initialize' in kwargs.keys(): - self.pars['initialize'] = False - - - - - def setParameter(self, **kwargs): - '''set named parameter for the reconstructor engine - - raises Exception if the named parameter is not recognized - - ''' - for key , value in kwargs.items(): - if key in self.acceptedInputKeywords: - self.pars[key] = value - else: - raise Exception('Wrong parameter {0} for '.format(key) + - 'reconstructor') - # setParameter - - def getParameter(self, key): - if type(key) is str: - if key in self.acceptedInputKeywords: - return self.pars[key] - else: - raise Exception('Unrecongnised parameter: {0} '.format(key) ) - elif type(key) is list: - outpars = [] - for k in key: - outpars.append(self.getParameter(k)) - return outpars - else: - raise Exception('Unhandled input {0}' .format(str(type(key)))) - - - def calculateLipschitzConstantWithPowerMethod(self): - ''' using Power method (PM) to establish L constant''' - - N = self.pars['output_geometry']['GridColCount'] - proj_geom = self.pars['projector_geometry'] - vol_geom = self.pars['output_geometry'] - weights = self.pars['weights'] - SlicesZ = self.pars['SlicesZ'] - - - - if (proj_geom['type'] == 'parallel') or \ - (proj_geom['type'] == 'parallel3d'): - #% for parallel geometry we can do just one slice - #print('Calculating Lipshitz constant for parallel beam geometry...') - niter = 5;# % number of iteration for the PM - #N = params.vol_geom.GridColCount; - #x1 = rand(N,N,1); - x1 = numpy.random.rand(1,N,N) - #sqweight = sqrt(weights(:,:,1)); - sqweight = numpy.sqrt(weights[0]) - proj_geomT = proj_geom.copy(); - proj_geomT['DetectorRowCount'] = 1; - vol_geomT = vol_geom.copy(); - vol_geomT['GridSliceCount'] = 1; - - #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - - - for i in range(niter): - # [id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geomT, vol_geomT); - # s = norm(x1(:)); - # x1 = x1/s; - # [sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - # y = sqweight.*y; - # astra_mex_data3d('delete', sino_id); - # astra_mex_data3d('delete', id); - #print ("iteration {0}".format(i)) - - sino_id, y = astra.creators.create_sino3d_gpu(x1, - proj_geomT, - vol_geomT) - - y = (sqweight * y).copy() # element wise multiplication - - #b=fig.add_subplot(2,1,2) - #imgplot = plt.imshow(x1[0]) - #plt.show() - - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id) - del x1 - - idx,x1 = astra.creators.create_backprojection3d_gpu((sqweight*y).copy(), - proj_geomT, - vol_geomT) - del y - - - s = numpy.linalg.norm(x1) - ### this line? - x1 = (x1/s).copy(); - - # ### this line? - # sino_id, y = astra.creators.create_sino3d_gpu(x1, - # proj_geomT, - # vol_geomT); - # y = sqweight * y; - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx) - print ("iteration {0} s= {1}".format(i,s)) - - #end - del proj_geomT - del vol_geomT - #plt.show() - else: - #% divergen beam geometry - print('Calculating Lipshitz constant for divergen beam geometry...') - niter = 8; #% number of iteration for PM - x1 = numpy.random.rand(SlicesZ , N , N); - #sqweight = sqrt(weights); - sqweight = numpy.sqrt(weights[0]) - - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id); - - for i in range(niter): - #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, - proj_geom, - vol_geom) - s = numpy.linalg.norm(x1) - ### this line? - x1 = x1/s; - ### this line? - #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); - sino_id, y = astra.creators.create_sino3d_gpu(x1, - proj_geom, - vol_geom); - - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - #astra_mex_data3d('delete', id); - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); - #end - #clear x1 - del x1 - - - return s - - - def setRegularizer(self, regularizer): - if regularizer is not None: - self.pars['regularizer'] = regularizer - - - def initialize(self): - # convenience variable storage - proj_geom = self.pars['projector_geometry'] - vol_geom = self.pars['output_geometry'] - sino = self.pars['input_sinogram'] - - # a 'warm start' with SIRT method - # Create a data object for the reconstruction - rec_id = astra.matlab.data3d('create', '-vol', - vol_geom); - - #sinogram_id = astra_mex_data3d('create', '-proj3d', proj_geom, sino); - sinogram_id = astra.matlab.data3d('create', '-proj3d', - proj_geom, - sino) - - sirt_config = astra.astra_dict('SIRT3D_CUDA') - sirt_config['ReconstructionDataId' ] = rec_id - sirt_config['ProjectionDataId'] = sinogram_id - - sirt = astra.algorithm.create(sirt_config) - astra.algorithm.run(sirt, iterations=35) - X = astra.matlab.data3d('get', rec_id) - - # clean up memory - astra.matlab.data3d('delete', rec_id) - astra.matlab.data3d('delete', sinogram_id) - astra.algorithm.delete(sirt) - - - - return X - - def createOrderedSubsets(self, subsets=None): - if subsets is None: - try: - subsets = self.getParameter('subsets') - except Exception(): - subsets = 0 - #return subsets - - angles = self.getParameter('projector_geometry')['ProjectionAngles'] - - #binEdges = numpy.linspace(angles.min(), - # angles.max(), - # subsets + 1) - binsDiscr, binEdges = numpy.histogram(angles, bins=subsets) - # get rearranged subset indices - IndicesReorg = numpy.zeros((numpy.shape(angles))) - counterM = 0 - for ii in range(binsDiscr.max()): - counter = 0 - for jj in range(subsets): - curr_index = ii + jj + counter - #print ("{0} {1} {2}".format(binsDiscr[jj] , ii, counterM)) - if binsDiscr[jj] > ii: - if (counterM < numpy.size(IndicesReorg)): - IndicesReorg[counterM] = curr_index - counterM = counterM + 1 - - counter = counter + binsDiscr[jj] - 1 - - # store the OS in parameters - self.setParameter(os_subsets=subsets, - os_bins=binsDiscr, - os_indices=IndicesReorg) - - - def prepareForIteration(self): - print ("FISTA Reconstructor: prepare for iteration") - - self.residual_error = numpy.zeros((self.pars['number_of_iterations'])) - self.objective = numpy.zeros((self.pars['number_of_iterations'])) - - #2D array (for 3D data) of sparse "ring" - detectors, nangles, sliceZ = numpy.shape(self.pars['input_sinogram']) - self.r = numpy.zeros((detectors, sliceZ), dtype=numpy.float) - # another ring variable - self.r_x = self.r.copy() - - self.residual = numpy.zeros(numpy.shape(self.pars['input_sinogram'])) - - if self.getParameter('Lipschitz_constant') is None: - self.pars['Lipschitz_constant'] = \ - self.calculateLipschitzConstantWithPowerMethod() - # errors vector (if the ground truth is given) - self.Resid_error = numpy.zeros((self.getParameter('number_of_iterations'))); - # objective function values vector - self.objective = numpy.zeros((self.getParameter('number_of_iterations'))); - - - # prepareForIteration - - def iterate(self, Xin=None): - print ("FISTA Reconstructor: iterate") - - if Xin is None: - if self.getParameter('initialize'): - X = self.initialize() - else: - N = vol_geom['GridColCount'] - X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) - else: - # copy by reference - X = Xin - # store the output volume in the parameters - self.setParameter(output_volume=X) - X_t = X.copy() - # convenience variable storage - proj_geom , vol_geom, sino , \ - SlicesZ = self.getParameter([ 'projector_geometry' , - 'output_geometry', - 'input_sinogram', - 'SlicesZ' ]) - - t = 1 - - for i in range(self.getParameter('number_of_iterations')): - X_old = X.copy() - t_old = t - r_old = self.r.copy() - if self.getParameter('projector_geometry')['type'] == 'parallel' or \ - self.getParameter('projector_geometry')['type'] == 'fanflat' or \ - self.getParameter('projector_geometry')['type'] == 'fanflat_vec': - # if the geometry is parallel use slice-by-slice - # projection-backprojection routine - #sino_updt = zeros(size(sino),'single'); - proj_geomT = proj_geom.copy() - proj_geomT['DetectorRowCount'] = 1 - vol_geomT = vol_geom.copy() - vol_geomT['GridSliceCount'] = 1; - self.sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) - for kkk in range(SlicesZ): - sino_id, self.sino_updt[kkk] = \ - astra.creators.create_sino3d_gpu( - X_t[kkk:kkk+1], proj_geomT, vol_geomT) - astra.matlab.data3d('delete', sino_id) - else: - # for divergent 3D geometry (watch the GPU memory overflow in - # ASTRA versions < 1.8) - #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); - sino_id, self.sino_updt = astra.creators.create_sino3d_gpu( - X_t, proj_geom, vol_geom) - - - ## RING REMOVAL - self.ringRemoval(i) - ## Projection/Backprojection Routine - self.projectionBackprojection(X, X_t) - astra.matlab.data3d('delete', sino_id) - ## REGULARIZATION - X = self.regularize(X) - ## Update Loop - X , X_t, t = self.updateLoop(i, X, X_old, r_old, t, t_old) - self.setParameter(output_volume=X) - return X - ## iterate - - def ringRemoval(self, i): - print ("FISTA Reconstructor: ring removal") - residual = self.residual - lambdaR_L1 , alpha_ring , weights , L_const , sino= \ - self.getParameter(['ring_lambda_R_L1', - 'ring_alpha' , 'weights', - 'Lipschitz_constant', - 'input_sinogram']) - r_x = self.r_x - sino_updt = self.sino_updt - - SlicesZ, anglesNumb, Detectors = \ - numpy.shape(self.getParameter('input_sinogram')) - if lambdaR_L1 > 0 : - for kkk in range(anglesNumb): - - residual[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ - ((sino_updt[:,kkk,:]).squeeze() - \ - (sino[:,kkk,:]).squeeze() -\ - (alpha_ring * r_x) - ) - vec = residual.sum(axis = 1) - #if SlicesZ > 1: - # vec = vec[:,1,:].squeeze() - self.r = (r_x - (1./L_const) * vec).copy() - self.objective[i] = (0.5 * (residual ** 2).sum()) - - def projectionBackprojection(self, X, X_t): - print ("FISTA Reconstructor: projection-backprojection routine") - - # a few useful variables - SlicesZ, anglesNumb, Detectors = \ - numpy.shape(self.getParameter('input_sinogram')) - residual = self.residual - proj_geom , vol_geom , L_const = \ - self.getParameter(['projector_geometry' , - 'output_geometry', - 'Lipschitz_constant']) - - - if self.getParameter('projector_geometry')['type'] == 'parallel' or \ - self.getParameter('projector_geometry')['type'] == 'fanflat' or \ - self.getParameter('projector_geometry')['type'] == 'fanflat_vec': - # if the geometry is parallel use slice-by-slice - # projection-backprojection routine - #sino_updt = zeros(size(sino),'single'); - proj_geomT = proj_geom.copy() - proj_geomT['DetectorRowCount'] = 1 - vol_geomT = vol_geom.copy() - vol_geomT['GridSliceCount'] = 1; - x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32) - - for kkk in range(SlicesZ): - - x_id, x_temp[kkk] = \ - astra.creators.create_backprojection3d_gpu( - residual[kkk:kkk+1], - proj_geomT, vol_geomT) - astra.matlab.data3d('delete', x_id) - else: - x_id, x_temp = \ - astra.creators.create_backprojection3d_gpu( - residual, proj_geom, vol_geom) - - X = X_t - (1/L_const) * x_temp - #astra.matlab.data3d('delete', sino_id) - astra.matlab.data3d('delete', x_id) - - def regularize(self, X): - print ("FISTA Reconstructor: regularize") - - regularizer = self.getParameter('regularizer') - if regularizer is not None: - return regularizer(input=X) - else: - return X - - def updateLoop(self, i, X, X_old, r_old, t, t_old): - print ("FISTA Reconstructor: update loop") - lambdaR_L1 = self.getParameter('ring_lambda_R_L1') - if lambdaR_L1 > 0: - self.r = numpy.max( - numpy.abs(self.r) - lambdaR_L1 , 0) * \ - numpy.sign(self.r) - t = (1 + numpy.sqrt(1 + 4 * t**2))/2 - X_t = X + (((t_old -1)/t) * (X - X_old)) - - if lambdaR_L1 > 0: - self.r_x = self.r + \ - (((t_old-1)/t) * (self.r - r_old)) - - if self.getParameter('region_of_interest') is None: - string = 'Iteration Number {0} | Objective {1} \n' - print (string.format( i, self.objective[i])) - else: - ROI , X_ideal = fistaRecon.getParameter('region_of_interest', - 'ideal_image') - - Resid_error[i] = RMSE(X*ROI, X_ideal*ROI) - string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' - print (string.format(i,Resid_error[i], self.objective[i])) - return (X , X_t, t) - - def os_iterate(self, Xin=None): - print ("FISTA Reconstructor: iterate") - - if Xin is None: - if self.getParameter('initialize'): - X = self.initialize() - else: - N = vol_geom['GridColCount'] - X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) - else: - # copy by reference - X = Xin - # store the output volume in the parameters - self.setParameter(output_volume=X) - X_t = X.copy() - - # some useful constants - proj_geom , vol_geom, sino , \ - SlicesZ, weights , alpha_ring , - lambdaR_L1 , L_const = self.getParameter( - ['projector_geometry' , 'output_geometry', - 'input_sinogram', 'SlicesZ' , 'weights', 'ring_alpha' , - 'ring_lambda_R_L1', 'Lipschitz_constant']) diff --git a/src/Python/ccpi/fista/Reconstructor.py b/src/Python/ccpi/fista/Reconstructor.py deleted file mode 100644 index d29ac0d..0000000 --- a/src/Python/ccpi/fista/Reconstructor.py +++ /dev/null @@ -1,425 +0,0 @@ -# -*- coding: utf-8 -*- -############################################################################### -#This work is part of the Core Imaging Library developed by -#Visual Analytics and Imaging System Group of the Science Technology -#Facilities Council, STFC -# -#Copyright 2017 Edoardo Pasca, Srikanth Nagella -#Copyright 2017 Daniil Kazantsev -# -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -#http://www.apache.org/licenses/LICENSE-2.0 -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. -############################################################################### - - - -import numpy -import h5py -from ccpi.reconstruction.parallelbeam import alg - -from Regularizer import Regularizer -from enum import Enum - -import astra - - - -class FISTAReconstructor(): - '''FISTA-based reconstruction algorithm using ASTRA-toolbox - - ''' - # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> - # ___Input___: - # params.[] file: - # - .proj_geom (geometry of the projector) [required] - # - .vol_geom (geometry of the reconstructed object) [required] - # - .sino (vectorized in 2D or 3D sinogram) [required] - # - .iterFISTA (iterations for the main loop, default 40) - # - .L_const (Lipschitz constant, default Power method) ) - # - .X_ideal (ideal image, if given) - # - .weights (statisitcal weights, size of the sinogram) - # - .ROI (Region-of-interest, only if X_ideal is given) - # - .initialize (a 'warm start' using SIRT method from ASTRA) - #----------------Regularization choices------------------------ - # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) - # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) - # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) - # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) - # - .Regul_Iterations (iterations for the selected penalty, default 25) - # - .Regul_tauLLT (time step parameter for LLT term) - # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) - # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) - #----------------Visualization parameters------------------------ - # - .show (visualize reconstruction 1/0, (0 default)) - # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) - # - .slice (for 3D volumes - slice number to imshow) - # ___Output___: - # 1. X - reconstructed image/volume - # 2. output - a structure with - # - .Resid_error - residual error (if X_ideal is given) - # - .objective: value of the objective function - # - .L_const: Lipshitz constant to avoid recalculations - - # References: - # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse - # Problems" by A. Beck and M Teboulle - # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo - # 3. "A novel tomographic reconstruction method based on the robust - # Student's t function for suppressing data outliers" D. Kazantsev et.al. - # D. Kazantsev, 2016-17 - def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): - self.params = dict() - self.params['projector_geometry'] = projector_geometry - self.params['output_geometry'] = output_geometry - self.params['input_sinogram'] = input_sinogram - detectors, nangles, sliceZ = numpy.shape(input_sinogram) - self.params['detectors'] = detectors - self.params['number_og_angles'] = nangles - self.params['SlicesZ'] = sliceZ - - # Accepted input keywords - kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , - 'weights' , 'region_of_interest' , 'initialize' , - 'regularizer' , - 'ring_lambda_R_L1', - 'ring_alpha') - - # handle keyworded parameters - if kwargs is not None: - for key, value in kwargs.items(): - if key in kw: - #print("{0} = {1}".format(key, value)) - self.pars[key] = value - - # set the default values for the parameters if not set - if 'number_of_iterations' in kwargs.keys(): - self.pars['number_of_iterations'] = kwargs['number_of_iterations'] - else: - self.pars['number_of_iterations'] = 40 - if 'weights' in kwargs.keys(): - self.pars['weights'] = kwargs['weights'] - else: - self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) - if 'Lipschitz_constant' in kwargs.keys(): - self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] - else: - self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() - - if not self.pars['ideal_image'] in kwargs.keys(): - self.pars['ideal_image'] = None - - if not self.pars['region_of_interest'] : - if self.pars['ideal_image'] == None: - pass - else: - self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) - - if not self.pars['regularizer'] : - self.pars['regularizer'] = None - else: - # the regularizer must be a correctly instantiated object - if not self.pars['ring_lambda_R_L1']: - self.pars['ring_lambda_R_L1'] = 0 - if not self.pars['ring_alpha']: - self.pars['ring_alpha'] = 1 - - - - - def calculateLipschitzConstantWithPowerMethod(self): - ''' using Power method (PM) to establish L constant''' - - #N = params.vol_geom.GridColCount - N = self.pars['output_geometry'].GridColCount - proj_geom = self.params['projector_geometry'] - vol_geom = self.params['output_geometry'] - weights = self.pars['weights'] - SlicesZ = self.pars['SlicesZ'] - - if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): - #% for parallel geometry we can do just one slice - #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); - niter = 15;# % number of iteration for the PM - #N = params.vol_geom.GridColCount; - #x1 = rand(N,N,1); - x1 = numpy.random.rand(1,N,N) - #sqweight = sqrt(weights(:,:,1)); - sqweight = numpy.sqrt(weights.T[0]) - proj_geomT = proj_geom.copy(); - proj_geomT.DetectorRowCount = 1; - vol_geomT = vol_geom.copy(); - vol_geomT['GridSliceCount'] = 1; - - - for i in range(niter): - if i == 0: - #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); - y = sqweight * y # element wise multiplication - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id) - - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); - s = numpy.linalg.norm(x1) - ### this line? - x1 = x1/s; - ### this line? - sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - y = sqweight*y; - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); - #end - del proj_geomT - del vol_geomT - else - #% divergen beam geometry - #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); - niter = 8; #% number of iteration for PM - x1 = numpy.random.rand(SlicesZ , N , N); - #sqweight = sqrt(weights); - sqweight = numpy.sqrt(weights.T[0]) - - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id); - - for i in range(niter): - #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, - proj_geom, - vol_geom) - s = numpy.linalg.norm(x1) - ### this line? - x1 = x1/s; - ### this line? - #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); - sino_id, y = astra.creators.create_sino3d_gpu(x1, - proj_geom, - vol_geom); - - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - #astra_mex_data3d('delete', id); - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); - #end - #clear x1 - del x1 - - return s - - - def setRegularizer(self, regularizer): - if regularizer - self.pars['regularizer'] = regularizer - - - - - -def getEntry(location): - for item in nx[location].keys(): - print (item) - - -print ("Loading Data") - -##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" -####ind = [i * 1049 for i in range(360)] -#### use only 360 images -##images = 200 -##ind = [int(i * 1049 / images) for i in range(images)] -##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) - -#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" -fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" -nx = h5py.File(fname, "r") - -# the data are stored in a particular location in the hdf5 -for item in nx['entry1/tomo_entry/data'].keys(): - print (item) - -data = nx.get('entry1/tomo_entry/data/rotation_angle') -angles = numpy.zeros(data.shape) -data.read_direct(angles) -print (angles) -# angles should be in degrees - -data = nx.get('entry1/tomo_entry/data/data') -stack = numpy.zeros(data.shape) -data.read_direct(stack) -print (data.shape) - -print ("Data Loaded") - - -# Normalize -data = nx.get('entry1/tomo_entry/instrument/detector/image_key') -itype = numpy.zeros(data.shape) -data.read_direct(itype) -# 2 is dark field -darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] -dark = darks[0] -for i in range(1, len(darks)): - dark += darks[i] -dark = dark / len(darks) -#dark[0][0] = dark[0][1] - -# 1 is flat field -flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] -flat = flats[0] -for i in range(1, len(flats)): - flat += flats[i] -flat = flat / len(flats) -#flat[0][0] = dark[0][1] - - -# 0 is projection data -proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] -angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] -angle_proj = numpy.asarray (angle_proj) -angle_proj = angle_proj.astype(numpy.float32) - -# normalized data are -# norm = (projection - dark)/(flat-dark) - -def normalize(projection, dark, flat, def_val=0.1): - a = (projection - dark) - b = (flat-dark) - with numpy.errstate(divide='ignore', invalid='ignore'): - c = numpy.true_divide( a, b ) - c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 - return c - - -norm = [normalize(projection, dark, flat) for projection in proj] -norm = numpy.asarray (norm) -norm = norm.astype(numpy.float32) - -#recon = Reconstructor(algorithm = Algorithm.CGLS, normalized_projection = norm, -# angles = angle_proj, center_of_rotation = 86.2 , -# flat_field = flat, dark_field = dark, -# iterations = 15, resolution = 1, isLogScale = False, threads = 3) - -#recon = Reconstructor(algorithm = Reconstructor.Algorithm.CGLS, projection_data = proj, -# angles = angle_proj, center_of_rotation = 86.2 , -# flat_field = flat, dark_field = dark, -# iterations = 15, resolution = 1, isLogScale = False, threads = 3) -#img_cgls = recon.reconstruct() -# -#pars = dict() -#pars['algorithm'] = Reconstructor.Algorithm.SIRT -#pars['projection_data'] = proj -#pars['angles'] = angle_proj -#pars['center_of_rotation'] = numpy.double(86.2) -#pars['flat_field'] = flat -#pars['iterations'] = 15 -#pars['dark_field'] = dark -#pars['resolution'] = 1 -#pars['isLogScale'] = False -#pars['threads'] = 3 -# -#img_sirt = recon.reconstruct(pars) -# -#recon.pars['algorithm'] = Reconstructor.Algorithm.MLEM -#img_mlem = recon.reconstruct() - -############################################################ -############################################################ -#recon.pars['algorithm'] = Reconstructor.Algorithm.CGLS_CONV -#recon.pars['regularize'] = numpy.double(0.1) -#img_cgls_conv = recon.reconstruct() - -niterations = 15 -threads = 3 - -img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) - -iteration_values = numpy.zeros((niterations,)) -img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - iteration_values, False) -print ("iteration values %s" % str(iteration_values)) - -iteration_values = numpy.zeros((niterations,)) -img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - numpy.double(1e-5), iteration_values , False) -print ("iteration values %s" % str(iteration_values)) -iteration_values = numpy.zeros((niterations,)) -img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - numpy.double(1e-5), iteration_values , False) -print ("iteration values %s" % str(iteration_values)) - - -##numpy.save("cgls_recon.npy", img_data) -import matplotlib.pyplot as plt -fig, ax = plt.subplots(1,6,sharey=True) -ax[0].imshow(img_cgls[80]) -ax[0].axis('off') # clear x- and y-axes -ax[1].imshow(img_sirt[80]) -ax[1].axis('off') # clear x- and y-axes -ax[2].imshow(img_mlem[80]) -ax[2].axis('off') # clear x- and y-axesplt.show() -ax[3].imshow(img_cgls_conv[80]) -ax[3].axis('off') # clear x- and y-axesplt.show() -ax[4].imshow(img_cgls_tikhonov[80]) -ax[4].axis('off') # clear x- and y-axesplt.show() -ax[5].imshow(img_cgls_TVreg[80]) -ax[5].axis('off') # clear x- and y-axesplt.show() - - -plt.show() - -#viewer = edo.CILViewer() -#viewer.setInputAsNumpy(img_cgls2) -#viewer.displaySliceActor(0) -#viewer.startRenderLoop() - -import vtk - -def NumpyToVTKImageData(numpyarray): - if (len(numpy.shape(numpyarray)) == 3): - doubleImg = vtk.vtkImageData() - shape = numpy.shape(numpyarray) - doubleImg.SetDimensions(shape[0], shape[1], shape[2]) - doubleImg.SetOrigin(0,0,0) - doubleImg.SetSpacing(1,1,1) - doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) - #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) - doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) - - for i in range(shape[0]): - for j in range(shape[1]): - for k in range(shape[2]): - doubleImg.SetScalarComponentFromDouble( - i,j,k,0, numpyarray[i][j][k]) - #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) - # rescale to appropriate VTK_UNSIGNED_SHORT - stats = vtk.vtkImageAccumulate() - stats.SetInputData(doubleImg) - stats.Update() - iMin = stats.GetMin()[0] - iMax = stats.GetMax()[0] - scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) - - shiftScaler = vtk.vtkImageShiftScale () - shiftScaler.SetInputData(doubleImg) - shiftScaler.SetScale(scale) - shiftScaler.SetShift(iMin) - shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) - shiftScaler.Update() - return shiftScaler.GetOutput() - -#writer = vtk.vtkMetaImageWriter() -#writer.SetFileName(alg + "_recon.mha") -#writer.SetInputData(NumpyToVTKImageData(img_cgls2)) -#writer.Write() diff --git a/src/Python/ccpi/fista/__init__.py b/src/Python/ccpi/fista/__init__.py deleted file mode 100644 index e69de29..0000000 -- cgit v1.2.3 From cf741b21f5a66d4b6157bef401a8ca240d8702b8 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Tue, 24 Oct 2017 12:59:53 +0100 Subject: fix wrong indentation --- src/Python/ccpi/reconstruction/FISTAReconstructor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'src/Python/ccpi') diff --git a/src/Python/ccpi/reconstruction/FISTAReconstructor.py b/src/Python/ccpi/reconstruction/FISTAReconstructor.py index 85bfac5..f43966c 100644 --- a/src/Python/ccpi/reconstruction/FISTAReconstructor.py +++ b/src/Python/ccpi/reconstruction/FISTAReconstructor.py @@ -602,7 +602,7 @@ class FISTAReconstructor(): # some useful constants proj_geom , vol_geom, sino , \ - SlicesZ, weights , alpha_ring , + SlicesZ, weights , alpha_ring ,\ lambdaR_L1 , L_const = self.getParameter( ['projector_geometry' , 'output_geometry', 'input_sinogram', 'SlicesZ' , 'weights', 'ring_alpha' , -- cgit v1.2.3 From 455ca86825c157512f61441d3d27b8148ca795a7 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Tue, 24 Oct 2017 16:37:21 +0100 Subject: Add regularization step Add regularization step OS seems to work --- src/Python/ccpi/reconstruction/FISTAReconstructor.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'src/Python/ccpi') diff --git a/src/Python/ccpi/reconstruction/FISTAReconstructor.py b/src/Python/ccpi/reconstruction/FISTAReconstructor.py index f43966c..c903712 100644 --- a/src/Python/ccpi/reconstruction/FISTAReconstructor.py +++ b/src/Python/ccpi/reconstruction/FISTAReconstructor.py @@ -363,6 +363,9 @@ class FISTAReconstructor(): except Exception(): subsets = 0 #return subsets + else: + self.setParameter(subsets=subsets) + angles = self.getParameter('projector_geometry')['ProjectionAngles'] @@ -371,7 +374,7 @@ class FISTAReconstructor(): # subsets + 1) binsDiscr, binEdges = numpy.histogram(angles, bins=subsets) # get rearranged subset indices - IndicesReorg = numpy.zeros((numpy.shape(angles))) + IndicesReorg = numpy.zeros((numpy.shape(angles)), dtype=numpy.int32) counterM = 0 for ii in range(binsDiscr.max()): counter = 0 -- cgit v1.2.3