summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2017-10-25 10:56:57 +0100
committerEdoardo Pasca <edo.paskino@gmail.com>2017-10-25 10:56:57 +0100
commitc21df8581e052541a2dd39a46a4c1f50e335fd9e (patch)
tree5c8f755f910a9b5006791e715328242f68fbe504
parent31097954f87d0f30f667b29a12f7098710c284ab (diff)
parent455ca86825c157512f61441d3d27b8148ca795a7 (diff)
downloadregularization-c21df8581e052541a2dd39a46a4c1f50e335fd9e.tar.gz
regularization-c21df8581e052541a2dd39a46a4c1f50e335fd9e.tar.bz2
regularization-c21df8581e052541a2dd39a46a4c1f50e335fd9e.tar.xz
regularization-c21df8581e052541a2dd39a46a4c1f50e335fd9e.zip
Merge branch origin/pythonize into pythonize
-rw-r--r--CMakeLists.txt2
-rw-r--r--demos/DemoRD1.m100
-rw-r--r--demos/Demo_Phantom3D_Cone.m5
-rw-r--r--demos/Demo_Phantom3D_Parallel.m79
-rw-r--r--demos/Demo_RealData3D_Parallel.m (renamed from demos/DemoRD2.m)16
-rw-r--r--main_func/FISTA_REC.m4
-rw-r--r--main_func/regularizers_CPU/FGP_TV.c6
-rw-r--r--src/Python/CMakeLists.txt111
-rw-r--r--src/Python/ccpi/fista/FISTAReconstructor.py609
-rw-r--r--src/Python/ccpi/fista/Reconstructor.py425
-rw-r--r--src/Python/ccpi/fista/__init__.py0
-rw-r--r--src/Python/ccpi/reconstruction/FISTAReconstructor.py602
-rw-r--r--src/Python/compile-fista.bat.in7
-rw-r--r--src/Python/compile-fista.sh.in9
-rw-r--r--src/Python/conda-recipe/meta.yaml2
-rw-r--r--src/Python/fista-recipe/build.sh10
-rw-r--r--src/Python/fista-recipe/meta.yaml29
-rw-r--r--src/Python/setup-fista.py.in27
-rw-r--r--src/Python/setup.py.in4
-rw-r--r--src/Python/test/test_reconstructor-os.py (renamed from src/Python/test_reconstructor-os.py)31
-rw-r--r--src/Python/test/test_reconstructor.py309
-rw-r--r--supp/sino_add_artifacts.m33
22 files changed, 1058 insertions, 1362 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index d05cdd9..4cfad7e 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -25,6 +25,6 @@ set (CIL_VERSION_MAJOR 0)
set (CIL_VERSION_MINOR 9)
set (CIL_VERSION_PATCH 1)
-set (CIL_VERSION '${CIL_VERSION_MAJOR}.${CIL_VERSION_MINOR}.${CIL_VERSION_PATCH}')
+set (CIL_VERSION '${CIL_VERSION_MAJOR}.${CIL_VERSION_MINOR}.${CIL_VERSION_PATCH}' CACHE INTERNAL "Core Imaging Library version" FORCE)
add_subdirectory(src)
diff --git a/demos/DemoRD1.m b/demos/DemoRD1.m
deleted file mode 100644
index 5bb5f6b..0000000
--- a/demos/DemoRD1.m
+++ /dev/null
@@ -1,100 +0,0 @@
-% Demonstration of tomographic reconstruction from neutron tomography
-% dataset (basalt sample) using Student t data fidelity
-clear all
-close all
-
-% adding paths
-addpath('../data/');
-addpath('../main_func/'); addpath('../main_func/regularizers_CPU/');
-addpath('../supp/');
-
-load('sino_basalt.mat') % load real neutron data
-
-size_det = size(sino_basalt, 1); % detector size
-angSize = size(sino_basalt,2); % angles dim
-recon_size = 650; % reconstruction size
-
-FBP = iradon(sino_basalt, rad2deg(angles),recon_size);
-figure; imshow(FBP , [0, 0.45]); title ('FBP reconstruction');
-%%
-% set projection/reconstruction geometry here
-Z_slices = 1;
-det_row_count = Z_slices;
-proj_geom = astra_create_proj_geom('parallel3d', 1, 1, det_row_count, size_det, angles);
-vol_geom = astra_create_vol_geom(recon_size,recon_size,Z_slices);
-%%
-fprintf('%s\n', 'Reconstruction using FISTA-LS without regularization...');
-clear params
-params.proj_geom = proj_geom; % pass geometry to the function
-params.vol_geom = vol_geom;
-params.sino = sino_basalt;
-params.iterFISTA = 50;
-params.show = 0;
-params.maxvalplot = 0.6; params.slice = 1;
-
-tic; [X_fista] = FISTA_REC(params); toc;
-figure; imshow(X_fista , [0, 0.45]); title ('FISTA-LS reconstruction');
-%%
-fprintf('%s\n', 'Reconstruction using FISTA-LS-TV...');
-clear params
-params.proj_geom = proj_geom; % pass geometry to the function
-params.vol_geom = vol_geom;
-params.sino = sino_basalt;
-params.iterFISTA = 60;
-params.Regul_LambdaTV = 0.0003; % TV regularization parameter
-params.show = 0;
-params.maxvalplot = 0.6; params.slice = 1;
-
-tic; [X_fista_TV] = FISTA_REC(params); toc;
-figure; imshow(X_fista_TV , [0, 0.45]); title ('FISTA-LS-TV reconstruction');
-%%
-%%
-fprintf('%s\n', 'Reconstruction using FISTA-GH-TV...');
-clear params
-params.proj_geom = proj_geom; % pass geometry to the function
-params.vol_geom = vol_geom;
-params.sino = sino_basalt;
-params.iterFISTA = 60;
-params.Regul_LambdaTV = 0.0003; % TV regularization parameter
-params.Ring_LambdaR_L1 = 0.001; % Soft-Thresh L1 ring variable parameter
-params.Ring_Alpha = 20; % acceleration for ring variable
-params.show = 0;
-params.maxvalplot = 0.6; params.slice = 1;
-
-tic; [X_fista_GH_TV] = FISTA_REC(params); toc;
-figure; imshow(X_fista_GH_TV , [0, 0.45]); title ('FISTA-GH-TV reconstruction');
-%%
-%%
-fprintf('%s\n', 'Reconstruction using FISTA-Student-TV...');
-clear params
-params.proj_geom = proj_geom; % pass geometry to the function
-params.vol_geom = vol_geom;
-params.sino = sino_basalt;
-params.iterFISTA = 50;
-params.L_const = 3500; % Lipshitz constant
-params.Regul_LambdaTV = 0.0003; % TV regularization parameter
-params.fidelity = 'student'; % choosing Student t penalty
-params.show = 1;
-params.initialize = 1; % warm start with SIRT
-params.maxvalplot = 0.6; params.slice = 1;
-
-tic; [X_fistaStudentTV] = FISTA_REC(params); toc;
-figure; imshow(X_fistaStudentTV , [0, 0.45]); title ('FISTA-Student-TV reconstruction');
-%%
-
-fprintf('%s\n', 'Segmentation using OTSU method ...');
-level = graythresh(X_fista);
-Segm_FISTA = im2bw(X_fista,level);
-figure; imshow(Segm_FISTA, []); title ('Segmented FISTA-LS reconstruction');
-
-level = graythresh(X_fista_TV);
-Segm_FISTA_TV = im2bw(X_fista_TV,level);
-figure; imshow(Segm_FISTA_TV, []); title ('Segmented FISTA-LS-TV reconstruction');
-
-level = graythresh(X_fista_GH_TV);
-BW_FISTA_GH_TV = im2bw(X_fista_GH_TV,level);
-figure; imshow(BW_FISTA_GH_TV, []); title ('Segmented FISTA-GH-TV reconstruction');
-
-level = graythresh(X_fistaStudentTV);
-BW_FISTA_Student_TV = im2bw(X_fistaStudentTV,level);
-figure; imshow(BW_FISTA_Student_TV, []); title ('Segmented FISTA-Student-LS reconstruction'); \ No newline at end of file
diff --git a/demos/Demo_Phantom3D_Cone.m b/demos/Demo_Phantom3D_Cone.m
index 3a9178b..a8f2c92 100644
--- a/demos/Demo_Phantom3D_Cone.m
+++ b/demos/Demo_Phantom3D_Cone.m
@@ -8,7 +8,6 @@ addpath('../data/');
addpath('../main_func/'); addpath('../main_func/regularizers_CPU/'); addpath('../main_func/regularizers_GPU/NL_Regul/'); addpath('../main_func/regularizers_GPU/Diffus_HO/');
addpath('../supp/');
-
%%
% build 3D phantom using TomoPhantom
modelNo = 3; % see Phantom3DLibrary.dat file in TomoPhantom
@@ -16,9 +15,11 @@ N = 256; % x-y-z size (cubic image)
angles = 0:1.5:360; % angles vector in degrees
angles_rad = angles*(pi/180); % conversion to radians
det_size = round(sqrt(2)*N); % detector size
-% in order to run functions you have to go to the directory:
+
+%---------TomoPhantom routines---------%
pathTP = '/home/algol/Documents/MATLAB/TomoPhantom/functions/models/Phantom3DLibrary.dat'; % path to TomoPhantom parameters file
TomoPhantom = buildPhantom3D(modelNo,N,pathTP); % generate 3D phantom
+%--------------------------------------%
%%
% using ASTRA-toolbox to set the projection geometry (cone beam)
% eg: astra.create_proj_geom('cone', 1.0 (resol), 1.0 (resol), detectorRowCount, detectorColCount, angles, originToSource, originToDetector)
diff --git a/demos/Demo_Phantom3D_Parallel.m b/demos/Demo_Phantom3D_Parallel.m
index 6a54450..402bdd2 100644
--- a/demos/Demo_Phantom3D_Parallel.m
+++ b/demos/Demo_Phantom3D_Parallel.m
@@ -9,48 +9,74 @@ addpath('../main_func/'); addpath('../main_func/regularizers_CPU/'); addpath('..
addpath('../supp/');
%%
-% build 3D phantom using TomoPhantom and generate projection data
+% Main reconstruction/data generation parameters
modelNo = 2; % see Phantom3DLibrary.dat file in TomoPhantom
N = 256; % x-y-z size (cubic image)
angles = 1:0.5:180; % angles vector in degrees
angles_rad = angles*(pi/180); % conversion to radians
det_size = round(sqrt(2)*N); % detector size
-% in order to run functions you have to go to the directory:
+
+%---------TomoPhantom routines---------%
pathTP = '/home/algol/Documents/MATLAB/TomoPhantom/functions/models/Phantom3DLibrary.dat'; % path to TomoPhantom parameters file
TomoPhantom = buildPhantom3D(modelNo,N,pathTP); % generate 3D phantom
sino_tomophan3D = buildSino3D(modelNo, N, det_size, single(angles),pathTP); % generate ideal data
+%--------------------------------------%
% Adding noise and distortions if required
-sino_artifacts = sino_add_artifacts(sino_tomophan3D,'rings');
+sino_tomophan3D = sino_add_artifacts(sino_tomophan3D,'rings');
+% adding Poisson noise
+dose = 3e9; % photon flux (controls noise level)
+multifactor = max(sino_tomophan3D(:));
+dataExp = dose.*exp(-sino_tomophan3D/multifactor); % noiseless raw data
+dataRaw = astra_add_noise_to_sino(dataExp, dose); % pre-log noisy raw data (weights)
+sino3D_log = log(dose./max(dataRaw,1))*multifactor; %log corrected data -> sinogram
+clear dataExp sino_tomophan3D
%
%%
+%-------------Astra toolbox------------%
+% one can generate data using ASTRA toolbox
+proj_geom = astra_create_proj_geom('parallel', 1, det_size, angles_rad);
+vol_geom = astra_create_vol_geom(N,N);
+sino_ASTRA3D = zeros(det_size, length(angles), N, 'single');
+for i = 1:N
+[sino_id, sinoT] = astra_create_sino_cuda(TomoPhantom(:,:,i), proj_geom, vol_geom);
+sino_ASTRA3D(:,:,i) = sinoT';
+astra_mex_data2d('delete', sino_id);
+end
+%--------------------------------------%
+%%
% using ASTRA-toolbox to set the projection geometry (parallel beam)
proj_geom = astra_create_proj_geom('parallel', 1, det_size, angles_rad);
vol_geom = astra_create_vol_geom(N,N);
%%
fprintf('%s\n', 'Reconstructing with FBP using ASTRA-toolbox ...');
-for i = 1:k
+reconASTRA_3D = zeros(size(TomoPhantom),'single');
+for k = 1:N
vol_id = astra_mex_data2d('create', '-vol', vol_geom, 0);
-proj_id = astra_mex_data2d('create', '-proj3d', proj_geom, sino_artifacts(:,:,k));
+proj_id = astra_mex_data2d('create', '-sino', proj_geom, sino3D_log(:,:,k)');
cfg = astra_struct('FBP_CUDA');
cfg.ProjectionDataId = proj_id;
cfg.ReconstructionDataId = vol_id;
cfg.option.MinConstraint = 0;
alg_id = astra_mex_algorithm('create', cfg);
-astra_mex_algorithm('iterate', alg_id, 15);
-reconASTRA_3D = astra_mex_data2d('get', vol_id);
+astra_mex_algorithm('iterate', alg_id, 1);
+rec = astra_mex_data2d('get', vol_id);
+reconASTRA_3D(:,:,k) = single(rec);
end
+figure; imshow(reconASTRA_3D(:,:,128), [0 1.3]);
%%
-fprintf('%s\n', 'Reconstruction using FISTA-LS without regularization...');
+%%
+fprintf('%s\n', 'Reconstruction using OS-FISTA-PWLS without regularization...');
clear params
% define parameters
params.proj_geom = proj_geom; % pass geometry to the function
params.vol_geom = vol_geom;
-params.sino = single(sino_tomophan3D); % sinogram
-params.iterFISTA = 5; %max number of outer iterations
+params.sino = single(sino3D_log); % sinogram
+params.iterFISTA = 12; %max number of outer iterations
params.X_ideal = TomoPhantom; % ideal phantom
+params.weights = dataRaw./max(dataRaw(:)); % statistical weight for PWLS
+params.subsets = 12; % the number of subsets
params.show = 1; % visualize reconstruction on each iteration
-params.subsets = 12;
-params.slice = round(N/2); params.maxvalplot = 1;
+params.slice = 1; params.maxvalplot = 1.3;
tic; [X_FISTA, output] = FISTA_REC(params); toc;
error_FISTA = output.Resid_error; obj_FISTA = output.objective;
@@ -63,4 +89,33 @@ subplot(1,2,2); imshow(Resid3D(:,:,params.slice),[0 0.1]); title('residual'); c
figure(3);
subplot(1,2,1); plot(error_FISTA); title('RMSE plot');
subplot(1,2,2); plot(obj_FISTA); title('Objective plot');
+%%
+%%
+fprintf('%s\n', 'Reconstruction using OS-FISTA-GH without FGP-TV regularization...');
+clear params
+% define parameters
+params.proj_geom = proj_geom; % pass geometry to the function
+params.vol_geom = vol_geom;
+params.sino = single(sino3D_log); % sinogram
+params.iterFISTA = 15; %max number of outer iterations
+params.X_ideal = TomoPhantom; % ideal phantom
+params.weights = dataRaw./max(dataRaw(:)); % statistical weight for PWLS
+params.subsets = 8; % the number of subsets
+params.Regul_Lambda_FGPTV = 0.003; % TV regularization parameter for FGP-TV
+params.Ring_LambdaR_L1 = 0.02; % Soft-Thresh L1 ring variable parameter
+params.Ring_Alpha = 21; % to boost ring removal procedure
+params.show = 1; % visualize reconstruction on each iteration
+params.slice = 1; params.maxvalplot = 1.3;
+tic; [X_FISTA_GH_TV, output] = FISTA_REC(params); toc;
+
+error_FISTA_GH_TV = output.Resid_error; obj_FISTA_GH_TV = output.objective;
+fprintf('%s %.4f\n', 'Min RMSE for FISTA-PWLS reconstruction is:', min(error_FISTA_GH_TV(:)));
+
+Resid3D = (TomoPhantom - X_FISTA_GH_TV).^2;
+figure(2);
+subplot(1,2,1); imshow(X_FISTA_GH_TV(:,:,params.slice),[0 params.maxvalplot]); title('FISTA-LS reconstruction'); colorbar;
+subplot(1,2,2); imshow(Resid3D(:,:,params.slice),[0 0.1]); title('residual'); colorbar;
+figure(3);
+subplot(1,2,1); plot(error_FISTA_GH_TV); title('RMSE plot');
+subplot(1,2,2); plot(obj_FISTA_GH_TV); title('Objective plot');
%% \ No newline at end of file
diff --git a/demos/DemoRD2.m b/demos/Demo_RealData3D_Parallel.m
index 717a55d..e4c9eb0 100644
--- a/demos/DemoRD2.m
+++ b/demos/Demo_RealData3D_Parallel.m
@@ -11,12 +11,12 @@ addpath('../supp/');
load('DendrRawData.mat') % load raw data of 3D dendritic set
angles_rad = angles*(pi/180); % conversion to radians
-size_det = size(data_raw3D,1); % detectors dim
+det_size = size(data_raw3D,1); % detectors dim
angSize = size(data_raw3D, 2); % angles dim
slices_tot = size(data_raw3D, 3); % no of slices
recon_size = 950; % reconstruction size
-Sino3D = zeros(size_det, angSize, slices_tot, 'single'); % log-corrected sino
+Sino3D = zeros(det_size, angSize, slices_tot, 'single'); % log-corrected sino
% normalizing the data
for jj = 1:slices_tot
sino = data_raw3D(:,:,jj);
@@ -30,10 +30,8 @@ Weights3D = single(data_raw3D); % weights for PW model
clear data_raw3D
%%
% set projection/reconstruction geometry here
-Z_slices = 5;
-det_row_count = Z_slices;
-proj_geom = astra_create_proj_geom('parallel3d', 1, 1, det_row_count, size_det, angles_rad);
-vol_geom = astra_create_vol_geom(recon_size,recon_size,Z_slices);
+proj_geom = astra_create_proj_geom('parallel', 1, det_size, angles_rad);
+vol_geom = astra_create_vol_geom(recon_size,recon_size);
%%
fprintf('%s\n', 'Reconstruction using FBP...');
FBP = iradon(Sino3D(:,:,10), angles,recon_size);
@@ -50,7 +48,7 @@ params.iterFISTA = 12;
params.weights = Weights3D;
params.subsets = 16; % the number of ordered subsets
params.show = 1;
-params.maxvalplot = 2.5; params.slice = 2;
+params.maxvalplot = 2.5; params.slice = 1;
tic; [X_fista, outputFISTA] = FISTA_REC(params); toc;
figure; imshow(X_fista(:,:,params.slice) , [0, 2.5]); title ('FISTA-OS-PWLS reconstruction');
@@ -76,13 +74,13 @@ params.proj_geom = proj_geom; % pass geometry to the function
params.vol_geom = vol_geom;
params.sino = Sino3D;
params.iterFISTA = 12;
-params.Regul_Lambda_FGPTV = 0.005; % TV regularization parameter for FGP-TV
+% params.Regul_Lambda_FGPTV = 0.005; % TV regularization parameter for FGP-TV
params.Ring_LambdaR_L1 = 0.002; % Soft-Thresh L1 ring variable parameter
params.Ring_Alpha = 21; % to boost ring removal procedure
params.weights = Weights3D;
params.subsets = 16; % the number of ordered subsets
params.show = 1;
-params.maxvalplot = 2.5; params.slice = 2;
+params.maxvalplot = 2.5; params.slice = 1;
tic; [X_fista_GH_TV, outputGHTV] = FISTA_REC(params); toc;
figure; imshow(X_fista_GH_TV(:,:,params.slice) , [0, 2.5]); title ('FISTA-OS-GH-TV reconstruction');
diff --git a/main_func/FISTA_REC.m b/main_func/FISTA_REC.m
index 1e4228d..3d22b97 100644
--- a/main_func/FISTA_REC.m
+++ b/main_func/FISTA_REC.m
@@ -685,6 +685,10 @@ else
counterInd = counterInd + numProjSub;
end
+ if (i == 1)
+ r_old = r;
+ end
+
% working with a 'ring vector'
if (lambdaR_L1 > 0)
r = max(abs(r)-lambdaR_L1, 0).*sign(r); % soft-thresholding operator for ring vector
diff --git a/main_func/regularizers_CPU/FGP_TV.c b/main_func/regularizers_CPU/FGP_TV.c
index 66442c9..30cea1a 100644
--- a/main_func/regularizers_CPU/FGP_TV.c
+++ b/main_func/regularizers_CPU/FGP_TV.c
@@ -66,7 +66,7 @@ void mexFunction(
A = (float *) mxGetData(prhs[0]); /*noisy image (2D/3D) */
lambda = (float) mxGetScalar(prhs[1]); /* regularization parameter */
iter = 50; /* default iterations number */
- epsil = 0.001; /* default tolerance constant */
+ epsil = 0.0001; /* default tolerance constant */
methTV = 0; /* default isotropic TV penalty */
if ((nrhs == 3) || (nrhs == 4) || (nrhs == 5)) iter = (int) mxGetScalar(prhs[2]); /* iterations number */
@@ -89,7 +89,7 @@ void mexFunction(
tk = 1.0f;
tkp1=1.0f;
- count = 1;
+ count = 0;
re_old = 0.0f;
if (number_of_dims == 2) {
@@ -128,7 +128,7 @@ void mexFunction(
}
re = sqrt(re)/sqrt(re1);
if (re < epsil) count++;
- if (count > 3) {
+ if (count > 4) {
Obj_func_CALC2D(A, D, funcvalA, lambda, dimX, dimY);
break; }
diff --git a/src/Python/CMakeLists.txt b/src/Python/CMakeLists.txt
index b464059..66630cb 100644
--- a/src/Python/CMakeLists.txt
+++ b/src/Python/CMakeLists.txt
@@ -12,21 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-message("CIL VERSION " ${CIL_VERSION})
-
-
# variables that must be set for conda compilation
#PREFIX=C:\Apps\Miniconda2\envs\cil\Library
#LIBRARY_INC=C:\\Apps\\Miniconda2\\envs\\cil\\Library\\include
set (NUMPY_VERSION 1.12)
-#set (PYTHON_VERSION 3.5)
-
-#https://groups.google.com/a/continuum.io/forum/#!topic/anaconda/R9gWjl09UFs
-#set (CONDA_ENVIRONMENT "cil")
## Tries to parse the output of conda env list to determine the current
## active conda environment
+message ("Trying to determine your active conda environment...")
execute_process(COMMAND "conda" "env" "list"
OUTPUT_VARIABLE _CONDA_ENVS
RESULT_VARIABLE _CONDA_RESULT
@@ -44,7 +38,7 @@ execute_process(COMMAND "conda" "env" "list"
endif()
endforeach()
else()
- message("conda result false" ${_CONDA_ERR})
+ message(FATAL_ERROR "ERROR with conda command " ${_CONDA_ERR})
endif()
if (${CONDA_ENVIRONMENT} AND ${CONDA_ENVIRONMENT_PATH})
@@ -55,24 +49,28 @@ else()
message("Using current conda environmnet path " ${CONDA_ENVIRONMENT_PATH})
endif()
-
-
message("CIL VERSION " ${CIL_VERSION})
# set the Python variables for the Conda environment
include(FindAnacondaEnvironment.cmake)
findPythonForAnacondaEnvironment(${CONDA_ENVIRONMENT_PATH})
+
message("Python found " ${PYTHON_VERSION_STRING})
message("Python found Major " ${PYTHON_VERSION_MAJOR})
message("Python found Minor " ${PYTHON_VERSION_MINOR})
+
findPythonPackagesPath()
message("PYTHON_PACKAGES_FOUND " ${PYTHON_PACKAGES_PATH})
-# copy the Pyhon files of the package
-file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging/)
-file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/__init__.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi)
-file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/imaging/__init__.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging)
-file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/imaging/Regularizer.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging)
+## CACHE SOME VARIABLES ##
+set (CONDA_ENVIRONMENT ${CONDA_ENVIRONMENT} CACHE INTERNAL "active conda environment" FORCE)
+set (CONDA_ENVIRONMENT_PATH ${CONDA_ENVIRONMENT_PATH} CACHE INTERNAL "active conda environment" FORCE)
+
+set (PYTHON_VERSION_STRING ${PYTHON_VERSION_STRING} CACHE INTERNAL "conda environment Python version string" FORCE)
+set (PYTHON_VERSION_MAJOR ${PYTHON_VERSION_MAJOR} CACHE INTERNAL "conda environment Python version major" FORCE)
+set (PYTHON_VERSION_MINOR ${PYTHON_VERSION_MINOR} CACHE INTERNAL "conda environment Python version minor" FORCE)
+set (PYTHON_VERSION_PATCH ${PYTHON_VERSION_PATCH} CACHE INTERNAL "conda environment Python version patch" FORCE)
+set (PYTHON_PACKAGES_PATH ${PYTHON_PACKAGES_PATH} CACHE INTERNAL "conda environment Python packages path" FORCE)
if (WIN32)
#set (CONDA_ENVIRONMENT_PATH "C:\\Apps\\Miniconda2\\envs\\${CONDA_ENVIRONMENT}" CACHE PATH "Main environment directory")
@@ -84,21 +82,92 @@ elseif (UNIX)
set (CONDA_ENVIRONMENT_LIBRARY_INC "${CONDA_ENVIRONMENT_PREFIX}/include" CACHE PATH "env dir")
endif()
+######### CONFIGURE REGULARIZER PACKAGE #############
+
+# copy the Pyhon files of the package regularizer
+file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging/)
+file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/__init__.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi)
+# regularizers
+file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/imaging/__init__.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging)
+file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/imaging/Regularizer.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging)
+
# Copy and configure the relative conda build and recipes
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in ${CMAKE_CURRENT_BINARY_DIR}/setup.py)
file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe)
file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/conda-recipe/meta.yaml DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe)
if (WIN32)
- file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/conda-recipe/bld.bat DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe/)
- configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile.bat.in ${CMAKE_CURRENT_BINARY_DIR}/compile.bat)
+
+ file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/conda-recipe/bld.bat DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe/)
+ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile.bat.in ${CMAKE_CURRENT_BINARY_DIR}/compile.bat)
+
+elseif(UNIX)
+
+ message ("We are on UNIX")
+ file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/conda-recipe/build.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe/)
+ # assumes we will use bash
+ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile.sh.in ${CMAKE_CURRENT_BINARY_DIR}/compile.sh)
+
+endif()
+
+########## CONFIGURE FISTA RECONSTRUCTOR PACKAGE
+# fista reconstructor
+file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/reconstruction/FISTAReconstructor.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/reconstruction)
+file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/reconstruction/__init__.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/reconstruction)
+
+configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup-fista.py.in ${CMAKE_CURRENT_BINARY_DIR}/setup-fista.py)
+file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/fista-recipe)
+file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/fista-recipe/meta.yaml DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/fista-recipe)
+
+if (WIN32)
+
+ file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/fista-recipe/bld.bat DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/fista-recipe/)
+ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile-fista.bat.in ${CMAKE_CURRENT_BINARY_DIR}/compile-fista.bat)
+
elseif(UNIX)
- message ("We are on UNIX")
- file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/conda-recipe/build.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe/)
- # assumes we will use bash
- configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile.sh.in ${CMAKE_CURRENT_BINARY_DIR}/compile.sh)
+ message ("We are on UNIX")
+ file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/fista-recipe/build.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/fista-recipe/)
+ # assumes we will use bash
+ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile-fista.sh.in ${CMAKE_CURRENT_BINARY_DIR}/compile-fista.sh)
endif()
+############################# TARGETS
+
+########################## REGULARIZER PACKAGE ###############################
+
+# runs cmake on the build tree to update the code from source
+add_custom_target(update_code
+ COMMAND ${CMAKE_COMMAND}
+ ARGS ${CMAKE_SOURCE_DIR}
+ WORKING_DIRECTORY ${CMAKE_BINARY_DIR}
+ )
+
+
+add_custom_target(fista
+ COMMAND bash
+ compile-fista.sh
+ WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
+ DEPENDS ${update_code}
+ )
+
+add_custom_target(regularizers
+ COMMAND bash
+ compile.sh
+ WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
+ DEPENDS ${update_code}
+ )
+
+add_custom_target(install-fista
+ COMMAND conda
+ install --force --use-local ccpi-fista=${CIL_VERSION} -c ccpi -c conda-forge
+ WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
+ DEPENDS ${fista})
+
+add_custom_target(install-regularizers
+ COMMAND conda
+ install --force --use-local ccpi-regularizers=${CIL_VERSION} -c ccpi -c conda-forge
+ WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
+ DEPENDS ${fista})
### add tests
#add_executable(RegularizersTest )
diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py
deleted file mode 100644
index 85bfac5..0000000
--- a/src/Python/ccpi/fista/FISTAReconstructor.py
+++ /dev/null
@@ -1,609 +0,0 @@
-# -*- coding: utf-8 -*-
-###############################################################################
-#This work is part of the Core Imaging Library developed by
-#Visual Analytics and Imaging System Group of the Science Technology
-#Facilities Council, STFC
-#
-#Copyright 2017 Edoardo Pasca, Srikanth Nagella
-#Copyright 2017 Daniil Kazantsev
-#
-#Licensed under the Apache License, Version 2.0 (the "License");
-#you may not use this file except in compliance with the License.
-#You may obtain a copy of the License at
-#http://www.apache.org/licenses/LICENSE-2.0
-#Unless required by applicable law or agreed to in writing, software
-#distributed under the License is distributed on an "AS IS" BASIS,
-#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-#See the License for the specific language governing permissions and
-#limitations under the License.
-###############################################################################
-
-
-
-import numpy
-#from ccpi.reconstruction.parallelbeam import alg
-
-#from ccpi.imaging.Regularizer import Regularizer
-from enum import Enum
-
-import astra
-
-
-
-class FISTAReconstructor():
- '''FISTA-based reconstruction algorithm using ASTRA-toolbox
-
- '''
- # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>>
- # ___Input___:
- # params.[] file:
- # - .proj_geom (geometry of the projector) [required]
- # - .vol_geom (geometry of the reconstructed object) [required]
- # - .sino (vectorized in 2D or 3D sinogram) [required]
- # - .iterFISTA (iterations for the main loop, default 40)
- # - .L_const (Lipschitz constant, default Power method) )
- # - .X_ideal (ideal image, if given)
- # - .weights (statisitcal weights, size of the sinogram)
- # - .ROI (Region-of-interest, only if X_ideal is given)
- # - .initialize (a 'warm start' using SIRT method from ASTRA)
- #----------------Regularization choices------------------------
- # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter)
- # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter)
- # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter)
- # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04)
- # - .Regul_Iterations (iterations for the selected penalty, default 25)
- # - .Regul_tauLLT (time step parameter for LLT term)
- # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal)
- # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1)
- #----------------Visualization parameters------------------------
- # - .show (visualize reconstruction 1/0, (0 default))
- # - .maxvalplot (maximum value to use for imshow[0 maxvalplot])
- # - .slice (for 3D volumes - slice number to imshow)
- # ___Output___:
- # 1. X - reconstructed image/volume
- # 2. output - a structure with
- # - .Resid_error - residual error (if X_ideal is given)
- # - .objective: value of the objective function
- # - .L_const: Lipshitz constant to avoid recalculations
-
- # References:
- # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse
- # Problems" by A. Beck and M Teboulle
- # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo
- # 3. "A novel tomographic reconstruction method based on the robust
- # Student's t function for suppressing data outliers" D. Kazantsev et.al.
- # D. Kazantsev, 2016-17
- def __init__(self, projector_geometry, output_geometry, input_sinogram,
- **kwargs):
- # handle parmeters:
- # obligatory parameters
- self.pars = dict()
- self.pars['projector_geometry'] = projector_geometry # proj_geom
- self.pars['output_geometry'] = output_geometry # vol_geom
- self.pars['input_sinogram'] = input_sinogram # sino
- sliceZ, nangles, detectors = numpy.shape(input_sinogram)
- self.pars['detectors'] = detectors
- self.pars['number_of_angles'] = nangles
- self.pars['SlicesZ'] = sliceZ
- self.pars['output_volume'] = None
-
- print (self.pars)
- # handle optional input parameters (at instantiation)
-
- # Accepted input keywords
- kw = (
- # mandatory fields
- 'projector_geometry',
- 'output_geometry',
- 'input_sinogram',
- 'detectors',
- 'number_of_angles',
- 'SlicesZ',
- # optional fields
- 'number_of_iterations',
- 'Lipschitz_constant' ,
- 'ideal_image' ,
- 'weights' ,
- 'region_of_interest' ,
- 'initialize' ,
- 'regularizer' ,
- 'ring_lambda_R_L1',
- 'ring_alpha',
- 'subsets',
- 'output_volume',
- 'os_subsets',
- 'os_indices',
- 'os_bins')
- self.acceptedInputKeywords = list(kw)
-
- # handle keyworded parameters
- if kwargs is not None:
- for key, value in kwargs.items():
- if key in kw:
- #print("{0} = {1}".format(key, value))
- self.pars[key] = value
-
- # set the default values for the parameters if not set
- if 'number_of_iterations' in kwargs.keys():
- self.pars['number_of_iterations'] = kwargs['number_of_iterations']
- else:
- self.pars['number_of_iterations'] = 40
- if 'weights' in kwargs.keys():
- self.pars['weights'] = kwargs['weights']
- else:
- self.pars['weights'] = \
- numpy.ones(numpy.shape(
- self.pars['input_sinogram']))
- if 'Lipschitz_constant' in kwargs.keys():
- self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant']
- else:
- self.pars['Lipschitz_constant'] = None
-
- if not 'ideal_image' in kwargs.keys():
- self.pars['ideal_image'] = None
-
- if not 'region_of_interest'in kwargs.keys() :
- if self.pars['ideal_image'] == None:
- self.pars['region_of_interest'] = None
- else:
- ## nonzero if the image is larger than m
- fsm = numpy.frompyfunc(lambda x,m: 1 if x>m else 0, 2,1)
-
- self.pars['region_of_interest'] = fsm(self.pars['ideal_image'], 0)
-
- # the regularizer must be a correctly instantiated object
- if not 'regularizer' in kwargs.keys() :
- self.pars['regularizer'] = None
-
- #RING REMOVAL
- if not 'ring_lambda_R_L1' in kwargs.keys():
- self.pars['ring_lambda_R_L1'] = 0
- if not 'ring_alpha' in kwargs.keys():
- self.pars['ring_alpha'] = 1
-
- # ORDERED SUBSET
- if not 'subsets' in kwargs.keys():
- self.pars['subsets'] = 0
- else:
- self.createOrderedSubsets()
-
- if not 'initialize' in kwargs.keys():
- self.pars['initialize'] = False
-
-
-
-
- def setParameter(self, **kwargs):
- '''set named parameter for the reconstructor engine
-
- raises Exception if the named parameter is not recognized
-
- '''
- for key , value in kwargs.items():
- if key in self.acceptedInputKeywords:
- self.pars[key] = value
- else:
- raise Exception('Wrong parameter {0} for '.format(key) +
- 'reconstructor')
- # setParameter
-
- def getParameter(self, key):
- if type(key) is str:
- if key in self.acceptedInputKeywords:
- return self.pars[key]
- else:
- raise Exception('Unrecongnised parameter: {0} '.format(key) )
- elif type(key) is list:
- outpars = []
- for k in key:
- outpars.append(self.getParameter(k))
- return outpars
- else:
- raise Exception('Unhandled input {0}' .format(str(type(key))))
-
-
- def calculateLipschitzConstantWithPowerMethod(self):
- ''' using Power method (PM) to establish L constant'''
-
- N = self.pars['output_geometry']['GridColCount']
- proj_geom = self.pars['projector_geometry']
- vol_geom = self.pars['output_geometry']
- weights = self.pars['weights']
- SlicesZ = self.pars['SlicesZ']
-
-
-
- if (proj_geom['type'] == 'parallel') or \
- (proj_geom['type'] == 'parallel3d'):
- #% for parallel geometry we can do just one slice
- #print('Calculating Lipshitz constant for parallel beam geometry...')
- niter = 5;# % number of iteration for the PM
- #N = params.vol_geom.GridColCount;
- #x1 = rand(N,N,1);
- x1 = numpy.random.rand(1,N,N)
- #sqweight = sqrt(weights(:,:,1));
- sqweight = numpy.sqrt(weights[0])
- proj_geomT = proj_geom.copy();
- proj_geomT['DetectorRowCount'] = 1;
- vol_geomT = vol_geom.copy();
- vol_geomT['GridSliceCount'] = 1;
-
- #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
-
-
- for i in range(niter):
- # [id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geomT, vol_geomT);
- # s = norm(x1(:));
- # x1 = x1/s;
- # [sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
- # y = sqweight.*y;
- # astra_mex_data3d('delete', sino_id);
- # astra_mex_data3d('delete', id);
- #print ("iteration {0}".format(i))
-
- sino_id, y = astra.creators.create_sino3d_gpu(x1,
- proj_geomT,
- vol_geomT)
-
- y = (sqweight * y).copy() # element wise multiplication
-
- #b=fig.add_subplot(2,1,2)
- #imgplot = plt.imshow(x1[0])
- #plt.show()
-
- #astra_mex_data3d('delete', sino_id);
- astra.matlab.data3d('delete', sino_id)
- del x1
-
- idx,x1 = astra.creators.create_backprojection3d_gpu((sqweight*y).copy(),
- proj_geomT,
- vol_geomT)
- del y
-
-
- s = numpy.linalg.norm(x1)
- ### this line?
- x1 = (x1/s).copy();
-
- # ### this line?
- # sino_id, y = astra.creators.create_sino3d_gpu(x1,
- # proj_geomT,
- # vol_geomT);
- # y = sqweight * y;
- astra.matlab.data3d('delete', sino_id);
- astra.matlab.data3d('delete', idx)
- print ("iteration {0} s= {1}".format(i,s))
-
- #end
- del proj_geomT
- del vol_geomT
- #plt.show()
- else:
- #% divergen beam geometry
- print('Calculating Lipshitz constant for divergen beam geometry...')
- niter = 8; #% number of iteration for PM
- x1 = numpy.random.rand(SlicesZ , N , N);
- #sqweight = sqrt(weights);
- sqweight = numpy.sqrt(weights[0])
-
- sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom);
- y = sqweight*y;
- #astra_mex_data3d('delete', sino_id);
- astra.matlab.data3d('delete', sino_id);
-
- for i in range(niter):
- #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom);
- idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y,
- proj_geom,
- vol_geom)
- s = numpy.linalg.norm(x1)
- ### this line?
- x1 = x1/s;
- ### this line?
- #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom);
- sino_id, y = astra.creators.create_sino3d_gpu(x1,
- proj_geom,
- vol_geom);
-
- y = sqweight*y;
- #astra_mex_data3d('delete', sino_id);
- #astra_mex_data3d('delete', id);
- astra.matlab.data3d('delete', sino_id);
- astra.matlab.data3d('delete', idx);
- #end
- #clear x1
- del x1
-
-
- return s
-
-
- def setRegularizer(self, regularizer):
- if regularizer is not None:
- self.pars['regularizer'] = regularizer
-
-
- def initialize(self):
- # convenience variable storage
- proj_geom = self.pars['projector_geometry']
- vol_geom = self.pars['output_geometry']
- sino = self.pars['input_sinogram']
-
- # a 'warm start' with SIRT method
- # Create a data object for the reconstruction
- rec_id = astra.matlab.data3d('create', '-vol',
- vol_geom);
-
- #sinogram_id = astra_mex_data3d('create', '-proj3d', proj_geom, sino);
- sinogram_id = astra.matlab.data3d('create', '-proj3d',
- proj_geom,
- sino)
-
- sirt_config = astra.astra_dict('SIRT3D_CUDA')
- sirt_config['ReconstructionDataId' ] = rec_id
- sirt_config['ProjectionDataId'] = sinogram_id
-
- sirt = astra.algorithm.create(sirt_config)
- astra.algorithm.run(sirt, iterations=35)
- X = astra.matlab.data3d('get', rec_id)
-
- # clean up memory
- astra.matlab.data3d('delete', rec_id)
- astra.matlab.data3d('delete', sinogram_id)
- astra.algorithm.delete(sirt)
-
-
-
- return X
-
- def createOrderedSubsets(self, subsets=None):
- if subsets is None:
- try:
- subsets = self.getParameter('subsets')
- except Exception():
- subsets = 0
- #return subsets
-
- angles = self.getParameter('projector_geometry')['ProjectionAngles']
-
- #binEdges = numpy.linspace(angles.min(),
- # angles.max(),
- # subsets + 1)
- binsDiscr, binEdges = numpy.histogram(angles, bins=subsets)
- # get rearranged subset indices
- IndicesReorg = numpy.zeros((numpy.shape(angles)))
- counterM = 0
- for ii in range(binsDiscr.max()):
- counter = 0
- for jj in range(subsets):
- curr_index = ii + jj + counter
- #print ("{0} {1} {2}".format(binsDiscr[jj] , ii, counterM))
- if binsDiscr[jj] > ii:
- if (counterM < numpy.size(IndicesReorg)):
- IndicesReorg[counterM] = curr_index
- counterM = counterM + 1
-
- counter = counter + binsDiscr[jj] - 1
-
- # store the OS in parameters
- self.setParameter(os_subsets=subsets,
- os_bins=binsDiscr,
- os_indices=IndicesReorg)
-
-
- def prepareForIteration(self):
- print ("FISTA Reconstructor: prepare for iteration")
-
- self.residual_error = numpy.zeros((self.pars['number_of_iterations']))
- self.objective = numpy.zeros((self.pars['number_of_iterations']))
-
- #2D array (for 3D data) of sparse "ring"
- detectors, nangles, sliceZ = numpy.shape(self.pars['input_sinogram'])
- self.r = numpy.zeros((detectors, sliceZ), dtype=numpy.float)
- # another ring variable
- self.r_x = self.r.copy()
-
- self.residual = numpy.zeros(numpy.shape(self.pars['input_sinogram']))
-
- if self.getParameter('Lipschitz_constant') is None:
- self.pars['Lipschitz_constant'] = \
- self.calculateLipschitzConstantWithPowerMethod()
- # errors vector (if the ground truth is given)
- self.Resid_error = numpy.zeros((self.getParameter('number_of_iterations')));
- # objective function values vector
- self.objective = numpy.zeros((self.getParameter('number_of_iterations')));
-
-
- # prepareForIteration
-
- def iterate(self, Xin=None):
- print ("FISTA Reconstructor: iterate")
-
- if Xin is None:
- if self.getParameter('initialize'):
- X = self.initialize()
- else:
- N = vol_geom['GridColCount']
- X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float)
- else:
- # copy by reference
- X = Xin
- # store the output volume in the parameters
- self.setParameter(output_volume=X)
- X_t = X.copy()
- # convenience variable storage
- proj_geom , vol_geom, sino , \
- SlicesZ = self.getParameter([ 'projector_geometry' ,
- 'output_geometry',
- 'input_sinogram',
- 'SlicesZ' ])
-
- t = 1
-
- for i in range(self.getParameter('number_of_iterations')):
- X_old = X.copy()
- t_old = t
- r_old = self.r.copy()
- if self.getParameter('projector_geometry')['type'] == 'parallel' or \
- self.getParameter('projector_geometry')['type'] == 'fanflat' or \
- self.getParameter('projector_geometry')['type'] == 'fanflat_vec':
- # if the geometry is parallel use slice-by-slice
- # projection-backprojection routine
- #sino_updt = zeros(size(sino),'single');
- proj_geomT = proj_geom.copy()
- proj_geomT['DetectorRowCount'] = 1
- vol_geomT = vol_geom.copy()
- vol_geomT['GridSliceCount'] = 1;
- self.sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float)
- for kkk in range(SlicesZ):
- sino_id, self.sino_updt[kkk] = \
- astra.creators.create_sino3d_gpu(
- X_t[kkk:kkk+1], proj_geomT, vol_geomT)
- astra.matlab.data3d('delete', sino_id)
- else:
- # for divergent 3D geometry (watch the GPU memory overflow in
- # ASTRA versions < 1.8)
- #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom);
- sino_id, self.sino_updt = astra.creators.create_sino3d_gpu(
- X_t, proj_geom, vol_geom)
-
-
- ## RING REMOVAL
- self.ringRemoval(i)
- ## Projection/Backprojection Routine
- self.projectionBackprojection(X, X_t)
- astra.matlab.data3d('delete', sino_id)
- ## REGULARIZATION
- X = self.regularize(X)
- ## Update Loop
- X , X_t, t = self.updateLoop(i, X, X_old, r_old, t, t_old)
- self.setParameter(output_volume=X)
- return X
- ## iterate
-
- def ringRemoval(self, i):
- print ("FISTA Reconstructor: ring removal")
- residual = self.residual
- lambdaR_L1 , alpha_ring , weights , L_const , sino= \
- self.getParameter(['ring_lambda_R_L1',
- 'ring_alpha' , 'weights',
- 'Lipschitz_constant',
- 'input_sinogram'])
- r_x = self.r_x
- sino_updt = self.sino_updt
-
- SlicesZ, anglesNumb, Detectors = \
- numpy.shape(self.getParameter('input_sinogram'))
- if lambdaR_L1 > 0 :
- for kkk in range(anglesNumb):
-
- residual[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \
- ((sino_updt[:,kkk,:]).squeeze() - \
- (sino[:,kkk,:]).squeeze() -\
- (alpha_ring * r_x)
- )
- vec = residual.sum(axis = 1)
- #if SlicesZ > 1:
- # vec = vec[:,1,:].squeeze()
- self.r = (r_x - (1./L_const) * vec).copy()
- self.objective[i] = (0.5 * (residual ** 2).sum())
-
- def projectionBackprojection(self, X, X_t):
- print ("FISTA Reconstructor: projection-backprojection routine")
-
- # a few useful variables
- SlicesZ, anglesNumb, Detectors = \
- numpy.shape(self.getParameter('input_sinogram'))
- residual = self.residual
- proj_geom , vol_geom , L_const = \
- self.getParameter(['projector_geometry' ,
- 'output_geometry',
- 'Lipschitz_constant'])
-
-
- if self.getParameter('projector_geometry')['type'] == 'parallel' or \
- self.getParameter('projector_geometry')['type'] == 'fanflat' or \
- self.getParameter('projector_geometry')['type'] == 'fanflat_vec':
- # if the geometry is parallel use slice-by-slice
- # projection-backprojection routine
- #sino_updt = zeros(size(sino),'single');
- proj_geomT = proj_geom.copy()
- proj_geomT['DetectorRowCount'] = 1
- vol_geomT = vol_geom.copy()
- vol_geomT['GridSliceCount'] = 1;
- x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32)
-
- for kkk in range(SlicesZ):
-
- x_id, x_temp[kkk] = \
- astra.creators.create_backprojection3d_gpu(
- residual[kkk:kkk+1],
- proj_geomT, vol_geomT)
- astra.matlab.data3d('delete', x_id)
- else:
- x_id, x_temp = \
- astra.creators.create_backprojection3d_gpu(
- residual, proj_geom, vol_geom)
-
- X = X_t - (1/L_const) * x_temp
- #astra.matlab.data3d('delete', sino_id)
- astra.matlab.data3d('delete', x_id)
-
- def regularize(self, X):
- print ("FISTA Reconstructor: regularize")
-
- regularizer = self.getParameter('regularizer')
- if regularizer is not None:
- return regularizer(input=X)
- else:
- return X
-
- def updateLoop(self, i, X, X_old, r_old, t, t_old):
- print ("FISTA Reconstructor: update loop")
- lambdaR_L1 = self.getParameter('ring_lambda_R_L1')
- if lambdaR_L1 > 0:
- self.r = numpy.max(
- numpy.abs(self.r) - lambdaR_L1 , 0) * \
- numpy.sign(self.r)
- t = (1 + numpy.sqrt(1 + 4 * t**2))/2
- X_t = X + (((t_old -1)/t) * (X - X_old))
-
- if lambdaR_L1 > 0:
- self.r_x = self.r + \
- (((t_old-1)/t) * (self.r - r_old))
-
- if self.getParameter('region_of_interest') is None:
- string = 'Iteration Number {0} | Objective {1} \n'
- print (string.format( i, self.objective[i]))
- else:
- ROI , X_ideal = fistaRecon.getParameter('region_of_interest',
- 'ideal_image')
-
- Resid_error[i] = RMSE(X*ROI, X_ideal*ROI)
- string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n'
- print (string.format(i,Resid_error[i], self.objective[i]))
- return (X , X_t, t)
-
- def os_iterate(self, Xin=None):
- print ("FISTA Reconstructor: iterate")
-
- if Xin is None:
- if self.getParameter('initialize'):
- X = self.initialize()
- else:
- N = vol_geom['GridColCount']
- X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float)
- else:
- # copy by reference
- X = Xin
- # store the output volume in the parameters
- self.setParameter(output_volume=X)
- X_t = X.copy()
-
- # some useful constants
- proj_geom , vol_geom, sino , \
- SlicesZ, weights , alpha_ring ,
- lambdaR_L1 , L_const = self.getParameter(
- ['projector_geometry' , 'output_geometry',
- 'input_sinogram', 'SlicesZ' , 'weights', 'ring_alpha' ,
- 'ring_lambda_R_L1', 'Lipschitz_constant'])
diff --git a/src/Python/ccpi/fista/Reconstructor.py b/src/Python/ccpi/fista/Reconstructor.py
deleted file mode 100644
index d29ac0d..0000000
--- a/src/Python/ccpi/fista/Reconstructor.py
+++ /dev/null
@@ -1,425 +0,0 @@
-# -*- coding: utf-8 -*-
-###############################################################################
-#This work is part of the Core Imaging Library developed by
-#Visual Analytics and Imaging System Group of the Science Technology
-#Facilities Council, STFC
-#
-#Copyright 2017 Edoardo Pasca, Srikanth Nagella
-#Copyright 2017 Daniil Kazantsev
-#
-#Licensed under the Apache License, Version 2.0 (the "License");
-#you may not use this file except in compliance with the License.
-#You may obtain a copy of the License at
-#http://www.apache.org/licenses/LICENSE-2.0
-#Unless required by applicable law or agreed to in writing, software
-#distributed under the License is distributed on an "AS IS" BASIS,
-#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-#See the License for the specific language governing permissions and
-#limitations under the License.
-###############################################################################
-
-
-
-import numpy
-import h5py
-from ccpi.reconstruction.parallelbeam import alg
-
-from Regularizer import Regularizer
-from enum import Enum
-
-import astra
-
-
-
-class FISTAReconstructor():
- '''FISTA-based reconstruction algorithm using ASTRA-toolbox
-
- '''
- # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>>
- # ___Input___:
- # params.[] file:
- # - .proj_geom (geometry of the projector) [required]
- # - .vol_geom (geometry of the reconstructed object) [required]
- # - .sino (vectorized in 2D or 3D sinogram) [required]
- # - .iterFISTA (iterations for the main loop, default 40)
- # - .L_const (Lipschitz constant, default Power method) )
- # - .X_ideal (ideal image, if given)
- # - .weights (statisitcal weights, size of the sinogram)
- # - .ROI (Region-of-interest, only if X_ideal is given)
- # - .initialize (a 'warm start' using SIRT method from ASTRA)
- #----------------Regularization choices------------------------
- # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter)
- # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter)
- # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter)
- # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04)
- # - .Regul_Iterations (iterations for the selected penalty, default 25)
- # - .Regul_tauLLT (time step parameter for LLT term)
- # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal)
- # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1)
- #----------------Visualization parameters------------------------
- # - .show (visualize reconstruction 1/0, (0 default))
- # - .maxvalplot (maximum value to use for imshow[0 maxvalplot])
- # - .slice (for 3D volumes - slice number to imshow)
- # ___Output___:
- # 1. X - reconstructed image/volume
- # 2. output - a structure with
- # - .Resid_error - residual error (if X_ideal is given)
- # - .objective: value of the objective function
- # - .L_const: Lipshitz constant to avoid recalculations
-
- # References:
- # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse
- # Problems" by A. Beck and M Teboulle
- # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo
- # 3. "A novel tomographic reconstruction method based on the robust
- # Student's t function for suppressing data outliers" D. Kazantsev et.al.
- # D. Kazantsev, 2016-17
- def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs):
- self.params = dict()
- self.params['projector_geometry'] = projector_geometry
- self.params['output_geometry'] = output_geometry
- self.params['input_sinogram'] = input_sinogram
- detectors, nangles, sliceZ = numpy.shape(input_sinogram)
- self.params['detectors'] = detectors
- self.params['number_og_angles'] = nangles
- self.params['SlicesZ'] = sliceZ
-
- # Accepted input keywords
- kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' ,
- 'weights' , 'region_of_interest' , 'initialize' ,
- 'regularizer' ,
- 'ring_lambda_R_L1',
- 'ring_alpha')
-
- # handle keyworded parameters
- if kwargs is not None:
- for key, value in kwargs.items():
- if key in kw:
- #print("{0} = {1}".format(key, value))
- self.pars[key] = value
-
- # set the default values for the parameters if not set
- if 'number_of_iterations' in kwargs.keys():
- self.pars['number_of_iterations'] = kwargs['number_of_iterations']
- else:
- self.pars['number_of_iterations'] = 40
- if 'weights' in kwargs.keys():
- self.pars['weights'] = kwargs['weights']
- else:
- self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram']))
- if 'Lipschitz_constant' in kwargs.keys():
- self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant']
- else:
- self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod()
-
- if not self.pars['ideal_image'] in kwargs.keys():
- self.pars['ideal_image'] = None
-
- if not self.pars['region_of_interest'] :
- if self.pars['ideal_image'] == None:
- pass
- else:
- self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0)
-
- if not self.pars['regularizer'] :
- self.pars['regularizer'] = None
- else:
- # the regularizer must be a correctly instantiated object
- if not self.pars['ring_lambda_R_L1']:
- self.pars['ring_lambda_R_L1'] = 0
- if not self.pars['ring_alpha']:
- self.pars['ring_alpha'] = 1
-
-
-
-
- def calculateLipschitzConstantWithPowerMethod(self):
- ''' using Power method (PM) to establish L constant'''
-
- #N = params.vol_geom.GridColCount
- N = self.pars['output_geometry'].GridColCount
- proj_geom = self.params['projector_geometry']
- vol_geom = self.params['output_geometry']
- weights = self.pars['weights']
- SlicesZ = self.pars['SlicesZ']
-
- if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'):
- #% for parallel geometry we can do just one slice
- #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...');
- niter = 15;# % number of iteration for the PM
- #N = params.vol_geom.GridColCount;
- #x1 = rand(N,N,1);
- x1 = numpy.random.rand(1,N,N)
- #sqweight = sqrt(weights(:,:,1));
- sqweight = numpy.sqrt(weights.T[0])
- proj_geomT = proj_geom.copy();
- proj_geomT.DetectorRowCount = 1;
- vol_geomT = vol_geom.copy();
- vol_geomT['GridSliceCount'] = 1;
-
-
- for i in range(niter):
- if i == 0:
- #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
- sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT);
- y = sqweight * y # element wise multiplication
- #astra_mex_data3d('delete', sino_id);
- astra.matlab.data3d('delete', sino_id)
-
- idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT);
- s = numpy.linalg.norm(x1)
- ### this line?
- x1 = x1/s;
- ### this line?
- sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
- y = sqweight*y;
- astra.matlab.data3d('delete', sino_id);
- astra.matlab.data3d('delete', idx);
- #end
- del proj_geomT
- del vol_geomT
- else
- #% divergen beam geometry
- #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...');
- niter = 8; #% number of iteration for PM
- x1 = numpy.random.rand(SlicesZ , N , N);
- #sqweight = sqrt(weights);
- sqweight = numpy.sqrt(weights.T[0])
-
- sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom);
- y = sqweight*y;
- #astra_mex_data3d('delete', sino_id);
- astra.matlab.data3d('delete', sino_id);
-
- for i in range(niter):
- #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom);
- idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y,
- proj_geom,
- vol_geom)
- s = numpy.linalg.norm(x1)
- ### this line?
- x1 = x1/s;
- ### this line?
- #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom);
- sino_id, y = astra.creators.create_sino3d_gpu(x1,
- proj_geom,
- vol_geom);
-
- y = sqweight*y;
- #astra_mex_data3d('delete', sino_id);
- #astra_mex_data3d('delete', id);
- astra.matlab.data3d('delete', sino_id);
- astra.matlab.data3d('delete', idx);
- #end
- #clear x1
- del x1
-
- return s
-
-
- def setRegularizer(self, regularizer):
- if regularizer
- self.pars['regularizer'] = regularizer
-
-
-
-
-
-def getEntry(location):
- for item in nx[location].keys():
- print (item)
-
-
-print ("Loading Data")
-
-##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif"
-####ind = [i * 1049 for i in range(360)]
-#### use only 360 images
-##images = 200
-##ind = [int(i * 1049 / images) for i in range(images)]
-##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None)
-
-#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs"
-fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs"
-nx = h5py.File(fname, "r")
-
-# the data are stored in a particular location in the hdf5
-for item in nx['entry1/tomo_entry/data'].keys():
- print (item)
-
-data = nx.get('entry1/tomo_entry/data/rotation_angle')
-angles = numpy.zeros(data.shape)
-data.read_direct(angles)
-print (angles)
-# angles should be in degrees
-
-data = nx.get('entry1/tomo_entry/data/data')
-stack = numpy.zeros(data.shape)
-data.read_direct(stack)
-print (data.shape)
-
-print ("Data Loaded")
-
-
-# Normalize
-data = nx.get('entry1/tomo_entry/instrument/detector/image_key')
-itype = numpy.zeros(data.shape)
-data.read_direct(itype)
-# 2 is dark field
-darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ]
-dark = darks[0]
-for i in range(1, len(darks)):
- dark += darks[i]
-dark = dark / len(darks)
-#dark[0][0] = dark[0][1]
-
-# 1 is flat field
-flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ]
-flat = flats[0]
-for i in range(1, len(flats)):
- flat += flats[i]
-flat = flat / len(flats)
-#flat[0][0] = dark[0][1]
-
-
-# 0 is projection data
-proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ]
-angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ]
-angle_proj = numpy.asarray (angle_proj)
-angle_proj = angle_proj.astype(numpy.float32)
-
-# normalized data are
-# norm = (projection - dark)/(flat-dark)
-
-def normalize(projection, dark, flat, def_val=0.1):
- a = (projection - dark)
- b = (flat-dark)
- with numpy.errstate(divide='ignore', invalid='ignore'):
- c = numpy.true_divide( a, b )
- c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0
- return c
-
-
-norm = [normalize(projection, dark, flat) for projection in proj]
-norm = numpy.asarray (norm)
-norm = norm.astype(numpy.float32)
-
-#recon = Reconstructor(algorithm = Algorithm.CGLS, normalized_projection = norm,
-# angles = angle_proj, center_of_rotation = 86.2 ,
-# flat_field = flat, dark_field = dark,
-# iterations = 15, resolution = 1, isLogScale = False, threads = 3)
-
-#recon = Reconstructor(algorithm = Reconstructor.Algorithm.CGLS, projection_data = proj,
-# angles = angle_proj, center_of_rotation = 86.2 ,
-# flat_field = flat, dark_field = dark,
-# iterations = 15, resolution = 1, isLogScale = False, threads = 3)
-#img_cgls = recon.reconstruct()
-#
-#pars = dict()
-#pars['algorithm'] = Reconstructor.Algorithm.SIRT
-#pars['projection_data'] = proj
-#pars['angles'] = angle_proj
-#pars['center_of_rotation'] = numpy.double(86.2)
-#pars['flat_field'] = flat
-#pars['iterations'] = 15
-#pars['dark_field'] = dark
-#pars['resolution'] = 1
-#pars['isLogScale'] = False
-#pars['threads'] = 3
-#
-#img_sirt = recon.reconstruct(pars)
-#
-#recon.pars['algorithm'] = Reconstructor.Algorithm.MLEM
-#img_mlem = recon.reconstruct()
-
-############################################################
-############################################################
-#recon.pars['algorithm'] = Reconstructor.Algorithm.CGLS_CONV
-#recon.pars['regularize'] = numpy.double(0.1)
-#img_cgls_conv = recon.reconstruct()
-
-niterations = 15
-threads = 3
-
-img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
-img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
-img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
-
-iteration_values = numpy.zeros((niterations,))
-img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
- iteration_values, False)
-print ("iteration values %s" % str(iteration_values))
-
-iteration_values = numpy.zeros((niterations,))
-img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
- numpy.double(1e-5), iteration_values , False)
-print ("iteration values %s" % str(iteration_values))
-iteration_values = numpy.zeros((niterations,))
-img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
- numpy.double(1e-5), iteration_values , False)
-print ("iteration values %s" % str(iteration_values))
-
-
-##numpy.save("cgls_recon.npy", img_data)
-import matplotlib.pyplot as plt
-fig, ax = plt.subplots(1,6,sharey=True)
-ax[0].imshow(img_cgls[80])
-ax[0].axis('off') # clear x- and y-axes
-ax[1].imshow(img_sirt[80])
-ax[1].axis('off') # clear x- and y-axes
-ax[2].imshow(img_mlem[80])
-ax[2].axis('off') # clear x- and y-axesplt.show()
-ax[3].imshow(img_cgls_conv[80])
-ax[3].axis('off') # clear x- and y-axesplt.show()
-ax[4].imshow(img_cgls_tikhonov[80])
-ax[4].axis('off') # clear x- and y-axesplt.show()
-ax[5].imshow(img_cgls_TVreg[80])
-ax[5].axis('off') # clear x- and y-axesplt.show()
-
-
-plt.show()
-
-#viewer = edo.CILViewer()
-#viewer.setInputAsNumpy(img_cgls2)
-#viewer.displaySliceActor(0)
-#viewer.startRenderLoop()
-
-import vtk
-
-def NumpyToVTKImageData(numpyarray):
- if (len(numpy.shape(numpyarray)) == 3):
- doubleImg = vtk.vtkImageData()
- shape = numpy.shape(numpyarray)
- doubleImg.SetDimensions(shape[0], shape[1], shape[2])
- doubleImg.SetOrigin(0,0,0)
- doubleImg.SetSpacing(1,1,1)
- doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1)
- #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation())
- doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1)
-
- for i in range(shape[0]):
- for j in range(shape[1]):
- for k in range(shape[2]):
- doubleImg.SetScalarComponentFromDouble(
- i,j,k,0, numpyarray[i][j][k])
- #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) )
- # rescale to appropriate VTK_UNSIGNED_SHORT
- stats = vtk.vtkImageAccumulate()
- stats.SetInputData(doubleImg)
- stats.Update()
- iMin = stats.GetMin()[0]
- iMax = stats.GetMax()[0]
- scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin)
-
- shiftScaler = vtk.vtkImageShiftScale ()
- shiftScaler.SetInputData(doubleImg)
- shiftScaler.SetScale(scale)
- shiftScaler.SetShift(iMin)
- shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT)
- shiftScaler.Update()
- return shiftScaler.GetOutput()
-
-#writer = vtk.vtkMetaImageWriter()
-#writer.SetFileName(alg + "_recon.mha")
-#writer.SetInputData(NumpyToVTKImageData(img_cgls2))
-#writer.Write()
diff --git a/src/Python/ccpi/fista/__init__.py b/src/Python/ccpi/fista/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/src/Python/ccpi/fista/__init__.py
+++ /dev/null
diff --git a/src/Python/ccpi/reconstruction/FISTAReconstructor.py b/src/Python/ccpi/reconstruction/FISTAReconstructor.py
index ea96b53..c903712 100644
--- a/src/Python/ccpi/reconstruction/FISTAReconstructor.py
+++ b/src/Python/ccpi/reconstruction/FISTAReconstructor.py
@@ -21,10 +21,9 @@
import numpy
-import h5py
#from ccpi.reconstruction.parallelbeam import alg
-from ccpi.imaging.Regularizer import Regularizer
+#from ccpi.imaging.Regularizer import Regularizer
from enum import Enum
import astra
@@ -74,18 +73,34 @@ class FISTAReconstructor():
# 3. "A novel tomographic reconstruction method based on the robust
# Student's t function for suppressing data outliers" D. Kazantsev et.al.
# D. Kazantsev, 2016-17
- def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs):
- self.params = dict()
- self.params['projector_geometry'] = projector_geometry
- self.params['output_geometry'] = output_geometry
- self.params['input_sinogram'] = input_sinogram
- detectors, nangles, sliceZ = numpy.shape(input_sinogram)
- self.params['detectors'] = detectors
- self.params['number_og_angles'] = nangles
- self.params['SlicesZ'] = sliceZ
+ def __init__(self, projector_geometry, output_geometry, input_sinogram,
+ **kwargs):
+ # handle parmeters:
+ # obligatory parameters
+ self.pars = dict()
+ self.pars['projector_geometry'] = projector_geometry # proj_geom
+ self.pars['output_geometry'] = output_geometry # vol_geom
+ self.pars['input_sinogram'] = input_sinogram # sino
+ sliceZ, nangles, detectors = numpy.shape(input_sinogram)
+ self.pars['detectors'] = detectors
+ self.pars['number_of_angles'] = nangles
+ self.pars['SlicesZ'] = sliceZ
+ self.pars['output_volume'] = None
+
+ print (self.pars)
+ # handle optional input parameters (at instantiation)
# Accepted input keywords
- kw = ('number_of_iterations',
+ kw = (
+ # mandatory fields
+ 'projector_geometry',
+ 'output_geometry',
+ 'input_sinogram',
+ 'detectors',
+ 'number_of_angles',
+ 'SlicesZ',
+ # optional fields
+ 'number_of_iterations',
'Lipschitz_constant' ,
'ideal_image' ,
'weights' ,
@@ -93,7 +108,13 @@ class FISTAReconstructor():
'initialize' ,
'regularizer' ,
'ring_lambda_R_L1',
- 'ring_alpha')
+ 'ring_alpha',
+ 'subsets',
+ 'output_volume',
+ 'os_subsets',
+ 'os_indices',
+ 'os_bins')
+ self.acceptedInputKeywords = list(kw)
# handle keyworded parameters
if kwargs is not None:
@@ -110,85 +131,160 @@ class FISTAReconstructor():
if 'weights' in kwargs.keys():
self.pars['weights'] = kwargs['weights']
else:
- self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram']))
+ self.pars['weights'] = \
+ numpy.ones(numpy.shape(
+ self.pars['input_sinogram']))
if 'Lipschitz_constant' in kwargs.keys():
self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant']
else:
- self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod()
+ self.pars['Lipschitz_constant'] = None
- if not self.pars['ideal_image'] in kwargs.keys():
+ if not 'ideal_image' in kwargs.keys():
self.pars['ideal_image'] = None
- if not self.pars['region_of_interest'] :
+ if not 'region_of_interest'in kwargs.keys() :
if self.pars['ideal_image'] == None:
- pass
+ self.pars['region_of_interest'] = None
else:
- self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0)
-
- if not self.pars['regularizer'] :
+ ## nonzero if the image is larger than m
+ fsm = numpy.frompyfunc(lambda x,m: 1 if x>m else 0, 2,1)
+
+ self.pars['region_of_interest'] = fsm(self.pars['ideal_image'], 0)
+
+ # the regularizer must be a correctly instantiated object
+ if not 'regularizer' in kwargs.keys() :
self.pars['regularizer'] = None
+
+ #RING REMOVAL
+ if not 'ring_lambda_R_L1' in kwargs.keys():
+ self.pars['ring_lambda_R_L1'] = 0
+ if not 'ring_alpha' in kwargs.keys():
+ self.pars['ring_alpha'] = 1
+
+ # ORDERED SUBSET
+ if not 'subsets' in kwargs.keys():
+ self.pars['subsets'] = 0
else:
- # the regularizer must be a correctly instantiated object
- if not self.pars['ring_lambda_R_L1']:
- self.pars['ring_lambda_R_L1'] = 0
- if not self.pars['ring_alpha']:
- self.pars['ring_alpha'] = 1
+ self.createOrderedSubsets()
+
+ if not 'initialize' in kwargs.keys():
+ self.pars['initialize'] = False
+
+ def setParameter(self, **kwargs):
+ '''set named parameter for the reconstructor engine
+
+ raises Exception if the named parameter is not recognized
+ '''
+ for key , value in kwargs.items():
+ if key in self.acceptedInputKeywords:
+ self.pars[key] = value
+ else:
+ raise Exception('Wrong parameter {0} for '.format(key) +
+ 'reconstructor')
+ # setParameter
+
+ def getParameter(self, key):
+ if type(key) is str:
+ if key in self.acceptedInputKeywords:
+ return self.pars[key]
+ else:
+ raise Exception('Unrecongnised parameter: {0} '.format(key) )
+ elif type(key) is list:
+ outpars = []
+ for k in key:
+ outpars.append(self.getParameter(k))
+ return outpars
+ else:
+ raise Exception('Unhandled input {0}' .format(str(type(key))))
+
+
def calculateLipschitzConstantWithPowerMethod(self):
''' using Power method (PM) to establish L constant'''
- #N = params.vol_geom.GridColCount
- N = self.pars['output_geometry'].GridColCount
- proj_geom = self.params['projector_geometry']
- vol_geom = self.params['output_geometry']
+ N = self.pars['output_geometry']['GridColCount']
+ proj_geom = self.pars['projector_geometry']
+ vol_geom = self.pars['output_geometry']
weights = self.pars['weights']
SlicesZ = self.pars['SlicesZ']
- if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'):
+
+
+ if (proj_geom['type'] == 'parallel') or \
+ (proj_geom['type'] == 'parallel3d'):
#% for parallel geometry we can do just one slice
- #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...');
- niter = 15;# % number of iteration for the PM
+ #print('Calculating Lipshitz constant for parallel beam geometry...')
+ niter = 5;# % number of iteration for the PM
#N = params.vol_geom.GridColCount;
#x1 = rand(N,N,1);
x1 = numpy.random.rand(1,N,N)
#sqweight = sqrt(weights(:,:,1));
- sqweight = numpy.sqrt(weights.T[0])
+ sqweight = numpy.sqrt(weights[0])
proj_geomT = proj_geom.copy();
- proj_geomT.DetectorRowCount = 1;
+ proj_geomT['DetectorRowCount'] = 1;
vol_geomT = vol_geom.copy();
vol_geomT['GridSliceCount'] = 1;
+ #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
+
for i in range(niter):
- if i == 0:
- #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
- sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT);
- y = sqweight * y # element wise multiplication
- #astra_mex_data3d('delete', sino_id);
- astra.matlab.data3d('delete', sino_id)
+ # [id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geomT, vol_geomT);
+ # s = norm(x1(:));
+ # x1 = x1/s;
+ # [sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
+ # y = sqweight.*y;
+ # astra_mex_data3d('delete', sino_id);
+ # astra_mex_data3d('delete', id);
+ #print ("iteration {0}".format(i))
+
+ sino_id, y = astra.creators.create_sino3d_gpu(x1,
+ proj_geomT,
+ vol_geomT)
+
+ y = (sqweight * y).copy() # element wise multiplication
+
+ #b=fig.add_subplot(2,1,2)
+ #imgplot = plt.imshow(x1[0])
+ #plt.show()
+
+ #astra_mex_data3d('delete', sino_id);
+ astra.matlab.data3d('delete', sino_id)
+ del x1
- idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT);
+ idx,x1 = astra.creators.create_backprojection3d_gpu((sqweight*y).copy(),
+ proj_geomT,
+ vol_geomT)
+ del y
+
+
s = numpy.linalg.norm(x1)
### this line?
- x1 = x1/s;
- ### this line?
- sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT);
- y = sqweight*y;
+ x1 = (x1/s).copy();
+
+ # ### this line?
+ # sino_id, y = astra.creators.create_sino3d_gpu(x1,
+ # proj_geomT,
+ # vol_geomT);
+ # y = sqweight * y;
astra.matlab.data3d('delete', sino_id);
- astra.matlab.data3d('delete', idx);
+ astra.matlab.data3d('delete', idx)
+ print ("iteration {0} s= {1}".format(i,s))
+
#end
del proj_geomT
del vol_geomT
+ #plt.show()
else:
#% divergen beam geometry
- #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...');
+ print('Calculating Lipshitz constant for divergen beam geometry...')
niter = 8; #% number of iteration for PM
x1 = numpy.random.rand(SlicesZ , N , N);
#sqweight = sqrt(weights);
- sqweight = numpy.sqrt(weights.T[0])
+ sqweight = numpy.sqrt(weights[0])
sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom);
y = sqweight*y;
@@ -217,6 +313,7 @@ class FISTAReconstructor():
#end
#clear x1
del x1
+
return s
@@ -225,130 +322,291 @@ class FISTAReconstructor():
if regularizer is not None:
self.pars['regularizer'] = regularizer
+
+ def initialize(self):
+ # convenience variable storage
+ proj_geom = self.pars['projector_geometry']
+ vol_geom = self.pars['output_geometry']
+ sino = self.pars['input_sinogram']
+
+ # a 'warm start' with SIRT method
+ # Create a data object for the reconstruction
+ rec_id = astra.matlab.data3d('create', '-vol',
+ vol_geom);
+
+ #sinogram_id = astra_mex_data3d('create', '-proj3d', proj_geom, sino);
+ sinogram_id = astra.matlab.data3d('create', '-proj3d',
+ proj_geom,
+ sino)
+
+ sirt_config = astra.astra_dict('SIRT3D_CUDA')
+ sirt_config['ReconstructionDataId' ] = rec_id
+ sirt_config['ProjectionDataId'] = sinogram_id
+
+ sirt = astra.algorithm.create(sirt_config)
+ astra.algorithm.run(sirt, iterations=35)
+ X = astra.matlab.data3d('get', rec_id)
+
+ # clean up memory
+ astra.matlab.data3d('delete', rec_id)
+ astra.matlab.data3d('delete', sinogram_id)
+ astra.algorithm.delete(sirt)
+
+
+
+ return X
+
+ def createOrderedSubsets(self, subsets=None):
+ if subsets is None:
+ try:
+ subsets = self.getParameter('subsets')
+ except Exception():
+ subsets = 0
+ #return subsets
+ else:
+ self.setParameter(subsets=subsets)
+
+
+ angles = self.getParameter('projector_geometry')['ProjectionAngles']
+
+ #binEdges = numpy.linspace(angles.min(),
+ # angles.max(),
+ # subsets + 1)
+ binsDiscr, binEdges = numpy.histogram(angles, bins=subsets)
+ # get rearranged subset indices
+ IndicesReorg = numpy.zeros((numpy.shape(angles)), dtype=numpy.int32)
+ counterM = 0
+ for ii in range(binsDiscr.max()):
+ counter = 0
+ for jj in range(subsets):
+ curr_index = ii + jj + counter
+ #print ("{0} {1} {2}".format(binsDiscr[jj] , ii, counterM))
+ if binsDiscr[jj] > ii:
+ if (counterM < numpy.size(IndicesReorg)):
+ IndicesReorg[counterM] = curr_index
+ counterM = counterM + 1
+
+ counter = counter + binsDiscr[jj] - 1
+
+ # store the OS in parameters
+ self.setParameter(os_subsets=subsets,
+ os_bins=binsDiscr,
+ os_indices=IndicesReorg)
+
+
+ def prepareForIteration(self):
+ print ("FISTA Reconstructor: prepare for iteration")
+
+ self.residual_error = numpy.zeros((self.pars['number_of_iterations']))
+ self.objective = numpy.zeros((self.pars['number_of_iterations']))
+
+ #2D array (for 3D data) of sparse "ring"
+ detectors, nangles, sliceZ = numpy.shape(self.pars['input_sinogram'])
+ self.r = numpy.zeros((detectors, sliceZ), dtype=numpy.float)
+ # another ring variable
+ self.r_x = self.r.copy()
+
+ self.residual = numpy.zeros(numpy.shape(self.pars['input_sinogram']))
+
+ if self.getParameter('Lipschitz_constant') is None:
+ self.pars['Lipschitz_constant'] = \
+ self.calculateLipschitzConstantWithPowerMethod()
+ # errors vector (if the ground truth is given)
+ self.Resid_error = numpy.zeros((self.getParameter('number_of_iterations')));
+ # objective function values vector
+ self.objective = numpy.zeros((self.getParameter('number_of_iterations')));
+
+
+ # prepareForIteration
+
+ def iterate(self, Xin=None):
+ print ("FISTA Reconstructor: iterate")
+
+ if Xin is None:
+ if self.getParameter('initialize'):
+ X = self.initialize()
+ else:
+ N = vol_geom['GridColCount']
+ X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float)
+ else:
+ # copy by reference
+ X = Xin
+ # store the output volume in the parameters
+ self.setParameter(output_volume=X)
+ X_t = X.copy()
+ # convenience variable storage
+ proj_geom , vol_geom, sino , \
+ SlicesZ = self.getParameter([ 'projector_geometry' ,
+ 'output_geometry',
+ 'input_sinogram',
+ 'SlicesZ' ])
+
+ t = 1
+
+ for i in range(self.getParameter('number_of_iterations')):
+ X_old = X.copy()
+ t_old = t
+ r_old = self.r.copy()
+ if self.getParameter('projector_geometry')['type'] == 'parallel' or \
+ self.getParameter('projector_geometry')['type'] == 'fanflat' or \
+ self.getParameter('projector_geometry')['type'] == 'fanflat_vec':
+ # if the geometry is parallel use slice-by-slice
+ # projection-backprojection routine
+ #sino_updt = zeros(size(sino),'single');
+ proj_geomT = proj_geom.copy()
+ proj_geomT['DetectorRowCount'] = 1
+ vol_geomT = vol_geom.copy()
+ vol_geomT['GridSliceCount'] = 1;
+ self.sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float)
+ for kkk in range(SlicesZ):
+ sino_id, self.sino_updt[kkk] = \
+ astra.creators.create_sino3d_gpu(
+ X_t[kkk:kkk+1], proj_geomT, vol_geomT)
+ astra.matlab.data3d('delete', sino_id)
+ else:
+ # for divergent 3D geometry (watch the GPU memory overflow in
+ # ASTRA versions < 1.8)
+ #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom);
+ sino_id, self.sino_updt = astra.creators.create_sino3d_gpu(
+ X_t, proj_geom, vol_geom)
+
+
+ ## RING REMOVAL
+ self.ringRemoval(i)
+ ## Projection/Backprojection Routine
+ self.projectionBackprojection(X, X_t)
+ astra.matlab.data3d('delete', sino_id)
+ ## REGULARIZATION
+ X = self.regularize(X)
+ ## Update Loop
+ X , X_t, t = self.updateLoop(i, X, X_old, r_old, t, t_old)
+ self.setParameter(output_volume=X)
+ return X
+ ## iterate
-
+ def ringRemoval(self, i):
+ print ("FISTA Reconstructor: ring removal")
+ residual = self.residual
+ lambdaR_L1 , alpha_ring , weights , L_const , sino= \
+ self.getParameter(['ring_lambda_R_L1',
+ 'ring_alpha' , 'weights',
+ 'Lipschitz_constant',
+ 'input_sinogram'])
+ r_x = self.r_x
+ sino_updt = self.sino_updt
+
+ SlicesZ, anglesNumb, Detectors = \
+ numpy.shape(self.getParameter('input_sinogram'))
+ if lambdaR_L1 > 0 :
+ for kkk in range(anglesNumb):
+
+ residual[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \
+ ((sino_updt[:,kkk,:]).squeeze() - \
+ (sino[:,kkk,:]).squeeze() -\
+ (alpha_ring * r_x)
+ )
+ vec = residual.sum(axis = 1)
+ #if SlicesZ > 1:
+ # vec = vec[:,1,:].squeeze()
+ self.r = (r_x - (1./L_const) * vec).copy()
+ self.objective[i] = (0.5 * (residual ** 2).sum())
+ def projectionBackprojection(self, X, X_t):
+ print ("FISTA Reconstructor: projection-backprojection routine")
+
+ # a few useful variables
+ SlicesZ, anglesNumb, Detectors = \
+ numpy.shape(self.getParameter('input_sinogram'))
+ residual = self.residual
+ proj_geom , vol_geom , L_const = \
+ self.getParameter(['projector_geometry' ,
+ 'output_geometry',
+ 'Lipschitz_constant'])
+
+
+ if self.getParameter('projector_geometry')['type'] == 'parallel' or \
+ self.getParameter('projector_geometry')['type'] == 'fanflat' or \
+ self.getParameter('projector_geometry')['type'] == 'fanflat_vec':
+ # if the geometry is parallel use slice-by-slice
+ # projection-backprojection routine
+ #sino_updt = zeros(size(sino),'single');
+ proj_geomT = proj_geom.copy()
+ proj_geomT['DetectorRowCount'] = 1
+ vol_geomT = vol_geom.copy()
+ vol_geomT['GridSliceCount'] = 1;
+ x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32)
+
+ for kkk in range(SlicesZ):
+
+ x_id, x_temp[kkk] = \
+ astra.creators.create_backprojection3d_gpu(
+ residual[kkk:kkk+1],
+ proj_geomT, vol_geomT)
+ astra.matlab.data3d('delete', x_id)
+ else:
+ x_id, x_temp = \
+ astra.creators.create_backprojection3d_gpu(
+ residual, proj_geom, vol_geom)
+
+ X = X_t - (1/L_const) * x_temp
+ #astra.matlab.data3d('delete', sino_id)
+ astra.matlab.data3d('delete', x_id)
-def getEntry(location):
- for item in nx[location].keys():
- print (item)
-
-
-print ("Loading Data")
-
-##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif"
-####ind = [i * 1049 for i in range(360)]
-#### use only 360 images
-##images = 200
-##ind = [int(i * 1049 / images) for i in range(images)]
-##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None)
-
-#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs"
-#fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs"
-##fname = "/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/dendr.h5"
-##nx = h5py.File(fname, "r")
-##
-### the data are stored in a particular location in the hdf5
-##for item in nx['entry1/tomo_entry/data'].keys():
-## print (item)
-##
-##data = nx.get('entry1/tomo_entry/data/rotation_angle')
-##angles = numpy.zeros(data.shape)
-##data.read_direct(angles)
-##print (angles)
-### angles should be in degrees
-##
-##data = nx.get('entry1/tomo_entry/data/data')
-##stack = numpy.zeros(data.shape)
-##data.read_direct(stack)
-##print (data.shape)
-##
-##print ("Data Loaded")
-##
-##
-### Normalize
-##data = nx.get('entry1/tomo_entry/instrument/detector/image_key')
-##itype = numpy.zeros(data.shape)
-##data.read_direct(itype)
-### 2 is dark field
-##darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ]
-##dark = darks[0]
-##for i in range(1, len(darks)):
-## dark += darks[i]
-##dark = dark / len(darks)
-###dark[0][0] = dark[0][1]
-##
-### 1 is flat field
-##flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ]
-##flat = flats[0]
-##for i in range(1, len(flats)):
-## flat += flats[i]
-##flat = flat / len(flats)
-###flat[0][0] = dark[0][1]
-##
-##
-### 0 is projection data
-##proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ]
-##angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ]
-##angle_proj = numpy.asarray (angle_proj)
-##angle_proj = angle_proj.astype(numpy.float32)
-##
-### normalized data are
-### norm = (projection - dark)/(flat-dark)
-##
-##def normalize(projection, dark, flat, def_val=0.1):
-## a = (projection - dark)
-## b = (flat-dark)
-## with numpy.errstate(divide='ignore', invalid='ignore'):
-## c = numpy.true_divide( a, b )
-## c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0
-## return c
-##
-##
-##norm = [normalize(projection, dark, flat) for projection in proj]
-##norm = numpy.asarray (norm)
-##norm = norm.astype(numpy.float32)
-
-
-##niterations = 15
-##threads = 3
-##
-##img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
-##img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
-##img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
-##
-##iteration_values = numpy.zeros((niterations,))
-##img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
-## iteration_values, False)
-##print ("iteration values %s" % str(iteration_values))
-##
-##iteration_values = numpy.zeros((niterations,))
-##img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
-## numpy.double(1e-5), iteration_values , False)
-##print ("iteration values %s" % str(iteration_values))
-##iteration_values = numpy.zeros((niterations,))
-##img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
-## numpy.double(1e-5), iteration_values , False)
-##print ("iteration values %s" % str(iteration_values))
-##
-##
-####numpy.save("cgls_recon.npy", img_data)
-##import matplotlib.pyplot as plt
-##fig, ax = plt.subplots(1,6,sharey=True)
-##ax[0].imshow(img_cgls[80])
-##ax[0].axis('off') # clear x- and y-axes
-##ax[1].imshow(img_sirt[80])
-##ax[1].axis('off') # clear x- and y-axes
-##ax[2].imshow(img_mlem[80])
-##ax[2].axis('off') # clear x- and y-axesplt.show()
-##ax[3].imshow(img_cgls_conv[80])
-##ax[3].axis('off') # clear x- and y-axesplt.show()
-##ax[4].imshow(img_cgls_tikhonov[80])
-##ax[4].axis('off') # clear x- and y-axesplt.show()
-##ax[5].imshow(img_cgls_TVreg[80])
-##ax[5].axis('off') # clear x- and y-axesplt.show()
-##
-##
-##plt.show()
-##
+ def regularize(self, X):
+ print ("FISTA Reconstructor: regularize")
+
+ regularizer = self.getParameter('regularizer')
+ if regularizer is not None:
+ return regularizer(input=X)
+ else:
+ return X
+
+ def updateLoop(self, i, X, X_old, r_old, t, t_old):
+ print ("FISTA Reconstructor: update loop")
+ lambdaR_L1 = self.getParameter('ring_lambda_R_L1')
+ if lambdaR_L1 > 0:
+ self.r = numpy.max(
+ numpy.abs(self.r) - lambdaR_L1 , 0) * \
+ numpy.sign(self.r)
+ t = (1 + numpy.sqrt(1 + 4 * t**2))/2
+ X_t = X + (((t_old -1)/t) * (X - X_old))
+
+ if lambdaR_L1 > 0:
+ self.r_x = self.r + \
+ (((t_old-1)/t) * (self.r - r_old))
+
+ if self.getParameter('region_of_interest') is None:
+ string = 'Iteration Number {0} | Objective {1} \n'
+ print (string.format( i, self.objective[i]))
+ else:
+ ROI , X_ideal = fistaRecon.getParameter('region_of_interest',
+ 'ideal_image')
+
+ Resid_error[i] = RMSE(X*ROI, X_ideal*ROI)
+ string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n'
+ print (string.format(i,Resid_error[i], self.objective[i]))
+ return (X , X_t, t)
+
+ def os_iterate(self, Xin=None):
+ print ("FISTA Reconstructor: iterate")
+
+ if Xin is None:
+ if self.getParameter('initialize'):
+ X = self.initialize()
+ else:
+ N = vol_geom['GridColCount']
+ X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float)
+ else:
+ # copy by reference
+ X = Xin
+ # store the output volume in the parameters
+ self.setParameter(output_volume=X)
+ X_t = X.copy()
+ # some useful constants
+ proj_geom , vol_geom, sino , \
+ SlicesZ, weights , alpha_ring ,\
+ lambdaR_L1 , L_const = self.getParameter(
+ ['projector_geometry' , 'output_geometry',
+ 'input_sinogram', 'SlicesZ' , 'weights', 'ring_alpha' ,
+ 'ring_lambda_R_L1', 'Lipschitz_constant'])
diff --git a/src/Python/compile-fista.bat.in b/src/Python/compile-fista.bat.in
new file mode 100644
index 0000000..b1db686
--- /dev/null
+++ b/src/Python/compile-fista.bat.in
@@ -0,0 +1,7 @@
+set CIL_VERSION=@CIL_VERSION@
+
+set PREFIX=@CONDA_ENVIRONMENT_PREFIX@
+set LIBRARY_INC=@CONDA_ENVIRONMENT_LIBRARY_INC@
+
+REM activate @CONDA_ENVIRONMENT@
+conda build fista-recipe --python=@PYTHON_VERSION_MAJOR@.@PYTHON_VERSION_MINOR@ --numpy=@NUMPY_VERSION@ -c ccpi -c conda-forge
diff --git a/src/Python/compile-fista.sh.in b/src/Python/compile-fista.sh.in
new file mode 100644
index 0000000..267f014
--- /dev/null
+++ b/src/Python/compile-fista.sh.in
@@ -0,0 +1,9 @@
+#!/bin/sh
+# compile within the right conda environment
+#module load python/anaconda
+#source activate @CONDA_ENVIRONMENT@
+
+export CIL_VERSION=@CIL_VERSION@
+export LIBRARY_INC=@CONDA_ENVIRONMENT_LIBRARY_INC@
+
+conda build fista-recipe --python=@PYTHON_VERSION_MAJOR@.@PYTHON_VERSION_MINOR@ --numpy=@NUMPY_VERSION@ -c ccpi
diff --git a/src/Python/conda-recipe/meta.yaml b/src/Python/conda-recipe/meta.yaml
index c5b7a89..7068e9d 100644
--- a/src/Python/conda-recipe/meta.yaml
+++ b/src/Python/conda-recipe/meta.yaml
@@ -1,5 +1,5 @@
package:
- name: ccpi-fista
+ name: ccpi-regularizers
version: {{ environ['CIL_VERSION'] }}
diff --git a/src/Python/fista-recipe/build.sh b/src/Python/fista-recipe/build.sh
new file mode 100644
index 0000000..e3f3552
--- /dev/null
+++ b/src/Python/fista-recipe/build.sh
@@ -0,0 +1,10 @@
+if [ -z "$CIL_VERSION" ]; then
+ echo "Need to set CIL_VERSION"
+ exit 1
+fi
+mkdir "$SRC_DIR/ccpifista"
+cp -r "$RECIPE_DIR/.." "$SRC_DIR/ccpifista"
+
+cd $SRC_DIR/ccpifista
+
+$PYTHON setup-fista.py install
diff --git a/src/Python/fista-recipe/meta.yaml b/src/Python/fista-recipe/meta.yaml
new file mode 100644
index 0000000..265541f
--- /dev/null
+++ b/src/Python/fista-recipe/meta.yaml
@@ -0,0 +1,29 @@
+package:
+ name: ccpi-fista
+ version: {{ environ['CIL_VERSION'] }}
+
+
+build:
+ preserve_egg_dir: False
+ script_env:
+ - CIL_VERSION
+# number: 0
+
+requirements:
+ build:
+ - python
+ - numpy
+ - setuptools
+
+ run:
+ - python
+ - numpy
+ #- astra-toolbox
+ - ccpi-regularizers
+
+
+
+about:
+ home: http://www.ccpi.ac.uk
+ license: Apache v.2.0 license
+ summary: 'CCPi Core Imaging Library (Viewer)'
diff --git a/src/Python/setup-fista.py.in b/src/Python/setup-fista.py.in
new file mode 100644
index 0000000..c5c9f4d
--- /dev/null
+++ b/src/Python/setup-fista.py.in
@@ -0,0 +1,27 @@
+from distutils.core import setup
+#from setuptools import setup, find_packages
+import os
+
+cil_version=os.environ['CIL_VERSION']
+if cil_version == '':
+ print("Please set the environmental variable CIL_VERSION")
+ sys.exit(1)
+
+setup(
+ name="ccpi-fista",
+ version=cil_version,
+ packages=['ccpi','ccpi.reconstruction'],
+ install_requires=['numpy'],
+
+ zip_safe = False,
+
+ # metadata for upload to PyPI
+ author="Edoardo Pasca",
+ author_email="edo.paskino@gmail.com",
+ description='CCPi Core Imaging Library - FISTA Reconstructor module',
+ license="Apache v2.0",
+ keywords="tomography interative reconstruction",
+ url="http://www.ccpi.ac.uk", # project home page, if any
+
+ # could also include long_description, download_url, classifiers, etc.
+)
diff --git a/src/Python/setup.py.in b/src/Python/setup.py.in
index 0a1f4ad..12e8af1 100644
--- a/src/Python/setup.py.in
+++ b/src/Python/setup.py.in
@@ -44,7 +44,7 @@ else:
setup(
name='ccpi',
- description='CCPi Core Imaging Library - FISTA Reconstruction Module',
+ description='CCPi Core Imaging Library - Image Regularizers',
version=cil_version,
cmdclass = {'build_ext': build_ext},
ext_modules = [Extension("ccpi.imaging.cpu_regularizers",
@@ -65,3 +65,5 @@ setup(
zip_safe = False,
packages = {'ccpi','ccpi.imaging'},
)
+
+
diff --git a/src/Python/test_reconstructor-os.py b/src/Python/test/test_reconstructor-os.py
index aee70a4..6c82ae0 100644
--- a/src/Python/test_reconstructor-os.py
+++ b/src/Python/test/test_reconstructor-os.py
@@ -9,9 +9,10 @@ Based on DemoRD2.m
import h5py
import numpy
-from ccpi.fista.FISTAReconstructor import FISTAReconstructor
+from ccpi.reconstruction.FISTAReconstructor import FISTAReconstructor
import astra
import matplotlib.pyplot as plt
+from ccpi.imaging.Regularizer import Regularizer
def RMSE(signal1, signal2):
'''RMSE Root Mean Squared Error'''
@@ -76,9 +77,18 @@ fistaRecon.setParameter(Lipschitz_constant = 767893952.0)
fistaRecon.setParameter(ring_alpha = 21)
fistaRecon.setParameter(ring_lambda_R_L1 = 0.002)
+
+reg = Regularizer(Regularizer.Algorithm.LLT_model)
+reg.setParameter(regularization_parameter=25,
+ time_step=0.0003,
+ tolerance_constant=0.0001,
+ number_of_iterations=300)
+
## Ordered subset
if True:
subsets = 16
+ fistaRecon.setParameter(subsets=subsets)
+ fistaRecon.createOrderedSubsets()
angles = fistaRecon.getParameter('projector_geometry')['ProjectionAngles']
#binEdges = numpy.linspace(angles.min(),
# angles.max(),
@@ -146,6 +156,7 @@ if True:
fistaRecon.residual2 = numpy.zeros(numpy.shape(fistaRecon.pars['input_sinogram']))
residual2 = fistaRecon.residual2
sino_updt_FULL = fistaRecon.residual.copy()
+ r_x = fistaRecon.r.copy()
print ("starting iterations")
## % Outer FISTA iterations loop
@@ -206,8 +217,13 @@ if True:
# the number of projections per subset
numProjSub = fistaRecon.getParameter('os_bins')[ss]
CurrSubIndices = fistaRecon.getParameter('os_indices')\
- [counterInd:counterInd+numProjSub-1]
- proj_geomSUB['ProjectionAngles'] = angles[CurrSubIndeces]
+ [counterInd:counterInd+numProjSub]
+ #print ("Len CurrSubIndices {0}".format(numProjSub))
+ mask = numpy.zeros(numpy.shape(angles), dtype=bool)
+ cc = 0
+ for j in range(len(CurrSubIndices)):
+ mask[int(CurrSubIndices[j])] = True
+ proj_geomSUB['ProjectionAngles'] = angles[mask]
shape = list(numpy.shape(fistaRecon.getParameter('input_sinogram')))
shape[1] = numProjSub
@@ -246,7 +262,8 @@ if True:
## sino_updt_FULL(:,indC,:) = squeeze(sino_updt_Sub(:,kkk,:)); % filling the full sinogram
## end
for kkk in range(numProjSub):
- indC = CurrSubIndices[kkk]
+ #print ("ring removal indC ... {0}".format(kkk))
+ indC = int(CurrSubIndices[kkk])
residualSub[:,kkk,:] = weights[:,indC,:].squeeze() * \
(sino_updt_Sub[:,kkk,:].squeeze() - \
sino[:,indC,:].squeeze() - alpha_ring * r_x)
@@ -288,7 +305,8 @@ if True:
# regularizer = fistaRecon.getParameter('regularizer')
# for slices:
# out = regularizer(input=X)
- print ("skipping regularizer")
+ print ("regularizer")
+ #X = reg(input=X)
## FINAL
@@ -312,7 +330,8 @@ if True:
Resid_error[i] = RMSE(X*ROI, X_ideal*ROI)
string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n'
print (string.format(i,Resid_error[i], objective[i]))
-
+
+ numpy.save("X_out_os.npy", X)
else:
fistaRecon = FISTAReconstructor(proj_geom,
diff --git a/src/Python/test/test_reconstructor.py b/src/Python/test/test_reconstructor.py
new file mode 100644
index 0000000..3342301
--- /dev/null
+++ b/src/Python/test/test_reconstructor.py
@@ -0,0 +1,309 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Wed Aug 23 16:34:49 2017
+
+@author: ofn77899
+Based on DemoRD2.m
+"""
+
+import h5py
+import numpy
+
+from ccpi.reconstruction.FISTAReconstructor import FISTAReconstructor
+import astra
+import matplotlib.pyplot as plt
+
+def RMSE(signal1, signal2):
+ '''RMSE Root Mean Squared Error'''
+ if numpy.shape(signal1) == numpy.shape(signal2):
+ err = (signal1 - signal2)
+ err = numpy.sum( err * err )/numpy.size(signal1); # MSE
+ err = sqrt(err); # RMSE
+ return err
+ else:
+ raise Exception('Input signals must have the same shape')
+
+filename = r'/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/demos/DendrData.h5'
+nx = h5py.File(filename, "r")
+#getEntry(nx, '/')
+# I have exported the entries as children of /
+entries = [entry for entry in nx['/'].keys()]
+print (entries)
+
+Sino3D = numpy.asarray(nx.get('/Sino3D'), dtype="float32")
+Weights3D = numpy.asarray(nx.get('/Weights3D'), dtype="float32")
+angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0]
+angles_rad = numpy.asarray(nx.get('/angles_rad'), dtype="float32")
+recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0]
+size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0]
+slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0]
+
+Z_slices = 20
+det_row_count = Z_slices
+# next definition is just for consistency of naming
+det_col_count = size_det
+
+detectorSpacingX = 1.0
+detectorSpacingY = detectorSpacingX
+
+
+proj_geom = astra.creators.create_proj_geom('parallel3d',
+ detectorSpacingX,
+ detectorSpacingY,
+ det_row_count,
+ det_col_count,
+ angles_rad)
+
+#vol_geom = astra_create_vol_geom(recon_size,recon_size,Z_slices);
+image_size_x = recon_size
+image_size_y = recon_size
+image_size_z = Z_slices
+vol_geom = astra.creators.create_vol_geom( image_size_x,
+ image_size_y,
+ image_size_z)
+
+## First pass the arguments to the FISTAReconstructor and test the
+## Lipschitz constant
+
+fistaRecon = FISTAReconstructor(proj_geom,
+ vol_geom,
+ Sino3D ,
+ weights=Weights3D)
+
+print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant']))
+fistaRecon.setParameter(number_of_iterations = 12)
+fistaRecon.setParameter(Lipschitz_constant = 767893952.0)
+fistaRecon.setParameter(ring_alpha = 21)
+fistaRecon.setParameter(ring_lambda_R_L1 = 0.002)
+
+reg = Regularizer(Regularizer.Algorithm.LLT_model)
+reg.setParameter(regularization_parameter=25,
+ time_step=0.0003,
+ tolerance_constant=0.0001,
+ number_of_iterations=300)
+fistaRecon.setParameter(regularizer = reg)
+
+## Ordered subset
+if False:
+ subsets = 16
+ angles = fistaRecon.getParameter('projector_geometry')['ProjectionAngles']
+ #binEdges = numpy.linspace(angles.min(),
+ # angles.max(),
+ # subsets + 1)
+ binsDiscr, binEdges = numpy.histogram(angles, bins=subsets)
+ # get rearranged subset indices
+ IndicesReorg = numpy.zeros((numpy.shape(angles)))
+ counterM = 0
+ for ii in range(binsDiscr.max()):
+ counter = 0
+ for jj in range(subsets):
+ curr_index = ii + jj + counter
+ #print ("{0} {1} {2}".format(binsDiscr[jj] , ii, counterM))
+ if binsDiscr[jj] > ii:
+ if (counterM < numpy.size(IndicesReorg)):
+ IndicesReorg[counterM] = curr_index
+ counterM = counterM + 1
+
+ counter = counter + binsDiscr[jj] - 1
+
+
+if False:
+ print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant']))
+ print ("prepare for iteration")
+ fistaRecon.prepareForIteration()
+
+
+
+ print("initializing ...")
+ if False:
+ # if X doesn't exist
+ #N = params.vol_geom.GridColCount
+ N = vol_geom['GridColCount']
+ print ("N " + str(N))
+ X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float)
+ else:
+ #X = fistaRecon.initialize()
+ X = numpy.load("X.npy")
+
+ print (numpy.shape(X))
+ X_t = X.copy()
+ print ("initialized")
+ proj_geom , vol_geom, sino , \
+ SlicesZ = fistaRecon.getParameter(['projector_geometry' ,
+ 'output_geometry',
+ 'input_sinogram',
+ 'SlicesZ'])
+
+ #fistaRecon.setParameter(number_of_iterations = 3)
+ iterFISTA = fistaRecon.getParameter('number_of_iterations')
+ # errors vector (if the ground truth is given)
+ Resid_error = numpy.zeros((iterFISTA));
+ # objective function values vector
+ objective = numpy.zeros((iterFISTA));
+
+
+ t = 1
+
+
+ print ("starting iterations")
+## % Outer FISTA iterations loop
+ for i in range(fistaRecon.getParameter('number_of_iterations')):
+ X_old = X.copy()
+ t_old = t
+ r_old = fistaRecon.r.copy()
+ if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \
+ fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat' or \
+ fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat_vec' :
+ # if the geometry is parallel use slice-by-slice
+ # projection-backprojection routine
+ #sino_updt = zeros(size(sino),'single');
+ proj_geomT = proj_geom.copy()
+ proj_geomT['DetectorRowCount'] = 1
+ vol_geomT = vol_geom.copy()
+ vol_geomT['GridSliceCount'] = 1;
+ sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float)
+ for kkk in range(SlicesZ):
+ sino_id, sino_updt[kkk] = \
+ astra.creators.create_sino3d_gpu(
+ X_t[kkk:kkk+1], proj_geom, vol_geom)
+ astra.matlab.data3d('delete', sino_id)
+ else:
+ # for divergent 3D geometry (watch the GPU memory overflow in
+ # ASTRA versions < 1.8)
+ #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom);
+ sino_id, sino_updt = astra.creators.create_sino3d_gpu(
+ X_t, proj_geom, vol_geom)
+
+ ## RING REMOVAL
+ residual = fistaRecon.residual
+ lambdaR_L1 , alpha_ring , weights , L_const= \
+ fistaRecon.getParameter(['ring_lambda_R_L1',
+ 'ring_alpha' , 'weights',
+ 'Lipschitz_constant'])
+ r_x = fistaRecon.r_x
+ SlicesZ, anglesNumb, Detectors = \
+ numpy.shape(fistaRecon.getParameter('input_sinogram'))
+ if lambdaR_L1 > 0 :
+ print ("ring removal")
+ for kkk in range(anglesNumb):
+
+ residual[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \
+ ((sino_updt[:,kkk,:]).squeeze() - \
+ (sino[:,kkk,:]).squeeze() -\
+ (alpha_ring * r_x)
+ )
+ vec = residual.sum(axis = 1)
+ #if SlicesZ > 1:
+ # vec = vec[:,1,:].squeeze()
+ fistaRecon.r = (r_x - (1./L_const) * vec).copy()
+ objective[i] = (0.5 * (residual ** 2).sum())
+## % the ring removal part (Group-Huber fidelity)
+## for kkk = 1:anglesNumb
+## residual(:,kkk,:) = squeeze(weights(:,kkk,:)).*
+## (squeeze(sino_updt(:,kkk,:)) -
+## (squeeze(sino(:,kkk,:)) - alpha_ring.*r_x));
+## end
+## vec = sum(residual,2);
+## if (SlicesZ > 1)
+## vec = squeeze(vec(:,1,:));
+## end
+## r = r_x - (1./L_const).*vec;
+## objective(i) = (0.5*sum(residual(:).^2)); % for the objective function output
+
+
+
+ # Projection/Backprojection Routine
+ if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \
+ fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat' or\
+ fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat_vec':
+ x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32)
+ print ("Projection/Backprojection Routine")
+ for kkk in range(SlicesZ):
+
+ x_id, x_temp[kkk] = \
+ astra.creators.create_backprojection3d_gpu(
+ residual[kkk:kkk+1],
+ proj_geomT, vol_geomT)
+ astra.matlab.data3d('delete', x_id)
+ else:
+ x_id, x_temp = \
+ astra.creators.create_backprojection3d_gpu(
+ residual, proj_geom, vol_geom)
+
+ X = X_t - (1/L_const) * x_temp
+ astra.matlab.data3d('delete', sino_id)
+ astra.matlab.data3d('delete', x_id)
+
+
+ ## REGULARIZATION
+ ## SKIPPING FOR NOW
+ ## Should be simpli
+ # regularizer = fistaRecon.getParameter('regularizer')
+ # for slices:
+ # out = regularizer(input=X)
+ print ("skipping regularizer")
+
+
+ ## FINAL
+ print ("final")
+ lambdaR_L1 = fistaRecon.getParameter('ring_lambda_R_L1')
+ if lambdaR_L1 > 0:
+ fistaRecon.r = numpy.max(
+ numpy.abs(fistaRecon.r) - lambdaR_L1 , 0) * \
+ numpy.sign(fistaRecon.r)
+ t = (1 + numpy.sqrt(1 + 4 * t**2))/2
+ X_t = X + (((t_old -1)/t) * (X - X_old))
+
+ if lambdaR_L1 > 0:
+ fistaRecon.r_x = fistaRecon.r + \
+ (((t_old-1)/t) * (fistaRecon.r - r_old))
+
+ if fistaRecon.getParameter('region_of_interest') is None:
+ string = 'Iteration Number {0} | Objective {1} \n'
+ print (string.format( i, objective[i]))
+ else:
+ ROI , X_ideal = fistaRecon.getParameter('region_of_interest',
+ 'ideal_image')
+
+ Resid_error[i] = RMSE(X*ROI, X_ideal*ROI)
+ string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n'
+ print (string.format(i,Resid_error[i], objective[i]))
+
+## if (lambdaR_L1 > 0)
+## r = max(abs(r)-lambdaR_L1, 0).*sign(r); % soft-thresholding operator for ring vector
+## end
+##
+## t = (1 + sqrt(1 + 4*t^2))/2; % updating t
+## X_t = X + ((t_old-1)/t).*(X - X_old); % updating X
+##
+## if (lambdaR_L1 > 0)
+## r_x = r + ((t_old-1)/t).*(r - r_old); % updating r
+## end
+##
+## if (show == 1)
+## figure(10); imshow(X(:,:,slice), [0 maxvalplot]);
+## if (lambdaR_L1 > 0)
+## figure(11); plot(r); title('Rings offset vector')
+## end
+## pause(0.01);
+## end
+## if (strcmp(X_ideal, 'none' ) == 0)
+## Resid_error(i) = RMSE(X(ROI), X_ideal(ROI));
+## fprintf('%s %i %s %s %.4f %s %s %f \n', 'Iteration Number:', i, '|', 'Error RMSE:', Resid_error(i), '|', 'Objective:', objective(i));
+## else
+## fprintf('%s %i %s %s %f \n', 'Iteration Number:', i, '|', 'Objective:', objective(i));
+## end
+else:
+ fistaRecon = FISTAReconstructor(proj_geom,
+ vol_geom,
+ Sino3D ,
+ weights=Weights3D)
+
+ print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant']))
+ fistaRecon.setParameter(number_of_iterations = 12)
+ fistaRecon.setParameter(Lipschitz_constant = 767893952.0)
+ fistaRecon.setParameter(ring_alpha = 21)
+ fistaRecon.setParameter(ring_lambda_R_L1 = 0.002)
+ fistaRecon.prepareForIteration()
+ X = fistaRecon.iterate(numpy.load("X.npy"))
+ numpy.save("X_out.npy", X)
diff --git a/supp/sino_add_artifacts.m b/supp/sino_add_artifacts.m
new file mode 100644
index 0000000..f601914
--- /dev/null
+++ b/supp/sino_add_artifacts.m
@@ -0,0 +1,33 @@
+function sino_artifacts = sino_add_artifacts(sino,artifact_type)
+% function to add various distortions to the sinogram space, current
+% version includes: random rings and zingers (streaks)
+% Input:
+% 1. sinogram
+% 2. artifact type: 'rings' or 'zingers' (streaks)
+
+
+[Detectors, anglesNumb, SlicesZ] = size(sino);
+fprintf('%s %i %s %i %s %i %s \n', 'Sinogram has a dimension of', Detectors, 'detectors;', anglesNumb, 'projections;', SlicesZ, 'vertical slices.');
+
+sino_artifacts = sino;
+
+if (strcmp(artifact_type,'rings'))
+ fprintf('%s \n', 'Adding rings...');
+ NumRings = round(Detectors/20); % Number of rings relatively to the size of Detectors
+ IntenOff = linspace(0.05,0.5,NumRings); % the intensity of rings in the selected range
+
+ for k = 1:SlicesZ
+ % generate random indices to propagate rings
+ RandInd = randperm(Detectors,Detectors);
+ for jj = 1:NumRings
+ ind_c = RandInd(jj);
+ sino_artifacts(ind_c,1:end,k) = sino_artifacts(ind_c,1:end,k) + IntenOff(jj).*sino_artifacts(ind_c,1:end,k); % generate a constant offset
+ end
+
+ end
+elseif (strcmp(artifact_type,'zingers'))
+ fprintf('%s \n', 'Adding zingers...');
+else
+ fprintf('%s \n', 'Nothing selected, the same sinogram returned...');
+end
+end \ No newline at end of file