From cbeb8a1174498751e38a3de8cd6fe55abae20192 Mon Sep 17 00:00:00 2001 From: Daniil Kazantsev Date: Thu, 3 Aug 2017 00:10:53 +0100 Subject: some cleaning inside C functions, updated mex_compile file and the work on ordered-subset FISTA started --- demos/DemoRD2.m | 4 +- main_func/FISTA_REC.m | 40 ++++++++++++- main_func/compile_mex.m | 11 ++-- main_func/regularizers_CPU/FGP_TV_core.c | 26 +++++++++ main_func/regularizers_CPU/FGP_TV_core.h | 27 --------- main_func/regularizers_CPU/LLT_model.c | 26 +++++++++ main_func/regularizers_CPU/LLT_model_core.c | 25 ++++++++ main_func/regularizers_CPU/LLT_model_core.h | 25 -------- main_func/regularizers_CPU/PatchBased_Regul.c | 43 +++++++------- main_func/regularizers_CPU/PatchBased_Regul_core.c | 57 ++++++++---------- main_func/regularizers_CPU/PatchBased_Regul_core.h | 33 ----------- main_func/regularizers_CPU/SplitBregman_TV_core.c | 1 - main_func/regularizers_CPU/SplitBregman_TV_core.h | 24 -------- main_func/regularizers_CPU/TGV_PD.c | 68 +++++----------------- main_func/regularizers_CPU/TGV_PD_core.c | 1 - main_func/regularizers_CPU/TGV_PD_core.h | 29 --------- 16 files changed, 182 insertions(+), 258 deletions(-) diff --git a/demos/DemoRD2.m b/demos/DemoRD2.m index 0db3595..293c1cd 100644 --- a/demos/DemoRD2.m +++ b/demos/DemoRD2.m @@ -30,7 +30,7 @@ Weights3D = single(data_raw3D); % weights for PW model clear data_raw3D %% % set projection/reconstruction geometry here -Z_slices = 3; +Z_slices = 20; det_row_count = Z_slices; proj_geom = astra_create_proj_geom('parallel3d', 1, 1, det_row_count, size_det, angles_rad); vol_geom = astra_create_vol_geom(recon_size,recon_size,Z_slices); @@ -44,7 +44,7 @@ clear params params.proj_geom = proj_geom; % pass geometry to the function params.vol_geom = vol_geom; params.sino = Sino3D; -params.iterFISTA = 35; +params.iterFISTA = 1; params.weights = Weights3D; params.show = 1; params.maxvalplot = 2.5; params.slice = 2; diff --git a/main_func/FISTA_REC.m b/main_func/FISTA_REC.m index 2823691..1e93719 100644 --- a/main_func/FISTA_REC.m +++ b/main_func/FISTA_REC.m @@ -99,7 +99,7 @@ else clear proj_geomT vol_geomT else % divergen beam geometry - fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); + fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry... will take some time!'); niter = 8; % number of iteration for PM x1 = rand(N,N,SlicesZ); sqweight = sqrt(weights); @@ -216,6 +216,39 @@ if (isfield(params,'initialize')) else X = zeros(N,N,SlicesZ, 'single'); % storage for the solution end +if (isfield(params,'OS')) + % Ordered Subsets reorganisation of data and angles + OS = 1; + subsets = 8; + angles = proj_geom.ProjectionAngles; + binEdges = linspace(min(angles),max(angles),subsets+1); + + % assign values to bins + [binsDiscr,~] = histc(angles, [binEdges(1:end-1) Inf]); + + % rearrange angles into subsets + AnglesReorg = zeros(length(angles),1); + SinoReorg = zeros(Detectors, anglesNumb, SlicesZ, 'single'); + + counterM = 0; + for ii = 1:max(binsDiscr(:)) + counter = 0; + for jj = 1:subsets + curr_index = ii+jj-1 + counter; + if (binsDiscr(jj) >= ii) + counterM = counterM + 1; + AnglesReorg(counterM) = angles(curr_index); + SinoReorg(:,counterM,:) = squeeze(sino(:,curr_index,:)); + end + counter = (counter + binsDiscr(jj)) - 1; + end + end + sino = SinoReorg; + clear SinoReorg; +else + OS = 0; % normal FISTA +end + %----------------Reconstruction part------------------------ Resid_error = zeros(iterFISTA,1); % errors vector (if the ground truth is given) @@ -241,7 +274,7 @@ for i = 1:iterFISTA % ring removal part (Group-Huber fidelity) for kkk = 1:anglesNumb residual(:,kkk,:) = squeeze(weights(:,kkk,:)).*(squeeze(sino_updt(:,kkk,:)) - (squeeze(sino(:,kkk,:)) - alpha_ring.*r_x)); - end + end vec = sum(residual,2); if (SlicesZ > 1) vec = squeeze(vec(:,1,:)); @@ -249,7 +282,7 @@ for i = 1:iterFISTA r = r_x - (1./L_const).*vec; else % no ring removal - residual = weights.*(sino_updt - sino); + residual = weights.*(sino_updt - sino); end objective(i) = (0.5*norm(residual(:))^2)/(Detectors*anglesNumb*SlicesZ); % for the objective function output @@ -305,5 +338,6 @@ end output.Resid_error = Resid_error; output.objective = objective; output.L_const = L_const; +output.sino = sino; %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% end diff --git a/main_func/compile_mex.m b/main_func/compile_mex.m index 7bfa8eb..66c05da 100644 --- a/main_func/compile_mex.m +++ b/main_func/compile_mex.m @@ -1,10 +1,11 @@ -% compile mex's once +% compile mex's in Matlab once cd regularizers_CPU/ -mex LLT_model.c CFLAGS="\$CFLAGS -fopenmp -Wall -std=c99" LDFLAGS="\$LDFLAGS -fopenmp" -mex FGP_TV.c CFLAGS="\$CFLAGS -fopenmp -Wall -std=c99" LDFLAGS="\$LDFLAGS -fopenmp" -mex SplitBregman_TV.c CFLAGS="\$CFLAGS -fopenmp -Wall -std=c99" LDFLAGS="\$LDFLAGS -fopenmp" -mex TGV_PD.c CFLAGS="\$CFLAGS -fopenmp -Wall -std=c99" LDFLAGS="\$LDFLAGS -fopenmp" +mex LLT_model.c LLT_model_core.c CFLAGS="\$CFLAGS -fopenmp -Wall -std=c99" LDFLAGS="\$LDFLAGS -fopenmp" +mex FGP_TV.c FGP_TV_core.c CFLAGS="\$CFLAGS -fopenmp -Wall -std=c99" LDFLAGS="\$LDFLAGS -fopenmp" +mex SplitBregman_TV.c SplitBregman_TV_core.c CFLAGS="\$CFLAGS -fopenmp -Wall -std=c99" LDFLAGS="\$LDFLAGS -fopenmp" +mex TGV_PD.c TGV_PD_core.c CFLAGS="\$CFLAGS -fopenmp -Wall -std=c99" LDFLAGS="\$LDFLAGS -fopenmp" +mex PatchBased_Regul.c PatchBased_Regul_core.c CFLAGS="\$CFLAGS -fopenmp -Wall -std=c99" LDFLAGS="\$LDFLAGS -fopenmp" cd ../../ cd demos diff --git a/main_func/regularizers_CPU/FGP_TV_core.c b/main_func/regularizers_CPU/FGP_TV_core.c index c383a71..a55991c 100644 --- a/main_func/regularizers_CPU/FGP_TV_core.c +++ b/main_func/regularizers_CPU/FGP_TV_core.c @@ -19,6 +19,32 @@ limitations under the License. #include "FGP_TV_core.h" +/* C-OMP implementation of FGP-TV [1] denoising/regularization model (2D/3D case) + * + * Input Parameters: + * 1. Noisy image/volume [REQUIRED] + * 2. lambda - regularization parameter [REQUIRED] + * 3. Number of iterations [OPTIONAL parameter] + * 4. eplsilon: tolerance constant [OPTIONAL parameter] + * 5. TV-type: 'iso' or 'l1' [OPTIONAL parameter] + * + * Output: + * [1] Filtered/regularized image + * [2] last function value + * + * Example of image denoising: + * figure; + * Im = double(imread('lena_gray_256.tif'))/255; % loading image + * u0 = Im + .05*randn(size(Im)); % adding noise + * u = FGP_TV(single(u0), 0.05, 100, 1e-04); + * + * This function is based on the Matlab's code and paper by + * [1] Amir Beck and Marc Teboulle, "Fast Gradient-Based Algorithms for Constrained Total Variation Image Denoising and Deblurring Problems" + * + * D. Kazantsev, 2016-17 + * + */ + /* 2D-case related Functions */ /*****************************************************************/ float Obj_func2D(float *A, float *D, float *R1, float *R2, float lambda, int dimX, int dimY) diff --git a/main_func/regularizers_CPU/FGP_TV_core.h b/main_func/regularizers_CPU/FGP_TV_core.h index 8d6af3e..e5328fb 100644 --- a/main_func/regularizers_CPU/FGP_TV_core.h +++ b/main_func/regularizers_CPU/FGP_TV_core.h @@ -24,33 +24,6 @@ limitations under the License. #include #include "omp.h" -/* C-OMP implementation of FGP-TV [1] denoising/regularization model (2D/3D case) -* -* Input Parameters: -* 1. Noisy image/volume [REQUIRED] -* 2. lambda - regularization parameter [REQUIRED] -* 3. Number of iterations [OPTIONAL parameter] -* 4. eplsilon: tolerance constant [OPTIONAL parameter] -* 5. TV-type: 'iso' or 'l1' [OPTIONAL parameter] -* -* Output: -* [1] Filtered/regularized image -* [2] last function value -* -* Example of image denoising: -* figure; -* Im = double(imread('lena_gray_256.tif'))/255; % loading image -* u0 = Im + .05*randn(size(Im)); % adding noise -* u = FGP_TV(single(u0), 0.05, 100, 1e-04); -* -* to compile with OMP support: mex FGP_TV.c CFLAGS="\$CFLAGS -fopenmp -Wall -std=c99" LDFLAGS="\$LDFLAGS -fopenmp" -* This function is based on the Matlab's code and paper by -* [1] Amir Beck and Marc Teboulle, "Fast Gradient-Based Algorithms for Constrained Total Variation Image Denoising and Deblurring Problems" -* -* D. Kazantsev, 2016-17 -* -*/ - float copyIm(float *A, float *B, int dimX, int dimY, int dimZ); float Obj_func2D(float *A, float *D, float *R1, float *R2, float lambda, int dimX, int dimY); float Grad_func2D(float *P1, float *P2, float *D, float *R1, float *R2, float lambda, int dimX, int dimY); diff --git a/main_func/regularizers_CPU/LLT_model.c b/main_func/regularizers_CPU/LLT_model.c index 6b50d33..47146a5 100644 --- a/main_func/regularizers_CPU/LLT_model.c +++ b/main_func/regularizers_CPU/LLT_model.c @@ -20,6 +20,32 @@ limitations under the License. #include "mex.h" #include "LLT_model_core.h" +/* C-OMP implementation of Lysaker, Lundervold and Tai (LLT) model of higher order regularization penalty +* +* Input Parameters: +* 1. U0 - origanal noise image/volume +* 2. lambda - regularization parameter +* 3. tau - time-step for explicit scheme +* 4. iter - iterations number +* 5. epsil - tolerance constant (to terminate earlier) +* 6. switcher - default is 0, switch to (1) to restrictive smoothing in Z dimension (in test) +* +* Output: +* Filtered/regularized image +* +* Example: +* figure; +* Im = double(imread('lena_gray_256.tif'))/255; % loading image +* u0 = Im + .03*randn(size(Im)); % adding noise +* [Den] = LLT_model(single(u0), 10, 0.1, 1); +* +* +* to compile with OMP support: mex LLT_model.c CFLAGS="\$CFLAGS -fopenmp -Wall -std=c99" LDFLAGS="\$LDFLAGS -fopenmp" +* References: Lysaker, Lundervold and Tai (LLT) 2003, IEEE +* +* 28.11.16/Harwell +*/ + void mexFunction( int nlhs, mxArray *plhs[], diff --git a/main_func/regularizers_CPU/LLT_model_core.c b/main_func/regularizers_CPU/LLT_model_core.c index 3098269..7a1cdbe 100644 --- a/main_func/regularizers_CPU/LLT_model_core.c +++ b/main_func/regularizers_CPU/LLT_model_core.c @@ -19,6 +19,31 @@ limitations under the License. #include "LLT_model_core.h" +/* C-OMP implementation of Lysaker, Lundervold and Tai (LLT) model of higher order regularization penalty +* +* Input Parameters: +* 1. U0 - origanal noise image/volume +* 2. lambda - regularization parameter +* 3. tau - time-step for explicit scheme +* 4. iter - iterations number +* 5. epsil - tolerance constant (to terminate earlier) +* 6. switcher - default is 0, switch to (1) to restrictive smoothing in Z dimension (in test) +* +* Output: +* Filtered/regularized image +* +* Example: +* figure; +* Im = double(imread('lena_gray_256.tif'))/255; % loading image +* u0 = Im + .03*randn(size(Im)); % adding noise +* [Den] = LLT_model(single(u0), 10, 0.1, 1); +* +* References: Lysaker, Lundervold and Tai (LLT) 2003, IEEE +* +* 28.11.16/Harwell +*/ + + float der2D(float *U, float *D1, float *D2, int dimX, int dimY, int dimZ) { int i, j, i_p, i_m, j_m, j_p; diff --git a/main_func/regularizers_CPU/LLT_model_core.h b/main_func/regularizers_CPU/LLT_model_core.h index ef8803d..10f52fe 100644 --- a/main_func/regularizers_CPU/LLT_model_core.h +++ b/main_func/regularizers_CPU/LLT_model_core.h @@ -26,31 +26,6 @@ limitations under the License. #define EPS 0.01 -/* C-OMP implementation of Lysaker, Lundervold and Tai (LLT) model of higher order regularization penalty -* -* Input Parameters: -* 1. U0 - origanal noise image/volume -* 2. lambda - regularization parameter -* 3. tau - time-step for explicit scheme -* 4. iter - iterations number -* 5. epsil - tolerance constant (to terminate earlier) -* 6. switcher - default is 0, switch to (1) to restrictive smoothing in Z dimension (in test) -* -* Output: -* Filtered/regularized image -* -* Example: -* figure; -* Im = double(imread('lena_gray_256.tif'))/255; % loading image -* u0 = Im + .03*randn(size(Im)); % adding noise -* [Den] = LLT_model(single(u0), 10, 0.1, 1); -* -* -* to compile with OMP support: mex LLT_model.c CFLAGS="\$CFLAGS -fopenmp -Wall -std=c99" LDFLAGS="\$LDFLAGS -fopenmp" -* References: Lysaker, Lundervold and Tai (LLT) 2003, IEEE -* -* 28.11.16/Harwell -*/ /* 2D functions */ float der2D(float *U, float *D1, float *D2, int dimX, int dimY, int dimZ); float div_upd2D(float *U0, float *U, float *D1, float *D2, int dimX, int dimY, int dimZ, float lambda, float tau); diff --git a/main_func/regularizers_CPU/PatchBased_Regul.c b/main_func/regularizers_CPU/PatchBased_Regul.c index 24ee210..59eb3b3 100644 --- a/main_func/regularizers_CPU/PatchBased_Regul.c +++ b/main_func/regularizers_CPU/PatchBased_Regul.c @@ -27,27 +27,23 @@ limitations under the License. * References: 1. Yang Z. & Jacob M. "Nonlocal Regularization of Inverse Problems" * 2. Kazantsev D. et al. "4D-CT reconstruction with unified spatial-temporal patch-based regularization" * - * Input Parameters (mandatory): - * 1. Image (2D or 3D) - * 2. ratio of the searching window (e.g. 3 = (2*3+1) = 7 pixels window) - * 3. ratio of the similarity window (e.g. 1 = (2*1+1) = 3 pixels window) - * 4. h - parameter for the PB penalty function - * 5. lambda - regularization parameter + * Input Parameters: + * 1. Image (2D or 3D) [required] + * 2. ratio of the searching window (e.g. 3 = (2*3+1) = 7 pixels window) [optional] + * 3. ratio of the similarity window (e.g. 1 = (2*1+1) = 3 pixels window) [optional] + * 4. h - parameter for the PB penalty function [optional] + * 5. lambda - regularization parameter [optional] * Output: * 1. regularized (denoised) Image (N x N)/volume (N x N x N) * - * Quick 2D denoising example in Matlab: + * 2D denoising example in Matlab: Im = double(imread('lena_gray_256.tif'))/255; % loading image u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise - ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); - * - * Please see more tests in a file: - TestTemporalSmoothing.m - + ImDen = PatchBased_Regul(single(u0), 3, 1, 0.08, 0.05); * * Matlab + C/mex compilers needed - * to compile with OMP support: mex PB_Regul_CPU.c CFLAGS="\$CFLAGS -fopenmp -Wall" LDFLAGS="\$LDFLAGS -fopenmp" + * to compile with OMP support: mex PatchBased_Regul.c CFLAGS="\$CFLAGS -fopenmp -Wall" LDFLAGS="\$LDFLAGS -fopenmp" * * D. Kazantsev * * 02/07/2014 @@ -70,17 +66,23 @@ void mexFunction( M = dims[1]; Z = dims[2]; - if ((numdims < 2) || (numdims > 3)) {mexErrMsgTxt("The input should be 2D image or 3D volume");} + if ((numdims < 2) || (numdims > 3)) {mexErrMsgTxt("The input is 2D image or 3D volume");} if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) {mexErrMsgTxt("The input in single precision is required"); } if(nrhs != 5) mexErrMsgTxt("Five inputs reqired: Image(2D,3D), SearchW, SimilW, Threshold, Regularization parameter"); /*Handling inputs*/ - A = (float *) mxGetData(prhs[0]); /* the image to regularize/filter */ - SearchW_real = (int) mxGetScalar(prhs[1]); /* the searching window ratio */ - SimilW = (int) mxGetScalar(prhs[2]); /* the similarity window ratio */ - h = (float) mxGetScalar(prhs[3]); /* parameter for the PB filtering function */ - lambda = (float) mxGetScalar(prhs[4]); /* regularization parameter */ + A = (float *) mxGetData(prhs[0]); /* the image/volume to regularize/filter */ + SearchW_real = 3; /*default value*/ + SimilW = 1; /*default value*/ + h = 0.1; + lambda = 0.1; + + if ((nrhs == 2) || (nrhs == 3) || (nrhs == 4) || (nrhs == 5)) SearchW_real = (int) mxGetScalar(prhs[1]); /* the searching window ratio */ + if ((nrhs == 3) || (nrhs == 4) || (nrhs == 5)) SimilW = (int) mxGetScalar(prhs[2]); /* the similarity window ratio */ + if ((nrhs == 4) || (nrhs == 5)) h = (float) mxGetScalar(prhs[3]); /* parameter for the PB filtering function */ + if ((nrhs == 5)) lambda = (float) mxGetScalar(prhs[4]); /* regularization parameter */ + if (h <= 0) mexErrMsgTxt("Parmeter for the PB penalty function should be > 0"); if (lambda <= 0) mexErrMsgTxt(" Regularization parmeter should be > 0"); @@ -89,7 +91,6 @@ void mexFunction( /* SearchW_full = 2*SearchW + 1; */ /* the full searching window size */ /* SimilW_full = 2*SimilW + 1; */ /* the full similarity window size */ - padXY = SearchW + 2*SimilW; /* padding sizes */ newsizeX = N + 2*(padXY); /* the X size of the padded array */ @@ -135,4 +136,4 @@ void mexFunction( switchpad_crop = 1; /*cropping*/ pad_crop(Bp, B, M, N, Z, newsizeY, newsizeX, newsizeZ, padXY, switchpad_crop); } /*end else ndims*/ -} \ No newline at end of file +} diff --git a/main_func/regularizers_CPU/PatchBased_Regul_core.c b/main_func/regularizers_CPU/PatchBased_Regul_core.c index 6f0a48d..acfb464 100644 --- a/main_func/regularizers_CPU/PatchBased_Regul_core.c +++ b/main_func/regularizers_CPU/PatchBased_Regul_core.c @@ -19,40 +19,33 @@ limitations under the License. #include "PatchBased_Regul_core.h" -/* C-OMP implementation of patch-based (PB) regularization (2D and 3D cases). -* This method finds self-similar patches in data and performs one fixed point iteration to mimimize the PB penalty function -* -* References: 1. Yang Z. & Jacob M. "Nonlocal Regularization of Inverse Problems" -* 2. Kazantsev D. et al. "4D-CT reconstruction with unified spatial-temporal patch-based regularization" -* -* Input Parameters (mandatory): -* 1. Image (2D or 3D) -* 2. ratio of the searching window (e.g. 3 = (2*3+1) = 7 pixels window) -* 3. ratio of the similarity window (e.g. 1 = (2*1+1) = 3 pixels window) -* 4. h - parameter for the PB penalty function -* 5. lambda - regularization parameter +/* C-OMP implementation of patch-based (PB) regularization (2D and 3D cases). + * This method finds self-similar patches in data and performs one fixed point iteration to mimimize the PB penalty function + * + * References: 1. Yang Z. & Jacob M. "Nonlocal Regularization of Inverse Problems" + * 2. Kazantsev D. et al. "4D-CT reconstruction with unified spatial-temporal patch-based regularization" + * + * Input Parameters: + * 1. Image (2D or 3D) [required] + * 2. ratio of the searching window (e.g. 3 = (2*3+1) = 7 pixels window) [optional] + * 3. ratio of the similarity window (e.g. 1 = (2*1+1) = 3 pixels window) [optional] + * 4. h - parameter for the PB penalty function [optional] + * 5. lambda - regularization parameter [optional] -* Output: -* 1. regularized (denoised) Image (N x N)/volume (N x N x N) -* -* Quick 2D denoising example in Matlab: -Im = double(imread('lena_gray_256.tif'))/255; % loading image -u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise -ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); -* -* Please see more tests in a file: -TestTemporalSmoothing.m - -* -* Matlab + C/mex compilers needed -* to compile with OMP support: mex PB_Regul_CPU.c CFLAGS="\$CFLAGS -fopenmp -Wall" LDFLAGS="\$LDFLAGS -fopenmp" -* -* D. Kazantsev * -* 02/07/2014 -* Harwell, UK -*/ + * Output: + * 1. regularized (denoised) Image (N x N)/volume (N x N x N) + * + * 2D denoising example in Matlab: + Im = double(imread('lena_gray_256.tif'))/255; % loading image + u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise + ImDen = PatchBased_Regul(single(u0), 3, 1, 0.08, 0.05); + + * D. Kazantsev * + * 02/07/2014 + * Harwell, UK + */ -/*2D version*/ +/*2D version function */ float PB_FUNC2D(float *A, float *B, int dimX, int dimY, int padXY, int SearchW, int SimilW, float h, float lambda) { int i, j, i_n, j_n, i_m, j_m, i_p, j_p, i_l, j_l, i1, j1, i2, j2, i3, j3, i5,j5, count, SimilW_full; diff --git a/main_func/regularizers_CPU/PatchBased_Regul_core.h b/main_func/regularizers_CPU/PatchBased_Regul_core.h index b83cf10..5aa6415 100644 --- a/main_func/regularizers_CPU/PatchBased_Regul_core.h +++ b/main_func/regularizers_CPU/PatchBased_Regul_core.h @@ -26,39 +26,6 @@ limitations under the License. #include #include "omp.h" -/* C-OMP implementation of patch-based (PB) regularization (2D and 3D cases). -* This method finds self-similar patches in data and performs one fixed point iteration to mimimize the PB penalty function -* -* References: 1. Yang Z. & Jacob M. "Nonlocal Regularization of Inverse Problems" -* 2. Kazantsev D. et al. "4D-CT reconstruction with unified spatial-temporal patch-based regularization" -* -* Input Parameters (mandatory): -* 1. Image (2D or 3D) -* 2. ratio of the searching window (e.g. 3 = (2*3+1) = 7 pixels window) -* 3. ratio of the similarity window (e.g. 1 = (2*1+1) = 3 pixels window) -* 4. h - parameter for the PB penalty function -* 5. lambda - regularization parameter - -* Output: -* 1. regularized (denoised) Image (N x N)/volume (N x N x N) -* -* Quick 2D denoising example in Matlab: -Im = double(imread('lena_gray_256.tif'))/255; % loading image -u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise -ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); -* -* Please see more tests in a file: -TestTemporalSmoothing.m - -* -* Matlab + C/mex compilers needed -* to compile with OMP support: mex PB_Regul_CPU.c CFLAGS="\$CFLAGS -fopenmp -Wall" LDFLAGS="\$LDFLAGS -fopenmp" -* -* D. Kazantsev * -* 02/07/2014 -* Harwell, UK -*/ - float pad_crop(float *A, float *Ap, int OldSizeX, int OldSizeY, int OldSizeZ, int NewSizeX, int NewSizeY, int NewSizeZ, int padXY, int switchpad_crop); float PB_FUNC2D(float *A, float *B, int dimX, int dimY, int padXY, int SearchW, int SimilW, float h, float lambda); float PB_FUNC3D(float *A, float *B, int dimX, int dimY, int dimZ, int padXY, int SearchW, int SimilW, float h, float lambda); diff --git a/main_func/regularizers_CPU/SplitBregman_TV_core.c b/main_func/regularizers_CPU/SplitBregman_TV_core.c index 26ad5b1..283dd43 100644 --- a/main_func/regularizers_CPU/SplitBregman_TV_core.c +++ b/main_func/regularizers_CPU/SplitBregman_TV_core.c @@ -37,7 +37,6 @@ limitations under the License. * u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0; * u = SplitBregman_TV(single(u0), 10, 30, 1e-04); * -* to compile with OMP support: mex SplitBregman_TV.c CFLAGS="\$CFLAGS -fopenmp -Wall -std=c99" LDFLAGS="\$LDFLAGS -fopenmp" * References: * The Split Bregman Method for L1 Regularized Problems, by Tom Goldstein and Stanley Osher. * D. Kazantsev, 2016* diff --git a/main_func/regularizers_CPU/SplitBregman_TV_core.h b/main_func/regularizers_CPU/SplitBregman_TV_core.h index ed9112c..a7aaabb 100644 --- a/main_func/regularizers_CPU/SplitBregman_TV_core.h +++ b/main_func/regularizers_CPU/SplitBregman_TV_core.h @@ -23,30 +23,6 @@ limitations under the License. #include #include "omp.h" -/* C-OMP implementation of Split Bregman - TV denoising-regularization model (2D/3D) -* -* Input Parameters: -* 1. Noisy image/volume -* 2. lambda - regularization parameter -* 3. Number of iterations [OPTIONAL parameter] -* 4. eplsilon - tolerance constant [OPTIONAL parameter] -* 5. TV-type: 'iso' or 'l1' [OPTIONAL parameter] -* -* Output: -* Filtered/regularized image -* -* Example: -* figure; -* Im = double(imread('lena_gray_256.tif'))/255; % loading image -* u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0; -* u = SplitBregman_TV(single(u0), 10, 30, 1e-04); -* -* to compile with OMP support: mex SplitBregman_TV.c CFLAGS="\$CFLAGS -fopenmp -Wall -std=c99" LDFLAGS="\$LDFLAGS -fopenmp" -* References: -* The Split Bregman Method for L1 Regularized Problems, by Tom Goldstein and Stanley Osher. -* D. Kazantsev, 2016* -*/ - float copyIm(float *A, float *B, int dimX, int dimY, int dimZ); float gauss_seidel2D(float *U, float *A, float *Dx, float *Dy, float *Bx, float *By, int dimX, int dimY, float lambda, float mu); float updDxDy_shrinkAniso2D(float *U, float *Dx, float *Dy, float *Bx, float *By, int dimX, int dimY, float lambda); diff --git a/main_func/regularizers_CPU/TGV_PD.c b/main_func/regularizers_CPU/TGV_PD.c index 8bce18a..d4bb787 100644 --- a/main_func/regularizers_CPU/TGV_PD.c +++ b/main_func/regularizers_CPU/TGV_PD.c @@ -39,7 +39,7 @@ limitations under the License. * u0 = Im + .03*randn(size(Im)); % adding noise * tic; u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); toc; * - * to compile with OMP support: mex TGV_PD.c CFLAGS="\$CFLAGS -fopenmp -Wall -std=c99" LDFLAGS="\$LDFLAGS -fopenmp" + * to compile with OMP support: mex TGV_PD.c TGV_PD_core.c CFLAGS="\$CFLAGS -fopenmp -Wall -std=c99" LDFLAGS="\$LDFLAGS -fopenmp" * References: * K. Bredies "Total Generalized Variation" * @@ -53,7 +53,7 @@ void mexFunction( { int number_of_dims, iter, dimX, dimY, dimZ, ll; const int *dim_array; - float *A, *U, *U_old, *P1, *P2, *P3, *Q1, *Q2, *Q3, *Q4, *Q5, *Q6, *Q7, *Q8, *Q9, *V1, *V1_old, *V2, *V2_old, *V3, *V3_old, lambda, L2, tau, sigma, alpha1, alpha0; + float *A, *U, *U_old, *P1, *P2, *Q1, *Q2, *Q3, *V1, *V1_old, *V2, *V2_old, lambda, L2, tau, sigma, alpha1, alpha0; number_of_dims = mxGetNumberOfDimensions(prhs[0]); dim_array = mxGetDimensions(prhs[0]); @@ -88,40 +88,10 @@ void mexFunction( V1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); V1_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); V2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - V2_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); } - else if (number_of_dims == 3) { - mexErrMsgTxt("The input data should be 2D"); - /*3D case*/ -// dimZ = dim_array[2]; -// U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); -// -// P1 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); -// P2 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); -// P3 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); -// -// Q1 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); -// Q2 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); -// Q3 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); -// Q4 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); -// Q5 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); -// Q6 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); -// Q7 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); -// Q8 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); -// Q9 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); -// -// U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); -// -// V1 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); -// V1_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); -// V2 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); -// V2_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); -// V3 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); -// V3_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - } - else {mexErrMsgTxt("The input data should be 2D");} - - - /*printf("%i \n", i);*/ + V2_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + + + /*printf("%i \n", i);*/ L2 = 12.0; /*Lipshitz constant*/ tau = 1.0/pow(L2,0.5); sigma = 1.0/pow(L2,0.5); @@ -129,7 +99,6 @@ void mexFunction( /*Copy A to U*/ copyIm(A, U, dimX, dimY, dimZ); - if (number_of_dims == 2) { /* Here primal-dual iterations begin for 2D */ for(ll = 0; ll < iter; ll++) { @@ -164,23 +133,12 @@ void mexFunction( /*get new V*/ newU(V1, V1_old, dimX, dimY, dimZ); newU(V2, V2_old, dimX, dimY, dimZ); - } /*end of iterations*/ + } /*end of iterations*/ } - -// /*3D version*/ -// if (number_of_dims == 3) { -// /* Here primal-dual iterations begin for 3D */ -// for(ll = 0; ll < iter; ll++) { -// -// /* Calculate Dual Variable P */ -// DualP_3D(U, V1, V2, V3, P1, P2, P3, dimX, dimY, dimZ, sigma); -// -// /*Projection onto convex set for P*/ -// ProjP_3D(P1, P2, P3, dimX, dimY, dimZ, alpha1); -// -// /* Calculate Dual Variable Q */ -// DualQ_3D(V1, V2, V2, Q1, Q2, Q3, Q4, Q5, Q6, Q7, Q8, Q9, dimX, dimY, dimZ, sigma); -// -// } /*end of iterations*/ -// } + else if (number_of_dims == 3) { + mexErrMsgTxt("The input data should be a 2D array"); + /*3D case*/ + } + else {mexErrMsgTxt("The input data should be a 2D array");} + } diff --git a/main_func/regularizers_CPU/TGV_PD_core.c b/main_func/regularizers_CPU/TGV_PD_core.c index 4c0427c..1164b73 100644 --- a/main_func/regularizers_CPU/TGV_PD_core.c +++ b/main_func/regularizers_CPU/TGV_PD_core.c @@ -38,7 +38,6 @@ limitations under the License. * u0 = Im + .03*randn(size(Im)); % adding noise * tic; u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); toc; * - * to compile with OMP support: mex TGV_PD.c CFLAGS="\$CFLAGS -fopenmp -Wall -std=c99" LDFLAGS="\$LDFLAGS -fopenmp" * References: * K. Bredies "Total Generalized Variation" * diff --git a/main_func/regularizers_CPU/TGV_PD_core.h b/main_func/regularizers_CPU/TGV_PD_core.h index e25ae68..04ba95c 100644 --- a/main_func/regularizers_CPU/TGV_PD_core.h +++ b/main_func/regularizers_CPU/TGV_PD_core.h @@ -24,32 +24,6 @@ limitations under the License. #include #include "omp.h" -/* C-OMP implementation of Primal-Dual denoising method for -* Total Generilized Variation (TGV)-L2 model (2D case only) -* -* Input Parameters: -* 1. Noisy image/volume (2D) -* 2. lambda - regularization parameter -* 3. parameter to control first-order term (alpha1) -* 4. parameter to control the second-order term (alpha0) -* 5. Number of CP iterations -* -* Output: -* Filtered/regularized image -* -* Example: -* figure; -* Im = double(imread('lena_gray_256.tif'))/255; % loading image -* u0 = Im + .03*randn(size(Im)); % adding noise -* tic; u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); toc; -* -* to compile with OMP support: mex TGV_PD.c CFLAGS="\$CFLAGS -fopenmp -Wall -std=c99" LDFLAGS="\$LDFLAGS -fopenmp" -* References: -* K. Bredies "Total Generalized Variation" -* -* 28.11.16/Harwell -*/ - /* 2D functions */ float DualP_2D(float *U, float *V1, float *V2, float *P1, float *P2, int dimX, int dimY, int dimZ, float sigma); float ProjP_2D(float *P1, float *P2, int dimX, int dimY, int dimZ, float alpha1); @@ -57,8 +31,5 @@ float DualQ_2D(float *V1, float *V2, float *Q1, float *Q2, float *Q3, int dimX, float ProjQ_2D(float *Q1, float *Q2, float *Q3, int dimX, int dimY, int dimZ, float alpha0); float DivProjP_2D(float *U, float *A, float *P1, float *P2, int dimX, int dimY, int dimZ, float lambda, float tau); float UpdV_2D(float *V1, float *V2, float *P1, float *P2, float *Q1, float *Q2, float *Q3, int dimX, int dimY, int dimZ, float tau); -/*3D functions*/ -float DualP_3D(float *U, float *V1, float *V2, float *V3, float *P1, float *P2, float *P3, int dimX, int dimY, int dimZ, float sigma); - float newU(float *U, float *U_old, int dimX, int dimY, int dimZ); float copyIm(float *A, float *U, int dimX, int dimY, int dimZ); -- cgit v1.2.3 From e412e73b172a99776d108d3b4c1f8662a20e9fce Mon Sep 17 00:00:00 2001 From: Daniil Kazantsev Date: Thu, 3 Aug 2017 00:26:46 +0100 Subject: 2D or 3D regularization choices added --- main_func/FISTA_REC.m | 114 ++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 93 insertions(+), 21 deletions(-) diff --git a/main_func/FISTA_REC.m b/main_func/FISTA_REC.m index 1e93719..43ed0cb 100644 --- a/main_func/FISTA_REC.m +++ b/main_func/FISTA_REC.m @@ -3,6 +3,7 @@ function [X, output] = FISTA_REC(params) % <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> % ___Input___: % params.[] file: +%----------------General Parameters------------------------ % - .proj_geom (geometry of the projector) [required] % - .vol_geom (geometry of the reconstructed object) [required] % - .sino (vectorized in 2D or 3D sinogram) [required] @@ -13,12 +14,19 @@ function [X, output] = FISTA_REC(params) % - .ROI (Region-of-interest, only if X_ideal is given) % - .initialize (a 'warm start' using SIRT method from ASTRA) %----------------Regularization choices------------------------ -% - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) -% - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) -% - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) +% 1 .Regul_Lambda_FGPTV (FGP-TV regularization parameter) +% 2 .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) +% 3 .Regul_LambdaLLT (Higher order LLT regularization parameter) +% 3.1 .Regul_tauLLT (time step parameter for LLT (HO) term) +% 4 .Regul_LambdaPatchBased (Patch-based nonlocal regularization parameter) +% 4.1 .Regul_PB_SearchW (ratio of the searching window (e.g. 3 = (2*3+1) = 7 pixels window)) +% 4.2 .Regul_PB_SimilW (ratio of the similarity window (e.g. 1 = (2*1+1) = 3 pixels window)) +% 4.3 .Regul_PB_h (PB penalty function threshold) +% 5 .Regul_LambdaTGV (Total Generalized variation regularization parameter) % - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) % - .Regul_Iterations (iterations for the selected penalty, default 25) -% - .Regul_tauLLT (time step parameter for LLT term) +% - .Regul_Dimension ('2D' or '3D' way to apply regularization, '3D' is the default) +%----------------Ring removal------------------------ % - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) % - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) %----------------Visualization parameters------------------------ @@ -150,8 +158,8 @@ if (isfield(params,'Regul_Iterations')) else IterationsRegul = 25; end -if (isfield(params,'Regul_LambdaHO')) - lambdaHO = params.Regul_LambdaHO; +if (isfield(params,'Regul_LambdaLLT')) + lambdaHO = params.Regul_LambdaLLT; else lambdaHO = 0; end @@ -165,6 +173,26 @@ if (isfield(params,'Regul_tauLLT')) else tauHO = 0.0001; end +if (isfield(params,'Regul_LambdaPatchBased')) + lambdaPB = params.Regul_LambdaPatchBased; +else + lambdaPB = 0; +end +if (isfield(params,'Regul_PB_SearchW')) + SearchW = params.Regul_PB_SearchW; +else + SearchW = 3; % default +end +if (isfield(params,'Regul_PB_SimilW')) + SimilW = params.Regul_PB_SimilW; +else + SimilW = 1; % default +end +if (isfield(params,'Regul_PB_h')) + h_PB = params.Regul_PB_h; +else + h_PB = 0.1; % default +end if (isfield(params,'Ring_LambdaR_L1')) lambdaR_L1 = params.Ring_LambdaR_L1; else @@ -175,6 +203,14 @@ if (isfield(params,'Ring_Alpha')) else alpha_ring = 1; end +if (isfield(params,'Regul_Dimension')) + Dimension = params.Regul_Dimension; + if ((strcmp('2D', Dimension) ~= 1) && (strcmp('3D', Dimension) ~= 1)) + Dimension = '3D'; + end +else + Dimension = '3D'; +end if (isfield(params,'show')) show = params.show; else @@ -293,21 +329,57 @@ for i = 1:iterFISTA astra_mex_data3d('delete', sino_id); astra_mex_data3d('delete', id); - if (lambdaFGP_TV > 0) - % FGP-TV regularization - [X, f_val] = FGP_TV(single(X), lambdaFGP_TV, IterationsRegul, tol, 'iso'); - objective(i) = objective(i) + f_val; - end - if (lambdaSB_TV > 0) - % Split Bregman regularization - X = SplitBregman_TV(single(X), lambdaSB_TV, IterationsRegul, tol); % (more memory efficent) - end - if (lambdaHO > 0) - % Higher Order (LLT) regularization - X2 = LLT_model(single(X), lambdaHO, tauHO, iterHO, 3.0e-05, 0); - X = 0.5.*(X + X2); % averaged combination of two solutions - end - + % regularization + if (lambdaFGP_TV > 0) + % FGP-TV regularization + if ((strcmp('2D', Dimension) == 1)) + % 2D regularization + for kkk = 1:SlicesZ + [X(:,:,kkk), f_val] = FGP_TV(single(X(:,:,kkk)), lambdaFGP_TV, IterationsRegul, tol, 'iso'); + end + else + % 3D regularization + [X, f_val] = FGP_TV(single(X), lambdaFGP_TV, IterationsRegul, tol, 'iso'); + end + objective(i) = objective(i) + f_val; + end + if (lambdaSB_TV > 0) + % Split Bregman regularization + if ((strcmp('2D', Dimension) == 1)) + % 2D regularization + for kkk = 1:SlicesZ + X(:,:,kkk) = SplitBregman_TV(single(X(:,:,kkk)), lambdaSB_TV, IterationsRegul, tol); % (more memory efficent) + end + else + % 3D regularization + X = SplitBregman_TV(single(X), lambdaSB_TV, IterationsRegul, tol); % (more memory efficent) + end + end + if (lambdaHO > 0) + % Higher Order (LLT) regularization + X2 = zeros(N,N,SlicesZ,'single'); + if ((strcmp('2D', Dimension) == 1)) + % 2D regularization + for kkk = 1:SlicesZ + X2(:,:,kkk) = LLT_model(single(X(:,:,kkk)), lambdaHO, tauHO, iterHO, 3.0e-05, 0); + end + else + % 3D regularization + X2 = LLT_model(single(X), lambdaHO, tauHO, iterHO, 3.0e-05, 0); + end + X = 0.5.*(X + X2); % averaged combination of two solutions + end + if (lambdaPB > 0) + % Patch-Based regularization (can be slow on CPU) + if ((strcmp('2D', Dimension) == 1)) + % 2D regularization + for kkk = 1:SlicesZ + X(:,:,kkk) = PatchBased_Regul(single(X(:,:,kkk)), SearchW, SimilW, h_PB, lambdaPB); + end + else + X = PatchBased_Regul(single(X), SearchW, SimilW, h_PB, lambdaPB); + end + end if (lambdaR_L1 > 0) -- cgit v1.2.3 From d7281b95f70dd3707f82640972a52f4155c2fde5 Mon Sep 17 00:00:00 2001 From: Daniil Kazantsev Date: Thu, 3 Aug 2017 15:25:50 +0100 Subject: Ordered Subset FISTA added and demos updated in DemoRD2 --- demos/DemoRD2.m | 86 ++++++-- main_func/FISTA_REC.m | 404 +++++++++++++++++++++++++----------- main_func/regularizers_CPU/TGV_PD.c | 6 +- 3 files changed, 356 insertions(+), 140 deletions(-) diff --git a/demos/DemoRD2.m b/demos/DemoRD2.m index 293c1cd..f177e26 100644 --- a/demos/DemoRD2.m +++ b/demos/DemoRD2.m @@ -1,6 +1,6 @@ % Demonstration of tomographic 3D reconstruction from X-ray synchrotron % dataset (dendrites) using various data fidelities -% ! can take up to 15-20 minutes to run for the whole 3D data ! +% ! It is advisable not to run the whole script, it will take lots of time to reconstruct the whole 3D data using many algorithms ! clear all close all %% @@ -30,7 +30,7 @@ Weights3D = single(data_raw3D); % weights for PW model clear data_raw3D %% % set projection/reconstruction geometry here -Z_slices = 20; +Z_slices = 1; det_row_count = Z_slices; proj_geom = astra_create_proj_geom('parallel3d', 1, 1, det_row_count, size_det, angles_rad); vol_geom = astra_create_vol_geom(recon_size,recon_size,Z_slices); @@ -38,70 +38,114 @@ vol_geom = astra_create_vol_geom(recon_size,recon_size,Z_slices); fprintf('%s\n', 'Reconstruction using FBP...'); FBP = iradon(Sino3D(:,:,10), angles,recon_size); figure; imshow(FBP , [0, 3]); title ('FBP reconstruction'); + +%--------FISTA_REC modular reconstruction alogrithms--------- %% -fprintf('%s\n', 'Reconstruction using FISTA-PWLS without regularization...'); +fprintf('%s\n', 'Reconstruction using FISTA-OS-PWLS without regularization...'); clear params params.proj_geom = proj_geom; % pass geometry to the function params.vol_geom = vol_geom; params.sino = Sino3D; -params.iterFISTA = 1; +params.iterFISTA = 12; params.weights = Weights3D; +params.subsets = 16; % the number of ordered subsets params.show = 1; params.maxvalplot = 2.5; params.slice = 2; -tic; [X_fista, output] = FISTA_REC(params); toc; -figure; imshow(X_fista(:,:,params.slice) , [0, 2.5]); title ('FISTA-PWLS reconstruction'); +tic; [X_fista, outputFISTA] = FISTA_REC(params); toc; +figure; imshow(X_fista(:,:,params.slice) , [0, 2.5]); title ('FISTA-OS-PWLS reconstruction'); %% -fprintf('%s\n', 'Reconstruction using FISTA-PWLS-TV...'); +fprintf('%s\n', 'Reconstruction using FISTA-OS-PWLS-TV...'); clear params params.proj_geom = proj_geom; % pass geometry to the function params.vol_geom = vol_geom; params.sino = Sino3D; -params.iterFISTA = 40; +params.iterFISTA = 12; params.Regul_Lambda_FGPTV = 0.005; % TV regularization parameter for FGP-TV params.weights = Weights3D; +params.subsets = 16; % the number of ordered subsets params.show = 1; params.maxvalplot = 2.5; params.slice = 2; -tic; [X_fista_TV] = FISTA_REC(params); toc; -figure; imshow(X_fista_TV(:,:,params.slice) , [0, 2.5]); title ('FISTA-PWLS-TV reconstruction'); -%% +tic; [X_fista_TV, outputTV] = FISTA_REC(params); toc; +figure; imshow(X_fista_TV(:,:,params.slice) , [0, 2.5]); title ('FISTA-OS-PWLS-TV reconstruction'); %% -fprintf('%s\n', 'Reconstruction using FISTA-GH-TV...'); +fprintf('%s\n', 'Reconstruction using FISTA-OS-GH-TV...'); clear params params.proj_geom = proj_geom; % pass geometry to the function params.vol_geom = vol_geom; params.sino = Sino3D; -params.iterFISTA = 40; +params.iterFISTA = 12; params.Regul_Lambda_FGPTV = 0.005; % TV regularization parameter for FGP-TV params.Ring_LambdaR_L1 = 0.002; % Soft-Thresh L1 ring variable parameter params.Ring_Alpha = 21; % to boost ring removal procedure params.weights = Weights3D; +params.subsets = 16; % the number of ordered subsets params.show = 1; params.maxvalplot = 2.5; params.slice = 2; -tic; [X_fista_GH_TV] = FISTA_REC(params); toc; -figure; imshow(X_fista_GH_TV(:,:,params.slice) , [0, 2.5]); title ('FISTA-GH-TV reconstruction'); -%% +tic; [X_fista_GH_TV, outputGHTV] = FISTA_REC(params); toc; +figure; imshow(X_fista_GH_TV(:,:,params.slice) , [0, 2.5]); title ('FISTA-OS-GH-TV reconstruction'); %% -fprintf('%s\n', 'Reconstruction using FISTA-GH-TV-LLT...'); +fprintf('%s\n', 'Reconstruction using FISTA-OS-GH-TV-LLT...'); clear params params.proj_geom = proj_geom; % pass geometry to the function params.vol_geom = vol_geom; params.sino = Sino3D; -params.iterFISTA = 5; +params.iterFISTA = 12; params.Regul_Lambda_FGPTV = 0.005; % TV regularization parameter for FGP-TV -params.Regul_LambdaHO = 250; % regularization parameter for LLT problem +params.Regul_LambdaLLT = 100; % regularization parameter for LLT problem params.Regul_tauLLT = 0.0005; % time-step parameter for the explicit scheme params.Ring_LambdaR_L1 = 0.002; % Soft-Thresh L1 ring variable parameter params.Ring_Alpha = 21; % to boost ring removal procedure params.weights = Weights3D; +params.subsets = 16; % the number of ordered subsets params.show = 1; params.maxvalplot = 2.5; params.slice = 2; -tic; [X_fista_GH_TVLLT] = FISTA_REC(params); toc; -figure; imshow(X_fista_GH_TVLLT(:,:,params.slice) , [0, 2.5]); title ('FISTA-GH-TV-LLT reconstruction'); +tic; [X_fista_GH_TVLLT, outputGH_TVLLT] = FISTA_REC(params); toc; +figure; imshow(X_fista_GH_TVLLT(:,:,params.slice) , [0, 2.5]); title ('FISTA-OS-GH-TV-LLT reconstruction'); + +%% +% fprintf('%s\n', 'Reconstruction using FISTA-OS-PB...'); +% % very-slow on CPU +% clear params +% params.proj_geom = proj_geom; % pass geometry to the function +% params.vol_geom = vol_geom; +% params.sino = Sino3D(:,:,1); +% params.iterFISTA = 12; +% params.Regul_LambdaPatchBased = 1; % PB regularization parameter +% params.Regul_PB_h = 0.1; % threhsold parameter +% params.Ring_LambdaR_L1 = 0.002; % Soft-Thresh L1 ring variable parameter +% params.Ring_Alpha = 21; % to boost ring removal procedure +% params.weights = Weights3D(:,:,1); +% params.subsets = 16; % the number of ordered subsets +% params.show = 1; +% params.maxvalplot = 2.5; params.slice = 1; +% +% tic; [X_fista_GH_PB, outputPB] = FISTA_REC(params); toc; +% figure; imshow(X_fista_GH_PB(:,:,params.slice) , [0, 2.5]); title ('FISTA-OS-PB reconstruction'); + %% +fprintf('%s\n', 'Reconstruction using FISTA-OS-GH-TGV...'); +% still testing... +clear params +params.proj_geom = proj_geom; % pass geometry to the function +params.vol_geom = vol_geom; +params.sino = Sino3D; +params.iterFISTA = 12; +params.Regul_LambdaTGV = 0.5; % TGV regularization parameter +params.Regul_Iterations = 5; +params.Ring_LambdaR_L1 = 0.002; % Soft-Thresh L1 ring variable parameter +params.Ring_Alpha = 21; % to boost ring removal procedure +params.weights = Weights3D; +params.subsets = 16; % the number of ordered subsets +params.show = 1; +params.maxvalplot = 2.5; params.slice = 1; + +tic; [X_fista_GH_TGV, outputTGV] = FISTA_REC(params); toc; +figure; imshow(X_fista_GH_TGV(:,:,params.slice) , [0, 2.5]); title ('FISTA-OS-GH-TGV reconstruction'); + %% % fprintf('%s\n', 'Reconstruction using FISTA-Student-TV...'); % clear params diff --git a/main_func/FISTA_REC.m b/main_func/FISTA_REC.m index 43ed0cb..381fa34 100644 --- a/main_func/FISTA_REC.m +++ b/main_func/FISTA_REC.m @@ -1,6 +1,6 @@ function [X, output] = FISTA_REC(params) -% <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> +% <<<< FISTA-based reconstruction routine using ASTRA-toolbox >>>> % ___Input___: % params.[] file: %----------------General Parameters------------------------ @@ -19,9 +19,9 @@ function [X, output] = FISTA_REC(params) % 3 .Regul_LambdaLLT (Higher order LLT regularization parameter) % 3.1 .Regul_tauLLT (time step parameter for LLT (HO) term) % 4 .Regul_LambdaPatchBased (Patch-based nonlocal regularization parameter) -% 4.1 .Regul_PB_SearchW (ratio of the searching window (e.g. 3 = (2*3+1) = 7 pixels window)) -% 4.2 .Regul_PB_SimilW (ratio of the similarity window (e.g. 1 = (2*1+1) = 3 pixels window)) -% 4.3 .Regul_PB_h (PB penalty function threshold) +% 4.1 .Regul_PB_SearchW (ratio of the searching window (e.g. 3 = (2*3+1) = 7 pixels window)) +% 4.2 .Regul_PB_SimilW (ratio of the similarity window (e.g. 1 = (2*1+1) = 3 pixels window)) +% 4.3 .Regul_PB_h (PB penalty function threshold) % 5 .Regul_LambdaTGV (Total Generalized variation regularization parameter) % - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) % - .Regul_Iterations (iterations for the selected penalty, default 25) @@ -193,6 +193,11 @@ if (isfield(params,'Regul_PB_h')) else h_PB = 0.1; % default end +if (isfield(params,'Regul_LambdaTGV')) + LambdaTGV = params.Regul_LambdaTGV; +else + LambdaTGV = 0.01; +end if (isfield(params,'Ring_LambdaR_L1')) lambdaR_L1 = params.Ring_LambdaR_L1; else @@ -252,20 +257,17 @@ if (isfield(params,'initialize')) else X = zeros(N,N,SlicesZ, 'single'); % storage for the solution end -if (isfield(params,'OS')) +if (isfield(params,'subsets')) % Ordered Subsets reorganisation of data and angles - OS = 1; - subsets = 8; + subsets = params.subsets; % subsets number angles = proj_geom.ProjectionAngles; binEdges = linspace(min(angles),max(angles),subsets+1); % assign values to bins [binsDiscr,~] = histc(angles, [binEdges(1:end-1) Inf]); - % rearrange angles into subsets - AnglesReorg = zeros(length(angles),1); - SinoReorg = zeros(Detectors, anglesNumb, SlicesZ, 'single'); - + % get rearranged subset indices + IndicesReorg = zeros(length(angles),1); counterM = 0; for ii = 1:max(binsDiscr(:)) counter = 0; @@ -273,143 +275,313 @@ if (isfield(params,'OS')) curr_index = ii+jj-1 + counter; if (binsDiscr(jj) >= ii) counterM = counterM + 1; - AnglesReorg(counterM) = angles(curr_index); - SinoReorg(:,counterM,:) = squeeze(sino(:,curr_index,:)); + IndicesReorg(counterM) = curr_index; end counter = (counter + binsDiscr(jj)) - 1; end end - sino = SinoReorg; - clear SinoReorg; -else - OS = 0; % normal FISTA +else + subsets = 0; % Classical FISTA end - %----------------Reconstruction part------------------------ Resid_error = zeros(iterFISTA,1); % errors vector (if the ground truth is given) objective = zeros(iterFISTA,1); % objective function values vector -t = 1; -X_t = X; -r = zeros(Detectors,SlicesZ, 'single'); % 2D array (for 3D data) of sparse "ring" vectors -r_x = r; % another ring variable -residual = zeros(size(sino),'single'); - -% Outer FISTA iterations loop -for i = 1:iterFISTA - - X_old = X; - t_old = t; - r_old = r; +if (subsets == 0) + % Classical FISTA + t = 1; + X_t = X; - [sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); + r = zeros(Detectors,SlicesZ, 'single'); % 2D array (for 3D data) of sparse "ring" vectors + r_x = r; % another ring variable + residual = zeros(size(sino),'single'); - if (lambdaR_L1 > 0) - % ring removal part (Group-Huber fidelity) - for kkk = 1:anglesNumb - residual(:,kkk,:) = squeeze(weights(:,kkk,:)).*(squeeze(sino_updt(:,kkk,:)) - (squeeze(sino(:,kkk,:)) - alpha_ring.*r_x)); + % Outer FISTA iterations loop + for i = 1:iterFISTA + + X_old = X; + t_old = t; + r_old = r; + + [sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); + + if (lambdaR_L1 > 0) + % the ring removal part (Group-Huber fidelity) + for kkk = 1:anglesNumb + residual(:,kkk,:) = squeeze(weights(:,kkk,:)).*(squeeze(sino_updt(:,kkk,:)) - (squeeze(sino(:,kkk,:)) - alpha_ring.*r_x)); + end + vec = sum(residual,2); + if (SlicesZ > 1) + vec = squeeze(vec(:,1,:)); + end + r = r_x - (1./L_const).*vec; + else + % no ring removal + residual = weights.*(sino_updt - sino); end - vec = sum(residual,2); - if (SlicesZ > 1) - vec = squeeze(vec(:,1,:)); + + objective(i) = (0.5*norm(residual(:))^2)/(Detectors*anglesNumb*SlicesZ); % for the objective function output + + [id, x_temp] = astra_create_backprojection3d_cuda(residual, proj_geom, vol_geom); + + X = X_t - (1/L_const).*x_temp; + astra_mex_data3d('delete', sino_id); + astra_mex_data3d('delete', id); + + % regularization + if (lambdaFGP_TV > 0) + % FGP-TV regularization + if ((strcmp('2D', Dimension) == 1)) + % 2D regularization + for kkk = 1:SlicesZ + [X(:,:,kkk), f_val] = FGP_TV(single(X(:,:,kkk)), lambdaFGP_TV, IterationsRegul, tol, 'iso'); + end + else + % 3D regularization + [X, f_val] = FGP_TV(single(X), lambdaFGP_TV, IterationsRegul, tol, 'iso'); + end + objective(i) = objective(i) + f_val; end - r = r_x - (1./L_const).*vec; - else - % no ring removal - residual = weights.*(sino_updt - sino); - end - - objective(i) = (0.5*norm(residual(:))^2)/(Detectors*anglesNumb*SlicesZ); % for the objective function output - - [id, x_temp] = astra_create_backprojection3d_cuda(residual, proj_geom, vol_geom); - - X = X_t - (1/L_const).*x_temp; - astra_mex_data3d('delete', sino_id); - astra_mex_data3d('delete', id); - - % regularization - if (lambdaFGP_TV > 0) - % FGP-TV regularization - if ((strcmp('2D', Dimension) == 1)) - % 2D regularization - for kkk = 1:SlicesZ - [X(:,:,kkk), f_val] = FGP_TV(single(X(:,:,kkk)), lambdaFGP_TV, IterationsRegul, tol, 'iso'); + if (lambdaSB_TV > 0) + % Split Bregman regularization + if ((strcmp('2D', Dimension) == 1)) + % 2D regularization + for kkk = 1:SlicesZ + X(:,:,kkk) = SplitBregman_TV(single(X(:,:,kkk)), lambdaSB_TV, IterationsRegul, tol); % (more memory efficent) + end + else + % 3D regularization + X = SplitBregman_TV(single(X), lambdaSB_TV, IterationsRegul, tol); % (more memory efficent) end - else - % 3D regularization - [X, f_val] = FGP_TV(single(X), lambdaFGP_TV, IterationsRegul, tol, 'iso'); end - objective(i) = objective(i) + f_val; - end - if (lambdaSB_TV > 0) - % Split Bregman regularization - if ((strcmp('2D', Dimension) == 1)) - % 2D regularization - for kkk = 1:SlicesZ - X(:,:,kkk) = SplitBregman_TV(single(X(:,:,kkk)), lambdaSB_TV, IterationsRegul, tol); % (more memory efficent) + if (lambdaHO > 0) + % Higher Order (LLT) regularization + X2 = zeros(N,N,SlicesZ,'single'); + if ((strcmp('2D', Dimension) == 1)) + % 2D regularization + for kkk = 1:SlicesZ + X2(:,:,kkk) = LLT_model(single(X(:,:,kkk)), lambdaHO, tauHO, iterHO, 3.0e-05, 0); + end + else + % 3D regularization + X2 = LLT_model(single(X), lambdaHO, tauHO, iterHO, 3.0e-05, 0); end - else - % 3D regularization - X = SplitBregman_TV(single(X), lambdaSB_TV, IterationsRegul, tol); % (more memory efficent) + X = 0.5.*(X + X2); % averaged combination of two solutions end - end - if (lambdaHO > 0) - % Higher Order (LLT) regularization - X2 = zeros(N,N,SlicesZ,'single'); - if ((strcmp('2D', Dimension) == 1)) - % 2D regularization - for kkk = 1:SlicesZ - X2(:,:,kkk) = LLT_model(single(X(:,:,kkk)), lambdaHO, tauHO, iterHO, 3.0e-05, 0); + if (lambdaPB > 0) + % Patch-Based regularization (can be slow on CPU) + if ((strcmp('2D', Dimension) == 1)) + % 2D regularization + for kkk = 1:SlicesZ + X(:,:,kkk) = PatchBased_Regul(single(X(:,:,kkk)), SearchW, SimilW, h_PB, lambdaPB); + end + else + X = PatchBased_Regul(single(X), SearchW, SimilW, h_PB, lambdaPB); end - else - % 3D regularization - X2 = LLT_model(single(X), lambdaHO, tauHO, iterHO, 3.0e-05, 0); end - X = 0.5.*(X + X2); % averaged combination of two solutions - end - if (lambdaPB > 0) - % Patch-Based regularization (can be slow on CPU) - if ((strcmp('2D', Dimension) == 1)) - % 2D regularization - for kkk = 1:SlicesZ - X(:,:,kkk) = PatchBased_Regul(single(X(:,:,kkk)), SearchW, SimilW, h_PB, lambdaPB); + if (LambdaTGV > 0) + % Total Generalized variation (currently only 2D) + lamTGV1 = 1.1; % smoothing trade-off parameters, see Pock's paper + lamTGV2 = 0.5; % second-order term + for kkk = 1:SlicesZ + X(:,:,kkk) = TGV_PD(single(X(:,:,kkk)), LambdaTGV, lamTGV1, lamTGV2, IterationsRegul); + end + end + + if (lambdaR_L1 > 0) + r = max(abs(r)-lambdaR_L1, 0).*sign(r); % soft-thresholding operator for ring vector + end + + t = (1 + sqrt(1 + 4*t^2))/2; % updating t + X_t = X + ((t_old-1)/t).*(X - X_old); % updating X + + if (lambdaR_L1 > 0) + r_x = r + ((t_old-1)/t).*(r - r_old); % updating r + end + + if (show == 1) + figure(10); imshow(X(:,:,slice), [0 maxvalplot]); + if (lambdaR_L1 > 0) + figure(11); plot(r); title('Rings offset vector') end + pause(0.01); + end + if (strcmp(X_ideal, 'none' ) == 0) + Resid_error(i) = RMSE(X(ROI), X_ideal(ROI)); + fprintf('%s %i %s %s %.4f %s %s %f \n', 'Iteration Number:', i, '|', 'Error RMSE:', Resid_error(i), '|', 'Objective:', objective(i)); else - X = PatchBased_Regul(single(X), SearchW, SimilW, h_PB, lambdaPB); + fprintf('%s %i %s %s %f \n', 'Iteration Number:', i, '|', 'Objective:', objective(i)); end - end - - - if (lambdaR_L1 > 0) - r = max(abs(r)-lambdaR_L1, 0).*sign(r); % soft-thresholding operator for ring vector end +else + % Ordered Subsets (OS) FISTA reconstruction routine (normally one order of magnitude faster than classical) + t = 1; + X_t = X; + proj_geomSUB = proj_geom; - t = (1 + sqrt(1 + 4*t^2))/2; % updating t - X_t = X + ((t_old-1)/t).*(X - X_old); % updating X - if (lambdaR_L1 > 0) - r_x = r + ((t_old-1)/t).*(r - r_old); % updating r - end + r = zeros(Detectors,SlicesZ, 'single'); % 2D array (for 3D data) of sparse "ring" vectors + r_x = r; % another ring variable + residual2 = zeros(size(sino),'single'); - if (show == 1) - figure(10); imshow(X(:,:,slice), [0 maxvalplot]); - if (lambdaR_L1 > 0) - figure(11); plot(r); title('Rings offset vector') + % Outer FISTA iterations loop + for i = 1:iterFISTA + + % With OS approach it becomes trickier to correlate independent subsets, hence additional work is required + % one solution is to work with a full sinogram at times + if ((i >= 3) && (lambdaR_L1 > 0)) + [sino_id2, sino_updt2] = astra_create_sino3d_cuda(X, proj_geom, vol_geom); + astra_mex_data3d('delete', sino_id2); + end + + % subsets loop + counterInd = 1; + for ss = 1:subsets + X_old = X; + t_old = t; + r_old = r; + + numProjSub = binsDiscr(ss); % the number of projections per subset + CurrSubIndeces = IndicesReorg(counterInd:(counterInd + numProjSub - 1)); % extract indeces attached to the subset + proj_geomSUB.ProjectionAngles = angles(CurrSubIndeces); + + if (lambdaR_L1 > 0) + + % the ring removal part (Group-Huber fidelity) + % first 2 iterations do additional work reconstructing whole dataset to ensure + % stablility + if (i < 3) + [sino_id2, sino_updt2] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); + astra_mex_data3d('delete', sino_id2); + else + [sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geomSUB, vol_geom); + end + + for kkk = 1:anglesNumb + residual2(:,kkk,:) = squeeze(weights(:,kkk,:)).*(squeeze(sino_updt2(:,kkk,:)) - (squeeze(sino(:,kkk,:)) - alpha_ring.*r_x)); + end + + residual = zeros(Detectors, numProjSub, SlicesZ,'single'); + for kkk = 1:numProjSub + indC = CurrSubIndeces(kkk); + if (i < 3) + residual(:,kkk,:) = squeeze(residual2(:,indC,:)); + else + residual(:,kkk,:) = squeeze(weights(:,indC,:)).*(squeeze(sino_updt(:,kkk,:)) - (squeeze(sino(:,indC,:)) - alpha_ring.*r_x)); + end + end + vec = sum(residual2,2); + if (SlicesZ > 1) + vec = squeeze(vec(:,1,:)); + end + r = r_x - (1./L_const).*vec; + else + [sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geomSUB, vol_geom); + % no ring removal + residual = squeeze(weights(:,CurrSubIndeces,:)).*(sino_updt - squeeze(sino(:,CurrSubIndeces,:))); + end + + [id, x_temp] = astra_create_backprojection3d_cuda(residual, proj_geomSUB, vol_geom); + + X = X_t - (1/L_const).*x_temp; + astra_mex_data3d('delete', sino_id); + astra_mex_data3d('delete', id); + + % regularization + if (lambdaFGP_TV > 0) + % FGP-TV regularization + if ((strcmp('2D', Dimension) == 1)) + % 2D regularization + for kkk = 1:SlicesZ + [X(:,:,kkk), f_val] = FGP_TV(single(X(:,:,kkk)), lambdaFGP_TV/subsets, IterationsRegul, tol, 'iso'); + end + else + % 3D regularization + [X, f_val] = FGP_TV(single(X), lambdaFGP_TV/subsets, IterationsRegul, tol, 'iso'); + end + objective(i) = objective(i) + f_val; + end + if (lambdaSB_TV > 0) + % Split Bregman regularization + if ((strcmp('2D', Dimension) == 1)) + % 2D regularization + for kkk = 1:SlicesZ + X(:,:,kkk) = SplitBregman_TV(single(X(:,:,kkk)), lambdaSB_TV/subsets, IterationsRegul, tol); % (more memory efficent) + end + else + % 3D regularization + X = SplitBregman_TV(single(X), lambdaSB_TV/subsets, IterationsRegul, tol); % (more memory efficent) + end + end + if (lambdaHO > 0) + % Higher Order (LLT) regularization + X2 = zeros(N,N,SlicesZ,'single'); + if ((strcmp('2D', Dimension) == 1)) + % 2D regularization + for kkk = 1:SlicesZ + X2(:,:,kkk) = LLT_model(single(X(:,:,kkk)), lambdaHO/subsets, tauHO/subsets, iterHO, 2.0e-05, 0); + end + else + % 3D regularization + X2 = LLT_model(single(X), lambdaHO/subsets, tauHO/subsets, iterHO, 2.0e-05, 0); + end + X = 0.5.*(X + X2); % averaged combination of two solutions + end + if (lambdaPB > 0) + % Patch-Based regularization (can be slow on CPU) + if ((strcmp('2D', Dimension) == 1)) + % 2D regularization + for kkk = 1:SlicesZ + X(:,:,kkk) = PatchBased_Regul(single(X(:,:,kkk)), SearchW, SimilW, h_PB, lambdaPB/subsets); + end + else + X = PatchBased_Regul(single(X), SearchW, SimilW, h_PB, lambdaPB/subsets); + end + end + if (LambdaTGV > 0) + % Total Generalized variation (currently only 2D) + lamTGV1 = 1.1; % smoothing trade-off parameters, see Pock's paper + lamTGV2 = 0.5; % second-order term + for kkk = 1:SlicesZ + X(:,:,kkk) = TGV_PD(single(X(:,:,kkk)), LambdaTGV/subsets, lamTGV1, lamTGV2, IterationsRegul); + end + end + + if (lambdaR_L1 > 0) + r = max(abs(r)-lambdaR_L1, 0).*sign(r); % soft-thresholding operator for ring vector + end + + t = (1 + sqrt(1 + 4*t^2))/2; % updating t + X_t = X + ((t_old-1)/t).*(X - X_old); % updating X + + if (lambdaR_L1 > 0) + r_x = r + ((t_old-1)/t).*(r - r_old); % updating r + end + + counterInd = counterInd + numProjSub; + end + + if (show == 1) + figure(10); imshow(X(:,:,slice), [0 maxvalplot]); + if (lambdaR_L1 > 0) + figure(11); plot(r); title('Rings offset vector') + end + pause(0.01); + end + + if (strcmp(X_ideal, 'none' ) == 0) + Resid_error(i) = RMSE(X(ROI), X_ideal(ROI)); + fprintf('%s %i %s %s %.4f %s %s %f \n', 'Iteration Number:', i, '|', 'Error RMSE:', Resid_error(i), '|', 'Objective:', objective(i)); + else + fprintf('%s %i %s %s %f \n', 'Iteration Number:', i, '|', 'Objective:', objective(i)); end - pause(0.01); - end - if (strcmp(X_ideal, 'none' ) == 0) - Resid_error(i) = RMSE(X(ROI), X_ideal(ROI)); - fprintf('%s %i %s %s %.4f %s %s %f \n', 'Iteration Number:', i, '|', 'Error RMSE:', Resid_error(i), '|', 'Objective:', objective(i)); - else - fprintf('%s %i %s %s %f \n', 'Iteration Number:', i, '|', 'Objective:', objective(i)); end end + output.Resid_error = Resid_error; output.objective = objective; output.L_const = L_const; -output.sino = sino; -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + end diff --git a/main_func/regularizers_CPU/TGV_PD.c b/main_func/regularizers_CPU/TGV_PD.c index d4bb787..6593d8e 100644 --- a/main_func/regularizers_CPU/TGV_PD.c +++ b/main_func/regularizers_CPU/TGV_PD.c @@ -37,9 +37,9 @@ limitations under the License. * figure; * Im = double(imread('lena_gray_256.tif'))/255; % loading image * u0 = Im + .03*randn(size(Im)); % adding noise - * tic; u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); toc; + * tic; u = TGV_PD(single(u0), 0.02, 1.3, 1, 550); toc; * - * to compile with OMP support: mex TGV_PD.c TGV_PD_core.c CFLAGS="\$CFLAGS -fopenmp -Wall -std=c99" LDFLAGS="\$LDFLAGS -fopenmp" + * to compile with OMP support: mex TGV_PD.c TGV_PD_core.c CFLAGS="\$CFLAGS -fopenmp -Wall -std=c99" LDFLAGS="\$LDFLAGS -fopenmp" * References: * K. Bredies "Total Generalized Variation" * @@ -92,7 +92,7 @@ void mexFunction( /*printf("%i \n", i);*/ - L2 = 12.0; /*Lipshitz constant*/ + L2 = 12.0f; /*Lipshitz constant*/ tau = 1.0/pow(L2,0.5); sigma = 1.0/pow(L2,0.5); -- cgit v1.2.3 From 69c2afc4670966c7e12420be7afd55ea69dfc8e8 Mon Sep 17 00:00:00 2001 From: Daniil Kazantsev Date: Thu, 24 Aug 2017 12:46:12 +0100 Subject: a minor big with TGV being constatly swtiched one removed --- main_func/FISTA_REC.m | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/main_func/FISTA_REC.m b/main_func/FISTA_REC.m index 381fa34..a6e0ae5 100644 --- a/main_func/FISTA_REC.m +++ b/main_func/FISTA_REC.m @@ -98,7 +98,7 @@ else for i = 1:niter [id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geomT, vol_geomT); s = norm(x1(:)); - x1 = x1/s; + x1 = x1./s; [sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); y = sqweight.*y; astra_mex_data3d('delete', sino_id); @@ -196,7 +196,7 @@ end if (isfield(params,'Regul_LambdaTGV')) LambdaTGV = params.Regul_LambdaTGV; else - LambdaTGV = 0.01; + LambdaTGV = 0; end if (isfield(params,'Ring_LambdaR_L1')) lambdaR_L1 = params.Ring_LambdaR_L1; @@ -369,6 +369,7 @@ if (subsets == 0) X2 = LLT_model(single(X), lambdaHO, tauHO, iterHO, 3.0e-05, 0); end X = 0.5.*(X + X2); % averaged combination of two solutions + end if (lambdaPB > 0) % Patch-Based regularization (can be slow on CPU) @@ -384,7 +385,7 @@ if (subsets == 0) if (LambdaTGV > 0) % Total Generalized variation (currently only 2D) lamTGV1 = 1.1; % smoothing trade-off parameters, see Pock's paper - lamTGV2 = 0.5; % second-order term + lamTGV2 = 0.8; % second-order term for kkk = 1:SlicesZ X(:,:,kkk) = TGV_PD(single(X(:,:,kkk)), LambdaTGV, lamTGV1, lamTGV2, IterationsRegul); end -- cgit v1.2.3 From ac2d90b52b15127eadbbf0d2f300d9da31f755c7 Mon Sep 17 00:00:00 2001 From: algol Date: Thu, 7 Sep 2017 14:59:36 +0100 Subject: parallel3D updated for big data in Matlab --- main_func/FISTA_REC.m | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/main_func/FISTA_REC.m b/main_func/FISTA_REC.m index a6e0ae5..8dd569f 100644 --- a/main_func/FISTA_REC.m +++ b/main_func/FISTA_REC.m @@ -305,7 +305,17 @@ if (subsets == 0) t_old = t; r_old = r; + % if the geometry is parallel use slice-by-slice projection-backprojection routine + if (strcmp(proj_geom.type,'parallel') || strcmp(proj_geom.type,'parallel3d')) + sino_updt = zeros(size(sino),'single'); + for kkk = 1:SlicesZ + [sino_id, sino_updt(:,kkk,:)] = astra_create_sino3d_cuda(X_t(:,:,kkk), proj_geomT, vol_geomT); + astra_mex_data3d('delete', sino_id); + end + else + % for divergent 3D geometry (for Matlab watch the GPU memory overflow) [sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); + end if (lambdaR_L1 > 0) % the ring removal part (Group-Huber fidelity) @@ -324,8 +334,16 @@ if (subsets == 0) objective(i) = (0.5*norm(residual(:))^2)/(Detectors*anglesNumb*SlicesZ); % for the objective function output + % if the geometry is parallel use slice-by-slice projection-backprojection routine + if (strcmp(proj_geom.type,'parallel') || strcmp(proj_geom.type,'parallel3d')) + x_temp = zeros(size(X),'single'); + for kkk = 1:SlicesZ + [id, x_temp(:,:,kkk)] = astra_create_backprojection3d_cuda(squeeze(residual(:,kkk,:)), proj_geomT, vol_geomT); + astra_mex_data3d('delete', id); + end + else [id, x_temp] = astra_create_backprojection3d_cuda(residual, proj_geom, vol_geom); - + end X = X_t - (1/L_const).*x_temp; astra_mex_data3d('delete', sino_id); astra_mex_data3d('delete', id); -- cgit v1.2.3 From 078b9e2db2e25d663a1140cc71ee4d16c36cc161 Mon Sep 17 00:00:00 2001 From: "Jakob Jorgensen, algol@harwell" Date: Thu, 7 Sep 2017 15:09:50 +0100 Subject: bugsf --- main_func/FISTA_REC.m | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/main_func/FISTA_REC.m b/main_func/FISTA_REC.m index 8dd569f..00e59ab 100644 --- a/main_func/FISTA_REC.m +++ b/main_func/FISTA_REC.m @@ -104,7 +104,7 @@ else astra_mex_data3d('delete', sino_id); astra_mex_data3d('delete', id); end - clear proj_geomT vol_geomT + %clear proj_geomT vol_geomT else % divergen beam geometry fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry... will take some time!'); @@ -309,7 +309,7 @@ if (subsets == 0) if (strcmp(proj_geom.type,'parallel') || strcmp(proj_geom.type,'parallel3d')) sino_updt = zeros(size(sino),'single'); for kkk = 1:SlicesZ - [sino_id, sino_updt(:,kkk,:)] = astra_create_sino3d_cuda(X_t(:,:,kkk), proj_geomT, vol_geomT); + [sino_id, sino_updt(:,:,kkk)] = astra_create_sino3d_cuda(X_t(:,:,kkk), proj_geomT, vol_geomT); astra_mex_data3d('delete', sino_id); end else @@ -338,7 +338,7 @@ if (subsets == 0) if (strcmp(proj_geom.type,'parallel') || strcmp(proj_geom.type,'parallel3d')) x_temp = zeros(size(X),'single'); for kkk = 1:SlicesZ - [id, x_temp(:,:,kkk)] = astra_create_backprojection3d_cuda(squeeze(residual(:,kkk,:)), proj_geomT, vol_geomT); + [id, x_temp(:,:,kkk)] = astra_create_backprojection3d_cuda(squeeze(residual(:,:,kkk)), proj_geomT, vol_geomT); astra_mex_data3d('delete', id); end else -- cgit v1.2.3 From 62ab6cd46c3f1c189328c8d41899db7444c7ac29 Mon Sep 17 00:00:00 2001 From: Daniil Kazantsev Date: Mon, 11 Sep 2017 09:36:13 +0100 Subject: 2 new GPU regularizers, FGP objective fixed, FISTA_REC updated --- demos/DemoRD2.m | 62 +++-- main_func/FISTA_REC.m | 138 ++++++++--- main_func/regularizers_CPU/FGP_TV.c | 47 +--- main_func/regularizers_CPU/FGP_TV_core.c | 49 ++++ main_func/regularizers_CPU/FGP_TV_core.h | 2 + .../Diffus_HO/Diff4thHajiaboli_GPU.cpp | 114 +++++++++ .../Diffus_HO/Diff4th_GPU_kernel.cu | 270 +++++++++++++++++++++ .../Diffus_HO/Diff4th_GPU_kernel.h | 6 + main_func/regularizers_GPU/NL_Regul/NLM_GPU.cpp | 171 +++++++++++++ .../regularizers_GPU/NL_Regul/NLM_GPU_kernel.cu | 239 ++++++++++++++++++ .../regularizers_GPU/NL_Regul/NLM_GPU_kernel.h | 6 + 11 files changed, 1016 insertions(+), 88 deletions(-) create mode 100644 main_func/regularizers_GPU/Diffus_HO/Diff4thHajiaboli_GPU.cpp create mode 100644 main_func/regularizers_GPU/Diffus_HO/Diff4th_GPU_kernel.cu create mode 100644 main_func/regularizers_GPU/Diffus_HO/Diff4th_GPU_kernel.h create mode 100644 main_func/regularizers_GPU/NL_Regul/NLM_GPU.cpp create mode 100644 main_func/regularizers_GPU/NL_Regul/NLM_GPU_kernel.cu create mode 100644 main_func/regularizers_GPU/NL_Regul/NLM_GPU_kernel.h diff --git a/demos/DemoRD2.m b/demos/DemoRD2.m index f177e26..717a55d 100644 --- a/demos/DemoRD2.m +++ b/demos/DemoRD2.m @@ -6,7 +6,7 @@ close all %% % % adding paths addpath('../data/'); -addpath('../main_func/'); addpath('../main_func/regularizers_CPU/'); +addpath('../main_func/'); addpath('../main_func/regularizers_CPU/'); addpath('../main_func/regularizers_GPU/NL_Regul/'); addpath('../main_func/regularizers_GPU/Diffus_HO/'); addpath('../supp/'); load('DendrRawData.mat') % load raw data of 3D dendritic set @@ -30,7 +30,7 @@ Weights3D = single(data_raw3D); % weights for PW model clear data_raw3D %% % set projection/reconstruction geometry here -Z_slices = 1; +Z_slices = 5; det_row_count = Z_slices; proj_geom = astra_create_proj_geom('parallel3d', 1, 1, det_row_count, size_det, angles_rad); vol_geom = astra_create_vol_geom(recon_size,recon_size,Z_slices); @@ -107,25 +107,46 @@ tic; [X_fista_GH_TVLLT, outputGH_TVLLT] = FISTA_REC(params); toc; figure; imshow(X_fista_GH_TVLLT(:,:,params.slice) , [0, 2.5]); title ('FISTA-OS-GH-TV-LLT reconstruction'); %% -% fprintf('%s\n', 'Reconstruction using FISTA-OS-PB...'); -% % very-slow on CPU -% clear params -% params.proj_geom = proj_geom; % pass geometry to the function -% params.vol_geom = vol_geom; -% params.sino = Sino3D(:,:,1); -% params.iterFISTA = 12; -% params.Regul_LambdaPatchBased = 1; % PB regularization parameter -% params.Regul_PB_h = 0.1; % threhsold parameter -% params.Ring_LambdaR_L1 = 0.002; % Soft-Thresh L1 ring variable parameter -% params.Ring_Alpha = 21; % to boost ring removal procedure -% params.weights = Weights3D(:,:,1); -% params.subsets = 16; % the number of ordered subsets -% params.show = 1; -% params.maxvalplot = 2.5; params.slice = 1; -% -% tic; [X_fista_GH_PB, outputPB] = FISTA_REC(params); toc; -% figure; imshow(X_fista_GH_PB(:,:,params.slice) , [0, 2.5]); title ('FISTA-OS-PB reconstruction'); +fprintf('%s\n', 'Reconstruction using FISTA-OS-GH-HigherOrderDiffusion...'); +% !GPU version! +clear params +params.proj_geom = proj_geom; % pass geometry to the function +params.vol_geom = vol_geom; +params.sino = Sino3D(:,:,1:5); +params.iterFISTA = 25; +params.Regul_LambdaDiffHO = 2; % DiffHO regularization parameter +params.Regul_DiffHO_EdgePar = 0.05; % threshold parameter +params.Regul_Iterations = 150; +params.Ring_LambdaR_L1 = 0.002; % Soft-Thresh L1 ring variable parameter +params.Ring_Alpha = 21; % to boost ring removal procedure +params.weights = Weights3D(:,:,1:5); +params.subsets = 16; % the number of ordered subsets +params.show = 1; +params.maxvalplot = 2.5; params.slice = 1; + +tic; [X_fista_GH_HO, outputHO] = FISTA_REC(params); toc; +figure; imshow(X_fista_GH_HO(:,:,params.slice) , [0, 2.5]); title ('FISTA-OS-HigherOrderDiffusion reconstruction'); +%% +fprintf('%s\n', 'Reconstruction using FISTA-PB...'); +% !GPU version! +clear params +params.proj_geom = proj_geom; % pass geometry to the function +params.vol_geom = vol_geom; +params.sino = Sino3D(:,:,1); +params.iterFISTA = 25; +params.Regul_LambdaPatchBased_GPU = 3; % PB regularization parameter +params.Regul_PB_h = 0.04; % threhsold parameter +params.Regul_PB_SearchW = 3; +params.Regul_PB_SimilW = 1; +params.Ring_LambdaR_L1 = 0.002; % Soft-Thresh L1 ring variable parameter +params.Ring_Alpha = 21; % to boost ring removal procedure +params.weights = Weights3D(:,:,1); +params.show = 1; +params.maxvalplot = 2.5; params.slice = 1; + +tic; [X_fista_GH_PB, outputPB] = FISTA_REC(params); toc; +figure; imshow(X_fista_GH_PB(:,:,params.slice) , [0, 2.5]); title ('FISTA-OS-PB reconstruction'); %% fprintf('%s\n', 'Reconstruction using FISTA-OS-GH-TGV...'); % still testing... @@ -146,6 +167,7 @@ params.maxvalplot = 2.5; params.slice = 1; tic; [X_fista_GH_TGV, outputTGV] = FISTA_REC(params); toc; figure; imshow(X_fista_GH_TGV(:,:,params.slice) , [0, 2.5]); title ('FISTA-OS-GH-TGV reconstruction'); + %% % fprintf('%s\n', 'Reconstruction using FISTA-Student-TV...'); % clear params diff --git a/main_func/FISTA_REC.m b/main_func/FISTA_REC.m index 00e59ab..edccbe1 100644 --- a/main_func/FISTA_REC.m +++ b/main_func/FISTA_REC.m @@ -1,6 +1,15 @@ function [X, output] = FISTA_REC(params) % <<<< FISTA-based reconstruction routine using ASTRA-toolbox >>>> +% This code solves regularised PWLS problem using FISTA approach. +% The code contains multiple regularisation penalties as well as it can be +% accelerated by using ordered-subset version. Various projection +% geometries supported. + +% DISCLAIMER +% It is recommended to use ASTRA version 1.8 or later in order to avoid +% crashing due to GPU memory overflow for big datasets + % ___Input___: % params.[] file: %----------------General Parameters------------------------ @@ -18,11 +27,17 @@ function [X, output] = FISTA_REC(params) % 2 .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) % 3 .Regul_LambdaLLT (Higher order LLT regularization parameter) % 3.1 .Regul_tauLLT (time step parameter for LLT (HO) term) -% 4 .Regul_LambdaPatchBased (Patch-based nonlocal regularization parameter) +% 4 .Regul_LambdaPatchBased_CPU (Patch-based nonlocal regularization parameter) % 4.1 .Regul_PB_SearchW (ratio of the searching window (e.g. 3 = (2*3+1) = 7 pixels window)) % 4.2 .Regul_PB_SimilW (ratio of the similarity window (e.g. 1 = (2*1+1) = 3 pixels window)) % 4.3 .Regul_PB_h (PB penalty function threshold) -% 5 .Regul_LambdaTGV (Total Generalized variation regularization parameter) +% 5 .Regul_LambdaPatchBased_GPU (Patch-based nonlocal regularization parameter) +% 5.1 .Regul_PB_SearchW (ratio of the searching window (e.g. 3 = (2*3+1) = 7 pixels window)) +% 5.2 .Regul_PB_SimilW (ratio of the similarity window (e.g. 1 = (2*1+1) = 3 pixels window)) +% 5.3 .Regul_PB_h (PB penalty function threshold) +% 6 .Regul_LambdaDiffHO (Higher-Order Diffusion regularization parameter) +% 6.1 .Regul_DiffHO_EdgePar (edge-preserving noise related parameter) +% 7 .Regul_LambdaTGV (Total Generalized variation regularization parameter) % - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) % - .Regul_Iterations (iterations for the selected penalty, default 25) % - .Regul_Dimension ('2D' or '3D' way to apply regularization, '3D' is the default) @@ -173,11 +188,16 @@ if (isfield(params,'Regul_tauLLT')) else tauHO = 0.0001; end -if (isfield(params,'Regul_LambdaPatchBased')) - lambdaPB = params.Regul_LambdaPatchBased; +if (isfield(params,'Regul_LambdaPatchBased_CPU')) + lambdaPB = params.Regul_LambdaPatchBased_CPU; else lambdaPB = 0; end +if (isfield(params,'Regul_LambdaPatchBased_GPU')) + lambdaPB_GPU = params.Regul_LambdaPatchBased_GPU; +else + lambdaPB_GPU = 0; +end if (isfield(params,'Regul_PB_SearchW')) SearchW = params.Regul_PB_SearchW; else @@ -193,6 +213,16 @@ if (isfield(params,'Regul_PB_h')) else h_PB = 0.1; % default end +if (isfield(params,'Regul_LambdaDiffHO')) + LambdaDiff_HO = params.Regul_LambdaDiffHO; +else + LambdaDiff_HO = 0; +end +if (isfield(params,'Regul_DiffHO_EdgePar')) + LambdaDiff_HO_EdgePar = params.Regul_DiffHO_EdgePar; +else + LambdaDiff_HO_EdgePar = 0.01; +end if (isfield(params,'Regul_LambdaTGV')) LambdaTGV = params.Regul_LambdaTGV; else @@ -305,16 +335,16 @@ if (subsets == 0) t_old = t; r_old = r; - % if the geometry is parallel use slice-by-slice projection-backprojection routine + % if the geometry is parallel use slice-by-slice projection-backprojection routine if (strcmp(proj_geom.type,'parallel') || strcmp(proj_geom.type,'parallel3d')) - sino_updt = zeros(size(sino),'single'); - for kkk = 1:SlicesZ - [sino_id, sino_updt(:,:,kkk)] = astra_create_sino3d_cuda(X_t(:,:,kkk), proj_geomT, vol_geomT); - astra_mex_data3d('delete', sino_id); - end - else - % for divergent 3D geometry (for Matlab watch the GPU memory overflow) - [sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); + sino_updt = zeros(size(sino),'single'); + for kkk = 1:SlicesZ + [sino_id, sino_updt(:,:,kkk)] = astra_create_sino3d_cuda(X_t(:,:,kkk), proj_geomT, vol_geomT); + astra_mex_data3d('delete', sino_id); + end + else + % for divergent 3D geometry (watch the GPU memory overflow in earlier ASTRA versions < 1.8) + [sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); end if (lambdaR_L1 > 0) @@ -332,17 +362,17 @@ if (subsets == 0) residual = weights.*(sino_updt - sino); end - objective(i) = (0.5*norm(residual(:))^2)/(Detectors*anglesNumb*SlicesZ); % for the objective function output + objective(i) = (0.5*sum(residual(:).^2)); % for the objective function output - % if the geometry is parallel use slice-by-slice projection-backprojection routine + % if the geometry is parallel use slice-by-slice projection-backprojection routine if (strcmp(proj_geom.type,'parallel') || strcmp(proj_geom.type,'parallel3d')) - x_temp = zeros(size(X),'single'); - for kkk = 1:SlicesZ - [id, x_temp(:,:,kkk)] = astra_create_backprojection3d_cuda(squeeze(residual(:,:,kkk)), proj_geomT, vol_geomT); - astra_mex_data3d('delete', id); - end + x_temp = zeros(size(X),'single'); + for kkk = 1:SlicesZ + [id, x_temp(:,:,kkk)] = astra_create_backprojection3d_cuda(squeeze(residual(:,:,kkk)), proj_geomT, vol_geomT); + astra_mex_data3d('delete', id); + end else - [id, x_temp] = astra_create_backprojection3d_cuda(residual, proj_geom, vol_geom); + [id, x_temp] = astra_create_backprojection3d_cuda(residual, proj_geom, vol_geom); end X = X_t - (1/L_const).*x_temp; astra_mex_data3d('delete', sino_id); @@ -360,7 +390,7 @@ if (subsets == 0) % 3D regularization [X, f_val] = FGP_TV(single(X), lambdaFGP_TV, IterationsRegul, tol, 'iso'); end - objective(i) = objective(i) + f_val; + objective(i) = (objective(i) + f_val)./(Detectors*anglesNumb*SlicesZ); end if (lambdaSB_TV > 0) % Split Bregman regularization @@ -390,7 +420,7 @@ if (subsets == 0) end if (lambdaPB > 0) - % Patch-Based regularization (can be slow on CPU) + % Patch-Based regularization (can be very slow on CPU) if ((strcmp('2D', Dimension) == 1)) % 2D regularization for kkk = 1:SlicesZ @@ -400,13 +430,35 @@ if (subsets == 0) X = PatchBased_Regul(single(X), SearchW, SimilW, h_PB, lambdaPB); end end - if (LambdaTGV > 0) - % Total Generalized variation (currently only 2D) - lamTGV1 = 1.1; % smoothing trade-off parameters, see Pock's paper - lamTGV2 = 0.8; % second-order term + if (lambdaPB_GPU > 0) + % Patch-Based regularization (GPU CUDA implementation) + if ((strcmp('2D', Dimension) == 1)) + % 2D regularization for kkk = 1:SlicesZ - X(:,:,kkk) = TGV_PD(single(X(:,:,kkk)), LambdaTGV, lamTGV1, lamTGV2, IterationsRegul); + X(:,:,kkk) = NLM_GPU(single(X(:,:,kkk)), SearchW, SimilW, h_PB, lambdaPB_GPU); end + else + X = NLM_GPU(single(X), SearchW, SimilW, h_PB, lambdaPB_GPU); + end + end + if (LambdaDiff_HO > 0) + % Higher-order diffusion penalty (GPU CUDA implementation) + if ((strcmp('2D', Dimension) == 1)) + % 2D regularization + for kkk = 1:SlicesZ + X(:,:,kkk) = Diff4thHajiaboli_GPU(single(X(:,:,kkk)), LambdaDiff_HO_EdgePar, LambdaDiff_HO, IterationsRegul); + end + else + X = Diff4thHajiaboli_GPU(X, LambdaDiff_HO_EdgePar, LambdaDiff_HO, IterationsRegul); + end + end + if (LambdaTGV > 0) + % Total Generalized variation (currently only 2D) + lamTGV1 = 1.1; % smoothing trade-off parameters, see Pock's paper + lamTGV2 = 0.8; % second-order term + for kkk = 1:SlicesZ + X(:,:,kkk) = TGV_PD(single(X(:,:,kkk)), LambdaTGV, lamTGV1, lamTGV2, IterationsRegul); + end end if (lambdaR_L1 > 0) @@ -470,7 +522,7 @@ else % the ring removal part (Group-Huber fidelity) % first 2 iterations do additional work reconstructing whole dataset to ensure - % stablility + % the stablility if (i < 3) [sino_id2, sino_updt2] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); astra_mex_data3d('delete', sino_id2); @@ -546,7 +598,7 @@ else % 3D regularization X2 = LLT_model(single(X), lambdaHO/subsets, tauHO/subsets, iterHO, 2.0e-05, 0); end - X = 0.5.*(X + X2); % averaged combination of two solutions + X = 0.5.*(X + X2); % the averaged combination of two solutions end if (lambdaPB > 0) % Patch-Based regularization (can be slow on CPU) @@ -559,14 +611,36 @@ else X = PatchBased_Regul(single(X), SearchW, SimilW, h_PB, lambdaPB/subsets); end end + if (lambdaPB_GPU > 0) + % Patch-Based regularization (GPU CUDA implementation) + if ((strcmp('2D', Dimension) == 1)) + % 2D regularization + for kkk = 1:SlicesZ + X(:,:,kkk) = NLM_GPU(single(X(:,:,kkk)), SearchW, SimilW, h_PB, lambdaPB_GPU); + end + else + X = NLM_GPU(single(X), SearchW, SimilW, h_PB, lambdaPB_GPU); + end + end + if (LambdaDiff_HO > 0) + % Higher-order diffusion penalty (GPU CUDA implementation) + if ((strcmp('2D', Dimension) == 1)) + % 2D regularization + for kkk = 1:SlicesZ + X(:,:,kkk) = Diff4thHajiaboli_GPU(single(X(:,:,kkk)), LambdaDiff_HO_EdgePar, LambdaDiff_HO, round(IterationsRegul/subsets)); + end + else + X = Diff4thHajiaboli_GPU(X, LambdaDiff_HO_EdgePar, LambdaDiff_HO, round(IterationsRegul/subsets)); + end + end if (LambdaTGV > 0) % Total Generalized variation (currently only 2D) lamTGV1 = 1.1; % smoothing trade-off parameters, see Pock's paper lamTGV2 = 0.5; % second-order term for kkk = 1:SlicesZ - X(:,:,kkk) = TGV_PD(single(X(:,:,kkk)), LambdaTGV/subsets, lamTGV1, lamTGV2, IterationsRegul); + X(:,:,kkk) = TGV_PD(single(X(:,:,kkk)), LambdaTGV/subsets, lamTGV1, lamTGV2, IterationsRegul); end - end + end if (lambdaR_L1 > 0) r = max(abs(r)-lambdaR_L1, 0).*sign(r); % soft-thresholding operator for ring vector diff --git a/main_func/regularizers_CPU/FGP_TV.c b/main_func/regularizers_CPU/FGP_TV.c index 5d8cfb9..cfe5b9e 100644 --- a/main_func/regularizers_CPU/FGP_TV.c +++ b/main_func/regularizers_CPU/FGP_TV.c @@ -54,7 +54,7 @@ void mexFunction( { int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV; const int *dim_array; - float *A, *D=NULL, *D_old=NULL, *P1=NULL, *P2=NULL, *P3=NULL, *P1_old=NULL, *P2_old=NULL, *P3_old=NULL, *R1=NULL, *R2=NULL, *R3=NULL, lambda, tk, tkp1, re, re1, re_old, epsil, funcval; + float *A, *D=NULL, *D_old=NULL, *P1=NULL, *P2=NULL, *P3=NULL, *P1_old=NULL, *P2_old=NULL, *P3_old=NULL, *R1=NULL, *R2=NULL, *R3=NULL, lambda, tk, tkp1, re, re1, re_old, epsil; number_of_dims = mxGetNumberOfDimensions(prhs[0]); dim_array = mxGetDimensions(prhs[0]); @@ -78,7 +78,6 @@ void mexFunction( mxFree(penalty_type); } /*output function value (last iteration) */ - funcval = 0.0f; plhs[1] = mxCreateNumericMatrix(1, 1, mxSINGLE_CLASS, mxREAL); float *funcvalA = (float *) mxGetData(plhs[1]); @@ -117,7 +116,7 @@ void mexFunction( /*updating R and t*/ tkp1 = (1.0f + sqrt(1.0f + 4.0f*tk*tk))*0.5f; - Rupd_func2D(P1, P1_old, P2, P2_old, R1, R2, tkp1, tk, dimX, dimY); + Rupd_func2D(P1, P1_old, P2, P2_old, R1, R2, tkp1, tk, dimX, dimY); /* calculate norm */ re = 0.0f; re1 = 0.0f; @@ -129,23 +128,17 @@ void mexFunction( re = sqrt(re)/sqrt(re1); if (re < epsil) count++; if (count > 3) { - Obj_func2D(A, D, P1, P2, lambda, dimX, dimY); - funcval = 0.0f; - for(j=0; j 2) { if (re > re_old) { - Obj_func2D(A, D, P1, P2, lambda, dimX, dimY); - funcval = 0.0f; - for(j=0; j 3) { - Obj_func3D(A, D, P1, P2, P3,lambda, dimX, dimY, dimZ); - funcval = 0.0f; - for(j=0; j 2) { if (re > re_old) { - Obj_func3D(A, D, P1, P2, P3,lambda, dimX, dimY, dimZ); - funcval = 0.0f; - for(j=0; j +#include +#include +#include +#include +#include +#include "Diff4th_GPU_kernel.h" + +/* + * 2D and 3D CUDA implementation of the 4th order PDE denoising model by Hajiaboli + * + * Reference : + * "An anisotropic fourth-order diffusion filter for image noise removal" by M. Hajiaboli + * + * Example + * figure; + * Im = double(imread('lena_gray_256.tif'))/255; % loading image + * u0 = Im + .05*randn(size(Im)); % adding noise + * u = Diff4thHajiaboli_GPU(single(u0), 0.02, 150); + * subplot (1,2,1); imshow(u0,[ ]); title('Noisy Image') + * subplot (1,2,2); imshow(u,[ ]); title('Denoised Image') + * + * + * Linux/Matlab compilation: + * compile in terminal: nvcc -Xcompiler -fPIC -shared -o Diff4th_GPU_kernel.o Diff4th_GPU_kernel.cu + * then compile in Matlab: mex -I/usr/local/cuda-7.5/include -L/usr/local/cuda-7.5/lib64 -lcudart Diff4thHajiaboli_GPU.cpp Diff4th_GPU_kernel.o + */ + +void mexFunction( + int nlhs, mxArray *plhs[], + int nrhs, const mxArray *prhs[]) +{ + int numdims, dimZ, size; + float *A, *B, *A_L, *B_L; + const int *dims; + + numdims = mxGetNumberOfDimensions(prhs[0]); + dims = mxGetDimensions(prhs[0]); + + float sigma = (float)mxGetScalar(prhs[1]); /* edge-preserving parameter */ + float lambda = (float)mxGetScalar(prhs[2]); /* regularization parameter */ + int iter = (int)mxGetScalar(prhs[3]); /* iterations number */ + + if (numdims == 2) { + + int N, M, Z, i, j; + Z = 0; // for the 2D case + float tau = 0.01; // time step is sufficiently small for an explicit methods + + /*Input data*/ + A = (float*)mxGetData(prhs[0]); + N = dims[0] + 2; + M = dims[1] + 2; + A_L = (float*)mxGetData(mxCreateNumericMatrix(N, M, mxSINGLE_CLASS, mxREAL)); + B_L = (float*)mxGetData(mxCreateNumericMatrix(N, M, mxSINGLE_CLASS, mxREAL)); + + /*Output data*/ + B = (float*)mxGetData(plhs[0] = mxCreateNumericMatrix(dims[0], dims[1], mxSINGLE_CLASS, mxREAL)); + + // copy A to the bigger A_L with boundaries + #pragma omp parallel for shared(A_L, A) private(i,j) + for (i=0; i < N; i++) { + for (j=0; j < M; j++) { + if (((i > 0) && (i < N-1)) && ((j > 0) && (j < M-1))) A_L[i*M+j] = A[(i-1)*(dims[1])+(j-1)]; + }} + + // Running CUDA code here + Diff4th_GPU_kernel(A_L, B_L, N, M, Z, (float)sigma, iter, (float)tau, lambda); + + // copy the processed B_L to a smaller B + #pragma omp parallel for shared(B_L, B) private(i,j) + for (i=0; i < N; i++) { + for (j=0; j < M; j++) { + if (((i > 0) && (i < N-1)) && ((j > 0) && (j < M-1))) B[(i-1)*(dims[1])+(j-1)] = B_L[i*M+j]; + }} + } + if (numdims == 3) { + // 3D image denoising / regularization + int N, M, Z, i, j, k; + float tau = 0.0007; // Time Step is small for an explicit methods + A = (float*)mxGetData(prhs[0]); + N = dims[0] + 2; + M = dims[1] + 2; + Z = dims[2] + 2; + int N_dims[] = {N, M, Z}; + A_L = (float*)mxGetPr(mxCreateNumericArray(3, N_dims, mxSINGLE_CLASS, mxREAL)); + B_L = (float*)mxGetPr(mxCreateNumericArray(3, N_dims, mxSINGLE_CLASS, mxREAL)); + + /* output data */ + B = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dims, mxSINGLE_CLASS, mxREAL)); + + // copy A to the bigger A_L with boundaries + #pragma omp parallel for shared(A_L, A) private(i,j,k) + for (i=0; i < N; i++) { + for (j=0; j < M; j++) { + for (k=0; k < Z; k++) { + if (((i > 0) && (i < N-1)) && ((j > 0) && (j < M-1)) && ((k > 0) && (k < Z-1))) { + A_L[(N*M)*(k)+(i)*M+(j)] = A[(dims[0]*dims[1])*(k-1)+(i-1)*dims[1]+(j-1)]; + }}}} + + // Running CUDA kernel here for diffusivity + Diff4th_GPU_kernel(A_L, B_L, N, M, Z, (float)sigma, iter, (float)tau, lambda); + + // copy the processed B_L to a smaller B + #pragma omp parallel for shared(B_L, B) private(i,j,k) + for (i=0; i < N; i++) { + for (j=0; j < M; j++) { + for (k=0; k < Z; k++) { + if (((i > 0) && (i < N-1)) && ((j > 0) && (j < M-1)) && ((k > 0) && (k < Z-1))) { + B[(dims[0]*dims[1])*(k-1)+(i-1)*dims[1]+(j-1)] = B_L[(N*M)*(k)+(i)*M+(j)]; + }}}} + } +} \ No newline at end of file diff --git a/main_func/regularizers_GPU/Diffus_HO/Diff4th_GPU_kernel.cu b/main_func/regularizers_GPU/Diffus_HO/Diff4th_GPU_kernel.cu new file mode 100644 index 0000000..178af00 --- /dev/null +++ b/main_func/regularizers_GPU/Diffus_HO/Diff4th_GPU_kernel.cu @@ -0,0 +1,270 @@ +#include +#include +#include +#include "Diff4th_GPU_kernel.h" + +#define checkCudaErrors(err) __checkCudaErrors (err, __FILE__, __LINE__) + +inline void __checkCudaErrors(cudaError err, const char *file, const int line) +{ + if (cudaSuccess != err) + { + fprintf(stderr, "%s(%i) : CUDA Runtime API error %d: %s.\n", + file, line, (int)err, cudaGetErrorString(err)); + exit(EXIT_FAILURE); + } +} + +#define idivup(a, b) ( ((a)%(b) != 0) ? (a)/(b)+1 : (a)/(b) ) +#define sizeT (sizeX*sizeY*sizeZ) +#define epsilon 0.00000001 + +///////////////////////////////////////////////// +// 2D Image denosing - Second Step (The second derrivative) +__global__ void Diff4th2D_derriv(float* B, float* A, float *A0, int N, int M, float sigma, int iter, float tau, float lambda) +{ + float gradXXc = 0, gradYYc = 0; + int i = blockIdx.x*blockDim.x + threadIdx.x; + int j = blockIdx.y*blockDim.y + threadIdx.y; + + int index = j + i*N; + + if (((i < 1) || (i > N-2)) || ((j < 1) || (j > M-2))) { + return; } + + int indexN = (j)+(i-1)*(N); if (A[indexN] == 0) indexN = index; + int indexS = (j)+(i+1)*(N); if (A[indexS] == 0) indexS = index; + int indexW = (j-1)+(i)*(N); if (A[indexW] == 0) indexW = index; + int indexE = (j+1)+(i)*(N); if (A[indexE] == 0) indexE = index; + + gradXXc = B[indexN] + B[indexS] - 2*B[index] ; + gradYYc = B[indexW] + B[indexE] - 2*B[index] ; + A[index] = A[index] - tau*((A[index] - A0[index]) + lambda*(gradXXc + gradYYc)); +} + +// 2D Image denosing - The First Step +__global__ void Diff4th2D(float* A, float* B, int N, int M, float sigma, int iter, float tau) +{ + float gradX, gradX_sq, gradY, gradY_sq, gradXX, gradYY, gradXY, sq_sum, xy_2, V_norm, V_orth, c, c_sq; + + int i = blockIdx.x*blockDim.x + threadIdx.x; + int j = blockIdx.y*blockDim.y + threadIdx.y; + + int index = j + i*N; + + V_norm = 0.0f; V_orth = 0.0f; + + if (((i < 1) || (i > N-2)) || ((j < 1) || (j > M-2))) { + return; } + + int indexN = (j)+(i-1)*(N); if (A[indexN] == 0) indexN = index; + int indexS = (j)+(i+1)*(N); if (A[indexS] == 0) indexS = index; + int indexW = (j-1)+(i)*(N); if (A[indexW] == 0) indexW = index; + int indexE = (j+1)+(i)*(N); if (A[indexE] == 0) indexE = index; + int indexNW = (j-1)+(i-1)*(N); if (A[indexNW] == 0) indexNW = index; + int indexNE = (j+1)+(i-1)*(N); if (A[indexNE] == 0) indexNE = index; + int indexWS = (j-1)+(i+1)*(N); if (A[indexWS] == 0) indexWS = index; + int indexES = (j+1)+(i+1)*(N); if (A[indexES] == 0) indexES = index; + + gradX = 0.5f*(A[indexN]-A[indexS]); + gradX_sq = gradX*gradX; + gradXX = A[indexN] + A[indexS] - 2*A[index]; + + gradY = 0.5f*(A[indexW]-A[indexE]); + gradY_sq = gradY*gradY; + gradYY = A[indexW] + A[indexE] - 2*A[index]; + + gradXY = 0.25f*(A[indexNW] - A[indexNE] - A[indexWS] + A[indexES]); + xy_2 = 2.0f*gradX*gradY*gradXY; + sq_sum = gradX_sq + gradY_sq; + + if (sq_sum <= epsilon) { + V_norm = (gradXX*gradX_sq + xy_2 + gradYY*gradY_sq)/epsilon; + V_orth = (gradXX*gradY_sq - xy_2 + gradYY*gradX_sq)/epsilon; } + else { + V_norm = (gradXX*gradX_sq + xy_2 + gradYY*gradY_sq)/sq_sum; + V_orth = (gradXX*gradY_sq - xy_2 + gradYY*gradX_sq)/sq_sum; } + + c = 1.0f/(1.0f + sq_sum/sigma); + c_sq = c*c; + B[index] = c_sq*V_norm + c*V_orth; +} + +///////////////////////////////////////////////// +// 3D data parocerssing +__global__ void Diff4th3D_derriv(float *B, float *A, float *A0, int N, int M, int Z, float sigma, int iter, float tau, float lambda) +{ + float gradXXc = 0, gradYYc = 0, gradZZc = 0; + int xIndex = blockDim.x * blockIdx.x + threadIdx.x; + int yIndex = blockDim.y * blockIdx.y + threadIdx.y; + int zIndex = blockDim.z * blockIdx.z + threadIdx.z; + + int index = xIndex + M*yIndex + N*M*zIndex; + + if (((xIndex < 1) || (xIndex > N-2)) || ((yIndex < 1) || (yIndex > M-2)) || ((zIndex < 1) || (zIndex > Z-2))) { + return; } + + int indexN = (xIndex-1) + M*yIndex + N*M*zIndex; if (A[indexN] == 0) indexN = index; + int indexS = (xIndex+1) + M*yIndex + N*M*zIndex; if (A[indexS] == 0) indexS = index; + int indexW = xIndex + M*(yIndex-1) + N*M*zIndex; if (A[indexW] == 0) indexW = index; + int indexE = xIndex + M*(yIndex+1) + N*M*zIndex; if (A[indexE] == 0) indexE = index; + int indexU = xIndex + M*yIndex + N*M*(zIndex-1); if (A[indexU] == 0) indexU = index; + int indexD = xIndex + M*yIndex + N*M*(zIndex+1); if (A[indexD] == 0) indexD = index; + + gradXXc = B[indexN] + B[indexS] - 2*B[index] ; + gradYYc = B[indexW] + B[indexE] - 2*B[index] ; + gradZZc = B[indexU] + B[indexD] - 2*B[index] ; + + A[index] = A[index] - tau*((A[index] - A0[index]) + lambda*(gradXXc + gradYYc + gradZZc)); +} + +__global__ void Diff4th3D(float* A, float* B, int N, int M, int Z, float sigma, int iter, float tau) +{ + float gradX, gradX_sq, gradY, gradY_sq, gradZ, gradZ_sq, gradXX, gradYY, gradZZ, gradXY, gradXZ, gradYZ, sq_sum, xy_2, xyz_1, xyz_2, V_norm, V_orth, c, c_sq; + + int xIndex = blockDim.x * blockIdx.x + threadIdx.x; + int yIndex = blockDim.y * blockIdx.y + threadIdx.y; + int zIndex = blockDim.z * blockIdx.z + threadIdx.z; + + int index = xIndex + M*yIndex + N*M*zIndex; + V_norm = 0.0f; V_orth = 0.0f; + + if (((xIndex < 1) || (xIndex > N-2)) || ((yIndex < 1) || (yIndex > M-2)) || ((zIndex < 1) || (zIndex > Z-2))) { + return; } + + B[index] = 0; + + int indexN = (xIndex-1) + M*yIndex + N*M*zIndex; if (A[indexN] == 0) indexN = index; + int indexS = (xIndex+1) + M*yIndex + N*M*zIndex; if (A[indexS] == 0) indexS = index; + int indexW = xIndex + M*(yIndex-1) + N*M*zIndex; if (A[indexW] == 0) indexW = index; + int indexE = xIndex + M*(yIndex+1) + N*M*zIndex; if (A[indexE] == 0) indexE = index; + int indexU = xIndex + M*yIndex + N*M*(zIndex-1); if (A[indexU] == 0) indexU = index; + int indexD = xIndex + M*yIndex + N*M*(zIndex+1); if (A[indexD] == 0) indexD = index; + + int indexNW = (xIndex-1) + M*(yIndex-1) + N*M*zIndex; if (A[indexNW] == 0) indexNW = index; + int indexNE = (xIndex-1) + M*(yIndex+1) + N*M*zIndex; if (A[indexNE] == 0) indexNE = index; + int indexWS = (xIndex+1) + M*(yIndex-1) + N*M*zIndex; if (A[indexWS] == 0) indexWS = index; + int indexES = (xIndex+1) + M*(yIndex+1) + N*M*zIndex; if (A[indexES] == 0) indexES = index; + + int indexUW = (xIndex-1) + M*(yIndex) + N*M*(zIndex-1); if (A[indexUW] == 0) indexUW = index; + int indexUE = (xIndex+1) + M*(yIndex) + N*M*(zIndex-1); if (A[indexUE] == 0) indexUE = index; + int indexDW = (xIndex-1) + M*(yIndex) + N*M*(zIndex+1); if (A[indexDW] == 0) indexDW = index; + int indexDE = (xIndex+1) + M*(yIndex) + N*M*(zIndex+1); if (A[indexDE] == 0) indexDE = index; + + int indexUN = (xIndex) + M*(yIndex-1) + N*M*(zIndex-1); if (A[indexUN] == 0) indexUN = index; + int indexUS = (xIndex) + M*(yIndex+1) + N*M*(zIndex-1); if (A[indexUS] == 0) indexUS = index; + int indexDN = (xIndex) + M*(yIndex-1) + N*M*(zIndex+1); if (A[indexDN] == 0) indexDN = index; + int indexDS = (xIndex) + M*(yIndex+1) + N*M*(zIndex+1); if (A[indexDS] == 0) indexDS = index; + + gradX = 0.5f*(A[indexN]-A[indexS]); + gradX_sq = gradX*gradX; + gradXX = A[indexN] + A[indexS] - 2*A[index]; + + gradY = 0.5f*(A[indexW]-A[indexE]); + gradY_sq = gradY*gradY; + gradYY = A[indexW] + A[indexE] - 2*A[index]; + + gradZ = 0.5f*(A[indexU]-A[indexD]); + gradZ_sq = gradZ*gradZ; + gradZZ = A[indexU] + A[indexD] - 2*A[index]; + + gradXY = 0.25f*(A[indexNW] - A[indexNE] - A[indexWS] + A[indexES]); + gradXZ = 0.25f*(A[indexUW] - A[indexUE] - A[indexDW] + A[indexDE]); + gradYZ = 0.25f*(A[indexUN] - A[indexUS] - A[indexDN] + A[indexDS]); + + xy_2 = 2.0f*gradX*gradY*gradXY; + xyz_1 = 2.0f*gradX*gradZ*gradXZ; + xyz_2 = 2.0f*gradY*gradZ*gradYZ; + + sq_sum = gradX_sq + gradY_sq + gradZ_sq; + + if (sq_sum <= epsilon) { + V_norm = (gradXX*gradX_sq + gradYY*gradY_sq + gradZZ*gradZ_sq + xy_2 + xyz_1 + xyz_2)/epsilon; + V_orth = ((gradY_sq + gradZ_sq)*gradXX + (gradX_sq + gradZ_sq)*gradYY + (gradX_sq + gradY_sq)*gradZZ - xy_2 - xyz_1 - xyz_2)/epsilon; } + else { + V_norm = (gradXX*gradX_sq + gradYY*gradY_sq + gradZZ*gradZ_sq + xy_2 + xyz_1 + xyz_2)/sq_sum; + V_orth = ((gradY_sq + gradZ_sq)*gradXX + (gradX_sq + gradZ_sq)*gradYY + (gradX_sq + gradY_sq)*gradZZ - xy_2 - xyz_1 - xyz_2)/sq_sum; } + + c = 1; + if ((1.0f + sq_sum/sigma) != 0.0f) {c = 1.0f/(1.0f + sq_sum/sigma);} + + c_sq = c*c; + B[index] = c_sq*V_norm + c*V_orth; +} + +/******************************************************/ +/********* HOST FUNCTION*************/ +extern "C" void Diff4th_GPU_kernel(float* A, float* B, int N, int M, int Z, float sigma, int iter, float tau, float lambda) +{ + int deviceCount = -1; // number of devices + cudaGetDeviceCount(&deviceCount); + if (deviceCount == 0) { + fprintf(stderr, "No CUDA devices found\n"); + return; + } + + int BLKXSIZE, BLKYSIZE,BLKZSIZE; + float *Ad, *Bd, *Cd; + sigma = sigma*sigma; + + if (Z == 0){ + // 4th order diffusion for 2D case + BLKXSIZE = 8; + BLKYSIZE = 16; + + dim3 dimBlock(BLKXSIZE,BLKYSIZE); + dim3 dimGrid(idivup(N,BLKXSIZE), idivup(M,BLKYSIZE)); + + checkCudaErrors(cudaMalloc((void**)&Ad,N*M*sizeof(float))); + checkCudaErrors(cudaMalloc((void**)&Bd,N*M*sizeof(float))); + checkCudaErrors(cudaMalloc((void**)&Cd,N*M*sizeof(float))); + + checkCudaErrors(cudaMemcpy(Ad,A,N*M*sizeof(float),cudaMemcpyHostToDevice)); + checkCudaErrors(cudaMemcpy(Bd,A,N*M*sizeof(float),cudaMemcpyHostToDevice)); + checkCudaErrors(cudaMemcpy(Cd,A,N*M*sizeof(float),cudaMemcpyHostToDevice)); + + int n = 1; + while (n <= iter) { + Diff4th2D<<>>(Bd, Cd, N, M, sigma, iter, tau); + cudaDeviceSynchronize(); + checkCudaErrors( cudaPeekAtLastError() ); + Diff4th2D_derriv<<>>(Cd, Bd, Ad, N, M, sigma, iter, tau, lambda); + cudaDeviceSynchronize(); + checkCudaErrors( cudaPeekAtLastError() ); + n++; + } + checkCudaErrors(cudaMemcpy(B,Bd,N*M*sizeof(float),cudaMemcpyDeviceToHost)); + cudaFree(Ad); cudaFree(Bd); cudaFree(Cd); + } + + if (Z != 0){ + // 4th order diffusion for 3D case + BLKXSIZE = 8; + BLKYSIZE = 8; + BLKZSIZE = 8; + + dim3 dimBlock(BLKXSIZE,BLKYSIZE,BLKZSIZE); + dim3 dimGrid(idivup(N,BLKXSIZE), idivup(M,BLKYSIZE),idivup(Z,BLKXSIZE)); + + checkCudaErrors(cudaMalloc((void**)&Ad,N*M*Z*sizeof(float))); + checkCudaErrors(cudaMalloc((void**)&Bd,N*M*Z*sizeof(float))); + checkCudaErrors(cudaMalloc((void**)&Cd,N*M*Z*sizeof(float))); + + checkCudaErrors(cudaMemcpy(Ad,A,N*M*Z*sizeof(float),cudaMemcpyHostToDevice)); + checkCudaErrors(cudaMemcpy(Bd,A,N*M*Z*sizeof(float),cudaMemcpyHostToDevice)); + checkCudaErrors(cudaMemcpy(Cd,A,N*M*Z*sizeof(float),cudaMemcpyHostToDevice)); + + int n = 1; + while (n <= iter) { + Diff4th3D<<>>(Bd, Cd, N, M, Z, sigma, iter, tau); + cudaDeviceSynchronize(); + checkCudaErrors( cudaPeekAtLastError() ); + Diff4th3D_derriv<<>>(Cd, Bd, Ad, N, M, Z, sigma, iter, tau, lambda); + cudaDeviceSynchronize(); + checkCudaErrors( cudaPeekAtLastError() ); + n++; + } + checkCudaErrors(cudaMemcpy(B,Bd,N*M*Z*sizeof(float),cudaMemcpyDeviceToHost)); + cudaFree(Ad); cudaFree(Bd); cudaFree(Cd); + } +} \ No newline at end of file diff --git a/main_func/regularizers_GPU/Diffus_HO/Diff4th_GPU_kernel.h b/main_func/regularizers_GPU/Diffus_HO/Diff4th_GPU_kernel.h new file mode 100644 index 0000000..cfbb45a --- /dev/null +++ b/main_func/regularizers_GPU/Diffus_HO/Diff4th_GPU_kernel.h @@ -0,0 +1,6 @@ +#ifndef __DIFF_HO_H_ +#define __DIFF_HO_H_ + +extern "C" void Diff4th_GPU_kernel(float* A, float* B, int N, int M, int Z, float sigma, int iter, float tau, float lambda); + +#endif diff --git a/main_func/regularizers_GPU/NL_Regul/NLM_GPU.cpp b/main_func/regularizers_GPU/NL_Regul/NLM_GPU.cpp new file mode 100644 index 0000000..ff0cc90 --- /dev/null +++ b/main_func/regularizers_GPU/NL_Regul/NLM_GPU.cpp @@ -0,0 +1,171 @@ +#include "mex.h" +#include +#include +#include +#include +#include +#include +#include "NLM_GPU_kernel.h" + +/* CUDA implementation of the patch-based (PB) regularization for 2D and 3D images/volumes + * This method finds self-similar patches in data and performs one fixed point iteration to mimimize the PB penalty function + * + * References: 1. Yang Z. & Jacob M. "Nonlocal Regularization of Inverse Problems" + * 2. Kazantsev D. at. all "4D-CT reconstruction with unified spatial-temporal patch-based regularization" + * + * Input Parameters (mandatory): + * 1. Image/volume (2D/3D) + * 2. ratio of the searching window (e.g. 3 = (2*3+1) = 7 pixels window) + * 3. ratio of the similarity window (e.g. 1 = (2*1+1) = 3 pixels window) + * 4. h - parameter for the PB penalty function + * 5. lambda - regularization parameter + + * Output: + * 1. regularized (denoised) Image/volume (N x N x N) + * + * In matlab check what kind of GPU you have with "gpuDevice" command, + * then set your ComputeCapability, here I use -arch compute_35 + * + * Quick 2D denoising example in Matlab: + Im = double(imread('lena_gray_256.tif'))/255; % loading image + u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise + ImDen = NLM_GPU(single(u0), 3, 2, 0.15, 1); + + * Linux/Matlab compilation: + * compile in terminal: nvcc -Xcompiler -fPIC -shared -o NLM_GPU_kernel.o NLM_GPU_kernel.cu + * then compile in Matlab: mex -I/usr/local/cuda-7.5/include -L/usr/local/cuda-7.5/lib64 -lcudart NLM_GPU.cpp NLM_GPU_kernel.o + * + * D. Kazantsev + * 2014-17 + * Harwell/Manchester UK + */ + +float pad_crop(float *A, float *Ap, int OldSizeX, int OldSizeY, int OldSizeZ, int NewSizeX, int NewSizeY, int NewSizeZ, int padXY, int switchpad_crop); + +void mexFunction( + int nlhs, mxArray *plhs[], + int nrhs, const mxArray *prhs[]) +{ + int N, M, Z, i_n, j_n, k_n, numdims, SearchW, SimilW, SearchW_real, padXY, newsizeX, newsizeY, newsizeZ, switchpad_crop, count, SearchW_full, SimilW_full; + const int *dims; + float *A, *B=NULL, *Ap=NULL, *Bp=NULL, *Eucl_Vec, h, h2, lambda, val, denh2; + + numdims = mxGetNumberOfDimensions(prhs[0]); + dims = mxGetDimensions(prhs[0]); + + N = dims[0]; + M = dims[1]; + Z = dims[2]; + + if ((numdims < 2) || (numdims > 3)) {mexErrMsgTxt("The input should be 2D image or 3D volume");} + if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) {mexErrMsgTxt("The input in single precision is required"); } + + if(nrhs != 5) mexErrMsgTxt("Five inputs reqired: Image(2D,3D), SearchW, SimilW, Threshold, Regularization parameter"); + + /*Handling inputs*/ + A = (float *) mxGetData(prhs[0]); /* the image to regularize/filter */ + SearchW_real = (int) mxGetScalar(prhs[1]); /* the searching window ratio */ + SimilW = (int) mxGetScalar(prhs[2]); /* the similarity window ratio */ + h = (float) mxGetScalar(prhs[3]); /* parameter for the PB filtering function */ + lambda = (float) mxGetScalar(prhs[4]); + + if (h <= 0) mexErrMsgTxt("Parmeter for the PB penalty function should be > 0"); + + SearchW = SearchW_real + 2*SimilW; + + SearchW_full = 2*SearchW + 1; /* the full searching window size */ + SimilW_full = 2*SimilW + 1; /* the full similarity window size */ + h2 = h*h; + + padXY = SearchW + 2*SimilW; /* padding sizes */ + newsizeX = N + 2*(padXY); /* the X size of the padded array */ + newsizeY = M + 2*(padXY); /* the Y size of the padded array */ + newsizeZ = Z + 2*(padXY); /* the Z size of the padded array */ + int N_dims[] = {newsizeX, newsizeY, newsizeZ}; + + /******************************2D case ****************************/ + if (numdims == 2) { + /*Handling output*/ + B = (float*)mxGetData(plhs[0] = mxCreateNumericMatrix(N, M, mxSINGLE_CLASS, mxREAL)); + /*allocating memory for the padded arrays */ + Ap = (float*)mxGetData(mxCreateNumericMatrix(newsizeX, newsizeY, mxSINGLE_CLASS, mxREAL)); + Bp = (float*)mxGetData(mxCreateNumericMatrix(newsizeX, newsizeY, mxSINGLE_CLASS, mxREAL)); + Eucl_Vec = (float*)mxGetData(mxCreateNumericMatrix(SimilW_full*SimilW_full, 1, mxSINGLE_CLASS, mxREAL)); + + /*Gaussian kernel */ + count = 0; + for(i_n=-SimilW; i_n<=SimilW; i_n++) { + for(j_n=-SimilW; j_n<=SimilW; j_n++) { + val = (float)(i_n*i_n + j_n*j_n)/(2*SimilW*SimilW); + Eucl_Vec[count] = exp(-val); + count = count + 1; + }} /*main neighb loop */ + + /**************************************************************************/ + /*Perform padding of image A to the size of [newsizeX * newsizeY] */ + switchpad_crop = 0; /*padding*/ + pad_crop(A, Ap, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop); + + /* Do PB regularization with the padded array */ + NLM_GPU_kernel(Ap, Bp, Eucl_Vec, newsizeY, newsizeX, 0, numdims, SearchW, SimilW, SearchW_real, (float)h2, (float)lambda); + + switchpad_crop = 1; /*cropping*/ + pad_crop(Bp, B, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop); + } + else + { + /******************************3D case ****************************/ + /*Handling output*/ + B = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dims, mxSINGLE_CLASS, mxREAL)); + /*allocating memory for the padded arrays */ + Ap = (float*)mxGetPr(mxCreateNumericArray(3, N_dims, mxSINGLE_CLASS, mxREAL)); + Bp = (float*)mxGetPr(mxCreateNumericArray(3, N_dims, mxSINGLE_CLASS, mxREAL)); + Eucl_Vec = (float*)mxGetData(mxCreateNumericMatrix(SimilW_full*SimilW_full*SimilW_full, 1, mxSINGLE_CLASS, mxREAL)); + + /*Gaussian kernel */ + count = 0; + for(i_n=-SimilW; i_n<=SimilW; i_n++) { + for(j_n=-SimilW; j_n<=SimilW; j_n++) { + for(k_n=-SimilW; k_n<=SimilW; k_n++) { + val = (float)(i_n*i_n + j_n*j_n + k_n*k_n)/(2*SimilW*SimilW*SimilW); + Eucl_Vec[count] = exp(-val); + count = count + 1; + }}} /*main neighb loop */ + /**************************************************************************/ + /*Perform padding of image A to the size of [newsizeX * newsizeY * newsizeZ] */ + switchpad_crop = 0; /*padding*/ + pad_crop(A, Ap, M, N, Z, newsizeY, newsizeX, newsizeZ, padXY, switchpad_crop); + + /* Do PB regularization with the padded array */ + NLM_GPU_kernel(Ap, Bp, Eucl_Vec, newsizeY, newsizeX, newsizeZ, numdims, SearchW, SimilW, SearchW_real, (float)h2, (float)lambda); + + switchpad_crop = 1; /*cropping*/ + pad_crop(Bp, B, M, N, Z, newsizeY, newsizeX, newsizeZ, padXY, switchpad_crop); + } /*end else ndims*/ +} + +float pad_crop(float *A, float *Ap, int OldSizeX, int OldSizeY, int OldSizeZ, int NewSizeX, int NewSizeY, int NewSizeZ, int padXY, int switchpad_crop) +{ + /* padding-cropping function */ + int i,j,k; + if (NewSizeZ > 1) { + for (i=0; i < NewSizeX; i++) { + for (j=0; j < NewSizeY; j++) { + for (k=0; k < NewSizeZ; k++) { + if (((i >= padXY) && (i < NewSizeX-padXY)) && ((j >= padXY) && (j < NewSizeY-padXY)) && ((k >= padXY) && (k < NewSizeZ-padXY))) { + if (switchpad_crop == 0) Ap[NewSizeX*NewSizeY*k + i*NewSizeY+j] = A[OldSizeX*OldSizeY*(k - padXY) + (i-padXY)*(OldSizeY)+(j-padXY)]; + else Ap[OldSizeX*OldSizeY*(k - padXY) + (i-padXY)*(OldSizeY)+(j-padXY)] = A[NewSizeX*NewSizeY*k + i*NewSizeY+j]; + } + }}} + } + else { + for (i=0; i < NewSizeX; i++) { + for (j=0; j < NewSizeY; j++) { + if (((i >= padXY) && (i < NewSizeX-padXY)) && ((j >= padXY) && (j < NewSizeY-padXY))) { + if (switchpad_crop == 0) Ap[i*NewSizeY+j] = A[(i-padXY)*(OldSizeY)+(j-padXY)]; + else Ap[(i-padXY)*(OldSizeY)+(j-padXY)] = A[i*NewSizeY+j]; + } + }} + } + return *Ap; +} \ No newline at end of file diff --git a/main_func/regularizers_GPU/NL_Regul/NLM_GPU_kernel.cu b/main_func/regularizers_GPU/NL_Regul/NLM_GPU_kernel.cu new file mode 100644 index 0000000..17da3a8 --- /dev/null +++ b/main_func/regularizers_GPU/NL_Regul/NLM_GPU_kernel.cu @@ -0,0 +1,239 @@ +#include +#include +#include +#include "NLM_GPU_kernel.h" + +#define checkCudaErrors(err) __checkCudaErrors (err, __FILE__, __LINE__) + +inline void __checkCudaErrors(cudaError err, const char *file, const int line) +{ + if (cudaSuccess != err) + { + fprintf(stderr, "%s(%i) : CUDA Runtime API error %d: %s.\n", + file, line, (int)err, cudaGetErrorString(err)); + exit(EXIT_FAILURE); + } +} + +extern __shared__ float sharedmem[]; + +// run PB den kernel here +__global__ void NLM_kernel(float *Ad, float* Bd, float *Eucl_Vec_d, int N, int M, int Z, int SearchW, int SimilW, int SearchW_real, int SearchW_full, int SimilW_full, int padXY, float h2, float lambda, dim3 imagedim, dim3 griddim, dim3 kerneldim, dim3 sharedmemdim, int nUpdatePerThread, float neighborsize) +{ + + int i1, j1, k1, i2, j2, k2, i3, j3, k3, i_l, j_l, k_l, count; + float value, Weight_norm, normsum, Weight; + + int bidx = blockIdx.x; + int bidy = blockIdx.y%griddim.y; + int bidz = (int)((blockIdx.y)/griddim.y); + + // global index for block endpoint + int beidx = __mul24(bidx,blockDim.x); + int beidy = __mul24(bidy,blockDim.y); + int beidz = __mul24(bidz,blockDim.z); + + int tid = __mul24(threadIdx.z,__mul24(blockDim.x,blockDim.y)) + + __mul24(threadIdx.y,blockDim.x) + threadIdx.x; + + #ifdef __DEVICE_EMULATION__ + printf("tid : %d", tid); + #endif + + // update shared memory + int nthreads = blockDim.x*blockDim.y*blockDim.z; + int sharedMemSize = sharedmemdim.x * sharedmemdim.y * sharedmemdim.z; + for(int i=0; i= padXY && idx < (imagedim.x - padXY) && + idy >= padXY && idy < (imagedim.y - padXY)) + { + int i_centr = threadIdx.x + (SearchW); /*indices of the centrilized (main) pixel */ + int j_centr = threadIdx.y + (SearchW); /*indices of the centrilized (main) pixel */ + + if ((i_centr > 0) && (i_centr < N) && (j_centr > 0) && (j_centr < M)) { + + Weight_norm = 0; value = 0.0; + /* Massive Search window loop */ + for(i1 = i_centr - SearchW_real ; i1 <= i_centr + SearchW_real; i1++) { + for(j1 = j_centr - SearchW_real ; j1<= j_centr + SearchW_real ; j1++) { + /* if inside the searching window */ + count = 0; normsum = 0.0; + for(i_l=-SimilW; i_l<=SimilW; i_l++) { + for(j_l=-SimilW; j_l<=SimilW; j_l++) { + i2 = i1+i_l; j2 = j1+j_l; + i3 = i_centr+i_l; j3 = j_centr+j_l; /*coordinates of the inner patch loop */ + if ((i2 > 0) && (i2 < N) && (j2 > 0) && (j2 < M)) { + if ((i3 > 0) && (i3 < N) && (j3 > 0) && (j3 < M)) { + normsum += Eucl_Vec_d[count]*pow((sharedmem[(j3)*sharedmemdim.x+(i3)] - sharedmem[j2*sharedmemdim.x+i2]), 2); + }} + count++; + }} + if (normsum != 0) Weight = (expf(-normsum/h2)); + else Weight = 0.0; + Weight_norm += Weight; + value += sharedmem[j1*sharedmemdim.x+i1]*Weight; + }} + + if (Weight_norm != 0) Bd[idz*imagedim.x*imagedim.y + idy*imagedim.x + idx] = value/Weight_norm; + else Bd[idz*imagedim.x*imagedim.y + idy*imagedim.x + idx] = Ad[idz*imagedim.x*imagedim.y + idy*imagedim.x + idx]; + } + } /*boundary conditions end*/ + } + else { + /*3D case*/ + /*checking boundaries to be within the image and avoid padded spaces */ + if( idx >= padXY && idx < (imagedim.x - padXY) && + idy >= padXY && idy < (imagedim.y - padXY) && + idz >= padXY && idz < (imagedim.z - padXY) ) + { + int i_centr = threadIdx.x + SearchW; /*indices of the centrilized (main) pixel */ + int j_centr = threadIdx.y + SearchW; /*indices of the centrilized (main) pixel */ + int k_centr = threadIdx.z + SearchW; /*indices of the centrilized (main) pixel */ + + if ((i_centr > 0) && (i_centr < N) && (j_centr > 0) && (j_centr < M) && (k_centr > 0) && (k_centr < Z)) { + + Weight_norm = 0; value = 0.0; + /* Massive Search window loop */ + for(i1 = i_centr - SearchW_real ; i1 <= i_centr + SearchW_real; i1++) { + for(j1 = j_centr - SearchW_real ; j1<= j_centr + SearchW_real ; j1++) { + for(k1 = k_centr - SearchW_real ; k1<= k_centr + SearchW_real ; k1++) { + /* if inside the searching window */ + count = 0; normsum = 0.0; + for(i_l=-SimilW; i_l<=SimilW; i_l++) { + for(j_l=-SimilW; j_l<=SimilW; j_l++) { + for(k_l=-SimilW; k_l<=SimilW; k_l++) { + i2 = i1+i_l; j2 = j1+j_l; k2 = k1+k_l; + i3 = i_centr+i_l; j3 = j_centr+j_l; k3 = k_centr+k_l; /*coordinates of the inner patch loop */ + if ((i2 > 0) && (i2 < N) && (j2 > 0) && (j2 < M) && (k2 > 0) && (k2 < Z)) { + if ((i3 > 0) && (i3 < N) && (j3 > 0) && (j3 < M) && (k3 > 0) && (k3 < Z)) { + normsum += Eucl_Vec_d[count]*pow((sharedmem[(k3)*sharedmemdim.x*sharedmemdim.y + (j3)*sharedmemdim.x+(i3)] - sharedmem[(k2)*sharedmemdim.x*sharedmemdim.y + j2*sharedmemdim.x+i2]), 2); + }} + count++; + }}} + if (normsum != 0) Weight = (expf(-normsum/h2)); + else Weight = 0.0; + Weight_norm += Weight; + value += sharedmem[k1*sharedmemdim.x*sharedmemdim.y + j1*sharedmemdim.x+i1]*Weight; + }}} /* BIG search window loop end*/ + + + if (Weight_norm != 0) Bd[idz*imagedim.x*imagedim.y + idy*imagedim.x + idx] = value/Weight_norm; + else Bd[idz*imagedim.x*imagedim.y + idy*imagedim.x + idx] = Ad[idz*imagedim.x*imagedim.y + idy*imagedim.x + idx]; + } + } /* boundary conditions end */ + } +} + +///////////////////////////////////////////////// +// HOST FUNCTION +extern "C" void NLM_GPU_kernel(float *A, float* B, float *Eucl_Vec, int N, int M, int Z, int dimension, int SearchW, int SimilW, int SearchW_real, float h2, float lambda) +{ + int deviceCount = -1; // number of devices + cudaGetDeviceCount(&deviceCount); + if (deviceCount == 0) { + fprintf(stderr, "No CUDA devices found\n"); + return; + } + +// cudaDeviceReset(); + + int padXY, SearchW_full, SimilW_full, blockWidth, blockHeight, blockDepth, nBlockX, nBlockY, nBlockZ, kernel_depth; + float *Ad, *Bd, *Eucl_Vec_d; + + if (dimension == 2) { + blockWidth = 16; + blockHeight = 16; + blockDepth = 1; + Z = 1; + kernel_depth = 0; + } + else { + blockWidth = 8; + blockHeight = 8; + blockDepth = 8; + kernel_depth = SearchW; + } + + // compute how many blocks are needed + nBlockX = ceil((float)N / (float)blockWidth); + nBlockY = ceil((float)M / (float)blockHeight); + nBlockZ = ceil((float)Z / (float)blockDepth); + + dim3 dimGrid(nBlockX,nBlockY*nBlockZ); + dim3 dimBlock(blockWidth, blockHeight, blockDepth); + dim3 imagedim(N,M,Z); + dim3 griddim(nBlockX,nBlockY,nBlockZ); + + dim3 kerneldim(SearchW,SearchW,kernel_depth); + dim3 sharedmemdim((SearchW*2)+blockWidth,(SearchW*2)+blockHeight,(kernel_depth*2)+blockDepth); + int sharedmemsize = sizeof(float)*sharedmemdim.x*sharedmemdim.y*sharedmemdim.z; + int updateperthread = ceil((float)(sharedmemdim.x*sharedmemdim.y*sharedmemdim.z)/(float)(blockWidth*blockHeight*blockDepth)); + float neighborsize = (2*SearchW+1)*(2*SearchW+1)*(2*kernel_depth+1); + + padXY = SearchW + 2*SimilW; /* padding sizes */ + + SearchW_full = 2*SearchW + 1; /* the full searching window size */ + SimilW_full = 2*SimilW + 1; /* the full similarity window size */ + + /*allocate space for images on device*/ + checkCudaErrors( cudaMalloc((void**)&Ad,N*M*Z*sizeof(float)) ); + checkCudaErrors( cudaMalloc((void**)&Bd,N*M*Z*sizeof(float)) ); + /*allocate space for vectors on device*/ + if (dimension == 2) { + checkCudaErrors( cudaMalloc((void**)&Eucl_Vec_d,SimilW_full*SimilW_full*sizeof(float)) ); + checkCudaErrors( cudaMemcpy(Eucl_Vec_d,Eucl_Vec,SimilW_full*SimilW_full*sizeof(float),cudaMemcpyHostToDevice) ); + } + else { + checkCudaErrors( cudaMalloc((void**)&Eucl_Vec_d,SimilW_full*SimilW_full*SimilW_full*sizeof(float)) ); + checkCudaErrors( cudaMemcpy(Eucl_Vec_d,Eucl_Vec,SimilW_full*SimilW_full*SimilW_full*sizeof(float),cudaMemcpyHostToDevice) ); + } + + /* copy data from the host to device */ + checkCudaErrors( cudaMemcpy(Ad,A,N*M*Z*sizeof(float),cudaMemcpyHostToDevice) ); + + // Run CUDA kernel here + NLM_kernel<<>>(Ad, Bd, Eucl_Vec_d, M, N, Z, SearchW, SimilW, SearchW_real, SearchW_full, SimilW_full, padXY, h2, lambda, imagedim, griddim, kerneldim, sharedmemdim, updateperthread, neighborsize); + + checkCudaErrors( cudaPeekAtLastError() ); +// gpuErrchk( cudaDeviceSynchronize() ); + + checkCudaErrors( cudaMemcpy(B,Bd,N*M*Z*sizeof(float),cudaMemcpyDeviceToHost) ); + cudaFree(Ad); cudaFree(Bd); cudaFree(Eucl_Vec_d); +} diff --git a/main_func/regularizers_GPU/NL_Regul/NLM_GPU_kernel.h b/main_func/regularizers_GPU/NL_Regul/NLM_GPU_kernel.h new file mode 100644 index 0000000..bc9d4a3 --- /dev/null +++ b/main_func/regularizers_GPU/NL_Regul/NLM_GPU_kernel.h @@ -0,0 +1,6 @@ +#ifndef __NLMREG_KERNELS_H_ +#define __NLMREG_KERNELS_H_ + +extern "C" void NLM_GPU_kernel(float *A, float* B, float *Eucl_Vec, int N, int M, int Z, int dimension, int SearchW, int SimilW, int SearchW_real, float denh2, float lambda); + +#endif -- cgit v1.2.3 From 6fb8f5d188ed31d7a7077cba8ab7aea17b25b8bf Mon Sep 17 00:00:00 2001 From: algol Date: Mon, 11 Sep 2017 12:23:44 +0100 Subject: StudentT fidelity added, can be skipped during pythonising --- main_func/FISTA_REC.m | 56 ++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 47 insertions(+), 9 deletions(-) diff --git a/main_func/FISTA_REC.m b/main_func/FISTA_REC.m index edccbe1..6987dca 100644 --- a/main_func/FISTA_REC.m +++ b/main_func/FISTA_REC.m @@ -19,7 +19,8 @@ function [X, output] = FISTA_REC(params) % - .iterFISTA (iterations for the main loop, default 40) % - .L_const (Lipschitz constant, default Power method) ) % - .X_ideal (ideal image, if given) -% - .weights (statisitcal weights, size of the sinogram) +% - .weights (statisitcal weights for the PWLS model, size of the sinogram) +% - .fidelity (use 'studentt' fidelity) % - .ROI (Region-of-interest, only if X_ideal is given) % - .initialize (a 'warm start' using SIRT method from ASTRA) %----------------Regularization choices------------------------ @@ -92,6 +93,15 @@ if (isfield(params,'weights')) else weights = ones(size(sino)); end +if (isfield(params,'fidelity')) + studentt = 0; + if (strcmp(params.fidelity,'studentt') == 1) + studentt = 1; + lambdaR_L1 = 0; + end +else + studentt = 0; +end if (isfield(params,'L_const')) L_const = params.L_const; else @@ -357,12 +367,25 @@ if (subsets == 0) vec = squeeze(vec(:,1,:)); end r = r_x - (1./L_const).*vec; - else - % no ring removal + objective(i) = (0.5*sum(residual(:).^2)); % for the objective function output + else + if (studentt == 1) + % artifacts removal with Students t penalty + residual = weights.*(sino_updt - sino); + for kkk = 1:SlicesZ + res_vec = reshape(residual(:,:,kkk), Detectors*anglesNumb, 1); % 1D vectorized sinogram + %s = 100; + %gr = (2)*res_vec./(s*2 + conj(res_vec).*res_vec); + [ff, gr] = studentst(res_vec, 1); + residual(:,:,kkk) = reshape(gr, Detectors, anglesNumb); + end + objective(i) = ff; % for the objective function output + else + % no ring removal (LS model) residual = weights.*(sino_updt - sino); - end - - objective(i) = (0.5*sum(residual(:).^2)); % for the objective function output + objective(i) = (0.5*sum(residual(:).^2)); % for the objective function output + end + end % if the geometry is parallel use slice-by-slice projection-backprojection routine if (strcmp(proj_geom.type,'parallel') || strcmp(proj_geom.type,'parallel3d')) @@ -549,9 +572,24 @@ else end r = r_x - (1./L_const).*vec; else - [sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geomSUB, vol_geom); - % no ring removal - residual = squeeze(weights(:,CurrSubIndeces,:)).*(sino_updt - squeeze(sino(:,CurrSubIndeces,:))); + [sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geomSUB, vol_geom); + + if (studentt == 1) + % artifacts removal with Students t penalty + residual = squeeze(weights(:,CurrSubIndeces,:)).*(sino_updt - squeeze(sino(:,CurrSubIndeces,:))); + + for kkk = 1:SlicesZ + res_vec = reshape(residual(:,:,kkk), Detectors*numProjSub, 1); % 1D vectorized sinogram + %s = 100; + %gr = (2)*res_vec./(s*2 + conj(res_vec).*res_vec); + [ff, gr] = studentst(res_vec, 1); + residual(:,:,kkk) = reshape(gr, Detectors, numProjSub); + end + objective(i) = ff; % for the objective function output + else + % no ring removal (LS model) + residual = squeeze(weights(:,CurrSubIndeces,:)).*(sino_updt - squeeze(sino(:,CurrSubIndeces,:))); + end end [id, x_temp] = astra_create_backprojection3d_cuda(residual, proj_geomSUB, vol_geom); -- cgit v1.2.3 From 33da5e128e51e5adc1f6085b4772b55b2cd76e2e Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 2 Aug 2017 16:51:14 +0100 Subject: added utils --- main_func/regularizers_CPU/FGP_TV_core.h | 30 ++++++++++++++++++++++- main_func/regularizers_CPU/LLT_model_core.h | 3 ++- main_func/regularizers_CPU/SplitBregman_TV_core.h | 28 ++++++++++++++++++++- main_func/regularizers_CPU/TGV_PD_core.h | 3 ++- 4 files changed, 60 insertions(+), 4 deletions(-) diff --git a/main_func/regularizers_CPU/FGP_TV_core.h b/main_func/regularizers_CPU/FGP_TV_core.h index 697fd84..d256c9a 100644 --- a/main_func/regularizers_CPU/FGP_TV_core.h +++ b/main_func/regularizers_CPU/FGP_TV_core.h @@ -23,8 +23,36 @@ limitations under the License. #include #include #include "omp.h" +#include "utils.h" -float copyIm(float *A, float *B, int dimX, int dimY, int dimZ); +/* C-OMP implementation of FGP-TV [1] denoising/regularization model (2D/3D case) +* +* Input Parameters: +* 1. Noisy image/volume [REQUIRED] +* 2. lambda - regularization parameter [REQUIRED] +* 3. Number of iterations [OPTIONAL parameter] +* 4. eplsilon: tolerance constant [OPTIONAL parameter] +* 5. TV-type: 'iso' or 'l1' [OPTIONAL parameter] +* +* Output: +* [1] Filtered/regularized image +* [2] last function value +* +* Example of image denoising: +* figure; +* Im = double(imread('lena_gray_256.tif'))/255; % loading image +* u0 = Im + .05*randn(size(Im)); % adding noise +* u = FGP_TV(single(u0), 0.05, 100, 1e-04); +* +* to compile with OMP support: mex FGP_TV.c CFLAGS="\$CFLAGS -fopenmp -Wall -std=c99" LDFLAGS="\$LDFLAGS -fopenmp" +* This function is based on the Matlab's code and paper by +* [1] Amir Beck and Marc Teboulle, "Fast Gradient-Based Algorithms for Constrained Total Variation Image Denoising and Deblurring Problems" +* +* D. Kazantsev, 2016-17 +* +*/ + +//float copyIm(float *A, float *B, int dimX, int dimY, int dimZ); float Obj_func2D(float *A, float *D, float *R1, float *R2, float lambda, int dimX, int dimY); float Grad_func2D(float *P1, float *P2, float *D, float *R1, float *R2, float lambda, int dimX, int dimY); float Proj_func2D(float *P1, float *P2, int methTV, int dimX, int dimY); diff --git a/main_func/regularizers_CPU/LLT_model_core.h b/main_func/regularizers_CPU/LLT_model_core.h index 10f52fe..560bb9c 100644 --- a/main_func/regularizers_CPU/LLT_model_core.h +++ b/main_func/regularizers_CPU/LLT_model_core.h @@ -23,6 +23,7 @@ limitations under the License. #include #include #include "omp.h" +#include "utils.h" #define EPS 0.01 @@ -36,4 +37,4 @@ float div_upd3D(float *U0, float *U, float *D1, float *D2, float *D3, unsigned s float calcMap(float *U, unsigned short *Map, int dimX, int dimY, int dimZ); float cleanMap(unsigned short *Map, int dimX, int dimY, int dimZ); -float copyIm(float *A, float *U, int dimX, int dimY, int dimZ); +//float copyIm(float *A, float *U, int dimX, int dimY, int dimZ); diff --git a/main_func/regularizers_CPU/SplitBregman_TV_core.h b/main_func/regularizers_CPU/SplitBregman_TV_core.h index a7aaabb..78aef09 100644 --- a/main_func/regularizers_CPU/SplitBregman_TV_core.h +++ b/main_func/regularizers_CPU/SplitBregman_TV_core.h @@ -23,7 +23,33 @@ limitations under the License. #include #include "omp.h" -float copyIm(float *A, float *B, int dimX, int dimY, int dimZ); +#include "utils.h" + +/* C-OMP implementation of Split Bregman - TV denoising-regularization model (2D/3D) +* +* Input Parameters: +* 1. Noisy image/volume +* 2. lambda - regularization parameter +* 3. Number of iterations [OPTIONAL parameter] +* 4. eplsilon - tolerance constant [OPTIONAL parameter] +* 5. TV-type: 'iso' or 'l1' [OPTIONAL parameter] +* +* Output: +* Filtered/regularized image +* +* Example: +* figure; +* Im = double(imread('lena_gray_256.tif'))/255; % loading image +* u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0; +* u = SplitBregman_TV(single(u0), 10, 30, 1e-04); +* +* to compile with OMP support: mex SplitBregman_TV.c CFLAGS="\$CFLAGS -fopenmp -Wall -std=c99" LDFLAGS="\$LDFLAGS -fopenmp" +* References: +* The Split Bregman Method for L1 Regularized Problems, by Tom Goldstein and Stanley Osher. +* D. Kazantsev, 2016* +*/ + +//float copyIm(float *A, float *B, int dimX, int dimY, int dimZ); float gauss_seidel2D(float *U, float *A, float *Dx, float *Dy, float *Bx, float *By, int dimX, int dimY, float lambda, float mu); float updDxDy_shrinkAniso2D(float *U, float *Dx, float *Dy, float *Bx, float *By, int dimX, int dimY, float lambda); float updDxDy_shrinkIso2D(float *U, float *Dx, float *Dy, float *Bx, float *By, int dimX, int dimY, float lambda); diff --git a/main_func/regularizers_CPU/TGV_PD_core.h b/main_func/regularizers_CPU/TGV_PD_core.h index 04ba95c..cbe7ea4 100644 --- a/main_func/regularizers_CPU/TGV_PD_core.h +++ b/main_func/regularizers_CPU/TGV_PD_core.h @@ -23,6 +23,7 @@ limitations under the License. #include #include #include "omp.h" +#include "utils.h" /* 2D functions */ float DualP_2D(float *U, float *V1, float *V2, float *P1, float *P2, int dimX, int dimY, int dimZ, float sigma); @@ -32,4 +33,4 @@ float ProjQ_2D(float *Q1, float *Q2, float *Q3, int dimX, int dimY, int dimZ, fl float DivProjP_2D(float *U, float *A, float *P1, float *P2, int dimX, int dimY, int dimZ, float lambda, float tau); float UpdV_2D(float *V1, float *V2, float *P1, float *P2, float *Q1, float *Q2, float *Q3, int dimX, int dimY, int dimZ, float tau); float newU(float *U, float *U_old, int dimX, int dimY, int dimZ); -float copyIm(float *A, float *U, int dimX, int dimY, int dimZ); +//float copyIm(float *A, float *U, int dimX, int dimY, int dimZ); -- cgit v1.2.3 From 068380a2a2c84b46e486f59c397329f53a2eae1f Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 2 Aug 2017 16:51:37 +0100 Subject: Added utils.c utils.h a few regularizers defined the same copyIm function. I moved it into this new common utils. --- main_func/regularizers_CPU/utils.c | 29 +++++++++++++++++++++++++++++ main_func/regularizers_CPU/utils.h | 27 +++++++++++++++++++++++++++ 2 files changed, 56 insertions(+) create mode 100644 main_func/regularizers_CPU/utils.c create mode 100644 main_func/regularizers_CPU/utils.h diff --git a/main_func/regularizers_CPU/utils.c b/main_func/regularizers_CPU/utils.c new file mode 100644 index 0000000..0e83d2c --- /dev/null +++ b/main_func/regularizers_CPU/utils.c @@ -0,0 +1,29 @@ +/* +This work is part of the Core Imaging Library developed by +Visual Analytics and Imaging System Group of the Science Technology +Facilities Council, STFC + +Copyright 2017 Daniil Kazanteev +Copyright 2017 Srikanth Nagella, Edoardo Pasca + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "utils.h" + +/* Copy Image */ +float copyIm(float *A, float *U, int dimX, int dimY, int dimZ) +{ + int j; +#pragma omp parallel for shared(A, U) private(j) + for (j = 0; j +#include +#include +#include +#include +#include "omp.h" + +float copyIm(float *A, float *U, int dimX, int dimY, int dimZ); -- cgit v1.2.3 From 5aaa46237fbf0a6bb008fe81576cabc61e3b1fce Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 3 Aug 2017 15:26:29 +0100 Subject: Added Python modules Matlab2Python_utils.cpp contains utilities for handling numpy arrays. Together with setup_test.py it creates a functional module for testing. fista_module.cpp and setup.py are meant for the real fista module. --- src/Python/Matlab2Python_utils.cpp | 206 ++++++++++++++++++++++++ src/Python/fista_module.cpp | 315 +++++++++++++++++++++++++++++++++++++ src/Python/setup.py | 58 +++++++ src/Python/setup_test.py | 58 +++++++ 4 files changed, 637 insertions(+) create mode 100644 src/Python/Matlab2Python_utils.cpp create mode 100644 src/Python/fista_module.cpp create mode 100644 src/Python/setup.py create mode 100644 src/Python/setup_test.py diff --git a/src/Python/Matlab2Python_utils.cpp b/src/Python/Matlab2Python_utils.cpp new file mode 100644 index 0000000..138e8da --- /dev/null +++ b/src/Python/Matlab2Python_utils.cpp @@ -0,0 +1,206 @@ +/* +This work is part of the Core Imaging Library developed by +Visual Analytics and Imaging System Group of the Science Technology +Facilities Council, STFC + +Copyright 2017 Daniil Kazanteev +Copyright 2017 Srikanth Nagella, Edoardo Pasca + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION + +#include +#include + +#include +#include +#include "boost/tuple/tuple.hpp" + +#if defined(_WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(_WIN64) +#include +// this trick only if compiler is MSVC +__if_not_exists(uint8_t) { typedef __int8 uint8_t; } +__if_not_exists(uint16_t) { typedef __int8 uint16_t; } +#endif + +namespace bp = boost::python; +namespace np = boost::python::numpy; + +/*! in the Matlab implementation this is called as +void mexFunction( +int nlhs, mxArray *plhs[], +int nrhs, const mxArray *prhs[]) +where: +prhs Array of pointers to the INPUT mxArrays +nrhs int number of INPUT mxArrays + +nlhs Array of pointers to the OUTPUT mxArrays +plhs int number of OUTPUT mxArrays + +*********************************************************** + +*********************************************************** +double mxGetScalar(const mxArray *pm); +args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. +Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray. In C, mxGetScalar returns a double. +*********************************************************** +char *mxArrayToString(const mxArray *array_ptr); +args: array_ptr Pointer to mxCHAR array. +Returns: C-style string. Returns NULL on failure. Possible reasons for failure include out of memory and specifying an array that is not an mxCHAR array. +Description: Call mxArrayToString to copy the character data of an mxCHAR array into a C-style string. +*********************************************************** +mxClassID mxGetClassID(const mxArray *pm); +args: pm Pointer to an mxArray +Returns: Numeric identifier of the class (category) of the mxArray that pm points to.For user-defined types, +mxGetClassId returns a unique value identifying the class of the array contents. +Use mxIsClass to determine whether an array is of a specific user-defined type. + +mxClassID Value MATLAB Type MEX Type C Primitive Type +mxINT8_CLASS int8 int8_T char, byte +mxUINT8_CLASS uint8 uint8_T unsigned char, byte +mxINT16_CLASS int16 int16_T short +mxUINT16_CLASS uint16 uint16_T unsigned short +mxINT32_CLASS int32 int32_T int +mxUINT32_CLASS uint32 uint32_T unsigned int +mxINT64_CLASS int64 int64_T long long +mxUINT64_CLASS uint64 uint64_T unsigned long long +mxSINGLE_CLASS single float float +mxDOUBLE_CLASS double double double + +**************************************************************** +double *mxGetPr(const mxArray *pm); +args: pm Pointer to an mxArray of type double +Returns: Pointer to the first element of the real data. Returns NULL in C (0 in Fortran) if there is no real data. +**************************************************************** +mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, +mxClassID classid, mxComplexity ComplexFlag); +args: ndimNumber of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2. +dims Dimensions array. Each element in the dimensions array contains the size of the array in that dimension. +For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array. +classid Identifier for the class of the array, which determines the way the numerical data is represented in memory. +For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer. +ComplexFlag If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran). +Returns: Pointer to the created mxArray, if successful. If unsuccessful in a standalone (non-MEX file) application, returns NULL in C (0 in Fortran). +If unsuccessful in a MEX file, the MEX file terminates and returns control to the MATLAB prompt. The function is unsuccessful when there is not +enough free heap space to create the mxArray. +*/ + +void mexErrMessageText(char* text) { + std::cerr << text << std::endl; +} + +/* +double mxGetScalar(const mxArray *pm); +args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. +Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray. In C, mxGetScalar returns a double. +*/ + +template +double mxGetScalar(const np::ndarray plh) { + return (double)bp::extract(plh[0]); +} + + + +template +T * mxGetData(const np::ndarray pm) { + //args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. + //Returns: Pointer to the value of the first real(nonimaginary) element of the mxArray.In C, mxGetScalar returns a double. + /*Access the numpy array pointer: + char * get_data() const; + Returns: Array’s raw data pointer as a char + Note: This returns char so stride math works properly on it.User will have to reinterpret_cast it. + probably this would work. + A = reinterpret_cast(prhs[0]); + */ + return reinterpret_cast(prhs[0]); +} + +template +np::ndarray zeros(int dims , int * dim_array, T el) { + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + np::dtype dtype = np::dtype::get_builtin(); + np::ndarray zz = np::zeros(shape, dtype); + return zz; +} + + +bp::list mexFunction( np::ndarray input ) { + int number_of_dims = input.get_nd(); + int dim_array[3]; + + dim_array[0] = input.shape(0); + dim_array[1] = input.shape(1); + if (number_of_dims == 2) { + dim_array[2] = -1; + } + else { + dim_array[2] = input.shape(2); + } + + /**************************************************************************/ + np::ndarray zz = zeros(3, dim_array, (int)0); + np::ndarray fzz = zeros(3, dim_array, (float)0); + /**************************************************************************/ + + int * A = reinterpret_cast( input.get_data() ); + int * B = reinterpret_cast( zz.get_data() ); + float * C = reinterpret_cast(fzz.get_data()); + + //Copy data and cast + for (int i = 0; i < dim_array[0]; i++) { + for (int j = 0; j < dim_array[1]; j++) { + for (int k = 0; k < dim_array[2]; k++) { + int index = k + dim_array[2] * j + dim_array[2] * dim_array[1] * i; + int val = (*(A + index)); + float fval = (float)val; + std::memcpy(B + index , &val, sizeof(int)); + std::memcpy(C + index , &fval, sizeof(float)); + } + } + } + + + bp::list result; + + result.append(number_of_dims); + result.append(dim_array[0]); + result.append(dim_array[1]); + result.append(dim_array[2]); + result.append(zz); + result.append(fzz); + + //result.append(tup); + return result; + +} + + +BOOST_PYTHON_MODULE(fista) +{ + np::initialize(); + + //To specify that this module is a package + bp::object package = bp::scope(); + package.attr("__path__") = "fista"; + + np::dtype dt1 = np::dtype::get_builtin(); + np::dtype dt2 = np::dtype::get_builtin(); + + //import_array(); + //numpy_boost_python_register_type(); + //numpy_boost_python_register_type(); + //numpy_boost_python_register_type(); + //numpy_boost_python_register_type(); + def("mexFunction", mexFunction); +} \ No newline at end of file diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp new file mode 100644 index 0000000..5344083 --- /dev/null +++ b/src/Python/fista_module.cpp @@ -0,0 +1,315 @@ +/* +This work is part of the Core Imaging Library developed by +Visual Analytics and Imaging System Group of the Science Technology +Facilities Council, STFC + +Copyright 2017 Daniil Kazanteev +Copyright 2017 Srikanth Nagella, Edoardo Pasca + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION + +#include +#include + +#include +#include +#include "boost/tuple/tuple.hpp" + +// include the regularizers +#include "FGP_TV_core.h" +#include "LLT_model_core.h" +#include "PatchBased_Regul_core.h" +#include "SplitBregman_TV_core.h" +#include "TGV_PD_core.h" + +#if defined(_WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(_WIN64) +#include +// this trick only if compiler is MSVC +__if_not_exists(uint8_t) { typedef __int8 uint8_t; } +__if_not_exists(uint16_t) { typedef __int8 uint16_t; } +#endif + +namespace bp = boost::python; +namespace np = boost::python::numpy; + + +/*! in the Matlab implementation this is called as +void mexFunction( +int nlhs, mxArray *plhs[], +int nrhs, const mxArray *prhs[]) +where: +prhs Array of pointers to the INPUT mxArrays +nrhs int number of INPUT mxArrays + +nlhs Array of pointers to the OUTPUT mxArrays +plhs int number of OUTPUT mxArrays + +*********************************************************** +mxGetData +args: pm Pointer to an mxArray +Returns: Pointer to the first element of the real data. Returns NULL in C (0 in Fortran) if there is no real data. +*********************************************************** +double mxGetScalar(const mxArray *pm); +args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. +Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray. In C, mxGetScalar returns a double. +*********************************************************** +char *mxArrayToString(const mxArray *array_ptr); +args: array_ptr Pointer to mxCHAR array. +Returns: C-style string. Returns NULL on failure. Possible reasons for failure include out of memory and specifying an array that is not an mxCHAR array. +Description: Call mxArrayToString to copy the character data of an mxCHAR array into a C-style string. +*********************************************************** +mxClassID mxGetClassID(const mxArray *pm); +args: pm Pointer to an mxArray +Returns: Numeric identifier of the class (category) of the mxArray that pm points to.For user-defined types, +mxGetClassId returns a unique value identifying the class of the array contents. +Use mxIsClass to determine whether an array is of a specific user-defined type. + +mxClassID Value MATLAB Type MEX Type C Primitive Type +mxINT8_CLASS int8 int8_T char, byte +mxUINT8_CLASS uint8 uint8_T unsigned char, byte +mxINT16_CLASS int16 int16_T short +mxUINT16_CLASS uint16 uint16_T unsigned short +mxINT32_CLASS int32 int32_T int +mxUINT32_CLASS uint32 uint32_T unsigned int +mxINT64_CLASS int64 int64_T long long +mxUINT64_CLASS uint64 uint64_T unsigned long long +mxSINGLE_CLASS single float float +mxDOUBLE_CLASS double double double + +**************************************************************** +double *mxGetPr(const mxArray *pm); +args: pm Pointer to an mxArray of type double +Returns: Pointer to the first element of the real data. Returns NULL in C (0 in Fortran) if there is no real data. +**************************************************************** +mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, mxClassID classid, mxComplexity ComplexFlag); +args: ndim: Number of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2. + dims: Dimensions array. Each element in the dimensions array contains the size of the array in that dimension. + For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array. + classid: Identifier for the class of the array, which determines the way the numerical data is represented in memory. + For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer. + ComplexFlag: If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). + Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran). + +Returns: Pointer to the created mxArray, if successful. If unsuccessful in a standalone (non-MEX file) application, returns NULL in C (0 in Fortran). + If unsuccessful in a MEX file, the MEX file terminates and returns control to the MATLAB prompt. The function is unsuccessful when there is not + enough free heap space to create the mxArray. +*/ + +template +np::ndarray zeros(int dims, int * dim_array, T el) { + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + np::dtype dtype = np::dtype::get_builtin(); + np::ndarray zz = np::zeros(shape, dtype); + return zz; +} + + +bp::list SplitBregman_TV(np::ndarray input, double d_mu, , int niterations, double d_epsil, int TV_type) { + /* C-OMP implementation of Split Bregman - TV denoising-regularization model (2D/3D) + * + * Input Parameters: + * 1. Noisy image/volume + * 2. lambda - regularization parameter + * 3. Number of iterations [OPTIONAL parameter] + * 4. eplsilon - tolerance constant [OPTIONAL parameter] + * 5. TV-type: 'iso' or 'l1' [OPTIONAL parameter] + * + * Output: + * Filtered/regularized image + * + * All sanity checks and default values are set in Python + */ + int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV; + const int dim_array[3]; + float *A, *U = NULL, *U_old = NULL, *Dx = NULL, *Dy = NULL, *Dz = NULL, *Bx = NULL, *By = NULL, *Bz = NULL, lambda, mu, epsil, re, re1, re_old; + + //number_of_dims = mxGetNumberOfDimensions(prhs[0]); + //dim_array = mxGetDimensions(prhs[0]); + number_of_dims = input.get_nd(); + + dim_array[0] = input.shape(0); + dim_array[1] = input.shape(1); + if (number_of_dims == 2) { + dim_array[2] = -11; + } + else { + dim_array[2] = input.shape(2); + } + + /*Handling Matlab input data*/ + //if ((nrhs < 2) || (nrhs > 5)) mexErrMsgTxt("At least 2 parameters is required: Image(2D/3D), Regularization parameter. The full list of parameters: Image(2D/3D), Regularization parameter, iterations number, tolerance, penalty type ('iso' or 'l1')"); + + /*Handling Matlab input data*/ + //A = (float *)mxGetData(prhs[0]); /*noisy image (2D/3D) */ + A = reinterpret_cast(input.get_data()); + + + //mu = (float)mxGetScalar(prhs[1]); /* regularization parameter */ + mu = (float)d_mu; + //iter = 35; /* default iterations number */ + iter = niterations; + //epsil = 0.0001; /* default tolerance constant */ + epsil = (float)d_epsil; + //methTV = 0; /* default isotropic TV penalty */ + methTV = TV_type; + //if ((nrhs == 3) || (nrhs == 4) || (nrhs == 5)) iter = (int)mxGetScalar(prhs[2]); /* iterations number */ + //if ((nrhs == 4) || (nrhs == 5)) epsil = (float)mxGetScalar(prhs[3]); /* tolerance constant */ + //if (nrhs == 5) { + // char *penalty_type; + // penalty_type = mxArrayToString(prhs[4]); /* choosing TV penalty: 'iso' or 'l1', 'iso' is the default */ + // if ((strcmp(penalty_type, "l1") != 0) && (strcmp(penalty_type, "iso") != 0)) mexErrMsgTxt("Choose TV type: 'iso' or 'l1',"); + // if (strcmp(penalty_type, "l1") == 0) methTV = 1; /* enable 'l1' penalty */ + // mxFree(penalty_type); + //} + //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input image must be in a single precision"); } + + lambda = 2.0f*mu; + count = 1; + re_old = 0.0f; + /*Handling Matlab output data*/ + dimY = dim_array[0]; dimX = dim_array[1]; dimZ = dim_array[2]; + + if (number_of_dims == 2) { + dimZ = 1; /*2D case*/ + /* + mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, mxClassID classid, mxComplexity ComplexFlag); +args: ndim: Number of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2. + dims: Dimensions array. Each element in the dimensions array contains the size of the array in that dimension. + For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array. + classid: Identifier for the class of the array, which determines the way the numerical data is represented in memory. + For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer. + ComplexFlag: If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). + Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran). + + mxCreateNumericArray initializes all its real data elements to 0. +*/ + +/* + U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Dx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Dy = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Bx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + By = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); +*/ + //U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + U = A = reinterpret_castinput.get_data(); + U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Dx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Dy = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Bx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + By = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + copyIm(A, U, dimX, dimY, dimZ); /*initialize */ + + /* begin outer SB iterations */ + for (ll = 0; ll 4) break; + + /* check that the residual norm is decreasing */ + if (ll > 2) { + if (re > re_old) break; + } + re_old = re; + /*printf("%f %i %i \n", re, ll, count); */ + + /*copyIm(U_old, U, dimX, dimY, dimZ); */ + } + printf("SB iterations stopped at iteration: %i\n", ll); + } + if (number_of_dims == 3) { + U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Dx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Dy = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Dz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Bx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + By = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Bz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + + copyIm(A, U, dimX, dimY, dimZ); /*initialize */ + + /* begin outer SB iterations */ + for (ll = 0; ll 4) break; + + /* check that the residual norm is decreasing */ + if (ll > 2) { + if (re > re_old) break; + } + /*printf("%f %i %i \n", re, ll, count); */ + re_old = re; + } + printf("SB iterations stopped at iteration: %i\n", ll); + } + bp::list result; + return result; +} + + +BOOST_PYTHON_MODULE(fista) +{ + np::initialize(); + + //To specify that this module is a package + bp::object package = bp::scope(); + package.attr("__path__") = "fista"; + + np::dtype dt1 = np::dtype::get_builtin(); + np::dtype dt2 = np::dtype::get_builtin(); + + + def("mexFunction", mexFunction); +} \ No newline at end of file diff --git a/src/Python/setup.py b/src/Python/setup.py new file mode 100644 index 0000000..ffb9c02 --- /dev/null +++ b/src/Python/setup.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python + +import setuptools +from distutils.core import setup +from distutils.extension import Extension +from Cython.Distutils import build_ext + +import os +import sys +import numpy +import platform + +cil_version=os.environ['CIL_VERSION'] +if cil_version == '': + print("Please set the environmental variable CIL_VERSION") + sys.exit(1) + +library_include_path = "" +library_lib_path = "" +try: + library_include_path = os.environ['LIBRARY_INC'] + library_lib_path = os.environ['LIBRARY_LIB'] +except: + library_include_path = os.environ['PREFIX']+'/include' + pass + +extra_include_dirs = [numpy.get_include(), library_include_path] +extra_library_dirs = [library_include_path+"/../lib", "C:\\Apps\\Miniconda2\\envs\\cil27\\Library\\lib"] +extra_compile_args = ['-fopenmp','-O2', '-funsigned-char', '-Wall', '-std=c++0x'] +extra_libraries = [] +if platform.system() == 'Windows': + extra_compile_args[0:] = ['/DWIN32','/EHsc','/DBOOST_ALL_NO_LIB'] + extra_include_dirs += ["..\\ContourTree\\", "..\\win32\\" , "..\\Core\\","."] + if sys.version_info.major == 3 : + extra_libraries += ['boost_python3-vc140-mt-1_64', 'boost_numpy3-vc140-mt-1_64'] + else: + extra_libraries += ['boost_python-vc90-mt-1_64', 'boost_numpy-vc90-mt-1_64'] +else: + extra_include_dirs += ["../ContourTree/", "../Core/","."] + if sys.version_info.major == 3: + extra_libraries += ['boost_python3', 'boost_numpy3','gomp'] + else: + extra_libraries += ['boost_python', 'boost_numpy','gomp'] + +setup( + name='ccpi', + description='CCPi Core Imaging Library - FISTA Reconstruction Module', + version=cil_version, + cmdclass = {'build_ext': build_ext}, + ext_modules = [Extension("fista", + sources=[ "Matlab2Python_utils.cpp", + ], + include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), + + ], + zip_safe = False, + packages = {'ccpi','ccpi.reconstruction'}, +) diff --git a/src/Python/setup_test.py b/src/Python/setup_test.py new file mode 100644 index 0000000..ffb9c02 --- /dev/null +++ b/src/Python/setup_test.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python + +import setuptools +from distutils.core import setup +from distutils.extension import Extension +from Cython.Distutils import build_ext + +import os +import sys +import numpy +import platform + +cil_version=os.environ['CIL_VERSION'] +if cil_version == '': + print("Please set the environmental variable CIL_VERSION") + sys.exit(1) + +library_include_path = "" +library_lib_path = "" +try: + library_include_path = os.environ['LIBRARY_INC'] + library_lib_path = os.environ['LIBRARY_LIB'] +except: + library_include_path = os.environ['PREFIX']+'/include' + pass + +extra_include_dirs = [numpy.get_include(), library_include_path] +extra_library_dirs = [library_include_path+"/../lib", "C:\\Apps\\Miniconda2\\envs\\cil27\\Library\\lib"] +extra_compile_args = ['-fopenmp','-O2', '-funsigned-char', '-Wall', '-std=c++0x'] +extra_libraries = [] +if platform.system() == 'Windows': + extra_compile_args[0:] = ['/DWIN32','/EHsc','/DBOOST_ALL_NO_LIB'] + extra_include_dirs += ["..\\ContourTree\\", "..\\win32\\" , "..\\Core\\","."] + if sys.version_info.major == 3 : + extra_libraries += ['boost_python3-vc140-mt-1_64', 'boost_numpy3-vc140-mt-1_64'] + else: + extra_libraries += ['boost_python-vc90-mt-1_64', 'boost_numpy-vc90-mt-1_64'] +else: + extra_include_dirs += ["../ContourTree/", "../Core/","."] + if sys.version_info.major == 3: + extra_libraries += ['boost_python3', 'boost_numpy3','gomp'] + else: + extra_libraries += ['boost_python', 'boost_numpy','gomp'] + +setup( + name='ccpi', + description='CCPi Core Imaging Library - FISTA Reconstruction Module', + version=cil_version, + cmdclass = {'build_ext': build_ext}, + ext_modules = [Extension("fista", + sources=[ "Matlab2Python_utils.cpp", + ], + include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), + + ], + zip_safe = False, + packages = {'ccpi','ccpi.reconstruction'}, +) -- cgit v1.2.3 From 22d41e596544668e71f3abef321d48f0a54f0f53 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 3 Aug 2017 16:52:20 +0100 Subject: added FGP_TV wrapper --- src/Python/fista_module.cpp | 576 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 473 insertions(+), 103 deletions(-) diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp index 5344083..2492884 100644 --- a/src/Python/fista_module.cpp +++ b/src/Python/fista_module.cpp @@ -26,12 +26,10 @@ limitations under the License. #include #include "boost/tuple/tuple.hpp" -// include the regularizers -#include "FGP_TV_core.h" -#include "LLT_model_core.h" -#include "PatchBased_Regul_core.h" #include "SplitBregman_TV_core.h" -#include "TGV_PD_core.h" +#include "FGP_TV_core.h" + + #if defined(_WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(_WIN64) #include @@ -43,7 +41,6 @@ __if_not_exists(uint16_t) { typedef __int8 uint16_t; } namespace bp = boost::python; namespace np = boost::python::numpy; - /*! in the Matlab implementation this is called as void mexFunction( int nlhs, mxArray *plhs[], @@ -56,9 +53,7 @@ nlhs Array of pointers to the OUTPUT mxArrays plhs int number of OUTPUT mxArrays *********************************************************** -mxGetData -args: pm Pointer to an mxArray -Returns: Pointer to the first element of the real data. Returns NULL in C (0 in Fortran) if there is no real data. + *********************************************************** double mxGetScalar(const mxArray *pm); args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. @@ -92,77 +87,143 @@ double *mxGetPr(const mxArray *pm); args: pm Pointer to an mxArray of type double Returns: Pointer to the first element of the real data. Returns NULL in C (0 in Fortran) if there is no real data. **************************************************************** -mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, mxClassID classid, mxComplexity ComplexFlag); -args: ndim: Number of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2. - dims: Dimensions array. Each element in the dimensions array contains the size of the array in that dimension. - For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array. - classid: Identifier for the class of the array, which determines the way the numerical data is represented in memory. - For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer. - ComplexFlag: If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). - Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran). - +mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, +mxClassID classid, mxComplexity ComplexFlag); +args: ndimNumber of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2. +dims Dimensions array. Each element in the dimensions array contains the size of the array in that dimension. +For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array. +classid Identifier for the class of the array, which determines the way the numerical data is represented in memory. +For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer. +ComplexFlag If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran). Returns: Pointer to the created mxArray, if successful. If unsuccessful in a standalone (non-MEX file) application, returns NULL in C (0 in Fortran). - If unsuccessful in a MEX file, the MEX file terminates and returns control to the MATLAB prompt. The function is unsuccessful when there is not - enough free heap space to create the mxArray. +If unsuccessful in a MEX file, the MEX file terminates and returns control to the MATLAB prompt. The function is unsuccessful when there is not +enough free heap space to create the mxArray. +*/ + +void mexErrMessageText(char* text) { + std::cerr << text << std::endl; +} + +/* +double mxGetScalar(const mxArray *pm); +args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. +Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray. In C, mxGetScalar returns a double. */ template -np::ndarray zeros(int dims, int * dim_array, T el) { - bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); - np::dtype dtype = np::dtype::get_builtin(); - np::ndarray zz = np::zeros(shape, dtype); - return zz; +double mxGetScalar(const np::ndarray plh) { + return (double)bp::extract(plh[0]); } -bp::list SplitBregman_TV(np::ndarray input, double d_mu, , int niterations, double d_epsil, int TV_type) { - /* C-OMP implementation of Split Bregman - TV denoising-regularization model (2D/3D) - * - * Input Parameters: - * 1. Noisy image/volume - * 2. lambda - regularization parameter - * 3. Number of iterations [OPTIONAL parameter] - * 4. eplsilon - tolerance constant [OPTIONAL parameter] - * 5. TV-type: 'iso' or 'l1' [OPTIONAL parameter] - * - * Output: - * Filtered/regularized image - * - * All sanity checks and default values are set in Python + +template +T * mxGetData(const np::ndarray pm) { + //args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. + //Returns: Pointer to the value of the first real(nonimaginary) element of the mxArray.In C, mxGetScalar returns a double. + /*Access the numpy array pointer: + char * get_data() const; + Returns: Array’s raw data pointer as a char + Note: This returns char so stride math works properly on it.User will have to reinterpret_cast it. + probably this would work. + A = reinterpret_cast(prhs[0]); */ + return reinterpret_cast(prhs[0]); +} + + + + +bp::list mexFunction(np::ndarray input) { + int number_of_dims = input.get_nd(); + int dim_array[3]; + + dim_array[0] = input.shape(0); + dim_array[1] = input.shape(1); + if (number_of_dims == 2) { + dim_array[2] = -1; + } + else { + dim_array[2] = input.shape(2); + } + + /**************************************************************************/ + np::ndarray zz = zeros(3, dim_array, (int)0); + np::ndarray fzz = zeros(3, dim_array, (float)0); + /**************************************************************************/ + + int * A = reinterpret_cast(input.get_data()); + int * B = reinterpret_cast(zz.get_data()); + float * C = reinterpret_cast(fzz.get_data()); + + //Copy data and cast + for (int i = 0; i < dim_array[0]; i++) { + for (int j = 0; j < dim_array[1]; j++) { + for (int k = 0; k < dim_array[2]; k++) { + int index = k + dim_array[2] * j + dim_array[2] * dim_array[1] * i; + int val = (*(A + index)); + float fval = (float)val; + std::memcpy(B + index, &val, sizeof(int)); + std::memcpy(C + index, &fval, sizeof(float)); + } + } + } + + + bp::list result; + + result.append(number_of_dims); + result.append(dim_array[0]); + result.append(dim_array[1]); + result.append(dim_array[2]); + result.append(zz); + result.append(fzz); + + //result.append(tup); + return result; + +} + +bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int methTV) { + + // the result is in the following list + bp::list result; + int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV; - const int dim_array[3]; + const int *dim_array; float *A, *U = NULL, *U_old = NULL, *Dx = NULL, *Dy = NULL, *Dz = NULL, *Bx = NULL, *By = NULL, *Bz = NULL, lambda, mu, epsil, re, re1, re_old; - + //number_of_dims = mxGetNumberOfDimensions(prhs[0]); //dim_array = mxGetDimensions(prhs[0]); - number_of_dims = input.get_nd(); + + int number_of_dims = input.get_nd(); + int dim_array[3]; dim_array[0] = input.shape(0); dim_array[1] = input.shape(1); if (number_of_dims == 2) { - dim_array[2] = -11; + dim_array[2] = -1; } else { dim_array[2] = input.shape(2); } - /*Handling Matlab input data*/ + // Parameter handling is be done in Python + ///*Handling Matlab input data*/ //if ((nrhs < 2) || (nrhs > 5)) mexErrMsgTxt("At least 2 parameters is required: Image(2D/3D), Regularization parameter. The full list of parameters: Image(2D/3D), Regularization parameter, iterations number, tolerance, penalty type ('iso' or 'l1')"); - /*Handling Matlab input data*/ + ///*Handling Matlab input data*/ //A = (float *)mxGetData(prhs[0]); /*noisy image (2D/3D) */ A = reinterpret_cast(input.get_data()); - //mu = (float)mxGetScalar(prhs[1]); /* regularization parameter */ mu = (float)d_mu; + //iter = 35; /* default iterations number */ - iter = niterations; + //epsil = 0.0001; /* default tolerance constant */ epsil = (float)d_epsil; //methTV = 0; /* default isotropic TV penalty */ - methTV = TV_type; //if ((nrhs == 3) || (nrhs == 4) || (nrhs == 5)) iter = (int)mxGetScalar(prhs[2]); /* iterations number */ //if ((nrhs == 4) || (nrhs == 5)) epsil = (float)mxGetScalar(prhs[3]); /* tolerance constant */ //if (nrhs == 5) { @@ -182,34 +243,31 @@ bp::list SplitBregman_TV(np::ndarray input, double d_mu, , int niterations, doub if (number_of_dims == 2) { dimZ = 1; /*2D case*/ - /* - mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, mxClassID classid, mxComplexity ComplexFlag); -args: ndim: Number of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2. - dims: Dimensions array. Each element in the dimensions array contains the size of the array in that dimension. - For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array. - classid: Identifier for the class of the array, which determines the way the numerical data is represented in memory. - For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer. - ComplexFlag: If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). - Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran). - - mxCreateNumericArray initializes all its real data elements to 0. -*/ - -/* - U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - Dx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - Dy = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - Bx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - By = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); -*/ //U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - U = A = reinterpret_castinput.get_data(); - U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - Dx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - Dy = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - Bx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - By = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + //U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + //Dx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + //Dy = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + //Bx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + //By = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin(); + + np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npU_old = np::zeros(shape, dtype); + np::ndarray npDx = np::zeros(shape, dtype); + np::ndarray npDy = np::zeros(shape, dtype); + np::ndarray npBx = np::zeros(shape, dtype); + np::ndarray npBy = np::zeros(shape, dtype); + + U = reinterpret_cast(npU.get_data()); + U_old = reinterpret_cast(npU_old.get_data()); + Dx = reinterpret_cast(npDx.get_data()); + Dy = reinterpret_cast(npDy.get_data()); + Bx = reinterpret_cast(npBx.get_data()); + By = reinterpret_cast(npBy.get_data()); + + + copyIm(A, U, dimX, dimY, dimZ); /*initialize */ /* begin outer SB iterations */ @@ -245,59 +303,370 @@ args: ndim: Number of dimensions. If you specify a value for ndim that is less /*printf("%f %i %i \n", re, ll, count); */ /*copyIm(U_old, U, dimX, dimY, dimZ); */ + result.append(npU); + result.append(ll); + } + //printf("SB iterations stopped at iteration: %i\n", ll); + if (number_of_dims == 3) { + /*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Dx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Dy = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Dz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Bx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + By = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Bz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));*/ + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + np::dtype dtype = np::dtype::get_builtin(); + + np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npU_old = np::zeros(shape, dtype); + np::ndarray npDx = np::zeros(shape, dtype); + np::ndarray npDy = np::zeros(shape, dtype); + np::ndarray npDz = np::zeros(shape, dtype); + np::ndarray npBx = np::zeros(shape, dtype); + np::ndarray npBy = np::zeros(shape, dtype); + np::ndarray npBz = np::zeros(shape, dtype); + + U = reinterpret_cast(npU.get_data()); + U_old = reinterpret_cast(npU_old.get_data()); + Dx = reinterpret_cast(npDx.get_data()); + Dy = reinterpret_cast(npDy.get_data()); + Dz = reinterpret_cast(npDz.get_data()); + Bx = reinterpret_cast(npBx.get_data()); + By = reinterpret_cast(npBy.get_data()); + Bz = reinterpret_cast(npBz.get_data()); + + copyIm(A, U, dimX, dimY, dimZ); /*initialize */ + + /* begin outer SB iterations */ + for (ll = 0; ll 4) break; + + /* check that the residual norm is decreasing */ + if (ll > 2) { + if (re > re_old) break; + } + /*printf("%f %i %i \n", re, ll, count); */ + re_old = re; + } + //printf("SB iterations stopped at iteration: %i\n", ll); + result.append(npU); + result.append(ll); } - printf("SB iterations stopped at iteration: %i\n", ll); } - if (number_of_dims == 3) { - U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - Dx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - Dy = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - Dz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - Bx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - By = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - Bz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + return result; - copyIm(A, U, dimX, dimY, dimZ); /*initialize */ +} - /* begin outer SB iterations */ +bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int methTV) { + + // the result is in the following list + bp::list result; + + int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV; + float *A, *D = NULL, *D_old = NULL, *P1 = NULL, *P2 = NULL, *P3 = NULL, *P1_old = NULL, *P2_old = NULL, *P3_old = NULL, *R1 = NULL, *R2 = NULL, *R3 = NULL; + float lambda, tk, tkp1, re, re1, re_old, epsil, funcval; + + //number_of_dims = mxGetNumberOfDimensions(prhs[0]); + //dim_array = mxGetDimensions(prhs[0]); + + int number_of_dims = input.get_nd(); + int dim_array[3]; + + dim_array[0] = input.shape(0); + dim_array[1] = input.shape(1); + if (number_of_dims == 2) { + dim_array[2] = -1; + } + else { + dim_array[2] = input.shape(2); + } + + // Parameter handling is be done in Python + ///*Handling Matlab input data*/ + //if ((nrhs < 2) || (nrhs > 5)) mexErrMsgTxt("At least 2 parameters is required: Image(2D/3D), Regularization parameter. The full list of parameters: Image(2D/3D), Regularization parameter, iterations number, tolerance, penalty type ('iso' or 'l1')"); + + ///*Handling Matlab input data*/ + //A = (float *)mxGetData(prhs[0]); /*noisy image (2D/3D) */ + A = reinterpret_cast(input.get_data()); + + //mu = (float)mxGetScalar(prhs[1]); /* regularization parameter */ + mu = (float)d_mu; + + //iter = 35; /* default iterations number */ + + //epsil = 0.0001; /* default tolerance constant */ + epsil = (float)d_epsil; + //methTV = 0; /* default isotropic TV penalty */ + //if ((nrhs == 3) || (nrhs == 4) || (nrhs == 5)) iter = (int)mxGetScalar(prhs[2]); /* iterations number */ + //if ((nrhs == 4) || (nrhs == 5)) epsil = (float)mxGetScalar(prhs[3]); /* tolerance constant */ + //if (nrhs == 5) { + // char *penalty_type; + // penalty_type = mxArrayToString(prhs[4]); /* choosing TV penalty: 'iso' or 'l1', 'iso' is the default */ + // if ((strcmp(penalty_type, "l1") != 0) && (strcmp(penalty_type, "iso") != 0)) mexErrMsgTxt("Choose TV type: 'iso' or 'l1',"); + // if (strcmp(penalty_type, "l1") == 0) methTV = 1; /* enable 'l1' penalty */ + // mxFree(penalty_type); + //} + //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input image must be in a single precision"); } + + //plhs[1] = mxCreateNumericMatrix(1, 1, mxSINGLE_CLASS, mxREAL); + bp::tuple shape1 = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin(); + np::ndarray out1 = np::zeros(shape1, dtype); + + //float *funcvalA = (float *)mxGetData(plhs[1]); + float * funcvalA = reinterpret_cast(out1.get_data()); + //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input image must be in a single precision"); } + + /*Handling Matlab output data*/ + dimX = dim_array[0]; dimY = dim_array[1]; dimZ = dim_array[2]; + + tk = 1.0f; + tkp1 = 1.0f; + count = 1; + re_old = 0.0f; + + if (number_of_dims == 2) { + dimZ = 1; /*2D case*/ + /*D = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + D_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + P1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + P2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + P1_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + P2_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + R1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + R2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));*/ + + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin(); + + + np::ndarray npD = np::zeros(shape, dtype); + np::ndarray npD_old = np::zeros(shape, dtype); + np::ndarray npP1 = np::zeros(shape, dtype); + np::ndarray npP2 = np::zeros(shape, dtype); + np::ndarray npP1_old = np::zeros(shape, dtype); + np::ndarray npP2_old = np::zeros(shape, dtype); + np::ndarray npR1 = np::zeros(shape, dtype); + np::ndarray npR2 = zeros(2, dim_array, (float)0); + + D = reinterpret_cast(npD.get_data()); + D_old = reinterpret_cast(npD_old.get_data()); + P1 = reinterpret_cast(npP1.get_data()); + P2 = reinterpret_cast(npP2.get_data()); + P1_old = reinterpret_cast(npP1_old.get_data()); + P2_old = reinterpret_cast(npP2_old.get_data()); + R1 = reinterpret_cast(npR1.get_data()); + R2 = reinterpret_cast(npR2.get_data()); + + /* begin iterations */ for (ll = 0; ll 4) break; + if (count > 3) { + Obj_func2D(A, D, P1, P2, lambda, dimX, dimY); + funcval = 0.0f; + for (j = 0; j 2) { - if (re > re_old) break; + if (re > re_old) { + Obj_func2D(A, D, P1, P2, lambda, dimX, dimY); + funcval = 0.0f; + for (j = 0; j(npD); + result.append(out1); + result.append(ll); + } + if (number_of_dims == 3) { + /*D = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + D_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + P1 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + P2 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + P3 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + P1_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + P2_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + P3_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + R1 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + R2 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + R3 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));*/ + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + np::dtype dtype = np::dtype::get_builtin(); + + np::ndarray npD = np::zeros(shape, dtype); + np::ndarray npD_old = np::zeros(shape, dtype); + np::ndarray npP1 = np::zeros(shape, dtype); + np::ndarray npP2 = np::zeros(shape, dtype); + np::ndarray npP3 = np::zeros(shape, dtype); + np::ndarray npP1_old = np::zeros(shape, dtype); + np::ndarray npP2_old = np::zeros(shape, dtype); + np::ndarray npP3_old = np::zeros(shape, dtype); + np::ndarray npR1 = np::zeros(shape, dtype); + np::ndarray npR2 = np::zeros(shape, dtype); + np::ndarray npR3 = np::zeros(shape, dtype); + + D = reinterpret_cast(npD.get_data()); + D_old = reinterpret_cast(npD_old.get_data()); + P1 = reinterpret_cast(npP1.get_data()); + P2 = reinterpret_cast(npP2.get_data()); + P3 = reinterpret_cast(npP3.get_data()); + P1_old = reinterpret_cast(npP1_old.get_data()); + P2_old = reinterpret_cast(npP2_old.get_data()); + P3_old = reinterpret_cast(npP3_old.get_data()); + R1 = reinterpret_cast(npR1.get_data()); + R2 = reinterpret_cast(npR2.get_data()); + R2 = reinterpret_cast(npR3.get_data()); + /* begin iterations */ + for (ll = 0; ll 3) { + Obj_func3D(A, D, P1, P2, P3, lambda, dimX, dimY, dimZ); + funcval = 0.0f; + for (j = 0; j 2) { + if (re > re_old) { + Obj_func3D(A, D, P1, P2, P3, lambda, dimX, dimY, dimZ); + funcval = 0.0f; + for (j = 0; j(npD); + result.append(out1); + result.append(ll); } - bp::list result; + return result; } - BOOST_PYTHON_MODULE(fista) { @@ -310,6 +679,7 @@ BOOST_PYTHON_MODULE(fista) np::dtype dt1 = np::dtype::get_builtin(); np::dtype dt2 = np::dtype::get_builtin(); - def("mexFunction", mexFunction); + def("SplitBregman_TV", SplitBregman_TV); + def("FGP_TV", FGP_TV); } \ No newline at end of file -- cgit v1.2.3 From a16f24574d727cc486c51f0dcf581aeeb8a71381 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 4 Aug 2017 16:08:36 +0100 Subject: test 2D data for regularizers --- data/lena_gray_512.tif | Bin 0 -> 262598 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 data/lena_gray_512.tif diff --git a/data/lena_gray_512.tif b/data/lena_gray_512.tif new file mode 100644 index 0000000..f80cafc Binary files /dev/null and b/data/lena_gray_512.tif differ -- cgit v1.2.3 From 846e09d5c27736a4af80af7f9a3adcda3c6d4bf0 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 4 Aug 2017 16:12:27 +0100 Subject: fix typo and add include "matrix.h" --- main_func/regularizers_CPU/FGP_TV.c | 3 ++- main_func/regularizers_CPU/LLT_model.c | 10 ++++++++-- main_func/regularizers_CPU/SplitBregman_TV.c | 3 ++- main_func/regularizers_CPU/SplitBregman_TV_core.c | 12 +----------- 4 files changed, 13 insertions(+), 15 deletions(-) diff --git a/main_func/regularizers_CPU/FGP_TV.c b/main_func/regularizers_CPU/FGP_TV.c index cfe5b9e..66442c9 100644 --- a/main_func/regularizers_CPU/FGP_TV.c +++ b/main_func/regularizers_CPU/FGP_TV.c @@ -3,7 +3,7 @@ This work is part of the Core Imaging Library developed by Visual Analytics and Imaging System Group of the Science Technology Facilities Council, STFC -Copyright 2017 Daniil Kazanteev +Copyright 2017 Daniil Kazantsev Copyright 2017 Srikanth Nagella, Edoardo Pasca Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,6 +16,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "matrix.h" #include "mex.h" #include "FGP_TV_core.h" diff --git a/main_func/regularizers_CPU/LLT_model.c b/main_func/regularizers_CPU/LLT_model.c index 47146a5..0050654 100644 --- a/main_func/regularizers_CPU/LLT_model.c +++ b/main_func/regularizers_CPU/LLT_model.c @@ -3,7 +3,7 @@ This work is part of the Core Imaging Library developed by Visual Analytics and Imaging System Group of the Science Technology Facilities Council, STFC -Copyright 2017 Daniil Kazanteev +Copyright 2017 Daniil Kazantsev Copyright 2017 Srikanth Nagella, Edoardo Pasca Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,12 +18,19 @@ limitations under the License. */ #include "mex.h" +#include "matrix.h" #include "LLT_model_core.h" /* C-OMP implementation of Lysaker, Lundervold and Tai (LLT) model of higher order regularization penalty * * Input Parameters: * 1. U0 - origanal noise image/volume +======= +/* C-OMP implementation of Lysaker, Lundervold and Tai (LLT) model of higher order regularization penalty +* +* Input Parameters: +* 1. U0 - original noise image/volume +>>>>>>> fix typo and add include "matrix.h" * 2. lambda - regularization parameter * 3. tau - time-step for explicit scheme * 4. iter - iterations number @@ -46,7 +53,6 @@ limitations under the License. * 28.11.16/Harwell */ - void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) diff --git a/main_func/regularizers_CPU/SplitBregman_TV.c b/main_func/regularizers_CPU/SplitBregman_TV.c index 0dc638d..38f6a9d 100644 --- a/main_func/regularizers_CPU/SplitBregman_TV.c +++ b/main_func/regularizers_CPU/SplitBregman_TV.c @@ -3,7 +3,7 @@ This work is part of the Core Imaging Library developed by Visual Analytics and Imaging System Group of the Science Technology Facilities Council, STFC -Copyright 2017 Daniil Kazanteev +Copyright 2017 Daniil Kazantsev Copyright 2017 Srikanth Nagella, Edoardo Pasca Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,6 +18,7 @@ limitations under the License. */ #include "mex.h" +#include #include "SplitBregman_TV_core.h" /* C-OMP implementation of Split Bregman - TV denoising-regularization model (2D/3D) diff --git a/main_func/regularizers_CPU/SplitBregman_TV_core.c b/main_func/regularizers_CPU/SplitBregman_TV_core.c index 283dd43..4109a4b 100644 --- a/main_func/regularizers_CPU/SplitBregman_TV_core.c +++ b/main_func/regularizers_CPU/SplitBregman_TV_core.c @@ -3,7 +3,7 @@ This work is part of the Core Imaging Library developed by Visual Analytics and Imaging System Group of the Science Technology Facilities Council, STFC -Copyright 2017 Daniil Kazanteev +Copyright 2017 Daniil Kazantsev Copyright 2017 Srikanth Nagella, Edoardo Pasca Licensed under the Apache License, Version 2.0 (the "License"); @@ -257,13 +257,3 @@ float updBxByBz3D(float *U, float *Dx, float *Dy, float *Dz, float *Bx, float *B }}} return 1; } -/* General Functions */ -/*****************************************************************/ -/* Copy Image */ -float copyIm(float *A, float *B, int dimX, int dimY, int dimZ) -{ - int j; -#pragma omp parallel for shared(A, B) private(j) - for(j=0; j Date: Fri, 4 Aug 2017 16:13:19 +0100 Subject: removed definition of CopyImage --- main_func/regularizers_CPU/FGP_TV_core.c | 13 ++----------- main_func/regularizers_CPU/LLT_model_core.c | 11 ++--------- 2 files changed, 4 insertions(+), 20 deletions(-) diff --git a/main_func/regularizers_CPU/FGP_TV_core.c b/main_func/regularizers_CPU/FGP_TV_core.c index 9cde327..03cd445 100644 --- a/main_func/regularizers_CPU/FGP_TV_core.c +++ b/main_func/regularizers_CPU/FGP_TV_core.c @@ -3,7 +3,7 @@ This work is part of the Core Imaging Library developed by Visual Analytics and Imaging System Group of the Science Technology Facilities Council, STFC -Copyright 2017 Daniil Kazanteev +Copyright 2017 Daniil Kazantsev Copyright 2017 Srikanth Nagella, Edoardo Pasca Licensed under the Apache License, Version 2.0 (the "License"); @@ -263,13 +263,4 @@ float Rupd_func3D(float *P1, float *P1_old, float *P2, float *P2_old, float *P3, return 1; } -/* General Functions */ -/*****************************************************************/ -/* Copy Image */ -float copyIm(float *A, float *B, int dimX, int dimY, int dimZ) -{ - int j; -#pragma omp parallel for shared(A, B) private(j) - for (j = 0; j Date: Fri, 4 Aug 2017 16:14:24 +0100 Subject: extern C function called by C++ adds extern "C" to the definition of the functions if the code is called from C++ rather than from C/Matlab --- main_func/regularizers_CPU/FGP_TV_core.h | 11 ++++-- main_func/regularizers_CPU/LLT_model_core.h | 10 ++++-- main_func/regularizers_CPU/PatchBased_Regul_core.h | 40 +++++++++++++++++++++- main_func/regularizers_CPU/SplitBregman_TV_core.h | 12 +++++-- main_func/regularizers_CPU/TGV_PD_core.h | 35 +++++++++++++++++-- 5 files changed, 98 insertions(+), 10 deletions(-) diff --git a/main_func/regularizers_CPU/FGP_TV_core.h b/main_func/regularizers_CPU/FGP_TV_core.h index d256c9a..7e7229a 100644 --- a/main_func/regularizers_CPU/FGP_TV_core.h +++ b/main_func/regularizers_CPU/FGP_TV_core.h @@ -3,7 +3,7 @@ This work is part of the Core Imaging Library developed by Visual Analytics and Imaging System Group of the Science Technology Facilities Council, STFC -Copyright 2017 Daniil Kazanteev +Copyright 2017 Daniil Kazantsev Copyright 2017 Srikanth Nagella, Edoardo Pasca Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,7 +17,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include +//#include #include #include #include @@ -51,7 +51,9 @@ limitations under the License. * D. Kazantsev, 2016-17 * */ - +#ifdef __cplusplus +extern "C" { +#endif //float copyIm(float *A, float *B, int dimX, int dimY, int dimZ); float Obj_func2D(float *A, float *D, float *R1, float *R2, float lambda, int dimX, int dimY); float Grad_func2D(float *P1, float *P2, float *D, float *R1, float *R2, float lambda, int dimX, int dimY); @@ -64,3 +66,6 @@ float Grad_func3D(float *P1, float *P2, float *P3, float *D, float *R1, float *R float Proj_func3D(float *P1, float *P2, float *P3, int dimX, int dimY, int dimZ); float Rupd_func3D(float *P1, float *P1_old, float *P2, float *P2_old, float *P3, float *P3_old, float *R1, float *R2, float *R3, float tkp1, float tk, int dimX, int dimY, int dimZ); float Obj_func_CALC3D(float *A, float *D, float *funcvalA, float lambda, int dimX, int dimY, int dimZ); +#ifdef __cplusplus +} +#endif diff --git a/main_func/regularizers_CPU/LLT_model_core.h b/main_func/regularizers_CPU/LLT_model_core.h index 560bb9c..13fce5a 100644 --- a/main_func/regularizers_CPU/LLT_model_core.h +++ b/main_func/regularizers_CPU/LLT_model_core.h @@ -3,7 +3,7 @@ This work is part of the Core Imaging Library developed by Visual Analytics and Imaging System Group of the Science Technology Facilities Council, STFC -Copyright 2017 Daniil Kazanteev +Copyright 2017 Daniil Kazantsev Copyright 2017 Srikanth Nagella, Edoardo Pasca Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,7 +17,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include +//#include #include #include #include @@ -28,6 +28,9 @@ limitations under the License. #define EPS 0.01 /* 2D functions */ +#ifdef __cplusplus +extern "C" { +#endif float der2D(float *U, float *D1, float *D2, int dimX, int dimY, int dimZ); float div_upd2D(float *U0, float *U, float *D1, float *D2, int dimX, int dimY, int dimZ, float lambda, float tau); @@ -38,3 +41,6 @@ float calcMap(float *U, unsigned short *Map, int dimX, int dimY, int dimZ); float cleanMap(unsigned short *Map, int dimX, int dimY, int dimZ); //float copyIm(float *A, float *U, int dimX, int dimY, int dimZ); +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/main_func/regularizers_CPU/PatchBased_Regul_core.h b/main_func/regularizers_CPU/PatchBased_Regul_core.h index 5aa6415..d4a8a46 100644 --- a/main_func/regularizers_CPU/PatchBased_Regul_core.h +++ b/main_func/regularizers_CPU/PatchBased_Regul_core.h @@ -19,13 +19,51 @@ limitations under the License. #define _USE_MATH_DEFINES -#include +//#include #include #include #include #include #include "omp.h" +/* C-OMP implementation of patch-based (PB) regularization (2D and 3D cases). +* This method finds self-similar patches in data and performs one fixed point iteration to mimimize the PB penalty function +* +* References: 1. Yang Z. & Jacob M. "Nonlocal Regularization of Inverse Problems" +* 2. Kazantsev D. et al. "4D-CT reconstruction with unified spatial-temporal patch-based regularization" +* +* Input Parameters (mandatory): +* 1. Image (2D or 3D) +* 2. ratio of the searching window (e.g. 3 = (2*3+1) = 7 pixels window) +* 3. ratio of the similarity window (e.g. 1 = (2*1+1) = 3 pixels window) +* 4. h - parameter for the PB penalty function +* 5. lambda - regularization parameter + +* Output: +* 1. regularized (denoised) Image (N x N)/volume (N x N x N) +* +* Quick 2D denoising example in Matlab: +Im = double(imread('lena_gray_256.tif'))/255; % loading image +u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); +* +* Please see more tests in a file: +TestTemporalSmoothing.m + +* +* Matlab + C/mex compilers needed +* to compile with OMP support: mex PB_Regul_CPU.c CFLAGS="\$CFLAGS -fopenmp -Wall" LDFLAGS="\$LDFLAGS -fopenmp" +* +* D. Kazantsev * +* 02/07/2014 +* Harwell, UK +*/ +#ifdef __cplusplus +extern "C" { +#endif float pad_crop(float *A, float *Ap, int OldSizeX, int OldSizeY, int OldSizeZ, int NewSizeX, int NewSizeY, int NewSizeZ, int padXY, int switchpad_crop); float PB_FUNC2D(float *A, float *B, int dimX, int dimY, int padXY, int SearchW, int SimilW, float h, float lambda); float PB_FUNC3D(float *A, float *B, int dimX, int dimY, int dimZ, int padXY, int SearchW, int SimilW, float h, float lambda); +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/main_func/regularizers_CPU/SplitBregman_TV_core.h b/main_func/regularizers_CPU/SplitBregman_TV_core.h index 78aef09..6ed3ff9 100644 --- a/main_func/regularizers_CPU/SplitBregman_TV_core.h +++ b/main_func/regularizers_CPU/SplitBregman_TV_core.h @@ -3,7 +3,7 @@ This work is part of the Core Imaging Library developed by Visual Analytics and Imaging System Group of the Science Technology Facilities Council, STFC -Copyright 2017 Daniil Kazanteev +Copyright 2017 Daniil Kazantsev Copyright 2017 Srikanth Nagella, Edoardo Pasca Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,7 +16,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include +//#include #include #include #include @@ -49,6 +49,10 @@ limitations under the License. * D. Kazantsev, 2016* */ +#ifdef __cplusplus +extern "C" { +#endif + //float copyIm(float *A, float *B, int dimX, int dimY, int dimZ); float gauss_seidel2D(float *U, float *A, float *Dx, float *Dy, float *Bx, float *By, int dimX, int dimY, float lambda, float mu); float updDxDy_shrinkAniso2D(float *U, float *Dx, float *Dy, float *Bx, float *By, int dimX, int dimY, float lambda); @@ -59,3 +63,7 @@ float gauss_seidel3D(float *U, float *A, float *Dx, float *Dy, float *Dz, float float updDxDyDz_shrinkAniso3D(float *U, float *Dx, float *Dy, float *Dz, float *Bx, float *By, float *Bz, int dimX, int dimY, int dimZ, float lambda); float updDxDyDz_shrinkIso3D(float *U, float *Dx, float *Dy, float *Dz, float *Bx, float *By, float *Bz, int dimX, int dimY, int dimZ, float lambda); float updBxByBz3D(float *U, float *Dx, float *Dy, float *Dz, float *Bx, float *By, float *Bz, int dimX, int dimY, int dimZ); + +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/main_func/regularizers_CPU/TGV_PD_core.h b/main_func/regularizers_CPU/TGV_PD_core.h index cbe7ea4..d5378df 100644 --- a/main_func/regularizers_CPU/TGV_PD_core.h +++ b/main_func/regularizers_CPU/TGV_PD_core.h @@ -3,7 +3,7 @@ This work is part of the Core Imaging Library developed by Visual Analytics and Imaging System Group of the Science Technology Facilities Council, STFC -Copyright 2017 Daniil Kazanteev +Copyright 2017 Daniil Kazantsev Copyright 2017 Srikanth Nagella, Edoardo Pasca Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,7 +17,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include +//#include #include #include #include @@ -25,6 +25,34 @@ limitations under the License. #include "omp.h" #include "utils.h" +/* C-OMP implementation of Primal-Dual denoising method for +* Total Generilized Variation (TGV)-L2 model (2D case only) +* +* Input Parameters: +* 1. Noisy image/volume (2D) +* 2. lambda - regularization parameter +* 3. parameter to control first-order term (alpha1) +* 4. parameter to control the second-order term (alpha0) +* 5. Number of CP iterations +* +* Output: +* Filtered/regularized image +* +* Example: +* figure; +* Im = double(imread('lena_gray_256.tif'))/255; % loading image +* u0 = Im + .03*randn(size(Im)); % adding noise +* tic; u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); toc; +* +* to compile with OMP support: mex TGV_PD.c CFLAGS="\$CFLAGS -fopenmp -Wall -std=c99" LDFLAGS="\$LDFLAGS -fopenmp" +* References: +* K. Bredies "Total Generalized Variation" +* +* 28.11.16/Harwell +*/ +#ifdef __cplusplus +extern "C" { +#endif /* 2D functions */ float DualP_2D(float *U, float *V1, float *V2, float *P1, float *P2, int dimX, int dimY, int dimZ, float sigma); float ProjP_2D(float *P1, float *P2, int dimX, int dimY, int dimZ, float alpha1); @@ -34,3 +62,6 @@ float DivProjP_2D(float *U, float *A, float *P1, float *P2, int dimX, int dimY, float UpdV_2D(float *V1, float *V2, float *P1, float *P2, float *Q1, float *Q2, float *Q3, int dimX, int dimY, int dimZ, float tau); float newU(float *U, float *U_old, int dimX, int dimY, int dimZ); //float copyIm(float *A, float *U, int dimX, int dimY, int dimZ); +#ifdef __cplusplus +} +#endif -- cgit v1.2.3 From c9834452a55de7e55301fa95ffdf8768694400b3 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 4 Aug 2017 16:14:41 +0100 Subject: extern C fix --- main_func/regularizers_CPU/utils.h | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/main_func/regularizers_CPU/utils.h b/main_func/regularizers_CPU/utils.h index c720006..53463a3 100644 --- a/main_func/regularizers_CPU/utils.h +++ b/main_func/regularizers_CPU/utils.h @@ -3,7 +3,7 @@ This work is part of the Core Imaging Library developed by Visual Analytics and Imaging System Group of the Science Technology Facilities Council, STFC -Copyright 2017 Daniil Kazanteev +Copyright 2017 Daniil Kazantsev Copyright 2017 Srikanth Nagella, Edoardo Pasca Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,11 +17,16 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include -#include +//#include +//#include #include #include -#include +//#include #include "omp.h" - +#ifdef __cplusplus +extern "C" { +#endif float copyIm(float *A, float *U, int dimX, int dimY, int dimZ); +#ifdef __cplusplus +} +#endif -- cgit v1.2.3 From 12dbe738d5a2af5573e33a31f1745a50dba165ba Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 4 Aug 2017 16:15:03 +0100 Subject: compilation fixes --- src/Python/setup.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/Python/setup.py b/src/Python/setup.py index ffb9c02..a8feb1c 100644 --- a/src/Python/setup.py +++ b/src/Python/setup.py @@ -29,14 +29,14 @@ extra_library_dirs = [library_include_path+"/../lib", "C:\\Apps\\Miniconda2\\env extra_compile_args = ['-fopenmp','-O2', '-funsigned-char', '-Wall', '-std=c++0x'] extra_libraries = [] if platform.system() == 'Windows': - extra_compile_args[0:] = ['/DWIN32','/EHsc','/DBOOST_ALL_NO_LIB'] - extra_include_dirs += ["..\\ContourTree\\", "..\\win32\\" , "..\\Core\\","."] + extra_compile_args[0:] = ['/DWIN32','/EHsc','/DBOOST_ALL_NO_LIB' , '/openmp' ] + extra_include_dirs += ["..\\..\\main_func\\regularizers_CPU\\","."] if sys.version_info.major == 3 : extra_libraries += ['boost_python3-vc140-mt-1_64', 'boost_numpy3-vc140-mt-1_64'] else: extra_libraries += ['boost_python-vc90-mt-1_64', 'boost_numpy-vc90-mt-1_64'] else: - extra_include_dirs += ["../ContourTree/", "../Core/","."] + extra_include_dirs += ["../../main_func/regularizers_CPU","."] if sys.version_info.major == 3: extra_libraries += ['boost_python3', 'boost_numpy3','gomp'] else: @@ -47,8 +47,12 @@ setup( description='CCPi Core Imaging Library - FISTA Reconstruction Module', version=cil_version, cmdclass = {'build_ext': build_ext}, - ext_modules = [Extension("fista", - sources=[ "Matlab2Python_utils.cpp", + ext_modules = [Extension("regularizers", + sources=["fista_module.cpp", + "..\\..\\main_func\\regularizers_CPU\\FGP_TV_core.c", + "..\\..\\main_func\\regularizers_CPU\\SplitBregman_TV_core.c", + "..\\..\\main_func\\regularizers_CPU\\LLT_model_core.c", + "..\\..\\main_func\\regularizers_CPU\\utils.c" ], include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), -- cgit v1.2.3 From 36e4c296223f67bb917511089ec59533460f1695 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 4 Aug 2017 16:15:17 +0100 Subject: test facility for regularizers --- src/Python/test_regularizers.py | 265 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 265 insertions(+) create mode 100644 src/Python/test_regularizers.py diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py new file mode 100644 index 0000000..6abfba4 --- /dev/null +++ b/src/Python/test_regularizers.py @@ -0,0 +1,265 @@ +# -*- coding: utf-8 -*- +""" +Created on Fri Aug 4 11:10:05 2017 + +@author: ofn77899 +""" + +from ccpi.viewer.CILViewer2D import Converter +import vtk + +import regularizers +import matplotlib.pyplot as plt +import numpy as np +import os +from enum import Enum + +class Regularizer(): + '''Class to handle regularizer algorithms to be used during reconstruction + + Currently 5 regularization algorithms are available: + + 1) SplitBregman_TV + 2) FGP_TV + 3) + 4) + 5) + + Usage: + the regularizer can be invoked as object or as static method + Depending on the actual regularizer the input parameter may vary, and + a different default setting is defined. + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + + out = reg(input=u0, regularization_parameter=10., number_of_iterations=30, + tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., + number_of_iterations=30, tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + + A number of optional parameters can be passed or skipped + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) + + ''' + class Algorithm(Enum): + SplitBregman_TV = regularizers.SplitBregman_TV + FGP_TV = regularizers.FGP_TV + LLT_model = regularizers.LLT_model + # Algorithm + + class TotalVariationPenalty(Enum): + isotropic = 0 + l1 = 1 + # TotalVariationPenalty + + def __init__(self , algorithm): + + self.algorithm = algorithm + self.pars = self.parsForAlgorithm(algorithm) + # __init__ + + def parsForAlgorithm(self, algorithm): + pars = dict() + if algorithm == Regularizer.Algorithm.SplitBregman_TV : + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['number_of_iterations'] = 35 + pars['tolerance_constant'] = 0.0001 + pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + elif algorithm == Regularizer.Algorithm.FGP_TV : + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['number_of_iterations'] = 50 + pars['tolerance_constant'] = 0.001 + pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + elif algorithm == Regularizer.Algorithm.LLT_model: + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['time_step'] = None + pars['number_of_iterations'] = None + pars['tolerance_constant'] = None + pars['restrictive_Z_smoothing'] = 0 + + return pars + # parsForAlgorithm + + def __call__(self, input, regularization_parameter, **kwargs): + + if kwargs is not None: + for key, value in kwargs.items(): + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + self.pars['input'] = input + self.pars['regularization_parameter'] = regularization_parameter + #for key, value in self.pars.items(): + # print("{0} = {1}".format(key, value)) + + if self.algorithm == Regularizer.Algorithm.SplitBregman_TV : + return self.algorithm(input, regularization_parameter, + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['TV_penalty'].value ) + elif self.algorithm == Regularizer.Algorithm.FGP_TV : + return self.algorithm(input, regularization_parameter, + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['TV_penalty'].value ) + elif self.algorithm == Regularizer.Algorithm.LLT_model : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + if None in self.pars: + raise Exception("Not all parameters have been provided") + else: + return self.algorithm(input, + regularization_parameter, + self.pars['time_step'] , + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['restrictive_Z_smoothing'] ) + + + # __call__ + + @staticmethod + def SplitBregman_TV(input, regularization_parameter , **kwargs): + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + out = list( reg(input, regularization_parameter, **kwargs) ) + out.append(reg.pars) + return out + + @staticmethod + def FGP_TV(input, regularization_parameter , **kwargs): + reg = Regularizer(Regularizer.Algorithm.FGP_TV) + out = list( reg(input, regularization_parameter, **kwargs) ) + out.append(reg.pars) + return out + + @staticmethod + def LLT_model(input, regularization_parameter , time_step, number_of_iterations, + tolerance_constant, restrictive_Z_smoothing=0): + reg = Regularizer(Regularizer.Algorithm.FGP_TV) + out = list( reg(input, regularization_parameter, time_step=time_step, + number_of_iterations=number_of_iterations, + tolerance_constant=tolerance_constant, + restrictive_Z_smoothing=restrictive_Z_smoothing) ) + out.append(reg.pars) + return out + + +#Example: +# figure; +# Im = double(imread('lena_gray_256.tif'))/255; % loading image +# u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0; +# u = SplitBregman_TV(single(u0), 10, 30, 1e-04); + +filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\lena_gray_512.tif" +reader = vtk.vtkTIFFReader() +reader.SetFileName(os.path.normpath(filename)) +reader.Update() +#vtk returns 3D images, let's take just the one slice there is as 2D +Im = Converter.vtk2numpy(reader.GetOutput()).T[0]/255 + +#imgplot = plt.imshow(Im) +perc = 0.05 +u0 = Im + (perc* np.random.normal(size=np.shape(Im))) +# map the u0 u0->u0>0 +f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1) +u0 = f(u0).astype('float32') + +# plot +fig = plt.figure() +a=fig.add_subplot(2,3,1) +a.set_title('Original') +imgplot = plt.imshow(Im) + +a=fig.add_subplot(2,3,2) +a.set_title('noise') +imgplot = plt.imshow(u0) + + +############################################################################## +# Call regularizer + +####################### SplitBregman_TV ##################################### +# u = SplitBregman_TV(single(u0), 10, 30, 1e-04); + +reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + +out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30, + #tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + +out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30, + tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) +out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) +pars = out2[2] + +a=fig.add_subplot(2,3,3) +a.set_title('SplitBregman_TV') +textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s' +textstr = textstr % (pars['regularization_parameter'], + pars['number_of_iterations'], + pars['tolerance_constant'], + pars['TV_penalty'].name) + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(out2[0]) + +###################### FGP_TV ######################################### +# u = FGP_TV(single(u0), 0.05, 100, 1e-04); +out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.05, + number_of_iterations=10) +pars = out2[-1] + +a=fig.add_subplot(2,3,4) +a.set_title('FGP_TV') +textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s' +textstr = textstr % (pars['regularization_parameter'], + pars['number_of_iterations'], + pars['tolerance_constant'], + pars['TV_penalty'].name) + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(out2[0]) + +###################### LLT_model ######################################### +# * u0 = Im + .03*randn(size(Im)); % adding noise +# [Den] = LLT_model(single(u0), 10, 0.1, 1); +out2 = Regularizer.LLT_model(input=u0, regularization_parameter=10., + time_step=0.1, + tolerance_constant=1e-4, + number_of_iterations=10) +pars = out2[-1] + +a=fig.add_subplot(2,3,5) +a.set_title('LLT_model') +textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\ntime-step=%f' +textstr = textstr % (pars['regularization_parameter'], + pars['number_of_iterations'], + pars['tolerance_constant'], + pars['time_step'] + ) + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(out2[0]) + + + -- cgit v1.2.3 From fd496731c8e9d4975864d76dbb6574cbeee7cf98 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 4 Aug 2017 16:16:37 +0100 Subject: Added 3 regularizers SplitBregman_TV FGP_TV LLT_model --- src/Python/fista_module.cpp | 266 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 232 insertions(+), 34 deletions(-) diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp index 2492884..d890b10 100644 --- a/src/Python/fista_module.cpp +++ b/src/Python/fista_module.cpp @@ -3,7 +3,7 @@ This work is part of the Core Imaging Library developed by Visual Analytics and Imaging System Group of the Science Technology Facilities Council, STFC -Copyright 2017 Daniil Kazanteev +Copyright 2017 Daniil Kazantsev Copyright 2017 Srikanth Nagella, Edoardo Pasca Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,6 +28,8 @@ limitations under the License. #include "SplitBregman_TV_core.h" #include "FGP_TV_core.h" +#include "LLT_model_core.h" +#include "utils.h" @@ -131,6 +133,18 @@ T * mxGetData(const np::ndarray pm) { return reinterpret_cast(prhs[0]); } +template +np::ndarray zeros(int dims, int * dim_array, T el) { + bp::tuple shape; + if (dims == 3) + shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + else if (dims == 2) + shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin(); + np::ndarray zz = np::zeros(shape, dtype); + return zz; +} + @@ -169,7 +183,6 @@ bp::list mexFunction(np::ndarray input) { } } - bp::list result; result.append(number_of_dims); @@ -189,14 +202,14 @@ bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsi // the result is in the following list bp::list result; - int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV; - const int *dim_array; + int number_of_dims, dimX, dimY, dimZ, ll, j, count; + //const int *dim_array; float *A, *U = NULL, *U_old = NULL, *Dx = NULL, *Dy = NULL, *Dz = NULL, *Bx = NULL, *By = NULL, *Bz = NULL, lambda, mu, epsil, re, re1, re_old; //number_of_dims = mxGetNumberOfDimensions(prhs[0]); //dim_array = mxGetDimensions(prhs[0]); - int number_of_dims = input.get_nd(); + number_of_dims = input.get_nd(); int dim_array[3]; dim_array[0] = input.shape(0); @@ -252,26 +265,26 @@ bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsi bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); np::dtype dtype = np::dtype::get_builtin(); - np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npU = np::zeros(shape, dtype); np::ndarray npU_old = np::zeros(shape, dtype); - np::ndarray npDx = np::zeros(shape, dtype); - np::ndarray npDy = np::zeros(shape, dtype); - np::ndarray npBx = np::zeros(shape, dtype); - np::ndarray npBy = np::zeros(shape, dtype); + np::ndarray npDx = np::zeros(shape, dtype); + np::ndarray npDy = np::zeros(shape, dtype); + np::ndarray npBx = np::zeros(shape, dtype); + np::ndarray npBy = np::zeros(shape, dtype); - U = reinterpret_cast(npU.get_data()); + U = reinterpret_cast(npU.get_data()); U_old = reinterpret_cast(npU_old.get_data()); - Dx = reinterpret_cast(npDx.get_data()); - Dy = reinterpret_cast(npDy.get_data()); - Bx = reinterpret_cast(npBx.get_data()); - By = reinterpret_cast(npBy.get_data()); + Dx = reinterpret_cast(npDx.get_data()); + Dy = reinterpret_cast(npDy.get_data()); + Bx = reinterpret_cast(npBx.get_data()); + By = reinterpret_cast(npBy.get_data()); + - copyIm(A, U, dimX, dimY, dimZ); /*initialize */ /* begin outer SB iterations */ - for (ll = 0; ll(npU); - result.append(ll); + } //printf("SB iterations stopped at iteration: %i\n", ll); - if (number_of_dims == 3) { + result.append(npU); + result.append(ll); + } + if (number_of_dims == 3) { /*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); Dx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); @@ -375,24 +390,25 @@ bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsi result.append(npU); result.append(ll); } - } return result; -} + } + + bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int methTV) { // the result is in the following list bp::list result; - int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV; + int number_of_dims, dimX, dimY, dimZ, ll, j, count; float *A, *D = NULL, *D_old = NULL, *P1 = NULL, *P2 = NULL, *P3 = NULL, *P1_old = NULL, *P2_old = NULL, *P3_old = NULL, *R1 = NULL, *R2 = NULL, *R3 = NULL; - float lambda, tk, tkp1, re, re1, re_old, epsil, funcval; + float lambda, tk, tkp1, re, re1, re_old, epsil, funcval, mu; //number_of_dims = mxGetNumberOfDimensions(prhs[0]); //dim_array = mxGetDimensions(prhs[0]); - int number_of_dims = input.get_nd(); + number_of_dims = input.get_nd(); int dim_array[3]; dim_array[0] = input.shape(0); @@ -512,7 +528,7 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me for (j = 0; j(input.get_data()); + lambda = (float)d_lambda; + tau = (float)d_tau; + // iter is passed as parameter + epsil = (float)d_epsil; + // switcher is passed as parameter + /*Handling Matlab output data*/ + dimX = dim_array[0]; dimY = dim_array[1]; dimZ = 1; + + if (number_of_dims == 2) { + /*2D case*/ + /*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + D1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + D2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));*/ + + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin(); + + + np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npU_old = np::zeros(shape, dtype); + np::ndarray npD1 = np::zeros(shape, dtype); + np::ndarray npD2 = np::zeros(shape, dtype); + + + U = reinterpret_cast(npU.get_data()); + U_old = reinterpret_cast(npU_old.get_data()); + D1 = reinterpret_cast(npD1.get_data()); + D2 = reinterpret_cast(npD2.get_data()); + + /*Copy U0 to U*/ + copyIm(U0, U, dimX, dimY, dimZ); + + count = 1; + re_old = 0.0f; + + for (ll = 0; ll < iter; ll++) { + + copyIm(U, U_old, dimX, dimY, dimZ); + + /*estimate inner derrivatives */ + der2D(U, D1, D2, dimX, dimY, dimZ); + /* calculate div^2 and update */ + div_upd2D(U0, U, D1, D2, dimX, dimY, dimZ, lambda, tau); + + /* calculate norm to terminate earlier */ + re = 0.0f; re1 = 0.0f; + for (j = 0; j 4) break; + + /* check that the residual norm is decreasing */ + if (ll > 2) { + if (re > re_old) break; + } + re_old = re; + + } /*end of iterations*/ + //printf("HO iterations stopped at iteration: %i\n", ll); + + result.append(npU); + } + else if (number_of_dims == 3) { + /*3D case*/ + dimZ = dim_array[2]; + /*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + D1 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + D2 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + D3 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + if (switcher != 0) { + Map = (unsigned short*)mxGetPr(plhs[1] = mxCreateNumericArray(3, dim_array, mxUINT16_CLASS, mxREAL)); + }*/ + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin(); + + + np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npU_old = np::zeros(shape, dtype); + np::ndarray npD1 = np::zeros(shape, dtype); + np::ndarray npD2 = np::zeros(shape, dtype); + np::ndarray npD3 = np::zeros(shape, dtype); + np::ndarray npMap = np::zeros(shape, np::dtype::get_builtin()); + Map = reinterpret_cast(npMap.get_data()); + if (switcher != 0) { + //Map = (unsigned short*)mxGetPr(plhs[1] = mxCreateNumericArray(3, dim_array, mxUINT16_CLASS, mxREAL)); + + Map = reinterpret_cast(npMap.get_data()); + } + + U = reinterpret_cast(npU.get_data()); + U_old = reinterpret_cast(npU_old.get_data()); + D1 = reinterpret_cast(npD1.get_data()); + D2 = reinterpret_cast(npD2.get_data()); + D3 = reinterpret_cast(npD2.get_data()); + + /*Copy U0 to U*/ + copyIm(U0, U, dimX, dimY, dimZ); + + count = 1; + re_old = 0.0f; + + + if (switcher == 1) { + /* apply restrictive smoothing */ + calcMap(U, Map, dimX, dimY, dimZ); + /*clear outliers */ + cleanMap(Map, dimX, dimY, dimZ); + } + for (ll = 0; ll < iter; ll++) { + + copyIm(U, U_old, dimX, dimY, dimZ); + + /*estimate inner derrivatives */ + der3D(U, D1, D2, D3, dimX, dimY, dimZ); + /* calculate div^2 and update */ + div_upd3D(U0, U, D1, D2, D3, Map, switcher, dimX, dimY, dimZ, lambda, tau); + + /* calculate norm to terminate earlier */ + re = 0.0f; re1 = 0.0f; + for (j = 0; j 4) break; + + /* check that the residual norm is decreasing */ + if (ll > 2) { + if (re > re_old) break; + } + re_old = re; + + } /*end of iterations*/ + //printf("HO iterations stopped at iteration: %i\n", ll); + result.append(npU); + if (switcher != 0) result.append(npMap); + + } + return result; +} + + +BOOST_PYTHON_MODULE(regularizers) { np::initialize(); //To specify that this module is a package bp::object package = bp::scope(); - package.attr("__path__") = "fista"; + package.attr("__path__") = "regularizers"; np::dtype dt1 = np::dtype::get_builtin(); np::dtype dt2 = np::dtype::get_builtin(); @@ -682,4 +879,5 @@ BOOST_PYTHON_MODULE(fista) def("mexFunction", mexFunction); def("SplitBregman_TV", SplitBregman_TV); def("FGP_TV", FGP_TV); + def("LLT_model", LLT_model); } \ No newline at end of file -- cgit v1.2.3 From 4bef3726577ddf1bf2b594620e106573c6f18693 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 4 Aug 2017 16:16:53 +0100 Subject: minor change --- src/Python/Matlab2Python_utils.cpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/Python/Matlab2Python_utils.cpp b/src/Python/Matlab2Python_utils.cpp index 138e8da..6aaad90 100644 --- a/src/Python/Matlab2Python_utils.cpp +++ b/src/Python/Matlab2Python_utils.cpp @@ -128,7 +128,11 @@ T * mxGetData(const np::ndarray pm) { template np::ndarray zeros(int dims , int * dim_array, T el) { - bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + bp::tuple shape; + if (dims == 3) + shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + else if (dims == 2) + shape = bp::make_tuple(dim_array[0], dim_array[1]); np::dtype dtype = np::dtype::get_builtin(); np::ndarray zz = np::zeros(shape, dtype); return zz; @@ -163,7 +167,7 @@ bp::list mexFunction( np::ndarray input ) { for (int k = 0; k < dim_array[2]; k++) { int index = k + dim_array[2] * j + dim_array[2] * dim_array[1] * i; int val = (*(A + index)); - float fval = (float)val; + float fval = sqrt((float)val); std::memcpy(B + index , &val, sizeof(int)); std::memcpy(C + index , &fval, sizeof(float)); } @@ -186,7 +190,7 @@ bp::list mexFunction( np::ndarray input ) { } -BOOST_PYTHON_MODULE(fista) +BOOST_PYTHON_MODULE(prova) { np::initialize(); -- cgit v1.2.3 From 662ab4ac9c3d89cdc1527c2a2bdcf442f3b6a173 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 4 Aug 2017 16:17:18 +0100 Subject: test for general boost::python / numpy routines --- src/Python/test.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 src/Python/test.py diff --git a/src/Python/test.py b/src/Python/test.py new file mode 100644 index 0000000..e283f89 --- /dev/null +++ b/src/Python/test.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +""" +Created on Thu Aug 3 14:08:09 2017 + +@author: ofn77899 +""" + +import fista +import numpy as np + +a = np.asarray([i for i in range(3*4*5)]) +a = a.reshape([3,4,5]) +print (a) +b = fista.mexFunction(a) +#print (b) +print (b[4].shape) +print (b[4]) +print (b[5]) \ No newline at end of file -- cgit v1.2.3 From 927e1caa9b1d0df9f4f2eb95170f6bd8be657656 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 4 Aug 2017 16:48:41 +0100 Subject: typo and include matrix.h --- main_func/regularizers_CPU/PatchBased_Regul.c | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/main_func/regularizers_CPU/PatchBased_Regul.c b/main_func/regularizers_CPU/PatchBased_Regul.c index 59eb3b3..9c925df 100644 --- a/main_func/regularizers_CPU/PatchBased_Regul.c +++ b/main_func/regularizers_CPU/PatchBased_Regul.c @@ -3,7 +3,7 @@ This work is part of the Core Imaging Library developed by Visual Analytics and Imaging System Group of the Science Technology Facilities Council, STFC -Copyright 2017 Daniil Kazanteev +Copyright 2017 Daniil Kazantsev Copyright 2017 Srikanth Nagella, Edoardo Pasca Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,6 +18,7 @@ limitations under the License. */ #include "mex.h" +#include "matrix.h" #include "PatchBased_Regul_core.h" -- cgit v1.2.3 From ecfb1146dc1de9ea6d8c6587d15417a9690f5ab4 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 4 Aug 2017 16:49:21 +0100 Subject: added PatchBased_Regul --- src/Python/fista_module.cpp | 123 +++++++++++++++++++++++++++++++++++++++++++- src/Python/setup.py | 1 + 2 files changed, 123 insertions(+), 1 deletion(-) diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp index d890b10..c2d9352 100644 --- a/src/Python/fista_module.cpp +++ b/src/Python/fista_module.cpp @@ -29,6 +29,7 @@ limitations under the License. #include "SplitBregman_TV_core.h" #include "FGP_TV_core.h" #include "LLT_model_core.h" +#include "PatchBased_Regul_core.h" #include "utils.h" @@ -793,7 +794,7 @@ bp::list LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, d if (switcher != 0) { Map = (unsigned short*)mxGetPr(plhs[1] = mxCreateNumericArray(3, dim_array, mxUINT16_CLASS, mxREAL)); }*/ - bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); np::dtype dtype = np::dtype::get_builtin(); @@ -865,6 +866,126 @@ bp::list LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, d } +bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, double d_h, double d_lambda) { + // the result is in the following list + bp::list result; + + int N, M, Z, numdims, SearchW, /*SimilW, SearchW_real,*/ padXY, newsizeX, newsizeY, newsizeZ, switchpad_crop; + //const int *dims; + float *A, *B = NULL, *Ap = NULL, *Bp = NULL, h, lambda; + + numdims = input.get_nd(); + int dims[3]; + + dims[0] = input.shape(0); + dims[1] = input.shape(1); + if (numdims == 2) { + dims[2] = -1; + } + else { + dims[2] = input.shape(2); + } + /*numdims = mxGetNumberOfDimensions(prhs[0]); + dims = mxGetDimensions(prhs[0]);*/ + + N = dims[0]; + M = dims[1]; + Z = dims[2]; + + //if ((numdims < 2) || (numdims > 3)) { mexErrMsgTxt("The input should be 2D image or 3D volume"); } + //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input in single precision is required"); } + + //if (nrhs != 5) mexErrMsgTxt("Five inputs reqired: Image(2D,3D), SearchW, SimilW, Threshold, Regularization parameter"); + + ///*Handling inputs*/ + //A = (float *)mxGetData(prhs[0]); /* the image to regularize/filter */ + //SearchW_real = (int)mxGetScalar(prhs[1]); /* the searching window ratio */ + //SimilW = (int)mxGetScalar(prhs[2]); /* the similarity window ratio */ + //h = (float)mxGetScalar(prhs[3]); /* parameter for the PB filtering function */ + //lambda = (float)mxGetScalar(prhs[4]); /* regularization parameter */ + + //if (h <= 0) mexErrMsgTxt("Parmeter for the PB penalty function should be > 0"); + //if (lambda <= 0) mexErrMsgTxt(" Regularization parmeter should be > 0"); + + SearchW = SearchW_real + 2 * SimilW; + + /* SearchW_full = 2*SearchW + 1; */ /* the full searching window size */ + /* SimilW_full = 2*SimilW + 1; */ /* the full similarity window size */ + + + padXY = SearchW + 2 * SimilW; /* padding sizes */ + newsizeX = N + 2 * (padXY); /* the X size of the padded array */ + newsizeY = M + 2 * (padXY); /* the Y size of the padded array */ + newsizeZ = Z + 2 * (padXY); /* the Z size of the padded array */ + int N_dims[] = { newsizeX, newsizeY, newsizeZ }; + + /******************************2D case ****************************/ + if (numdims == 2) { + ///*Handling output*/ + //B = (float*)mxGetData(plhs[0] = mxCreateNumericMatrix(N, M, mxSINGLE_CLASS, mxREAL)); + ///*allocating memory for the padded arrays */ + //Ap = (float*)mxGetData(mxCreateNumericMatrix(newsizeX, newsizeY, mxSINGLE_CLASS, mxREAL)); + //Bp = (float*)mxGetData(mxCreateNumericMatrix(newsizeX, newsizeY, mxSINGLE_CLASS, mxREAL)); + ///**************************************************************************/ + + bp::tuple shape = bp::make_tuple(N, M); + np::dtype dtype = np::dtype::get_builtin(); + + np::ndarray npB = np::zeros(shape, dtype); + + shape = bp::make_tuple(newsizeX, newsizeY); + np::ndarray npAp = np::zeros(shape, dtype); + np::ndarray npBp = np::zeros(shape, dtype); + B = reinterpret_cast(npB.get_data()); + Ap = reinterpret_cast(npAp.get_data()); + Bp = reinterpret_cast(npBp.get_data()); + + /*Perform padding of image A to the size of [newsizeX * newsizeY] */ + switchpad_crop = 0; /*padding*/ + pad_crop(A, Ap, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop); + + /* Do PB regularization with the padded array */ + PB_FUNC2D(Ap, Bp, newsizeY, newsizeX, padXY, SearchW, SimilW, (float)h, (float)lambda); + + switchpad_crop = 1; /*cropping*/ + pad_crop(Bp, B, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop); + result.append(npB); + } + else + { + /******************************3D case ****************************/ + ///*Handling output*/ + //B = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dims, mxSINGLE_CLASS, mxREAL)); + ///*allocating memory for the padded arrays */ + //Ap = (float*)mxGetPr(mxCreateNumericArray(3, N_dims, mxSINGLE_CLASS, mxREAL)); + //Bp = (float*)mxGetPr(mxCreateNumericArray(3, N_dims, mxSINGLE_CLASS, mxREAL)); + /**************************************************************************/ + bp::tuple shape = bp::make_tuple(dims[0], dims[1], dims[2]); + bp::tuple shape_AB = bp::make_tuple(N_dims[0], N_dims[1], N_dims[2]); + np::dtype dtype = np::dtype::get_builtin(); + + np::ndarray npB = np::zeros(shape, dtype); + np::ndarray npAp = np::zeros(shape_AB, dtype); + np::ndarray npBp = np::zeros(shape_AB, dtype); + B = reinterpret_cast(npB.get_data()); + Ap = reinterpret_cast(npAp.get_data()); + Bp = reinterpret_cast(npBp.get_data()); + /*Perform padding of image A to the size of [newsizeX * newsizeY * newsizeZ] */ + switchpad_crop = 0; /*padding*/ + pad_crop(A, Ap, M, N, Z, newsizeY, newsizeX, newsizeZ, padXY, switchpad_crop); + + /* Do PB regularization with the padded array */ + PB_FUNC3D(Ap, Bp, newsizeY, newsizeX, newsizeZ, padXY, SearchW, SimilW, (float)h, (float)lambda); + + switchpad_crop = 1; /*cropping*/ + pad_crop(Bp, B, M, N, Z, newsizeY, newsizeX, newsizeZ, padXY, switchpad_crop); + + result.append(npB); + } /*end else ndims*/ + + return result; +} + BOOST_PYTHON_MODULE(regularizers) { np::initialize(); diff --git a/src/Python/setup.py b/src/Python/setup.py index a8feb1c..a4eed14 100644 --- a/src/Python/setup.py +++ b/src/Python/setup.py @@ -52,6 +52,7 @@ setup( "..\\..\\main_func\\regularizers_CPU\\FGP_TV_core.c", "..\\..\\main_func\\regularizers_CPU\\SplitBregman_TV_core.c", "..\\..\\main_func\\regularizers_CPU\\LLT_model_core.c", + "..\\..\\main_func\\regularizers_CPU\\PatchBased_Regul_core.c", "..\\..\\main_func\\regularizers_CPU\\utils.c" ], include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), -- cgit v1.2.3 From c20554d5bec0168a54a6c472c0cc2f378511eccc Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Mon, 7 Aug 2017 16:46:51 +0100 Subject: typo --- main_func/regularizers_CPU/TGV_PD.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main_func/regularizers_CPU/TGV_PD.c b/main_func/regularizers_CPU/TGV_PD.c index 6593d8e..c9cb440 100644 --- a/main_func/regularizers_CPU/TGV_PD.c +++ b/main_func/regularizers_CPU/TGV_PD.c @@ -3,7 +3,7 @@ This work is part of the Core Imaging Library developed by Visual Analytics and Imaging System Group of the Science Technology Facilities Council, STFC -Copyright 2017 Daniil Kazanteev +Copyright 2017 Daniil Kazantsev Copyright 2017 Srikanth Nagella, Edoardo Pasca Licensed under the Apache License, Version 2.0 (the "License"); -- cgit v1.2.3 From 8094acfff8659e5c959066495936a8d5aff56fd5 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Mon, 7 Aug 2017 16:47:28 +0100 Subject: removed copyIm --- main_func/regularizers_CPU/TGV_PD_core.c | 8 -------- 1 file changed, 8 deletions(-) diff --git a/main_func/regularizers_CPU/TGV_PD_core.c b/main_func/regularizers_CPU/TGV_PD_core.c index 1164b73..4139d10 100644 --- a/main_func/regularizers_CPU/TGV_PD_core.c +++ b/main_func/regularizers_CPU/TGV_PD_core.c @@ -186,14 +186,6 @@ float UpdV_2D(float *V1, float *V2, float *P1, float *P2, float *Q1, float *Q2, }} return 1; } -/* Copy Image */ -float copyIm(float *A, float *U, int dimX, int dimY, int dimZ) -{ - int j; -#pragma omp parallel for shared(A, U) private(j) - for(j=0; j Date: Mon, 7 Aug 2017 17:21:12 +0100 Subject: added TGV_PD, removed useless code --- src/Python/fista_module.cpp | 245 ++++++++++++++++++++++++++------------------ 1 file changed, 146 insertions(+), 99 deletions(-) diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp index c2d9352..eacda3d 100644 --- a/src/Python/fista_module.cpp +++ b/src/Python/fista_module.cpp @@ -30,6 +30,7 @@ limitations under the License. #include "FGP_TV_core.h" #include "LLT_model_core.h" #include "PatchBased_Regul_core.h" +#include "TGV_PD_core.h" #include "utils.h" @@ -103,101 +104,8 @@ If unsuccessful in a MEX file, the MEX file terminates and returns control to th enough free heap space to create the mxArray. */ -void mexErrMessageText(char* text) { - std::cerr << text << std::endl; -} - -/* -double mxGetScalar(const mxArray *pm); -args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. -Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray. In C, mxGetScalar returns a double. -*/ - -template -double mxGetScalar(const np::ndarray plh) { - return (double)bp::extract(plh[0]); -} - - - -template -T * mxGetData(const np::ndarray pm) { - //args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. - //Returns: Pointer to the value of the first real(nonimaginary) element of the mxArray.In C, mxGetScalar returns a double. - /*Access the numpy array pointer: - char * get_data() const; - Returns: Array’s raw data pointer as a char - Note: This returns char so stride math works properly on it.User will have to reinterpret_cast it. - probably this would work. - A = reinterpret_cast(prhs[0]); - */ - return reinterpret_cast(prhs[0]); -} - -template -np::ndarray zeros(int dims, int * dim_array, T el) { - bp::tuple shape; - if (dims == 3) - shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); - else if (dims == 2) - shape = bp::make_tuple(dim_array[0], dim_array[1]); - np::dtype dtype = np::dtype::get_builtin(); - np::ndarray zz = np::zeros(shape, dtype); - return zz; -} - - -bp::list mexFunction(np::ndarray input) { - int number_of_dims = input.get_nd(); - int dim_array[3]; - - dim_array[0] = input.shape(0); - dim_array[1] = input.shape(1); - if (number_of_dims == 2) { - dim_array[2] = -1; - } - else { - dim_array[2] = input.shape(2); - } - - /**************************************************************************/ - np::ndarray zz = zeros(3, dim_array, (int)0); - np::ndarray fzz = zeros(3, dim_array, (float)0); - /**************************************************************************/ - - int * A = reinterpret_cast(input.get_data()); - int * B = reinterpret_cast(zz.get_data()); - float * C = reinterpret_cast(fzz.get_data()); - - //Copy data and cast - for (int i = 0; i < dim_array[0]; i++) { - for (int j = 0; j < dim_array[1]; j++) { - for (int k = 0; k < dim_array[2]; k++) { - int index = k + dim_array[2] * j + dim_array[2] * dim_array[1] * i; - int val = (*(A + index)); - float fval = (float)val; - std::memcpy(B + index, &val, sizeof(int)); - std::memcpy(C + index, &fval, sizeof(float)); - } - } - } - - bp::list result; - - result.append(number_of_dims); - result.append(dim_array[0]); - result.append(dim_array[1]); - result.append(dim_array[2]); - result.append(zz); - result.append(fzz); - - //result.append(tup); - return result; - -} - bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int methTV) { // the result is in the following list @@ -487,7 +395,7 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me np::ndarray npP1_old = np::zeros(shape, dtype); np::ndarray npP2_old = np::zeros(shape, dtype); np::ndarray npR1 = np::zeros(shape, dtype); - np::ndarray npR2 = zeros(2, dim_array, (float)0); + np::ndarray npR2 = np::zeros(shape, dtype); D = reinterpret_cast(npD.get_data()); D_old = reinterpret_cast(npD_old.get_data()); @@ -866,7 +774,7 @@ bp::list LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, d } -bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, double d_h, double d_lambda) { +bp::list PatchBased_Regul(np::ndarray input, double d_lambda, int SearchW_real, int SimilW, double d_h) { // the result is in the following list bp::list result; @@ -899,6 +807,7 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, doub ///*Handling inputs*/ //A = (float *)mxGetData(prhs[0]); /* the image to regularize/filter */ + A = reinterpret_cast(input.get_data()); //SearchW_real = (int)mxGetScalar(prhs[1]); /* the searching window ratio */ //SimilW = (int)mxGetScalar(prhs[2]); /* the similarity window ratio */ //h = (float)mxGetScalar(prhs[3]); /* parameter for the PB filtering function */ @@ -907,6 +816,8 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, doub //if (h <= 0) mexErrMsgTxt("Parmeter for the PB penalty function should be > 0"); //if (lambda <= 0) mexErrMsgTxt(" Regularization parmeter should be > 0"); + lambda = (float)d_lambda; + h = (float)d_h; SearchW = SearchW_real + 2 * SimilW; /* SearchW_full = 2*SearchW + 1; */ /* the full searching window size */ @@ -918,7 +829,6 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, doub newsizeY = M + 2 * (padXY); /* the Y size of the padded array */ newsizeZ = Z + 2 * (padXY); /* the Z size of the padded array */ int N_dims[] = { newsizeX, newsizeY, newsizeZ }; - /******************************2D case ****************************/ if (numdims == 2) { ///*Handling output*/ @@ -943,12 +853,13 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, doub /*Perform padding of image A to the size of [newsizeX * newsizeY] */ switchpad_crop = 0; /*padding*/ pad_crop(A, Ap, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop); - + /* Do PB regularization with the padded array */ PB_FUNC2D(Ap, Bp, newsizeY, newsizeX, padXY, SearchW, SimilW, (float)h, (float)lambda); - + switchpad_crop = 1; /*cropping*/ pad_crop(Bp, B, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop); + result.append(npB); } else @@ -983,6 +894,141 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, doub result.append(npB); } /*end else ndims*/ + return result; +} + +bp::list TGV_PD(np::ndarray input, double d_lambda, double d_alpha1, double d_alpha0, int iter) { + // the result is in the following list + bp::list result; + int number_of_dims, /*iter,*/ dimX, dimY, dimZ, ll; + //const int *dim_array; + float *A, *U, *U_old, *P1, *P2, *P3, *Q1, *Q2, *Q3, *Q4, *Q5, *Q6, *Q7, *Q8, *Q9, *V1, *V1_old, *V2, *V2_old, *V3, *V3_old, lambda, L2, tau, sigma, alpha1, alpha0; + + //number_of_dims = mxGetNumberOfDimensions(prhs[0]); + //dim_array = mxGetDimensions(prhs[0]); + number_of_dims = input.get_nd(); + int dim_array[3]; + + dim_array[0] = input.shape(0); + dim_array[1] = input.shape(1); + if (number_of_dims == 2) { + dim_array[2] = -1; + } + else { + dim_array[2] = input.shape(2); + } + /*Handling Matlab input data*/ + //A = (float *)mxGetData(prhs[0]); /*origanal noise image/volume*/ + //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input in single precision is required"); } + + A = reinterpret_cast(input.get_data()); + + //lambda = (float)mxGetScalar(prhs[1]); /*regularization parameter*/ + //alpha1 = (float)mxGetScalar(prhs[2]); /*first-order term*/ + //alpha0 = (float)mxGetScalar(prhs[3]); /*second-order term*/ + //iter = (int)mxGetScalar(prhs[4]); /*iterations number*/ + //if (nrhs != 5) mexErrMsgTxt("Five input parameters is reqired: Image(2D/3D), Regularization parameter, alpha1, alpha0, Iterations"); + lambda = (float)d_lambda; + alpha1 = (float)d_alpha1; + alpha0 = (float)d_alpha0; + + /*Handling Matlab output data*/ + dimX = dim_array[0]; dimY = dim_array[1]; + + if (number_of_dims == 2) { + /*2D case*/ + dimZ = 1; + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin(); + + np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npP1 = np::zeros(shape, dtype); + np::ndarray npP2 = np::zeros(shape, dtype); + np::ndarray npQ1 = np::zeros(shape, dtype); + np::ndarray npQ2 = np::zeros(shape, dtype); + np::ndarray npQ3 = np::zeros(shape, dtype); + np::ndarray npV1 = np::zeros(shape, dtype); + np::ndarray npV1_old = np::zeros(shape, dtype); + np::ndarray npV2 = np::zeros(shape, dtype); + np::ndarray npV2_old = np::zeros(shape, dtype); + np::ndarray npU_old = np::zeros(shape, dtype); + + U = reinterpret_cast(npU.get_data()); + U_old = reinterpret_cast(npU_old.get_data()); + P1 = reinterpret_cast(npP1.get_data()); + P2 = reinterpret_cast(npP2.get_data()); + Q1 = reinterpret_cast(npQ1.get_data()); + Q2 = reinterpret_cast(npQ2.get_data()); + Q3 = reinterpret_cast(npQ3.get_data()); + V1 = reinterpret_cast(npV1.get_data()); + V1_old = reinterpret_cast(npV1_old.get_data()); + V2 = reinterpret_cast(npV2.get_data()); + V2_old = reinterpret_cast(npV2_old.get_data()); + //U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + + /*dual variables*/ + /*P1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + P2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + + Q1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Q2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Q3 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + + U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + + V1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + V1_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + V2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + V2_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));*/ + /*printf("%i \n", i);*/ + L2 = 12.0; /*Lipshitz constant*/ + tau = 1.0 / pow(L2, 0.5); + sigma = 1.0 / pow(L2, 0.5); + + /*Copy A to U*/ + copyIm(A, U, dimX, dimY, dimZ); + /* Here primal-dual iterations begin for 2D */ + for (ll = 0; ll < iter; ll++) { + + /* Calculate Dual Variable P */ + DualP_2D(U, V1, V2, P1, P2, dimX, dimY, dimZ, sigma); + + /*Projection onto convex set for P*/ + ProjP_2D(P1, P2, dimX, dimY, dimZ, alpha1); + + /* Calculate Dual Variable Q */ + DualQ_2D(V1, V2, Q1, Q2, Q3, dimX, dimY, dimZ, sigma); + + /*Projection onto convex set for Q*/ + ProjQ_2D(Q1, Q2, Q3, dimX, dimY, dimZ, alpha0); + + /*saving U into U_old*/ + copyIm(U, U_old, dimX, dimY, dimZ); + + /*adjoint operation -> divergence and projection of P*/ + DivProjP_2D(U, A, P1, P2, dimX, dimY, dimZ, lambda, tau); + + /*get updated solution U*/ + newU(U, U_old, dimX, dimY, dimZ); + + /*saving V into V_old*/ + copyIm(V1, V1_old, dimX, dimY, dimZ); + copyIm(V2, V2_old, dimX, dimY, dimZ); + + /* upd V*/ + UpdV_2D(V1, V2, P1, P2, Q1, Q2, Q3, dimX, dimY, dimZ, tau); + + /*get new V*/ + newU(V1, V1_old, dimX, dimY, dimZ); + newU(V2, V2_old, dimX, dimY, dimZ); + } /*end of iterations*/ + + result.append(npU); + } + + + + return result; } @@ -997,8 +1043,9 @@ BOOST_PYTHON_MODULE(regularizers) np::dtype dt1 = np::dtype::get_builtin(); np::dtype dt2 = np::dtype::get_builtin(); - def("mexFunction", mexFunction); def("SplitBregman_TV", SplitBregman_TV); def("FGP_TV", FGP_TV); def("LLT_model", LLT_model); + def("PatchBased_Regul", PatchBased_Regul); + def("TGV_PD", TGV_PD); } \ No newline at end of file -- cgit v1.2.3 From 4534a11d1c32a65484f4f38348c27a7bb2d9ad19 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Mon, 7 Aug 2017 17:21:54 +0100 Subject: added TGV_PD --- src/Python/setup.py | 1 + src/Python/test_regularizers.py | 195 ++++++++++++++++++++++++++++++++++------ 2 files changed, 168 insertions(+), 28 deletions(-) diff --git a/src/Python/setup.py b/src/Python/setup.py index a4eed14..0468722 100644 --- a/src/Python/setup.py +++ b/src/Python/setup.py @@ -53,6 +53,7 @@ setup( "..\\..\\main_func\\regularizers_CPU\\SplitBregman_TV_core.c", "..\\..\\main_func\\regularizers_CPU\\LLT_model_core.c", "..\\..\\main_func\\regularizers_CPU\\PatchBased_Regul_core.c", + "..\\..\\main_func\\regularizers_CPU\\TGV_PD_core.c", "..\\..\\main_func\\regularizers_CPU\\utils.c" ], include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py index 6abfba4..6a34749 100644 --- a/src/Python/test_regularizers.py +++ b/src/Python/test_regularizers.py @@ -47,6 +47,8 @@ class Regularizer(): SplitBregman_TV = regularizers.SplitBregman_TV FGP_TV = regularizers.FGP_TV LLT_model = regularizers.LLT_model + PatchBased_Regul = regularizers.PatchBased_Regul + TGV_PD = regularizers.TGV_PD # Algorithm class TotalVariationPenalty(Enum): @@ -55,13 +57,17 @@ class Regularizer(): # TotalVariationPenalty def __init__(self , algorithm): - + self.setAlgorithm ( algorithm ) + # __init__ + + def setAlgorithm(self, algorithm): self.algorithm = algorithm self.pars = self.parsForAlgorithm(algorithm) - # __init__ + # setAlgorithm def parsForAlgorithm(self, algorithm): pars = dict() + if algorithm == Regularizer.Algorithm.SplitBregman_TV : pars['algorithm'] = algorithm pars['input'] = None @@ -69,6 +75,7 @@ class Regularizer(): pars['number_of_iterations'] = 35 pars['tolerance_constant'] = 0.0001 pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + elif algorithm == Regularizer.Algorithm.FGP_TV : pars['algorithm'] = algorithm pars['input'] = None @@ -76,6 +83,7 @@ class Regularizer(): pars['number_of_iterations'] = 50 pars['tolerance_constant'] = 0.001 pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + elif algorithm == Regularizer.Algorithm.LLT_model: pars['algorithm'] = algorithm pars['input'] = None @@ -85,6 +93,24 @@ class Regularizer(): pars['tolerance_constant'] = None pars['restrictive_Z_smoothing'] = 0 + elif algorithm == Regularizer.Algorithm.PatchBased_Regul: + pars['algorithm'] = algorithm + pars['input'] = None + pars['searching_window_ratio'] = None + pars['similarity_window_ratio'] = None + pars['PB_filtering_parameter'] = None + pars['regularization_parameter'] = None + + elif algorithm == Regularizer.Algorithm.TGV_PD: + pars['algorithm'] = algorithm + pars['input'] = None + pars['first_order_term'] = None + pars['second_order_term'] = None + pars['number_of_iterations'] = None + pars['regularization_parameter'] = None + + + return pars # parsForAlgorithm @@ -98,6 +124,8 @@ class Regularizer(): self.pars['regularization_parameter'] = regularization_parameter #for key, value in self.pars.items(): # print("{0} = {1}".format(key, value)) + if None in self.pars: + raise Exception("Not all parameters have been provided") if self.algorithm == Regularizer.Algorithm.SplitBregman_TV : return self.algorithm(input, regularization_parameter, @@ -112,15 +140,27 @@ class Regularizer(): elif self.algorithm == Regularizer.Algorithm.LLT_model : #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) # no default - if None in self.pars: - raise Exception("Not all parameters have been provided") - else: - return self.algorithm(input, - regularization_parameter, - self.pars['time_step'] , - self.pars['number_of_iterations'], - self.pars['tolerance_constant'], - self.pars['restrictive_Z_smoothing'] ) + return self.algorithm(input, + regularization_parameter, + self.pars['time_step'] , + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['restrictive_Z_smoothing'] ) + elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + return self.algorithm(input, regularization_parameter, + self.pars['searching_window_ratio'] , + self.pars['similarity_window_ratio'] , + self.pars['PB_filtering_parameter']) + elif self.algorithm == Regularizer.Algorithm.TGV_PD : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + return self.algorithm(input, regularization_parameter, + self.pars['first_order_term'] , + self.pars['second_order_term'] , + self.pars['number_of_iterations']) + # __call__ @@ -142,13 +182,40 @@ class Regularizer(): @staticmethod def LLT_model(input, regularization_parameter , time_step, number_of_iterations, tolerance_constant, restrictive_Z_smoothing=0): - reg = Regularizer(Regularizer.Algorithm.FGP_TV) + reg = Regularizer(Regularizer.Algorithm.LLT_model) out = list( reg(input, regularization_parameter, time_step=time_step, number_of_iterations=number_of_iterations, tolerance_constant=tolerance_constant, restrictive_Z_smoothing=restrictive_Z_smoothing) ) out.append(reg.pars) return out + + @staticmethod + def PatchBased_Regul(input, regularization_parameter, + searching_window_ratio, + similarity_window_ratio, + PB_filtering_parameter): + reg = Regularizer(Regularizer.Algorithm.PatchBased_Regul) + out = list( reg(input, + regularization_parameter, + searching_window_ratio=searching_window_ratio, + similarity_window_ratio=similarity_window_ratio, + PB_filtering_parameter=PB_filtering_parameter ) + ) + out.append(reg.pars) + return out + + @staticmethod + def TGV_PD(input, regularization_parameter , first_order_term, + second_order_term, number_of_iterations): + + reg = Regularizer(Regularizer.Algorithm.TGV_PD) + out = list( reg(input, regularization_parameter, + first_order_term=first_order_term, + second_order_term=second_order_term, + number_of_iterations=number_of_iterations) ) + out.append(reg.pars) + return out #Example: @@ -171,17 +238,17 @@ u0 = Im + (perc* np.random.normal(size=np.shape(Im))) f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1) u0 = f(u0).astype('float32') -# plot +## plot fig = plt.figure() -a=fig.add_subplot(2,3,1) -a.set_title('Original') -imgplot = plt.imshow(Im) +#a=fig.add_subplot(3,3,1) +#a.set_title('Original') +#imgplot = plt.imshow(Im) -a=fig.add_subplot(2,3,2) +a=fig.add_subplot(2,3,1) a.set_title('noise') imgplot = plt.imshow(u0) - +reg_output = [] ############################################################################## # Call regularizer @@ -199,8 +266,9 @@ out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., numbe TV_Penalty=Regularizer.TotalVariationPenalty.l1) out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) pars = out2[2] +reg_output.append(out2) -a=fig.add_subplot(2,3,3) +a=fig.add_subplot(2,3,2) a.set_title('SplitBregman_TV') textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s' textstr = textstr % (pars['regularization_parameter'], @@ -213,7 +281,7 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) -imgplot = plt.imshow(out2[0]) +imgplot = plt.imshow(reg_output[-1][0]) ###################### FGP_TV ######################################### # u = FGP_TV(single(u0), 0.05, 100, 1e-04); @@ -221,7 +289,9 @@ out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.05, number_of_iterations=10) pars = out2[-1] -a=fig.add_subplot(2,3,4) +reg_output.append(out2) + +a=fig.add_subplot(2,3,3) a.set_title('FGP_TV') textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s' textstr = textstr % (pars['regularization_parameter'], @@ -234,18 +304,23 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) -imgplot = plt.imshow(out2[0]) +imgplot = plt.imshow(reg_output[-1][0]) ###################### LLT_model ######################################### # * u0 = Im + .03*randn(size(Im)); % adding noise # [Den] = LLT_model(single(u0), 10, 0.1, 1); -out2 = Regularizer.LLT_model(input=u0, regularization_parameter=10., - time_step=0.1, - tolerance_constant=1e-4, - number_of_iterations=10) +#Den = LLT_model(single(u0), 25, 0.0003, 300, 0.0001, 0); +#input, regularization_parameter , time_step, number_of_iterations, +# tolerance_constant, restrictive_Z_smoothing=0 +out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25, + time_step=0.0003, + tolerance_constant=0.0001, + number_of_iterations=300) pars = out2[-1] -a=fig.add_subplot(2,3,5) +reg_output.append(out2) + +a=fig.add_subplot(2,3,4) a.set_title('LLT_model') textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\ntime-step=%f' textstr = textstr % (pars['regularization_parameter'], @@ -259,7 +334,71 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) -imgplot = plt.imshow(out2[0]) +imgplot = plt.imshow(reg_output[-1][0]) + +###################### PatchBased_Regul ######################################### +# Quick 2D denoising example in Matlab: +# Im = double(imread('lena_gray_256.tif'))/255; % loading image +# u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +# ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); + +out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, + searching_window_ratio=3, + similarity_window_ratio=1, + PB_filtering_parameter=0.08) +pars = out2[-1] +reg_output.append(out2) + +a=fig.add_subplot(2,3,5) +a.set_title('PatchBased_Regul') +textstr = 'regularization_parameter=%.2f\nsearching_window_ratio=%d\nsimilarity_window_ratio=%.2e\nPB_filtering_parameter=%f' +textstr = textstr % (pars['regularization_parameter'], + pars['searching_window_ratio'], + pars['similarity_window_ratio'], + pars['PB_filtering_parameter']) + + + + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0]) + + +###################### TGV_PD ######################################### +# Quick 2D denoising example in Matlab: +# Im = double(imread('lena_gray_256.tif'))/255; % loading image +# u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +# u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); + + +out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, + first_order_term=1.3, + second_order_term=1, + number_of_iterations=550) +pars = out2[-1] +reg_output.append(out2) + +a=fig.add_subplot(2,3,6) +a.set_title('TGV_PD') +textstr = 'regularization_parameter=%.2f\nfirst_order_term=%.2f\nsecond_order_term=%.2f\nnumber_of_iterations=%d' +textstr = textstr % (pars['regularization_parameter'], + pars['first_order_term'], + pars['second_order_term'], + pars['number_of_iterations']) + + + + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0]) -- cgit v1.2.3 From 3fffd568589137b17d1fbe44e55a757e3745a3b1 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 11 Oct 2017 15:42:05 +0100 Subject: added simple_astra_test.py --- src/Python/test/simple_astra_test.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 src/Python/test/simple_astra_test.py diff --git a/src/Python/test/simple_astra_test.py b/src/Python/test/simple_astra_test.py new file mode 100644 index 0000000..905eeea --- /dev/null +++ b/src/Python/test/simple_astra_test.py @@ -0,0 +1,25 @@ +import astra +import numpy + +detectorSpacingX = 1.0 +detectorSpacingY = 1.0 +det_row_count = 128 +det_col_count = 128 + +angles_rad = numpy.asarray([i for i in range(360)], dtype=float) / 180. * numpy.pi + +proj_geom = astra.creators.create_proj_geom('parallel3d', + detectorSpacingX, + detectorSpacingY, + det_row_count, + det_col_count, + angles_rad) + +image_size_x = 64 +image_size_y = 64 +image_size_z = 32 + +vol_geom = astra.creators.create_vol_geom(image_size_x,image_size_y,image_size_z) + +x1 = numpy.random.rand(image_size_z,image_size_y,image_size_x) +sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom) -- cgit v1.2.3 From 0611d34c31fa1e706c3bcd7e17651f7555469e00 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 17 Aug 2017 16:33:09 +0100 Subject: initial revision --- src/Python/test/simple_astra_test.py | 25 ------------------------- 1 file changed, 25 deletions(-) delete mode 100644 src/Python/test/simple_astra_test.py diff --git a/src/Python/test/simple_astra_test.py b/src/Python/test/simple_astra_test.py deleted file mode 100644 index 905eeea..0000000 --- a/src/Python/test/simple_astra_test.py +++ /dev/null @@ -1,25 +0,0 @@ -import astra -import numpy - -detectorSpacingX = 1.0 -detectorSpacingY = 1.0 -det_row_count = 128 -det_col_count = 128 - -angles_rad = numpy.asarray([i for i in range(360)], dtype=float) / 180. * numpy.pi - -proj_geom = astra.creators.create_proj_geom('parallel3d', - detectorSpacingX, - detectorSpacingY, - det_row_count, - det_col_count, - angles_rad) - -image_size_x = 64 -image_size_y = 64 -image_size_z = 32 - -vol_geom = astra.creators.create_vol_geom(image_size_x,image_size_y,image_size_z) - -x1 = numpy.random.rand(image_size_z,image_size_y,image_size_x) -sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom) -- cgit v1.2.3 From bc29e0690d856ad9dd147b435d34c5761556a1e5 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 12:55:19 +0100 Subject: Regularizer.pyfirst commit --- src/Python/Regularizer.py | 322 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 322 insertions(+) create mode 100644 src/Python/Regularizer.py diff --git a/src/Python/Regularizer.py b/src/Python/Regularizer.py new file mode 100644 index 0000000..15dbbb4 --- /dev/null +++ b/src/Python/Regularizer.py @@ -0,0 +1,322 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Aug 8 14:26:00 2017 + +@author: ofn77899 +""" + +import regularizers +import numpy as np +from enum import Enum +import timeit + +class Regularizer(): + '''Class to handle regularizer algorithms to be used during reconstruction + + Currently 5 CPU (OMP) regularization algorithms are available: + + 1) SplitBregman_TV + 2) FGP_TV + 3) LLT_model + 4) PatchBased_Regul + 5) TGV_PD + + Usage: + the regularizer can be invoked as object or as static method + Depending on the actual regularizer the input parameter may vary, and + a different default setting is defined. + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + + out = reg(input=u0, regularization_parameter=10., number_of_iterations=30, + tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., + number_of_iterations=30, tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + + A number of optional parameters can be passed or skipped + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) + + ''' + class Algorithm(Enum): + SplitBregman_TV = regularizers.SplitBregman_TV + FGP_TV = regularizers.FGP_TV + LLT_model = regularizers.LLT_model + PatchBased_Regul = regularizers.PatchBased_Regul + TGV_PD = regularizers.TGV_PD + # Algorithm + + class TotalVariationPenalty(Enum): + isotropic = 0 + l1 = 1 + # TotalVariationPenalty + + def __init__(self , algorithm, debug = True): + self.setAlgorithm ( algorithm ) + self.debug = debug + # __init__ + + def setAlgorithm(self, algorithm): + self.algorithm = algorithm + self.pars = self.getDefaultParsForAlgorithm(algorithm) + # setAlgorithm + + def getDefaultParsForAlgorithm(self, algorithm): + pars = dict() + + if algorithm == Regularizer.Algorithm.SplitBregman_TV : + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['number_of_iterations'] = 35 + pars['tolerance_constant'] = 0.0001 + pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + + elif algorithm == Regularizer.Algorithm.FGP_TV : + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['number_of_iterations'] = 50 + pars['tolerance_constant'] = 0.001 + pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + + elif algorithm == Regularizer.Algorithm.LLT_model: + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['time_step'] = None + pars['number_of_iterations'] = None + pars['tolerance_constant'] = None + pars['restrictive_Z_smoothing'] = 0 + + elif algorithm == Regularizer.Algorithm.PatchBased_Regul: + pars['algorithm'] = algorithm + pars['input'] = None + pars['searching_window_ratio'] = None + pars['similarity_window_ratio'] = None + pars['PB_filtering_parameter'] = None + pars['regularization_parameter'] = None + + elif algorithm == Regularizer.Algorithm.TGV_PD: + pars['algorithm'] = algorithm + pars['input'] = None + pars['first_order_term'] = None + pars['second_order_term'] = None + pars['number_of_iterations'] = None + pars['regularization_parameter'] = None + + else: + raise Exception('Unknown regularizer algorithm') + + return pars + # parsForAlgorithm + + def setParameter(self, **kwargs): + '''set named parameter for the regularization engine + + raises Exception if the named parameter is not recognized + Typical usage is: + + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + reg.setParameter(input=u0) + reg.setParameter(regularization_parameter=10.) + + it can be also used as + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + reg.setParameter(input=u0 , regularization_parameter=10.) + ''' + + for key , value in kwargs.items(): + if key in self.pars.keys(): + self.pars[key] = value + else: + raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) + # setParameter + + def getParameter(self, **kwargs): + ret = {} + for key , value in kwargs.items(): + if key in self.pars.keys(): + ret[key] = self.pars[key] + else: + raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) + # setParameter + + + def __call__(self, input = None, regularization_parameter = None, **kwargs): + '''Actual call for the regularizer. + + One can either set the regularization parameters first and then call the + algorithm or set the regularization parameter during the call (as + is done in the static methods). + ''' + + if kwargs is not None: + for key, value in kwargs.items(): + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + if input is not None: + self.pars['input'] = input + if regularization_parameter is not None: + self.pars['regularization_parameter'] = regularization_parameter + + if self.debug: + print ("--------------------------------------------------") + for key, value in self.pars.items(): + if key== 'algorithm' : + print("{0} = {1}".format(key, value.__name__)) + elif key == 'input': + print("{0} = {1}".format(key, np.shape(value))) + else: + print("{0} = {1}".format(key, value)) + + + if None in self.pars: + raise Exception("Not all parameters have been provided") + + input = self.pars['input'] + regularization_parameter = self.pars['regularization_parameter'] + if self.algorithm == Regularizer.Algorithm.SplitBregman_TV : + return self.algorithm(input, regularization_parameter, + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['TV_penalty'].value ) + elif self.algorithm == Regularizer.Algorithm.FGP_TV : + return self.algorithm(input, regularization_parameter, + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['TV_penalty'].value ) + elif self.algorithm == Regularizer.Algorithm.LLT_model : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + return self.algorithm(input, + regularization_parameter, + self.pars['time_step'] , + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['restrictive_Z_smoothing'] ) + elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + return self.algorithm(input, regularization_parameter, + self.pars['searching_window_ratio'] , + self.pars['similarity_window_ratio'] , + self.pars['PB_filtering_parameter']) + elif self.algorithm == Regularizer.Algorithm.TGV_PD : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + if len(np.shape(input)) == 2: + return self.algorithm(input, regularization_parameter, + self.pars['first_order_term'] , + self.pars['second_order_term'] , + self.pars['number_of_iterations']) + elif len(np.shape(input)) == 3: + #assuming it's 3D + # run independent calls on each slice + out3d = input.copy() + for i in range(np.shape(input)[2]): + out = self.algorithm(input, regularization_parameter, + self.pars['first_order_term'] , + self.pars['second_order_term'] , + self.pars['number_of_iterations']) + # copy the result in the 3D image + out3d.T[i] = out[0].copy() + # append the rest of the info that the algorithm returns + output = [out3d] + for i in range(1,len(out)): + output.append(out[i]) + return output + + + + + + # __call__ + + @staticmethod + def SplitBregman_TV(input, regularization_parameter , **kwargs): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + out = list( reg(input, regularization_parameter, **kwargs) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def FGP_TV(input, regularization_parameter , **kwargs): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.FGP_TV) + out = list( reg(input, regularization_parameter, **kwargs) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def LLT_model(input, regularization_parameter , time_step, number_of_iterations, + tolerance_constant, restrictive_Z_smoothing=0): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.LLT_model) + out = list( reg(input, regularization_parameter, time_step=time_step, + number_of_iterations=number_of_iterations, + tolerance_constant=tolerance_constant, + restrictive_Z_smoothing=restrictive_Z_smoothing) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def PatchBased_Regul(input, regularization_parameter, + searching_window_ratio, + similarity_window_ratio, + PB_filtering_parameter): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.PatchBased_Regul) + out = list( reg(input, + regularization_parameter, + searching_window_ratio=searching_window_ratio, + similarity_window_ratio=similarity_window_ratio, + PB_filtering_parameter=PB_filtering_parameter ) + ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def TGV_PD(input, regularization_parameter , first_order_term, + second_order_term, number_of_iterations): + start_time = timeit.default_timer() + + reg = Regularizer(Regularizer.Algorithm.TGV_PD) + out = list( reg(input, regularization_parameter, + first_order_term=first_order_term, + second_order_term=second_order_term, + number_of_iterations=number_of_iterations) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + + return out + + def printParametersToString(self): + txt = r'' + for key, value in self.pars.items(): + if key== 'algorithm' : + txt += "{0} = {1}".format(key, value.__name__) + elif key == 'input': + txt += "{0} = {1}".format(key, np.shape(value)) + else: + txt += "{0} = {1}".format(key, value) + txt += '\n' + return txt + -- cgit v1.2.3 From 48a4d5315b4b6ca62eaa931912b6a02993979688 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 12:56:09 +0100 Subject: Test module for Boost Python currently can pass a function to the C++ layer to be evaluated. --- src/Python/Matlab2Python_utils.cpp | 68 +++++++++++++++++++++++++++++++++++++- src/Python/setup_test.py | 6 ++-- src/Python/test.py | 34 ++++++++++++++++--- 3 files changed, 99 insertions(+), 9 deletions(-) diff --git a/src/Python/Matlab2Python_utils.cpp b/src/Python/Matlab2Python_utils.cpp index 6aaad90..e15d738 100644 --- a/src/Python/Matlab2Python_utils.cpp +++ b/src/Python/Matlab2Python_utils.cpp @@ -175,6 +175,71 @@ bp::list mexFunction( np::ndarray input ) { } + bp::list result; + + result.append(number_of_dims); + result.append(dim_array[0]); + result.append(dim_array[1]); + result.append(dim_array[2]); + result.append(zz); + result.append(fzz); + + //result.append(tup); + return result; + +} +bp::list doSomething(np::ndarray input, PyObject *pyobj , PyObject *pyobj2) { + + boost::python::object output(boost::python::handle<>(boost::python::borrowed(pyobj))); + int isOutput = !(output == boost::python::api::object()); + + boost::python::object calculate(boost::python::handle<>(boost::python::borrowed(pyobj2))); + int isCalculate = !(calculate == boost::python::api::object()); + + int number_of_dims = input.get_nd(); + int dim_array[3]; + + dim_array[0] = input.shape(0); + dim_array[1] = input.shape(1); + if (number_of_dims == 2) { + dim_array[2] = -1; + } + else { + dim_array[2] = input.shape(2); + } + + /**************************************************************************/ + np::ndarray zz = zeros(3, dim_array, (int)0); + np::ndarray fzz = zeros(3, dim_array, (float)0); + /**************************************************************************/ + + int * A = reinterpret_cast(input.get_data()); + int * B = reinterpret_cast(zz.get_data()); + float * C = reinterpret_cast(fzz.get_data()); + + //Copy data and cast + for (int i = 0; i < dim_array[0]; i++) { + for (int j = 0; j < dim_array[1]; j++) { + for (int k = 0; k < dim_array[2]; k++) { + int index = k + dim_array[2] * j + dim_array[2] * dim_array[1] * i; + int val = (*(A + index)); + float fval = sqrt((float)val); + std::memcpy(B + index, &val, sizeof(int)); + std::memcpy(C + index, &fval, sizeof(float)); + // if the PyObj is not None evaluate the function + if (isOutput) + output(fval); + if (isCalculate) { + float nfval = (float)bp::extract(calculate(val)); + if (isOutput) + output(nfval); + std::memcpy(C + index, &nfval, sizeof(float)); + } + } + } + } + + bp::list result; result.append(number_of_dims); @@ -196,7 +261,7 @@ BOOST_PYTHON_MODULE(prova) //To specify that this module is a package bp::object package = bp::scope(); - package.attr("__path__") = "fista"; + package.attr("__path__") = "prova"; np::dtype dt1 = np::dtype::get_builtin(); np::dtype dt2 = np::dtype::get_builtin(); @@ -207,4 +272,5 @@ BOOST_PYTHON_MODULE(prova) //numpy_boost_python_register_type(); //numpy_boost_python_register_type(); def("mexFunction", mexFunction); + def("doSomething", doSomething); } \ No newline at end of file diff --git a/src/Python/setup_test.py b/src/Python/setup_test.py index ffb9c02..7c86175 100644 --- a/src/Python/setup_test.py +++ b/src/Python/setup_test.py @@ -30,13 +30,13 @@ extra_compile_args = ['-fopenmp','-O2', '-funsigned-char', '-Wall', '-std=c++0x' extra_libraries = [] if platform.system() == 'Windows': extra_compile_args[0:] = ['/DWIN32','/EHsc','/DBOOST_ALL_NO_LIB'] - extra_include_dirs += ["..\\ContourTree\\", "..\\win32\\" , "..\\Core\\","."] + #extra_include_dirs += ["..\\ContourTree\\", "..\\win32\\" , "..\\Core\\","."] if sys.version_info.major == 3 : extra_libraries += ['boost_python3-vc140-mt-1_64', 'boost_numpy3-vc140-mt-1_64'] else: extra_libraries += ['boost_python-vc90-mt-1_64', 'boost_numpy-vc90-mt-1_64'] else: - extra_include_dirs += ["../ContourTree/", "../Core/","."] + #extra_include_dirs += ["../ContourTree/", "../Core/","."] if sys.version_info.major == 3: extra_libraries += ['boost_python3', 'boost_numpy3','gomp'] else: @@ -47,7 +47,7 @@ setup( description='CCPi Core Imaging Library - FISTA Reconstruction Module', version=cil_version, cmdclass = {'build_ext': build_ext}, - ext_modules = [Extension("fista", + ext_modules = [Extension("prova", sources=[ "Matlab2Python_utils.cpp", ], include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), diff --git a/src/Python/test.py b/src/Python/test.py index e283f89..db47380 100644 --- a/src/Python/test.py +++ b/src/Python/test.py @@ -5,14 +5,38 @@ Created on Thu Aug 3 14:08:09 2017 @author: ofn77899 """ -import fista +import prova import numpy as np -a = np.asarray([i for i in range(3*4*5)]) -a = a.reshape([3,4,5]) +a = np.asarray([i for i in range(1*2*3)]) +a = a.reshape([1,2,3]) print (a) -b = fista.mexFunction(a) +b = prova.mexFunction(a) #print (b) print (b[4].shape) print (b[4]) -print (b[5]) \ No newline at end of file +print (b[5]) + +def print_element(input): + print ("f: {0}".format(input)) + +prova.doSomething(a, print_element, None) + +c = [] +def append_to_list(input, shouldPrint=False): + c.append(input) + if shouldPrint: + print ("{0} appended to list {1}".format(input, c)) + +def element_wise_algebra(input, shouldPrint=True): + ret = input - 7 + if shouldPrint: + print ("element_wise {0}".format(ret)) + return ret + +prova.doSomething(a, append_to_list, None) +#print ("this is c: {0}".format(c)) + +b = prova.doSomething(a, None, element_wise_algebra) +#print (a) +print (b[5]) -- cgit v1.2.3 From c28385d0dd5efcb32bd2c33e4bd93ba61f959b3f Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 14:27:28 +0100 Subject: updated test for regularizer API --- src/Python/test_regularizers.py | 590 ++++++++++++++++++++-------------------- 1 file changed, 290 insertions(+), 300 deletions(-) diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py index 6a34749..5d25f02 100644 --- a/src/Python/test_regularizers.py +++ b/src/Python/test_regularizers.py @@ -8,216 +8,37 @@ Created on Fri Aug 4 11:10:05 2017 from ccpi.viewer.CILViewer2D import Converter import vtk -import regularizers import matplotlib.pyplot as plt import numpy as np import os from enum import Enum - -class Regularizer(): - '''Class to handle regularizer algorithms to be used during reconstruction - - Currently 5 regularization algorithms are available: - - 1) SplitBregman_TV - 2) FGP_TV - 3) - 4) - 5) - - Usage: - the regularizer can be invoked as object or as static method - Depending on the actual regularizer the input parameter may vary, and - a different default setting is defined. - reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) - - out = reg(input=u0, regularization_parameter=10., number_of_iterations=30, - tolerance_constant=1e-4, - TV_Penalty=Regularizer.TotalVariationPenalty.l1) - - out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., - number_of_iterations=30, tolerance_constant=1e-4, - TV_Penalty=Regularizer.TotalVariationPenalty.l1) - - A number of optional parameters can be passed or skipped - out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) - - ''' - class Algorithm(Enum): - SplitBregman_TV = regularizers.SplitBregman_TV - FGP_TV = regularizers.FGP_TV - LLT_model = regularizers.LLT_model - PatchBased_Regul = regularizers.PatchBased_Regul - TGV_PD = regularizers.TGV_PD - # Algorithm - - class TotalVariationPenalty(Enum): - isotropic = 0 - l1 = 1 - # TotalVariationPenalty - - def __init__(self , algorithm): - self.setAlgorithm ( algorithm ) - # __init__ - - def setAlgorithm(self, algorithm): - self.algorithm = algorithm - self.pars = self.parsForAlgorithm(algorithm) - # setAlgorithm - - def parsForAlgorithm(self, algorithm): - pars = dict() - - if algorithm == Regularizer.Algorithm.SplitBregman_TV : - pars['algorithm'] = algorithm - pars['input'] = None - pars['regularization_parameter'] = None - pars['number_of_iterations'] = 35 - pars['tolerance_constant'] = 0.0001 - pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic - - elif algorithm == Regularizer.Algorithm.FGP_TV : - pars['algorithm'] = algorithm - pars['input'] = None - pars['regularization_parameter'] = None - pars['number_of_iterations'] = 50 - pars['tolerance_constant'] = 0.001 - pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic - - elif algorithm == Regularizer.Algorithm.LLT_model: - pars['algorithm'] = algorithm - pars['input'] = None - pars['regularization_parameter'] = None - pars['time_step'] = None - pars['number_of_iterations'] = None - pars['tolerance_constant'] = None - pars['restrictive_Z_smoothing'] = 0 - - elif algorithm == Regularizer.Algorithm.PatchBased_Regul: - pars['algorithm'] = algorithm - pars['input'] = None - pars['searching_window_ratio'] = None - pars['similarity_window_ratio'] = None - pars['PB_filtering_parameter'] = None - pars['regularization_parameter'] = None - - elif algorithm == Regularizer.Algorithm.TGV_PD: - pars['algorithm'] = algorithm - pars['input'] = None - pars['first_order_term'] = None - pars['second_order_term'] = None - pars['number_of_iterations'] = None - pars['regularization_parameter'] = None - - - - return pars - # parsForAlgorithm - - def __call__(self, input, regularization_parameter, **kwargs): - - if kwargs is not None: - for key, value in kwargs.items(): - #print("{0} = {1}".format(key, value)) - self.pars[key] = value - self.pars['input'] = input - self.pars['regularization_parameter'] = regularization_parameter - #for key, value in self.pars.items(): - # print("{0} = {1}".format(key, value)) - if None in self.pars: - raise Exception("Not all parameters have been provided") - - if self.algorithm == Regularizer.Algorithm.SplitBregman_TV : - return self.algorithm(input, regularization_parameter, - self.pars['number_of_iterations'], - self.pars['tolerance_constant'], - self.pars['TV_penalty'].value ) - elif self.algorithm == Regularizer.Algorithm.FGP_TV : - return self.algorithm(input, regularization_parameter, - self.pars['number_of_iterations'], - self.pars['tolerance_constant'], - self.pars['TV_penalty'].value ) - elif self.algorithm == Regularizer.Algorithm.LLT_model : - #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) - # no default - return self.algorithm(input, - regularization_parameter, - self.pars['time_step'] , - self.pars['number_of_iterations'], - self.pars['tolerance_constant'], - self.pars['restrictive_Z_smoothing'] ) - elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul : - #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) - # no default - return self.algorithm(input, regularization_parameter, - self.pars['searching_window_ratio'] , - self.pars['similarity_window_ratio'] , - self.pars['PB_filtering_parameter']) - elif self.algorithm == Regularizer.Algorithm.TGV_PD : - #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) - # no default - return self.algorithm(input, regularization_parameter, - self.pars['first_order_term'] , - self.pars['second_order_term'] , - self.pars['number_of_iterations']) - - - - # __call__ - - @staticmethod - def SplitBregman_TV(input, regularization_parameter , **kwargs): - reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) - out = list( reg(input, regularization_parameter, **kwargs) ) - out.append(reg.pars) - return out - - @staticmethod - def FGP_TV(input, regularization_parameter , **kwargs): - reg = Regularizer(Regularizer.Algorithm.FGP_TV) - out = list( reg(input, regularization_parameter, **kwargs) ) - out.append(reg.pars) - return out - - @staticmethod - def LLT_model(input, regularization_parameter , time_step, number_of_iterations, - tolerance_constant, restrictive_Z_smoothing=0): - reg = Regularizer(Regularizer.Algorithm.LLT_model) - out = list( reg(input, regularization_parameter, time_step=time_step, - number_of_iterations=number_of_iterations, - tolerance_constant=tolerance_constant, - restrictive_Z_smoothing=restrictive_Z_smoothing) ) - out.append(reg.pars) - return out - - @staticmethod - def PatchBased_Regul(input, regularization_parameter, - searching_window_ratio, - similarity_window_ratio, - PB_filtering_parameter): - reg = Regularizer(Regularizer.Algorithm.PatchBased_Regul) - out = list( reg(input, - regularization_parameter, - searching_window_ratio=searching_window_ratio, - similarity_window_ratio=similarity_window_ratio, - PB_filtering_parameter=PB_filtering_parameter ) - ) - out.append(reg.pars) - return out - - @staticmethod - def TGV_PD(input, regularization_parameter , first_order_term, - second_order_term, number_of_iterations): - - reg = Regularizer(Regularizer.Algorithm.TGV_PD) - out = list( reg(input, regularization_parameter, - first_order_term=first_order_term, - second_order_term=second_order_term, - number_of_iterations=number_of_iterations) ) - out.append(reg.pars) - return out - - +import timeit + +from Regularizer import Regularizer + +############################################################################### +#https://stackoverflow.com/questions/13875989/comparing-image-in-url-to-image-in-filesystem-in-python/13884956#13884956 +#NRMSE a normalization of the root of the mean squared error +#NRMSE is simply 1 - [RMSE / (maxval - minval)]. Where maxval is the maximum +# intensity from the two images being compared, and respectively the same for +# minval. RMSE is given by the square root of MSE: +# sqrt[(sum(A - B) ** 2) / |A|], +# where |A| means the number of elements in A. By doing this, the maximum value +# given by RMSE is maxval. + +def nrmse(im1, im2): + a, b = im1.shape + rmse = np.sqrt(np.sum((im2 - im1) ** 2) / float(a * b)) + max_val = max(np.max(im1), np.max(im2)) + min_val = min(np.min(im1), np.min(im2)) + return 1 - (rmse / (max_val - min_val)) +############################################################################### + +############################################################################### +# +# 2D Regularizers +# +############################################################################### #Example: # figure; # Im = double(imread('lena_gray_256.tif'))/255; % loading image @@ -255,49 +76,55 @@ reg_output = [] ####################### SplitBregman_TV ##################################### # u = SplitBregman_TV(single(u0), 10, 30, 1e-04); -reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) +use_object = True +if use_object: + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + reg.setParameter(input=u0) + reg.setParameter(regularization_parameter=10.) + # or + # reg.setParameter(input=u0, regularization_parameter=10., #number_of_iterations=30, + #tolerance_constant=1e-4, + #TV_Penalty=Regularizer.TotalVariationPenalty.l1) + plotme = reg() [0] + pars = reg.pars + textstr = reg.printParametersToString() + + #out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30, + #tolerance_constant=1e-4, + # TV_Penalty=Regularizer.TotalVariationPenalty.l1) + +#out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30, +# tolerance_constant=1e-4, +# TV_Penalty=Regularizer.TotalVariationPenalty.l1) -out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30, - #tolerance_constant=1e-4, - TV_Penalty=Regularizer.TotalVariationPenalty.l1) - -out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30, - tolerance_constant=1e-4, - TV_Penalty=Regularizer.TotalVariationPenalty.l1) -out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) -pars = out2[2] -reg_output.append(out2) +else: + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) + pars = out2[2] + reg_output.append(out2) + plotme = reg_output[-1][0] + textstr = out2[-1] a=fig.add_subplot(2,3,2) -a.set_title('SplitBregman_TV') -textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s' -textstr = textstr % (pars['regularization_parameter'], - pars['number_of_iterations'], - pars['tolerance_constant'], - pars['TV_penalty'].name) + # these are matplotlib.patch.Patch properties props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) -imgplot = plt.imshow(reg_output[-1][0]) +imgplot = plt.imshow(plotme) ###################### FGP_TV ######################################### # u = FGP_TV(single(u0), 0.05, 100, 1e-04); -out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.05, - number_of_iterations=10) -pars = out2[-1] +out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.005, + number_of_iterations=200) +pars = out2[-2] reg_output.append(out2) a=fig.add_subplot(2,3,3) -a.set_title('FGP_TV') -textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s' -textstr = textstr % (pars['regularization_parameter'], - pars['number_of_iterations'], - pars['tolerance_constant'], - pars['TV_penalty'].name) + +textstr = out2[-1] # these are matplotlib.patch.Patch properties props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) @@ -316,50 +143,12 @@ out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25, time_step=0.0003, tolerance_constant=0.0001, number_of_iterations=300) -pars = out2[-1] +pars = out2[-2] reg_output.append(out2) a=fig.add_subplot(2,3,4) -a.set_title('LLT_model') -textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\ntime-step=%f' -textstr = textstr % (pars['regularization_parameter'], - pars['number_of_iterations'], - pars['tolerance_constant'], - pars['time_step'] - ) - -# these are matplotlib.patch.Patch properties -props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) -# place a text box in upper left in axes coords -a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, - verticalalignment='top', bbox=props) -imgplot = plt.imshow(reg_output[-1][0]) - -###################### PatchBased_Regul ######################################### -# Quick 2D denoising example in Matlab: -# Im = double(imread('lena_gray_256.tif'))/255; % loading image -# u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise -# ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); - -out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, - searching_window_ratio=3, - similarity_window_ratio=1, - PB_filtering_parameter=0.08) -pars = out2[-1] -reg_output.append(out2) - -a=fig.add_subplot(2,3,5) -a.set_title('PatchBased_Regul') -textstr = 'regularization_parameter=%.2f\nsearching_window_ratio=%d\nsimilarity_window_ratio=%.2e\nPB_filtering_parameter=%f' -textstr = textstr % (pars['regularization_parameter'], - pars['searching_window_ratio'], - pars['similarity_window_ratio'], - pars['PB_filtering_parameter']) - - - - +textstr = out2[-1] # these are matplotlib.patch.Patch properties props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords @@ -367,6 +156,215 @@ a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) imgplot = plt.imshow(reg_output[-1][0]) +# ###################### PatchBased_Regul ######################################### +# # Quick 2D denoising example in Matlab: +# # Im = double(imread('lena_gray_256.tif'))/255; % loading image +# # u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +# # ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); + +# out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, + # searching_window_ratio=3, + # similarity_window_ratio=1, + # PB_filtering_parameter=0.08) +# pars = out2[-2] +# reg_output.append(out2) + +# a=fig.add_subplot(2,3,5) + + +# textstr = out2[-1] + +# # these are matplotlib.patch.Patch properties +# props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# # place a text box in upper left in axes coords +# a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + # verticalalignment='top', bbox=props) +# imgplot = plt.imshow(reg_output[-1][0]) + + +# ###################### TGV_PD ######################################### +# # Quick 2D denoising example in Matlab: +# # Im = double(imread('lena_gray_256.tif'))/255; % loading image +# # u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +# # u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); + + +# out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, + # first_order_term=1.3, + # second_order_term=1, + # number_of_iterations=550) +# pars = out2[-2] +# reg_output.append(out2) + +# a=fig.add_subplot(2,3,6) + + +# textstr = out2[-1] + + +# # these are matplotlib.patch.Patch properties +# props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# # place a text box in upper left in axes coords +# a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + # verticalalignment='top', bbox=props) +# imgplot = plt.imshow(reg_output[-1][0]) + + +plt.show() + +################################################################################ +## +## 3D Regularizers +## +################################################################################ +##Example: +## figure; +## Im = double(imread('lena_gray_256.tif'))/255; % loading image +## u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0; +## u = SplitBregman_TV(single(u0), 10, 30, 1e-04); +# +##filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-Reconstruction\python\test\reconstruction_example.mha" +#filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-Simpleflex\data\head.mha" +# +#reader = vtk.vtkMetaImageReader() +#reader.SetFileName(os.path.normpath(filename)) +#reader.Update() +##vtk returns 3D images, let's take just the one slice there is as 2D +#Im = Converter.vtk2numpy(reader.GetOutput()) +#Im = Im.astype('float32') +##imgplot = plt.imshow(Im) +#perc = 0.05 +#u0 = Im + (perc* np.random.normal(size=np.shape(Im))) +## map the u0 u0->u0>0 +#f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1) +#u0 = f(u0).astype('float32') +#converter = Converter.numpy2vtkImporter(u0, reader.GetOutput().GetSpacing(), +# reader.GetOutput().GetOrigin()) +#converter.Update() +#writer = vtk.vtkMetaImageWriter() +#writer.SetInputData(converter.GetOutput()) +#writer.SetFileName(r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\noisy_head.mha") +##writer.Write() +# +# +### plot +#fig3D = plt.figure() +##a=fig.add_subplot(3,3,1) +##a.set_title('Original') +##imgplot = plt.imshow(Im) +#sliceNo = 32 +# +#a=fig3D.add_subplot(2,3,1) +#a.set_title('noise') +#imgplot = plt.imshow(u0.T[sliceNo]) +# +#reg_output3d = [] +# +############################################################################### +## Call regularizer +# +######################## SplitBregman_TV ##################################### +## u = SplitBregman_TV(single(u0), 10, 30, 1e-04); +# +##reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) +# +##out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30, +## #tolerance_constant=1e-4, +## TV_Penalty=Regularizer.TotalVariationPenalty.l1) +# +#out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30, +# tolerance_constant=1e-4, +# TV_Penalty=Regularizer.TotalVariationPenalty.l1) +# +# +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# +####################### FGP_TV ######################################### +## u = FGP_TV(single(u0), 0.05, 100, 1e-04); +#out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.005, +# number_of_iterations=200) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# +####################### LLT_model ######################################### +## * u0 = Im + .03*randn(size(Im)); % adding noise +## [Den] = LLT_model(single(u0), 10, 0.1, 1); +##Den = LLT_model(single(u0), 25, 0.0003, 300, 0.0001, 0); +##input, regularization_parameter , time_step, number_of_iterations, +## tolerance_constant, restrictive_Z_smoothing=0 +#out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25, +# time_step=0.0003, +# tolerance_constant=0.0001, +# number_of_iterations=300) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# +####################### PatchBased_Regul ######################################### +## Quick 2D denoising example in Matlab: +## Im = double(imread('lena_gray_256.tif'))/255; % loading image +## u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +## ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); +# +#out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, +# searching_window_ratio=3, +# similarity_window_ratio=1, +# PB_filtering_parameter=0.08) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# ###################### TGV_PD ######################################### # Quick 2D denoising example in Matlab: @@ -375,30 +373,22 @@ imgplot = plt.imshow(reg_output[-1][0]) # u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); -out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, - first_order_term=1.3, - second_order_term=1, - number_of_iterations=550) -pars = out2[-1] -reg_output.append(out2) - -a=fig.add_subplot(2,3,6) -a.set_title('TGV_PD') -textstr = 'regularization_parameter=%.2f\nfirst_order_term=%.2f\nsecond_order_term=%.2f\nnumber_of_iterations=%d' -textstr = textstr % (pars['regularization_parameter'], - pars['first_order_term'], - pars['second_order_term'], - pars['number_of_iterations']) - - - - -# these are matplotlib.patch.Patch properties -props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) -# place a text box in upper left in axes coords -a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, - verticalalignment='top', bbox=props) -imgplot = plt.imshow(reg_output[-1][0]) - - - +#out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, +# first_order_term=1.3, +# second_order_term=1, +# number_of_iterations=550) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) -- cgit v1.2.3 From db45d96898f23c3bc97e4c19e834fa976ec301c8 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 14:31:16 +0100 Subject: initial commit of Reconstructor.py --- src/Python/ccpi/reconstruction/Reconstructor.py | 598 ++++++++++++++++++++++++ 1 file changed, 598 insertions(+) create mode 100644 src/Python/ccpi/reconstruction/Reconstructor.py diff --git a/src/Python/ccpi/reconstruction/Reconstructor.py b/src/Python/ccpi/reconstruction/Reconstructor.py new file mode 100644 index 0000000..ba67327 --- /dev/null +++ b/src/Python/ccpi/reconstruction/Reconstructor.py @@ -0,0 +1,598 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +import h5py +from ccpi.reconstruction.parallelbeam import alg + +from Regularizer import Regularizer +from enum import Enum + +import astra + + +class Reconstructor: + + class Algorithm(Enum): + CGLS = alg.cgls + CGLS_CONV = alg.cgls_conv + SIRT = alg.sirt + MLEM = alg.mlem + CGLS_TICHONOV = alg.cgls_tikhonov + CGLS_TVREG = alg.cgls_TVreg + FISTA = 'fista' + + def __init__(self, algorithm = None, projection_data = None, + angles = None, center_of_rotation = None , + flat_field = None, dark_field = None, + iterations = None, resolution = None, isLogScale = False, threads = None, + normalized_projection = None): + + self.pars = dict() + self.pars['algorithm'] = algorithm + self.pars['projection_data'] = projection_data + self.pars['normalized_projection'] = normalized_projection + self.pars['angles'] = angles + self.pars['center_of_rotation'] = numpy.double(center_of_rotation) + self.pars['flat_field'] = flat_field + self.pars['iterations'] = iterations + self.pars['dark_field'] = dark_field + self.pars['resolution'] = resolution + self.pars['isLogScale'] = isLogScale + self.pars['threads'] = threads + if (iterations != None): + self.pars['iterationValues'] = numpy.zeros((iterations)) + + if projection_data != None and dark_field != None and flat_field != None: + norm = self.normalize(projection_data, dark_field, flat_field, 0.1) + self.pars['normalized_projection'] = norm + + + def setPars(self, parameters): + keys = ['algorithm','projection_data' ,'normalized_projection', \ + 'angles' , 'center_of_rotation' , 'flat_field', \ + 'iterations','dark_field' , 'resolution', 'isLogScale' , \ + 'threads' , 'iterationValues', 'regularize'] + + for k in keys: + if k not in parameters.keys(): + self.pars[k] = None + else: + self.pars[k] = parameters[k] + + + def sanityCheck(self): + projection_data = self.pars['projection_data'] + dark_field = self.pars['dark_field'] + flat_field = self.pars['flat_field'] + angles = self.pars['angles'] + + if projection_data != None and dark_field != None and \ + angles != None and flat_field != None: + data_shape = numpy.shape(projection_data) + angle_shape = numpy.shape(angles) + + if angle_shape[0] != data_shape[0]: + #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ + # (angle_shape[0] , data_shape[0]) ) + return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ + (angle_shape[0] , data_shape[0]) ) + + if data_shape[1:] != numpy.shape(flat_field): + #raise Exception('Projection and flat field dimensions do not match') + return (False , 'Projection and flat field dimensions do not match') + if data_shape[1:] != numpy.shape(dark_field): + #raise Exception('Projection and dark field dimensions do not match') + return (False , 'Projection and dark field dimensions do not match') + + return (True , '' ) + elif self.pars['normalized_projection'] != None: + data_shape = numpy.shape(self.pars['normalized_projection']) + angle_shape = numpy.shape(angles) + + if angle_shape[0] != data_shape[0]: + #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ + # (angle_shape[0] , data_shape[0]) ) + return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ + (angle_shape[0] , data_shape[0]) ) + else: + return (True , '' ) + else: + return (False , 'Not enough data') + + def reconstruct(self, parameters = None): + if parameters != None: + self.setPars(parameters) + + go , reason = self.sanityCheck() + if go: + return self._reconstruct() + else: + raise Exception(reason) + + + def _reconstruct(self, parameters=None): + if parameters!=None: + self.setPars(parameters) + parameters = self.pars + + if parameters['algorithm'] != None and \ + parameters['normalized_projection'] != None and \ + parameters['angles'] != None and \ + parameters['center_of_rotation'] != None and \ + parameters['iterations'] != None and \ + parameters['resolution'] != None and\ + parameters['threads'] != None and\ + parameters['isLogScale'] != None: + + + if parameters['algorithm'] in (Reconstructor.Algorithm.CGLS, + Reconstructor.Algorithm.MLEM, Reconstructor.Algorithm.SIRT): + #store parameters + self.pars = parameters + result = parameters['algorithm']( + parameters['normalized_projection'] , + parameters['angles'], + parameters['center_of_rotation'], + parameters['resolution'], + parameters['iterations'], + parameters['threads'] , + parameters['isLogScale'] + ) + return result + elif parameters['algorithm'] in (Reconstructor.Algorithm.CGLS_CONV, + Reconstructor.Algorithm.CGLS_TICHONOV, + Reconstructor.Algorithm.CGLS_TVREG) : + self.pars = parameters + result = parameters['algorithm']( + parameters['normalized_projection'] , + parameters['angles'], + parameters['center_of_rotation'], + parameters['resolution'], + parameters['iterations'], + parameters['threads'] , + parameters['regularize'], + numpy.zeros((parameters['iterations'])), + parameters['isLogScale'] + ) + + elif parameters['algorithm'] == Reconstructor.Algorithm.FISTA: + pass + + else: + if parameters['projection_data'] != None and \ + parameters['dark_field'] != None and \ + parameters['flat_field'] != None: + norm = self.normalize(parameters['projection_data'], + parameters['dark_field'], + parameters['flat_field'], 0.1) + self.pars['normalized_projection'] = norm + return self._reconstruct(parameters) + + + + def _normalize(self, projection, dark, flat, def_val=0): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + def normalize(self, projections, dark, flat, def_val=0): + norm = [self._normalize(projection, dark, flat, def_val) for projection in projections] + return numpy.asarray (norm, dtype=numpy.float32) + + + +class FISTA(): + '''FISTA-based reconstruction algorithm using ASTRA-toolbox + + ''' + # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> + # ___Input___: + # params.[] file: + # - .proj_geom (geometry of the projector) [required] + # - .vol_geom (geometry of the reconstructed object) [required] + # - .sino (vectorized in 2D or 3D sinogram) [required] + # - .iterFISTA (iterations for the main loop, default 40) + # - .L_const (Lipschitz constant, default Power method) ) + # - .X_ideal (ideal image, if given) + # - .weights (statisitcal weights, size of the sinogram) + # - .ROI (Region-of-interest, only if X_ideal is given) + # - .initialize (a 'warm start' using SIRT method from ASTRA) + #----------------Regularization choices------------------------ + # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) + # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) + # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) + # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) + # - .Regul_Iterations (iterations for the selected penalty, default 25) + # - .Regul_tauLLT (time step parameter for LLT term) + # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) + # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) + #----------------Visualization parameters------------------------ + # - .show (visualize reconstruction 1/0, (0 default)) + # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) + # - .slice (for 3D volumes - slice number to imshow) + # ___Output___: + # 1. X - reconstructed image/volume + # 2. output - a structure with + # - .Resid_error - residual error (if X_ideal is given) + # - .objective: value of the objective function + # - .L_const: Lipshitz constant to avoid recalculations + + # References: + # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse + # Problems" by A. Beck and M Teboulle + # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo + # 3. "A novel tomographic reconstruction method based on the robust + # Student's t function for suppressing data outliers" D. Kazantsev et.al. + # D. Kazantsev, 2016-17 + def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + self.params = dict() + self.params['projector_geometry'] = projector_geometry + self.params['output_geometry'] = output_geometry + self.params['input_sinogram'] = input_sinogram + detectors, nangles, sliceZ = numpy.shape(input_sinogram) + self.params['detectors'] = detectors + self.params['number_og_angles'] = nangles + self.params['SlicesZ'] = sliceZ + + # Accepted input keywords + kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , + 'weights' , 'region_of_interest' , 'initialize' , + 'regularizer' , + 'ring_lambda_R_L1', + 'ring_alpha') + + # handle keyworded parameters + if kwargs is not None: + for key, value in kwargs.items(): + if key in kw: + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + # set the default values for the parameters if not set + if 'number_of_iterations' in kwargs.keys(): + self.pars['number_of_iterations'] = kwargs['number_of_iterations'] + else: + self.pars['number_of_iterations'] = 40 + if 'weights' in kwargs.keys(): + self.pars['weights'] = kwargs['weights'] + else: + self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) + if 'Lipschitz_constant' in kwargs.keys(): + self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] + else: + self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + + if not self.pars['ideal_image'] in kwargs.keys(): + self.pars['ideal_image'] = None + + if not self.pars['region_of_interest'] : + if self.pars['ideal_image'] == None: + pass + else: + self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + + if not self.pars['regularizer'] : + self.pars['regularizer'] = None + else: + # the regularizer must be a correctly instantiated object + if not self.pars['ring_lambda_R_L1']: + self.pars['ring_lambda_R_L1'] = 0 + if not self.pars['ring_alpha']: + self.pars['ring_alpha'] = 1 + + + + + def calculateLipschitzConstantWithPowerMethod(self): + ''' using Power method (PM) to establish L constant''' + + #N = params.vol_geom.GridColCount + N = self.pars['output_geometry'].GridColCount + proj_geom = self.params['projector_geometry'] + vol_geom = self.params['output_geometry'] + weights = self.pars['weights'] + SlicesZ = self.pars['SlicesZ'] + + if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); + niter = 15;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights.T[0]) + proj_geomT = proj_geom.copy(); + proj_geomT.DetectorRowCount = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + + for i in range(niter): + if i == 0: + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); + y = sqweight * y # element wise multiplication + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + y = sqweight*y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + del proj_geomT + del vol_geomT + else + #% divergen beam geometry + #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights.T[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 + + return s + + + def setRegularizer(self, regularizer): + if regularizer + self.pars['regularizer'] = regularizer + + + + + +def getEntry(location): + for item in nx[location].keys(): + print (item) + + +print ("Loading Data") + +##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" +####ind = [i * 1049 for i in range(360)] +#### use only 360 images +##images = 200 +##ind = [int(i * 1049 / images) for i in range(images)] +##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) + +#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" +fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" +nx = h5py.File(fname, "r") + +# the data are stored in a particular location in the hdf5 +for item in nx['entry1/tomo_entry/data'].keys(): + print (item) + +data = nx.get('entry1/tomo_entry/data/rotation_angle') +angles = numpy.zeros(data.shape) +data.read_direct(angles) +print (angles) +# angles should be in degrees + +data = nx.get('entry1/tomo_entry/data/data') +stack = numpy.zeros(data.shape) +data.read_direct(stack) +print (data.shape) + +print ("Data Loaded") + + +# Normalize +data = nx.get('entry1/tomo_entry/instrument/detector/image_key') +itype = numpy.zeros(data.shape) +data.read_direct(itype) +# 2 is dark field +darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] +dark = darks[0] +for i in range(1, len(darks)): + dark += darks[i] +dark = dark / len(darks) +#dark[0][0] = dark[0][1] + +# 1 is flat field +flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] +flat = flats[0] +for i in range(1, len(flats)): + flat += flats[i] +flat = flat / len(flats) +#flat[0][0] = dark[0][1] + + +# 0 is projection data +proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = numpy.asarray (angle_proj) +angle_proj = angle_proj.astype(numpy.float32) + +# normalized data are +# norm = (projection - dark)/(flat-dark) + +def normalize(projection, dark, flat, def_val=0.1): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + +norm = [normalize(projection, dark, flat) for projection in proj] +norm = numpy.asarray (norm) +norm = norm.astype(numpy.float32) + +#recon = Reconstructor(algorithm = Algorithm.CGLS, normalized_projection = norm, +# angles = angle_proj, center_of_rotation = 86.2 , +# flat_field = flat, dark_field = dark, +# iterations = 15, resolution = 1, isLogScale = False, threads = 3) + +#recon = Reconstructor(algorithm = Reconstructor.Algorithm.CGLS, projection_data = proj, +# angles = angle_proj, center_of_rotation = 86.2 , +# flat_field = flat, dark_field = dark, +# iterations = 15, resolution = 1, isLogScale = False, threads = 3) +#img_cgls = recon.reconstruct() +# +#pars = dict() +#pars['algorithm'] = Reconstructor.Algorithm.SIRT +#pars['projection_data'] = proj +#pars['angles'] = angle_proj +#pars['center_of_rotation'] = numpy.double(86.2) +#pars['flat_field'] = flat +#pars['iterations'] = 15 +#pars['dark_field'] = dark +#pars['resolution'] = 1 +#pars['isLogScale'] = False +#pars['threads'] = 3 +# +#img_sirt = recon.reconstruct(pars) +# +#recon.pars['algorithm'] = Reconstructor.Algorithm.MLEM +#img_mlem = recon.reconstruct() + +############################################################ +############################################################ +#recon.pars['algorithm'] = Reconstructor.Algorithm.CGLS_CONV +#recon.pars['regularize'] = numpy.double(0.1) +#img_cgls_conv = recon.reconstruct() + +niterations = 15 +threads = 3 + +img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + iteration_values, False) +print ("iteration values %s" % str(iteration_values)) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) +iteration_values = numpy.zeros((niterations,)) +img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) + + +##numpy.save("cgls_recon.npy", img_data) +import matplotlib.pyplot as plt +fig, ax = plt.subplots(1,6,sharey=True) +ax[0].imshow(img_cgls[80]) +ax[0].axis('off') # clear x- and y-axes +ax[1].imshow(img_sirt[80]) +ax[1].axis('off') # clear x- and y-axes +ax[2].imshow(img_mlem[80]) +ax[2].axis('off') # clear x- and y-axesplt.show() +ax[3].imshow(img_cgls_conv[80]) +ax[3].axis('off') # clear x- and y-axesplt.show() +ax[4].imshow(img_cgls_tikhonov[80]) +ax[4].axis('off') # clear x- and y-axesplt.show() +ax[5].imshow(img_cgls_TVreg[80]) +ax[5].axis('off') # clear x- and y-axesplt.show() + + +plt.show() + +#viewer = edo.CILViewer() +#viewer.setInputAsNumpy(img_cgls2) +#viewer.displaySliceActor(0) +#viewer.startRenderLoop() + +import vtk + +def NumpyToVTKImageData(numpyarray): + if (len(numpy.shape(numpyarray)) == 3): + doubleImg = vtk.vtkImageData() + shape = numpy.shape(numpyarray) + doubleImg.SetDimensions(shape[0], shape[1], shape[2]) + doubleImg.SetOrigin(0,0,0) + doubleImg.SetSpacing(1,1,1) + doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) + #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) + doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) + + for i in range(shape[0]): + for j in range(shape[1]): + for k in range(shape[2]): + doubleImg.SetScalarComponentFromDouble( + i,j,k,0, numpyarray[i][j][k]) + #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) + # rescale to appropriate VTK_UNSIGNED_SHORT + stats = vtk.vtkImageAccumulate() + stats.SetInputData(doubleImg) + stats.Update() + iMin = stats.GetMin()[0] + iMax = stats.GetMax()[0] + scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) + + shiftScaler = vtk.vtkImageShiftScale () + shiftScaler.SetInputData(doubleImg) + shiftScaler.SetScale(scale) + shiftScaler.SetShift(iMin) + shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) + shiftScaler.Update() + return shiftScaler.GetOutput() + +#writer = vtk.vtkMetaImageWriter() +#writer.SetFileName(alg + "_recon.mha") +#writer.SetInputData(NumpyToVTKImageData(img_cgls2)) +#writer.Write() -- cgit v1.2.3 From c3b58791b906aa6a3b99f32fa5f69a09bb075527 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 14:56:08 +0100 Subject: module rename to cpu_regularizers --- src/Python/setup.py | 4 ++-- src/Python/test_regularizers.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/Python/setup.py b/src/Python/setup.py index 0468722..94467c4 100644 --- a/src/Python/setup.py +++ b/src/Python/setup.py @@ -47,7 +47,7 @@ setup( description='CCPi Core Imaging Library - FISTA Reconstruction Module', version=cil_version, cmdclass = {'build_ext': build_ext}, - ext_modules = [Extension("regularizers", + ext_modules = [Extension("ccpi.imaging.cpu_regularizers", sources=["fista_module.cpp", "..\\..\\main_func\\regularizers_CPU\\FGP_TV_core.c", "..\\..\\main_func\\regularizers_CPU\\SplitBregman_TV_core.c", @@ -60,5 +60,5 @@ setup( ], zip_safe = False, - packages = {'ccpi','ccpi.reconstruction'}, + packages = {'ccpi','ccpi.fistareconstruction'}, ) diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py index 5d25f02..755804a 100644 --- a/src/Python/test_regularizers.py +++ b/src/Python/test_regularizers.py @@ -14,7 +14,8 @@ import os from enum import Enum import timeit -from Regularizer import Regularizer +#from Regularizer import Regularizer +from ccpi.imaging.Regularizer import Regularizer ############################################################################### #https://stackoverflow.com/questions/13875989/comparing-image-in-url-to-image-in-filesystem-in-python/13884956#13884956 -- cgit v1.2.3 From 70d03d2c7567fac409086f015ca9e2ac47b0fc20 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 14:58:11 +0100 Subject: changed the backward slash to forward --- src/Python/setup.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/Python/setup.py b/src/Python/setup.py index 94467c4..154f979 100644 --- a/src/Python/setup.py +++ b/src/Python/setup.py @@ -49,12 +49,12 @@ setup( cmdclass = {'build_ext': build_ext}, ext_modules = [Extension("ccpi.imaging.cpu_regularizers", sources=["fista_module.cpp", - "..\\..\\main_func\\regularizers_CPU\\FGP_TV_core.c", - "..\\..\\main_func\\regularizers_CPU\\SplitBregman_TV_core.c", - "..\\..\\main_func\\regularizers_CPU\\LLT_model_core.c", - "..\\..\\main_func\\regularizers_CPU\\PatchBased_Regul_core.c", - "..\\..\\main_func\\regularizers_CPU\\TGV_PD_core.c", - "..\\..\\main_func\\regularizers_CPU\\utils.c" + "../../main_func/regularizers_CPU/FGP_TV_core.c", + "../../main_func/regularizers_CPU/SplitBregman_TV_core.c", + "../../main_func/regularizers_CPU/LLT_model_core.c", + "../../main_func/regularizers_CPU/PatchBased_Regul_core.c", + "../../main_func/regularizers_CPU/TGV_PD_core.c", + "../../main_func/regularizers_CPU/utils.c" ], include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), -- cgit v1.2.3 From 396c11bd2c8bde1197b708062590a9e3b95538bd Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 15:01:28 +0100 Subject: added viewer for testing --- src/Python/ccpi/viewer/CILViewer.py | 361 +++++++ src/Python/ccpi/viewer/CILViewer2D.py | 1126 ++++++++++++++++++++ src/Python/ccpi/viewer/QVTKWidget.py | 340 ++++++ src/Python/ccpi/viewer/QVTKWidget2.py | 84 ++ src/Python/ccpi/viewer/__init__.py | 1 + .../viewer/__pycache__/CILViewer.cpython-35.pyc | Bin 0 -> 10542 bytes .../viewer/__pycache__/CILViewer2D.cpython-35.pyc | Bin 0 -> 35633 bytes .../viewer/__pycache__/QVTKWidget.cpython-35.pyc | Bin 0 -> 10099 bytes .../viewer/__pycache__/QVTKWidget2.cpython-35.pyc | Bin 0 -> 1316 bytes .../viewer/__pycache__/__init__.cpython-35.pyc | Bin 0 -> 210 bytes src/Python/ccpi/viewer/embedvtk.py | 75 ++ 11 files changed, 1987 insertions(+) create mode 100644 src/Python/ccpi/viewer/CILViewer.py create mode 100644 src/Python/ccpi/viewer/CILViewer2D.py create mode 100644 src/Python/ccpi/viewer/QVTKWidget.py create mode 100644 src/Python/ccpi/viewer/QVTKWidget2.py create mode 100644 src/Python/ccpi/viewer/__init__.py create mode 100644 src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc create mode 100644 src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc create mode 100644 src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc create mode 100644 src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc create mode 100644 src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc create mode 100644 src/Python/ccpi/viewer/embedvtk.py diff --git a/src/Python/ccpi/viewer/CILViewer.py b/src/Python/ccpi/viewer/CILViewer.py new file mode 100644 index 0000000..efcf8be --- /dev/null +++ b/src/Python/ccpi/viewer/CILViewer.py @@ -0,0 +1,361 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Edoardo Pasca +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import vtk +import numpy +import math +from vtk.util import numpy_support + +SLICE_ORIENTATION_XY = 2 # Z +SLICE_ORIENTATION_XZ = 1 # Y +SLICE_ORIENTATION_YZ = 0 # X + + + +class CILViewer(): + '''Simple 3D Viewer based on VTK classes''' + + def __init__(self, dimx=600,dimy=600): + '''creates the rendering pipeline''' + + # create a rendering window and renderer + self.ren = vtk.vtkRenderer() + self.renWin = vtk.vtkRenderWindow() + self.renWin.SetSize(dimx,dimy) + self.renWin.AddRenderer(self.ren) + + # img 3D as slice + self.img3D = None + self.sliceno = 0 + self.sliceOrientation = SLICE_ORIENTATION_XY + self.sliceActor = None + self.voi = None + self.wl = None + self.ia = None + self.sliceActorNo = 0 + # create a renderwindowinteractor + self.iren = vtk.vtkRenderWindowInteractor() + self.iren.SetRenderWindow(self.renWin) + + self.style = vtk.vtkInteractorStyleTrackballCamera() + self.iren.SetInteractorStyle(self.style) + + self.ren.SetBackground(.1, .2, .4) + + self.actors = {} + self.iren.RemoveObservers('MouseWheelForwardEvent') + self.iren.RemoveObservers('MouseWheelBackwardEvent') + + self.iren.AddObserver('MouseWheelForwardEvent', self.mouseInteraction, 1.0) + self.iren.AddObserver('MouseWheelBackwardEvent', self.mouseInteraction, 1.0) + + self.iren.RemoveObservers('KeyPressEvent') + self.iren.AddObserver('KeyPressEvent', self.keyPress, 1.0) + + + self.iren.Initialize() + + + + def getRenderer(self): + '''returns the renderer''' + return self.ren + + def getRenderWindow(self): + '''returns the render window''' + return self.renWin + + def getInteractor(self): + '''returns the render window interactor''' + return self.iren + + def getCamera(self): + '''returns the active camera''' + return self.ren.GetActiveCamera() + + def createPolyDataActor(self, polydata): + '''returns an actor for a given polydata''' + mapper = vtk.vtkPolyDataMapper() + if vtk.VTK_MAJOR_VERSION <= 5: + mapper.SetInput(polydata) + else: + mapper.SetInputData(polydata) + + # actor + actor = vtk.vtkActor() + actor.SetMapper(mapper) + #actor.GetProperty().SetOpacity(0.8) + return actor + + def setPolyDataActor(self, actor): + '''displays the given polydata''' + + self.ren.AddActor(actor) + + self.actors[len(self.actors)+1] = [actor, True] + self.iren.Initialize() + self.renWin.Render() + + def displayPolyData(self, polydata): + self.setPolyDataActor(self.createPolyDataActor(polydata)) + + def hideActor(self, actorno): + '''Hides an actor identified by its number in the list of actors''' + try: + if self.actors[actorno][1]: + self.ren.RemoveActor(self.actors[actorno][0]) + self.actors[actorno][1] = False + except KeyError as ke: + print ("Warning Actor not present") + + def showActor(self, actorno, actor = None): + '''Shows hidden actor identified by its number in the list of actors''' + try: + if not self.actors[actorno][1]: + self.ren.AddActor(self.actors[actorno][0]) + self.actors[actorno][1] = True + return actorno + except KeyError as ke: + # adds it to the actors if not there already + if actor != None: + self.ren.AddActor(actor) + self.actors[len(self.actors)+1] = [actor, True] + return len(self.actors) + + def addActor(self, actor): + '''Adds an actor to the render''' + return self.showActor(0, actor) + + + def saveRender(self, filename, renWin=None): + '''Save the render window to PNG file''' + # screenshot code: + w2if = vtk.vtkWindowToImageFilter() + if renWin == None: + renWin = self.renWin + w2if.SetInput(renWin) + w2if.Update() + + writer = vtk.vtkPNGWriter() + writer.SetFileName("%s.png" % (filename)) + writer.SetInputConnection(w2if.GetOutputPort()) + writer.Write() + + + def startRenderLoop(self): + self.iren.Start() + + + def setupObservers(self, interactor): + interactor.RemoveObservers('LeftButtonPressEvent') + interactor.AddObserver('LeftButtonPressEvent', self.mouseInteraction) + interactor.Initialize() + + + def mouseInteraction(self, interactor, event): + if event == 'MouseWheelForwardEvent': + maxSlice = self.img3D.GetDimensions()[self.sliceOrientation] + if (self.sliceno + 1 < maxSlice): + self.hideActor(self.sliceActorNo) + self.sliceno = self.sliceno + 1 + self.displaySliceActor(self.sliceno) + else: + minSlice = 0 + if (self.sliceno - 1 > minSlice): + self.hideActor(self.sliceActorNo) + self.sliceno = self.sliceno - 1 + self.displaySliceActor(self.sliceno) + + + def keyPress(self, interactor, event): + #print ("Pressed key %s" % interactor.GetKeyCode()) + # Slice Orientation + if interactor.GetKeyCode() == "x": + # slice on the other orientation + self.sliceOrientation = SLICE_ORIENTATION_YZ + self.sliceno = int(self.img3D.GetDimensions()[1] / 2) + self.hideActor(self.sliceActorNo) + self.displaySliceActor(self.sliceno) + elif interactor.GetKeyCode() == "y": + # slice on the other orientation + self.sliceOrientation = SLICE_ORIENTATION_XZ + self.sliceno = int(self.img3D.GetDimensions()[1] / 2) + self.hideActor(self.sliceActorNo) + self.displaySliceActor(self.sliceno) + elif interactor.GetKeyCode() == "z": + # slice on the other orientation + self.sliceOrientation = SLICE_ORIENTATION_XY + self.sliceno = int(self.img3D.GetDimensions()[2] / 2) + self.hideActor(self.sliceActorNo) + self.displaySliceActor(self.sliceno) + if interactor.GetKeyCode() == "X": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.ren.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.ren.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.ren.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_YZ] = math.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,0,-1) + self.ren.SetActiveCamera(camera) + self.ren.ResetCamera() + self.ren.Render() + interactor.SetKeyCode("x") + self.keyPress(interactor, event) + elif interactor.GetKeyCode() == "Y": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.ren.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.ren.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.ren.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_XZ] = math.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_YZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,0,-1) + self.ren.SetActiveCamera(camera) + self.ren.ResetCamera() + self.ren.Render() + interactor.SetKeyCode("y") + self.keyPress(interactor, event) + elif interactor.GetKeyCode() == "Z": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.ren.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.ren.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.ren.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_XY] = math.sqrt(newposition[SLICE_ORIENTATION_YZ] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,0,-1) + self.ren.SetActiveCamera(camera) + self.ren.ResetCamera() + self.ren.Render() + interactor.SetKeyCode("z") + self.keyPress(interactor, event) + else : + print ("Unhandled event %s" % interactor.GetKeyCode()) + + + + def setInput3DData(self, imageData): + self.img3D = imageData + + def setInputAsNumpy(self, numpyarray): + if (len(numpy.shape(numpyarray)) == 3): + doubleImg = vtk.vtkImageData() + shape = numpy.shape(numpyarray) + doubleImg.SetDimensions(shape[0], shape[1], shape[2]) + doubleImg.SetOrigin(0,0,0) + doubleImg.SetSpacing(1,1,1) + doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) + #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) + doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) + + for i in range(shape[0]): + for j in range(shape[1]): + for k in range(shape[2]): + doubleImg.SetScalarComponentFromDouble( + i,j,k,0, numpyarray[i][j][k]) + #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) + # rescale to appropriate VTK_UNSIGNED_SHORT + stats = vtk.vtkImageAccumulate() + stats.SetInputData(doubleImg) + stats.Update() + iMin = stats.GetMin()[0] + iMax = stats.GetMax()[0] + scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) + + shiftScaler = vtk.vtkImageShiftScale () + shiftScaler.SetInputData(doubleImg) + shiftScaler.SetScale(scale) + shiftScaler.SetShift(iMin) + shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) + shiftScaler.Update() + self.img3D = shiftScaler.GetOutput() + + def displaySliceActor(self, sliceno = 0): + self.sliceno = sliceno + first = False + + self.sliceActor , self.voi, self.wl , self.ia = \ + self.getSliceActor(self.img3D, + sliceno, + self.sliceActor, + self.voi, + self.wl, + self.ia) + no = self.showActor(self.sliceActorNo, self.sliceActor) + self.sliceActorNo = no + + self.iren.Initialize() + self.renWin.Render() + + return self.sliceActorNo + + + def getSliceActor(self, + imageData , + sliceno=0, + imageActor=None , + voi=None, + windowLevel=None, + imageAccumulate=None): + '''Slices a 3D volume and then creates an actor to be rendered''' + if (voi==None): + voi = vtk.vtkExtractVOI() + #voi = vtk.vtkImageClip() + voi.SetInputData(imageData) + #select one slice in Z + extent = [ i for i in self.img3D.GetExtent()] + extent[self.sliceOrientation * 2] = sliceno + extent[self.sliceOrientation * 2 + 1] = sliceno + voi.SetVOI(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + + voi.Update() + # set window/level for all slices + if imageAccumulate == None: + imageAccumulate = vtk.vtkImageAccumulate() + + if (windowLevel == None): + windowLevel = vtk.vtkImageMapToWindowLevelColors() + imageAccumulate.SetInputData(imageData) + imageAccumulate.Update() + cmax = imageAccumulate.GetMax()[0] + cmin = imageAccumulate.GetMin()[0] + windowLevel.SetLevel((cmax+cmin)/2) + windowLevel.SetWindow(cmax-cmin) + + windowLevel.SetInputData(voi.GetOutput()) + windowLevel.Update() + + if imageActor == None: + imageActor = vtk.vtkImageActor() + imageActor.SetInputData(windowLevel.GetOutput()) + imageActor.SetDisplayExtent(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + imageActor.Update() + return (imageActor , voi, windowLevel, imageAccumulate) + + + # Set interpolation on + def setInterpolateOn(self): + self.sliceActor.SetInterpolate(True) + self.renWin.Render() + + # Set interpolation off + def setInterpolateOff(self): + self.sliceActor.SetInterpolate(False) + self.renWin.Render() \ No newline at end of file diff --git a/src/Python/ccpi/viewer/CILViewer2D.py b/src/Python/ccpi/viewer/CILViewer2D.py new file mode 100644 index 0000000..c1629af --- /dev/null +++ b/src/Python/ccpi/viewer/CILViewer2D.py @@ -0,0 +1,1126 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Edoardo Pasca +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import vtk +import numpy +from vtk.util import numpy_support , vtkImageImportFromArray +from enum import Enum + +SLICE_ORIENTATION_XY = 2 # Z +SLICE_ORIENTATION_XZ = 1 # Y +SLICE_ORIENTATION_YZ = 0 # X + +CONTROL_KEY = 8 +SHIFT_KEY = 4 +ALT_KEY = -128 + + +# Converter class +class Converter(): + + # Utility functions to transform numpy arrays to vtkImageData and viceversa + @staticmethod + def numpy2vtkImporter(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): + '''Creates a vtkImageImportFromArray object and returns it. + + It handles the different axis order from numpy to VTK''' + importer = vtkImageImportFromArray.vtkImageImportFromArray() + importer.SetArray(numpy.transpose(nparray).copy()) + importer.SetDataSpacing(spacing) + importer.SetDataOrigin(origin) + return importer + + @staticmethod + def numpy2vtk(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): + '''Converts a 3D numpy array to a vtkImageData''' + importer = Converter.numpy2vtkImporter(nparray, spacing, origin) + importer.Update() + return importer.GetOutput() + + @staticmethod + def vtk2numpy(imgdata): + '''Converts the VTK data to 3D numpy array''' + img_data = numpy_support.vtk_to_numpy( + imgdata.GetPointData().GetScalars()) + + dims = imgdata.GetDimensions() + dims = (dims[2],dims[1],dims[0]) + data3d = numpy.reshape(img_data, dims) + + return numpy.transpose(data3d).copy() + + @staticmethod + def tiffStack2numpy(filename, indices, + extent = None , sampleRate = None ,\ + flatField = None, darkField = None): + '''Converts a stack of TIFF files to numpy array. + + filename must contain the whole path. The filename is supposed to be named and + have a suffix with the ordinal file number, i.e. /path/to/projection_%03d.tif + + indices are the suffix, generally an increasing number + + Optionally extracts only a selection of the 2D images and (optionally) + normalizes. + ''' + + stack = vtk.vtkImageData() + reader = vtk.vtkTIFFReader() + voi = vtk.vtkExtractVOI() + + #directory = "C:\\Users\\ofn77899\\Documents\\CCPi\\IMAT\\20170419_crabtomo\\crabtomo\\" + + stack_image = numpy.asarray([]) + nreduced = len(indices) + + for num in range(len(indices)): + fn = filename % indices[num] + print ("resampling %s" % ( fn ) ) + reader.SetFileName(fn) + reader.Update() + print (reader.GetOutput().GetScalarTypeAsString()) + if num == 0: + if (extent == None): + sliced = reader.GetOutput().GetExtent() + stack.SetExtent(sliced[0],sliced[1], sliced[2],sliced[3], 0, nreduced-1) + else: + sliced = extent + voi.SetVOI(extent) + + if sampleRate is not None: + voi.SetSampleRate(sampleRate) + ext = numpy.asarray([(sliced[2*i+1] - sliced[2*i])/sampleRate[i] for i in range(3)], dtype=int) + print ("ext {0}".format(ext)) + stack.SetExtent(0, ext[0] , 0, ext[1], 0, nreduced-1) + else: + stack.SetExtent(0, sliced[1] - sliced[0] , 0, sliced[3]-sliced[2], 0, nreduced-1) + if (flatField != None and darkField != None): + stack.AllocateScalars(vtk.VTK_FLOAT, 1) + else: + stack.AllocateScalars(reader.GetOutput().GetScalarType(), 1) + print ("Image Size: %d" % ((sliced[1]+1)*(sliced[3]+1) )) + stack_image = Converter.vtk2numpy(stack) + print ("Stack shape %s" % str(numpy.shape(stack_image))) + + if extent!=None: + voi.SetInputData(reader.GetOutput()) + voi.Update() + img = voi.GetOutput() + else: + img = reader.GetOutput() + + theSlice = Converter.vtk2numpy(img).T[0] + if darkField != None and flatField != None: + print("Try to normalize") + #if numpy.shape(darkField) == numpy.shape(flatField) and numpy.shape(flatField) == numpy.shape(theSlice): + theSlice = Converter.normalize(theSlice, darkField, flatField, 0.01) + print (theSlice.dtype) + + + print ("Slice shape %s" % str(numpy.shape(theSlice))) + stack_image.T[num] = theSlice.copy() + + return stack_image + + @staticmethod + def normalize(projection, dark, flat, def_val=0): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + + +## Utility functions to transform numpy arrays to vtkImageData and viceversa +#def numpy2vtkImporter(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): +# return Converter.numpy2vtkImporter(nparray, spacing, origin) +# +#def numpy2vtk(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): +# return Converter.numpy2vtk(nparray, spacing, origin) +# +#def vtk2numpy(imgdata): +# return Converter.vtk2numpy(imgdata) +# +#def tiffStack2numpy(filename, indices): +# return Converter.tiffStack2numpy(filename, indices) + +class ViewerEvent(Enum): + # left button + PICK_EVENT = 0 + # alt + right button + move + WINDOW_LEVEL_EVENT = 1 + # shift + right button + ZOOM_EVENT = 2 + # control + right button + PAN_EVENT = 3 + # control + left button + CREATE_ROI_EVENT = 4 + # alt + left button + DELETE_ROI_EVENT = 5 + # release button + NO_EVENT = -1 + + +#class CILInteractorStyle(vtk.vtkInteractorStyleTrackballCamera): +class CILInteractorStyle(vtk.vtkInteractorStyleImage): + + def __init__(self, callback): + vtk.vtkInteractorStyleImage.__init__(self) + self.callback = callback + self._viewer = callback + priority = 1.0 + +# self.AddObserver("MouseWheelForwardEvent" , callback.OnMouseWheelForward , priority) +# self.AddObserver("MouseWheelBackwardEvent" , callback.OnMouseWheelBackward, priority) +# self.AddObserver('KeyPressEvent', callback.OnKeyPress, priority) +# self.AddObserver('LeftButtonPressEvent', callback.OnLeftButtonPressEvent, priority) +# self.AddObserver('RightButtonPressEvent', callback.OnRightButtonPressEvent, priority) +# self.AddObserver('LeftButtonReleaseEvent', callback.OnLeftButtonReleaseEvent, priority) +# self.AddObserver('RightButtonReleaseEvent', callback.OnRightButtonReleaseEvent, priority) +# self.AddObserver('MouseMoveEvent', callback.OnMouseMoveEvent, priority) + + self.AddObserver("MouseWheelForwardEvent" , self.OnMouseWheelForward , priority) + self.AddObserver("MouseWheelBackwardEvent" , self.OnMouseWheelBackward, priority) + self.AddObserver('KeyPressEvent', self.OnKeyPress, priority) + self.AddObserver('LeftButtonPressEvent', self.OnLeftButtonPressEvent, priority) + self.AddObserver('RightButtonPressEvent', self.OnRightButtonPressEvent, priority) + self.AddObserver('LeftButtonReleaseEvent', self.OnLeftButtonReleaseEvent, priority) + self.AddObserver('RightButtonReleaseEvent', self.OnRightButtonReleaseEvent, priority) + self.AddObserver('MouseMoveEvent', self.OnMouseMoveEvent, priority) + + self.InitialEventPosition = (0,0) + + + def SetInitialEventPosition(self, xy): + self.InitialEventPosition = xy + + def GetInitialEventPosition(self): + return self.InitialEventPosition + + def GetKeyCode(self): + return self.GetInteractor().GetKeyCode() + + def SetKeyCode(self, keycode): + self.GetInteractor().SetKeyCode(keycode) + + def GetControlKey(self): + return self.GetInteractor().GetControlKey() == CONTROL_KEY + + def GetShiftKey(self): + return self.GetInteractor().GetShiftKey() == SHIFT_KEY + + def GetAltKey(self): + return self.GetInteractor().GetAltKey() == ALT_KEY + + def GetEventPosition(self): + return self.GetInteractor().GetEventPosition() + + def GetEventPositionInWorldCoordinates(self): + pass + + def GetDeltaEventPosition(self): + x,y = self.GetInteractor().GetEventPosition() + return (x - self.InitialEventPosition[0] , y - self.InitialEventPosition[1]) + + def Dolly(self, factor): + self.callback.camera.Dolly(factor) + self.callback.ren.ResetCameraClippingRange() + + def GetDimensions(self): + return self._viewer.img3D.GetDimensions() + + def GetInputData(self): + return self._viewer.img3D + + def GetSliceOrientation(self): + return self._viewer.sliceOrientation + + def SetSliceOrientation(self, orientation): + self._viewer.sliceOrientation = orientation + + def GetActiveSlice(self): + return self._viewer.sliceno + + def SetActiveSlice(self, sliceno): + self._viewer.sliceno = sliceno + + def UpdatePipeline(self, reset = False): + self._viewer.updatePipeline(reset) + + def GetActiveCamera(self): + return self._viewer.ren.GetActiveCamera() + + def SetActiveCamera(self, camera): + self._viewer.ren.SetActiveCamera(camera) + + def ResetCamera(self): + self._viewer.ren.ResetCamera() + + def Render(self): + self._viewer.renWin.Render() + + def UpdateSliceActor(self): + self._viewer.sliceActor.Update() + + def AdjustCamera(self): + self._viewer.AdjustCamera() + + def SaveRender(self, filename): + self._viewer.SaveRender(filename) + + def GetRenderWindow(self): + return self._viewer.renWin + + def GetRenderer(self): + return self._viewer.ren + + def GetROIWidget(self): + return self._viewer.ROIWidget + + def SetViewerEvent(self, event): + self._viewer.event = event + + def GetViewerEvent(self): + return self._viewer.event + + def SetInitialCameraPosition(self, position): + self._viewer.InitialCameraPosition = position + + def GetInitialCameraPosition(self): + return self._viewer.InitialCameraPosition + + def SetInitialLevel(self, level): + self._viewer.InitialLevel = level + + def GetInitialLevel(self): + return self._viewer.InitialLevel + + def SetInitialWindow(self, window): + self._viewer.InitialWindow = window + + def GetInitialWindow(self): + return self._viewer.InitialWindow + + def GetWindowLevel(self): + return self._viewer.wl + + def SetROI(self, roi): + self._viewer.ROI = roi + + def GetROI(self): + return self._viewer.ROI + + def UpdateCornerAnnotation(self, text, corner): + self._viewer.updateCornerAnnotation(text, corner) + + def GetPicker(self): + return self._viewer.picker + + def GetCornerAnnotation(self): + return self._viewer.cornerAnnotation + + def UpdateROIHistogram(self): + self._viewer.updateROIHistogram() + + + ############### Handle events + def OnMouseWheelForward(self, interactor, event): + maxSlice = self.GetDimensions()[self.GetSliceOrientation()] + shift = interactor.GetShiftKey() + advance = 1 + if shift: + advance = 10 + + if (self.GetActiveSlice() + advance < maxSlice): + self.SetActiveSlice(self.GetActiveSlice() + advance) + + self.UpdatePipeline() + else: + print ("maxSlice %d request %d" % (maxSlice, self.GetActiveSlice() + 1 )) + + def OnMouseWheelBackward(self, interactor, event): + minSlice = 0 + shift = interactor.GetShiftKey() + advance = 1 + if shift: + advance = 10 + if (self.GetActiveSlice() - advance >= minSlice): + self.SetActiveSlice( self.GetActiveSlice() - advance) + self.UpdatePipeline() + else: + print ("minSlice %d request %d" % (minSlice, self.GetActiveSlice() + 1 )) + + def OnKeyPress(self, interactor, event): + #print ("Pressed key %s" % interactor.GetKeyCode()) + # Slice Orientation + if interactor.GetKeyCode() == "X": + # slice on the other orientation + self.SetSliceOrientation ( SLICE_ORIENTATION_YZ ) + self.SetActiveSlice( int(self.GetDimensions()[1] / 2) ) + self.UpdatePipeline(True) + elif interactor.GetKeyCode() == "Y": + # slice on the other orientation + self.SetSliceOrientation ( SLICE_ORIENTATION_XZ ) + self.SetActiveSlice ( int(self.GetInputData().GetDimensions()[1] / 2) ) + self.UpdatePipeline(True) + elif interactor.GetKeyCode() == "Z": + # slice on the other orientation + self.SetSliceOrientation ( SLICE_ORIENTATION_XY ) + self.SetActiveSlice ( int(self.GetInputData().GetDimensions()[2] / 2) ) + self.UpdatePipeline(True) + if interactor.GetKeyCode() == "x": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_YZ] = numpy.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,0,-1) + self.SetActiveCamera(camera) + self.Render() + interactor.SetKeyCode("X") + self.OnKeyPress(interactor, event) + elif interactor.GetKeyCode() == "y": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_XZ] = numpy.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_YZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,0,-1) + self.SetActiveCamera(camera) + self.Render() + interactor.SetKeyCode("Y") + self.OnKeyPress(interactor, event) + elif interactor.GetKeyCode() == "z": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_XY] = numpy.sqrt(newposition[SLICE_ORIENTATION_YZ] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,1,0) + self.SetActiveCamera(camera) + self.ResetCamera() + self.Render() + interactor.SetKeyCode("Z") + self.OnKeyPress(interactor, event) + elif interactor.GetKeyCode() == "a": + # reset color/window + cmax = self._viewer.ia.GetMax()[0] + cmin = self._viewer.ia.GetMin()[0] + + self.SetInitialLevel( (cmax+cmin)/2 ) + self.SetInitialWindow( cmax-cmin ) + + self.GetWindowLevel().SetLevel(self.GetInitialLevel()) + self.GetWindowLevel().SetWindow(self.GetInitialWindow()) + + self.GetWindowLevel().Update() + + self.UpdateSliceActor() + self.AdjustCamera() + self.Render() + + elif interactor.GetKeyCode() == "s": + filename = "current_render" + self.SaveRender(filename) + elif interactor.GetKeyCode() == "q": + print ("Terminating by pressing q %s" % (interactor.GetKeyCode(), )) + interactor.SetKeyCode("e") + self.OnKeyPress(interactor, event) + else : + #print ("Unhandled event %s" % (interactor.GetKeyCode(), ))) + pass + + def OnLeftButtonPressEvent(self, interactor, event): + alt = interactor.GetAltKey() + shift = interactor.GetShiftKey() + ctrl = interactor.GetControlKey() +# print ("alt pressed " + (lambda x : "Yes" if x else "No")(alt)) +# print ("shift pressed " + (lambda x : "Yes" if x else "No")(shift)) +# print ("ctrl pressed " + (lambda x : "Yes" if x else "No")(ctrl)) + + interactor.SetInitialEventPosition(interactor.GetEventPosition()) + + if ctrl and not (alt and shift): + self.SetViewerEvent( ViewerEvent.CREATE_ROI_EVENT ) + wsize = self.GetRenderWindow().GetSize() + position = interactor.GetEventPosition() + self.GetROIWidget().GetBorderRepresentation().SetPosition((position[0]/wsize[0] - 0.05) , (position[1]/wsize[1] - 0.05)) + self.GetROIWidget().GetBorderRepresentation().SetPosition2( (0.1) , (0.1)) + + self.GetROIWidget().On() + self.SetDisplayHistogram(True) + self.Render() + print ("Event %s is CREATE_ROI_EVENT" % (event)) + elif alt and not (shift and ctrl): + self.SetViewerEvent( ViewerEvent.DELETE_ROI_EVENT ) + self.GetROIWidget().Off() + self._viewer.updateCornerAnnotation("", 1, False) + self.SetDisplayHistogram(False) + self.Render() + print ("Event %s is DELETE_ROI_EVENT" % (event)) + elif not (ctrl and alt and shift): + self.SetViewerEvent ( ViewerEvent.PICK_EVENT ) + self.HandlePickEvent(interactor, event) + print ("Event %s is PICK_EVENT" % (event)) + + + def SetDisplayHistogram(self, display): + if display: + if (self._viewer.displayHistogram == 0): + self.GetRenderer().AddActor(self._viewer.histogramPlotActor) + self.firstHistogram = 1 + self.Render() + + self._viewer.histogramPlotActor.VisibilityOn() + self._viewer.displayHistogram = True + else: + self._viewer.histogramPlotActor.VisibilityOff() + self._viewer.displayHistogram = False + + + def OnLeftButtonReleaseEvent(self, interactor, event): + if self.GetViewerEvent() == ViewerEvent.CREATE_ROI_EVENT: + #bc = self.ROIWidget.GetBorderRepresentation().GetPositionCoordinate() + #print (bc.GetValue()) + self.OnROIModifiedEvent(interactor, event) + + elif self.GetViewerEvent() == ViewerEvent.PICK_EVENT: + self.HandlePickEvent(interactor, event) + + self.SetViewerEvent( ViewerEvent.NO_EVENT ) + + def OnRightButtonPressEvent(self, interactor, event): + alt = interactor.GetAltKey() + shift = interactor.GetShiftKey() + ctrl = interactor.GetControlKey() +# print ("alt pressed " + (lambda x : "Yes" if x else "No")(alt)) +# print ("shift pressed " + (lambda x : "Yes" if x else "No")(shift)) +# print ("ctrl pressed " + (lambda x : "Yes" if x else "No")(ctrl)) + + interactor.SetInitialEventPosition(interactor.GetEventPosition()) + + + if alt and not (ctrl and shift): + self.SetViewerEvent( ViewerEvent.WINDOW_LEVEL_EVENT ) + print ("Event %s is WINDOW_LEVEL_EVENT" % (event)) + self.HandleWindowLevel(interactor, event) + elif shift and not (ctrl and alt): + self.SetViewerEvent( ViewerEvent.ZOOM_EVENT ) + self.SetInitialCameraPosition( self.GetActiveCamera().GetPosition()) + print ("Event %s is ZOOM_EVENT" % (event)) + elif ctrl and not (shift and alt): + self.SetViewerEvent (ViewerEvent.PAN_EVENT ) + self.SetInitialCameraPosition ( self.GetActiveCamera().GetPosition() ) + print ("Event %s is PAN_EVENT" % (event)) + + def OnRightButtonReleaseEvent(self, interactor, event): + print (event) + if self.GetViewerEvent() == ViewerEvent.WINDOW_LEVEL_EVENT: + self.SetInitialLevel( self.GetWindowLevel().GetLevel() ) + self.SetInitialWindow ( self.GetWindowLevel().GetWindow() ) + elif self.GetViewerEvent() == ViewerEvent.ZOOM_EVENT or \ + self.GetViewerEvent() == ViewerEvent.PAN_EVENT: + self.SetInitialCameraPosition( () ) + + self.SetViewerEvent( ViewerEvent.NO_EVENT ) + + + def OnROIModifiedEvent(self, interactor, event): + + #print ("ROI EVENT " + event) + p1 = self.GetROIWidget().GetBorderRepresentation().GetPositionCoordinate() + p2 = self.GetROIWidget().GetBorderRepresentation().GetPosition2Coordinate() + wsize = self.GetRenderWindow().GetSize() + + #print (p1.GetValue()) + #print (p2.GetValue()) + pp1 = [p1.GetValue()[0] * wsize[0] , p1.GetValue()[1] * wsize[1] , 0.0] + pp2 = [p2.GetValue()[0] * wsize[0] + pp1[0] , p2.GetValue()[1] * wsize[1] + pp1[1] , 0.0] + vox1 = self.viewport2imageCoordinate(pp1) + vox2 = self.viewport2imageCoordinate(pp2) + + self.SetROI( (vox1 , vox2) ) + roi = self.GetROI() + print ("Pixel1 %d,%d,%d Value %f" % vox1 ) + print ("Pixel2 %d,%d,%d Value %f" % vox2 ) + if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: + print ("slice orientation : XY") + x = abs(roi[1][0] - roi[0][0]) + y = abs(roi[1][1] - roi[0][1]) + elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: + print ("slice orientation : XY") + x = abs(roi[1][0] - roi[0][0]) + y = abs(roi[1][2] - roi[0][2]) + elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: + print ("slice orientation : XY") + x = abs(roi[1][1] - roi[0][1]) + y = abs(roi[1][2] - roi[0][2]) + + text = "ROI: %d x %d, %.2f kp" % (x,y,float(x*y)/1024.) + print (text) + self.UpdateCornerAnnotation(text, 1) + self.UpdateROIHistogram() + self.SetViewerEvent( ViewerEvent.NO_EVENT ) + + def viewport2imageCoordinate(self, viewerposition): + #Determine point index + + self.GetPicker().Pick(viewerposition[0], viewerposition[1], 0.0, self.GetRenderer()) + pickPosition = list(self.GetPicker().GetPickPosition()) + pickPosition[self.GetSliceOrientation()] = \ + self.GetInputData().GetSpacing()[self.GetSliceOrientation()] * self.GetActiveSlice() + \ + self.GetInputData().GetOrigin()[self.GetSliceOrientation()] + print ("Pick Position " + str (pickPosition)) + + if (pickPosition != [0,0,0]): + dims = self.GetInputData().GetDimensions() + print (dims) + spac = self.GetInputData().GetSpacing() + orig = self.GetInputData().GetOrigin() + imagePosition = [int(pickPosition[i] / spac[i] + orig[i]) for i in range(3) ] + + pixelValue = self.GetInputData().GetScalarComponentAsDouble(imagePosition[0], imagePosition[1], imagePosition[2], 0) + return (imagePosition[0], imagePosition[1], imagePosition[2] , pixelValue) + else: + return (0,0,0,0) + + + + + def OnMouseMoveEvent(self, interactor, event): + if self.GetViewerEvent() == ViewerEvent.WINDOW_LEVEL_EVENT: + print ("Event %s is WINDOW_LEVEL_EVENT" % (event)) + self.HandleWindowLevel(interactor, event) + elif self.GetViewerEvent() == ViewerEvent.PICK_EVENT: + self.HandlePickEvent(interactor, event) + elif self.GetViewerEvent() == ViewerEvent.ZOOM_EVENT: + self.HandleZoomEvent(interactor, event) + elif self.GetViewerEvent() == ViewerEvent.PAN_EVENT: + self.HandlePanEvent(interactor, event) + + + def HandleZoomEvent(self, interactor, event): + dx,dy = interactor.GetDeltaEventPosition() + size = self.GetRenderWindow().GetSize() + dy = - 4 * dy / size[1] + + print ("distance: " + str(self.GetActiveCamera().GetDistance())) + + print ("\ndy: %f\ncamera dolly %f\n" % (dy, 1 + dy)) + + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) + #print ("current position " + str(self.InitialCameraPosition)) + camera.SetViewUp(self.GetActiveCamera().GetViewUp()) + camera.SetPosition(self.GetInitialCameraPosition()) + newposition = [i for i in self.GetInitialCameraPosition()] + if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: + dist = newposition[SLICE_ORIENTATION_XY] * ( 1 + dy ) + newposition[SLICE_ORIENTATION_XY] *= ( 1 + dy ) + elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: + newposition[SLICE_ORIENTATION_XZ] *= ( 1 + dy ) + elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: + newposition[SLICE_ORIENTATION_YZ] *= ( 1 + dy ) + #print ("new position " + str(newposition)) + camera.SetPosition(newposition) + self.SetActiveCamera(camera) + + self.Render() + + print ("distance after: " + str(self.GetActiveCamera().GetDistance())) + + def HandlePanEvent(self, interactor, event): + x,y = interactor.GetEventPosition() + x0,y0 = interactor.GetInitialEventPosition() + + ic = self.viewport2imageCoordinate((x,y)) + ic0 = self.viewport2imageCoordinate((x0,y0)) + + dx = 4 *( ic[0] - ic0[0]) + dy = 4* (ic[1] - ic0[1]) + + camera = vtk.vtkCamera() + #print ("current position " + str(self.InitialCameraPosition)) + camera.SetViewUp(self.GetActiveCamera().GetViewUp()) + camera.SetPosition(self.GetInitialCameraPosition()) + newposition = [i for i in self.GetInitialCameraPosition()] + newfocalpoint = [i for i in self.GetActiveCamera().GetFocalPoint()] + if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: + newposition[0] -= dx + newposition[1] -= dy + newfocalpoint[0] = newposition[0] + newfocalpoint[1] = newposition[1] + elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: + newposition[0] -= dx + newposition[2] -= dy + newfocalpoint[0] = newposition[0] + newfocalpoint[2] = newposition[2] + elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: + newposition[1] -= dx + newposition[2] -= dy + newfocalpoint[2] = newposition[2] + newfocalpoint[1] = newposition[1] + #print ("new position " + str(newposition)) + camera.SetFocalPoint(newfocalpoint) + camera.SetPosition(newposition) + self.SetActiveCamera(camera) + + self.Render() + + def HandleWindowLevel(self, interactor, event): + dx,dy = interactor.GetDeltaEventPosition() + print ("Event delta %d %d" % (dx,dy)) + size = self.GetRenderWindow().GetSize() + + dx = 4 * dx / size[0] + dy = 4 * dy / size[1] + window = self.GetInitialWindow() + level = self.GetInitialLevel() + + if abs(window) > 0.01: + dx = dx * window + else: + dx = dx * (lambda x: -0.01 if x <0 else 0.01)(window); + + if abs(level) > 0.01: + dy = dy * level + else: + dy = dy * (lambda x: -0.01 if x <0 else 0.01)(level) + + + # Abs so that direction does not flip + + if window < 0.0: + dx = -1*dx + if level < 0.0: + dy = -1*dy + + # Compute new window level + + newWindow = dx + window + newLevel = level - dy + + # Stay away from zero and really + + if abs(newWindow) < 0.01: + newWindow = 0.01 * (lambda x: -1 if x <0 else 1)(newWindow) + + if abs(newLevel) < 0.01: + newLevel = 0.01 * (lambda x: -1 if x <0 else 1)(newLevel) + + self.GetWindowLevel().SetWindow(newWindow) + self.GetWindowLevel().SetLevel(newLevel) + + self.GetWindowLevel().Update() + self.UpdateSliceActor() + self.AdjustCamera() + + self.Render() + + def HandlePickEvent(self, interactor, event): + position = interactor.GetEventPosition() + #print ("PICK " + str(position)) + vox = self.viewport2imageCoordinate(position) + #print ("Pixel %d,%d,%d Value %f" % vox ) + self._viewer.cornerAnnotation.VisibilityOn() + self.UpdateCornerAnnotation("[%d,%d,%d] : %.2f" % vox , 0) + self.Render() + +############################################################################### + + + +class CILViewer2D(): + '''Simple Interactive Viewer based on VTK classes''' + + def __init__(self, dimx=600,dimy=600, ren=None, renWin=None,iren=None): + '''creates the rendering pipeline''' + # create a rendering window and renderer + if ren == None: + self.ren = vtk.vtkRenderer() + else: + self.ren = ren + if renWin == None: + self.renWin = vtk.vtkRenderWindow() + else: + self.renWin = renWin + if iren == None: + self.iren = vtk.vtkRenderWindowInteractor() + else: + self.iren = iren + + self.renWin.SetSize(dimx,dimy) + self.renWin.AddRenderer(self.ren) + + self.style = CILInteractorStyle(self) + + self.iren.SetInteractorStyle(self.style) + self.iren.SetRenderWindow(self.renWin) + self.iren.Initialize() + self.ren.SetBackground(.1, .2, .4) + + self.camera = vtk.vtkCamera() + self.camera.ParallelProjectionOn() + self.ren.SetActiveCamera(self.camera) + + # data + self.img3D = None + self.sliceno = 0 + self.sliceOrientation = SLICE_ORIENTATION_XY + + #Actors + self.sliceActor = vtk.vtkImageActor() + self.voi = vtk.vtkExtractVOI() + self.wl = vtk.vtkImageMapToWindowLevelColors() + self.ia = vtk.vtkImageAccumulate() + self.sliceActorNo = 0 + + #initial Window/Level + self.InitialLevel = 0 + self.InitialWindow = 0 + + #ViewerEvent + self.event = ViewerEvent.NO_EVENT + + # ROI Widget + self.ROIWidget = vtk.vtkBorderWidget() + self.ROIWidget.SetInteractor(self.iren) + self.ROIWidget.CreateDefaultRepresentation() + self.ROIWidget.GetBorderRepresentation().GetBorderProperty().SetColor(0,1,0) + self.ROIWidget.AddObserver(vtk.vtkWidgetEvent.Select, self.style.OnROIModifiedEvent, 1.0) + + # edge points of the ROI + self.ROI = () + + #picker + self.picker = vtk.vtkPropPicker() + self.picker.PickFromListOn() + self.picker.AddPickList(self.sliceActor) + + self.iren.SetPicker(self.picker) + + # corner annotation + self.cornerAnnotation = vtk.vtkCornerAnnotation() + self.cornerAnnotation.SetMaximumFontSize(12); + self.cornerAnnotation.PickableOff(); + self.cornerAnnotation.VisibilityOff(); + self.cornerAnnotation.GetTextProperty().ShadowOn(); + self.cornerAnnotation.SetLayerNumber(1); + + + + # cursor doesn't show up + self.cursor = vtk.vtkCursor2D() + self.cursorMapper = vtk.vtkPolyDataMapper2D() + self.cursorActor = vtk.vtkActor2D() + self.cursor.SetModelBounds(-10, 10, -10, 10, 0, 0) + self.cursor.SetFocalPoint(0, 0, 0) + self.cursor.AllOff() + self.cursor.AxesOn() + self.cursorActor.PickableOff() + self.cursorActor.VisibilityOn() + self.cursorActor.GetProperty().SetColor(1, 1, 1) + self.cursorActor.SetLayerNumber(1) + self.cursorMapper.SetInputData(self.cursor.GetOutput()) + self.cursorActor.SetMapper(self.cursorMapper) + + # Zoom + self.InitialCameraPosition = () + + # XY Plot actor for histogram + self.displayHistogram = False + self.firstHistogram = 0 + self.roiIA = vtk.vtkImageAccumulate() + self.roiVOI = vtk.vtkExtractVOI() + self.histogramPlotActor = vtk.vtkXYPlotActor() + self.histogramPlotActor.ExchangeAxesOff(); + self.histogramPlotActor.SetXLabelFormat( "%g" ) + self.histogramPlotActor.SetXLabelFormat( "%g" ) + self.histogramPlotActor.SetAdjustXLabels(3) + self.histogramPlotActor.SetXTitle( "Level" ) + self.histogramPlotActor.SetYTitle( "N" ) + self.histogramPlotActor.SetXValuesToValue() + self.histogramPlotActor.SetPlotColor(0, (0,1,1) ) + self.histogramPlotActor.SetPosition(0.6,0.6) + self.histogramPlotActor.SetPosition2(0.4,0.4) + + + + def GetInteractor(self): + return self.iren + + def GetRenderer(self): + return self.ren + + def setInput3DData(self, imageData): + self.img3D = imageData + self.installPipeline() + + def setInputAsNumpy(self, numpyarray, origin=(0,0,0), spacing=(1.,1.,1.), + rescale=True, dtype=vtk.VTK_UNSIGNED_SHORT): + importer = Converter.numpy2vtkImporter(numpyarray, spacing, origin) + importer.Update() + + if rescale: + # rescale to appropriate VTK_UNSIGNED_SHORT + stats = vtk.vtkImageAccumulate() + stats.SetInputData(importer.GetOutput()) + stats.Update() + iMin = stats.GetMin()[0] + iMax = stats.GetMax()[0] + if (iMax - iMin == 0): + scale = 1 + else: + if dtype == vtk.VTK_UNSIGNED_SHORT: + scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) + elif dtype == vtk.VTK_UNSIGNED_INT: + scale = vtk.VTK_UNSIGNED_INT_MAX / (iMax - iMin) + + shiftScaler = vtk.vtkImageShiftScale () + shiftScaler.SetInputData(importer.GetOutput()) + shiftScaler.SetScale(scale) + shiftScaler.SetShift(-iMin) + shiftScaler.SetOutputScalarType(dtype) + shiftScaler.Update() + self.img3D = shiftScaler.GetOutput() + else: + self.img3D = importer.GetOutput() + + self.installPipeline() + + def displaySlice(self, sliceno = 0): + self.sliceno = sliceno + + self.updatePipeline() + + self.renWin.Render() + + return self.sliceActorNo + + def updatePipeline(self, resetcamera = False): + extent = [ i for i in self.img3D.GetExtent()] + extent[self.sliceOrientation * 2] = self.sliceno + extent[self.sliceOrientation * 2 + 1] = self.sliceno + self.voi.SetVOI(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + + self.voi.Update() + self.ia.Update() + self.wl.Update() + self.sliceActor.SetDisplayExtent(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + self.sliceActor.Update() + + self.updateCornerAnnotation("Slice %d/%d" % (self.sliceno + 1 , self.img3D.GetDimensions()[self.sliceOrientation])) + + if self.displayHistogram: + self.updateROIHistogram() + + self.AdjustCamera(resetcamera) + + self.renWin.Render() + + + def installPipeline(self): + '''Slices a 3D volume and then creates an actor to be rendered''' + + self.ren.AddViewProp(self.cornerAnnotation) + + self.voi.SetInputData(self.img3D) + #select one slice in Z + extent = [ i for i in self.img3D.GetExtent()] + extent[self.sliceOrientation * 2] = self.sliceno + extent[self.sliceOrientation * 2 + 1] = self.sliceno + self.voi.SetVOI(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + + self.voi.Update() + # set window/level for current slices + + + self.wl = vtk.vtkImageMapToWindowLevelColors() + self.ia.SetInputData(self.voi.GetOutput()) + self.ia.Update() + cmax = self.ia.GetMax()[0] + cmin = self.ia.GetMin()[0] + + self.InitialLevel = (cmax+cmin)/2 + self.InitialWindow = cmax-cmin + + + self.wl.SetLevel(self.InitialLevel) + self.wl.SetWindow(self.InitialWindow) + + self.wl.SetInputData(self.voi.GetOutput()) + self.wl.Update() + + self.sliceActor.SetInputData(self.wl.GetOutput()) + self.sliceActor.SetDisplayExtent(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + self.sliceActor.Update() + self.sliceActor.SetInterpolate(False) + self.ren.AddActor(self.sliceActor) + self.ren.ResetCamera() + self.ren.Render() + + self.AdjustCamera() + + self.ren.AddViewProp(self.cursorActor) + self.cursorActor.VisibilityOn() + + self.iren.Initialize() + self.renWin.Render() + #self.iren.Start() + + def AdjustCamera(self, resetcamera = False): + self.ren.ResetCameraClippingRange() + if resetcamera: + self.ren.ResetCamera() + + + def getROI(self): + return self.ROI + + def getROIExtent(self): + p0 = self.ROI[0] + p1 = self.ROI[1] + return (p0[0], p1[0],p0[1],p1[1],p0[2],p1[2]) + + ############### Handle events are moved to the interactor style + + + def viewport2imageCoordinate(self, viewerposition): + #Determine point index + + self.picker.Pick(viewerposition[0], viewerposition[1], 0.0, self.GetRenderer()) + pickPosition = list(self.picker.GetPickPosition()) + pickPosition[self.sliceOrientation] = \ + self.img3D.GetSpacing()[self.sliceOrientation] * self.sliceno + \ + self.img3D.GetOrigin()[self.sliceOrientation] + print ("Pick Position " + str (pickPosition)) + + if (pickPosition != [0,0,0]): + dims = self.img3D.GetDimensions() + print (dims) + spac = self.img3D.GetSpacing() + orig = self.img3D.GetOrigin() + imagePosition = [int(pickPosition[i] / spac[i] + orig[i]) for i in range(3) ] + + pixelValue = self.img3D.GetScalarComponentAsDouble(imagePosition[0], imagePosition[1], imagePosition[2], 0) + return (imagePosition[0], imagePosition[1], imagePosition[2] , pixelValue) + else: + return (0,0,0,0) + + + + def GetRenderWindow(self): + return self.renWin + + + def startRenderLoop(self): + self.iren.Start() + + def GetSliceOrientation(self): + return self.sliceOrientation + + def GetActiveSlice(self): + return self.sliceno + + def updateCornerAnnotation(self, text , idx=0, visibility=True): + if visibility: + self.cornerAnnotation.VisibilityOn() + else: + self.cornerAnnotation.VisibilityOff() + + self.cornerAnnotation.SetText(idx, text) + self.iren.Render() + + def saveRender(self, filename, renWin=None): + '''Save the render window to PNG file''' + # screenshot code: + w2if = vtk.vtkWindowToImageFilter() + if renWin == None: + renWin = self.renWin + w2if.SetInput(renWin) + w2if.Update() + + writer = vtk.vtkPNGWriter() + writer.SetFileName("%s.png" % (filename)) + writer.SetInputConnection(w2if.GetOutputPort()) + writer.Write() + + def updateROIHistogram(self): + + extent = [0 for i in range(6)] + if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: + print ("slice orientation : XY") + extent[0] = self.ROI[0][0] + extent[1] = self.ROI[1][0] + extent[2] = self.ROI[0][1] + extent[3] = self.ROI[1][1] + extent[4] = self.GetActiveSlice() + extent[5] = self.GetActiveSlice()+1 + #y = abs(roi[1][1] - roi[0][1]) + elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: + print ("slice orientation : XY") + extent[0] = self.ROI[0][0] + extent[1] = self.ROI[1][0] + #x = abs(roi[1][0] - roi[0][0]) + extent[4] = self.ROI[0][2] + extent[5] = self.ROI[1][2] + #y = abs(roi[1][2] - roi[0][2]) + extent[2] = self.GetActiveSlice() + extent[3] = self.GetActiveSlice()+1 + elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: + print ("slice orientation : XY") + extent[2] = self.ROI[0][1] + extent[3] = self.ROI[1][1] + #x = abs(roi[1][1] - roi[0][1]) + extent[4] = self.ROI[0][2] + extent[5] = self.ROI[1][2] + #y = abs(roi[1][2] - roi[0][2]) + extent[0] = self.GetActiveSlice() + extent[1] = self.GetActiveSlice()+1 + + self.roiVOI.SetVOI(extent) + self.roiVOI.SetInputData(self.img3D) + self.roiVOI.Update() + irange = self.roiVOI.GetOutput().GetScalarRange() + + self.roiIA.SetInputData(self.roiVOI.GetOutput()) + self.roiIA.IgnoreZeroOff() + self.roiIA.SetComponentExtent(0,int(irange[1]-irange[0]-1),0,0,0,0 ) + self.roiIA.SetComponentOrigin( int(irange[0]),0,0 ); + self.roiIA.SetComponentSpacing( 1,0,0 ); + self.roiIA.Update() + + self.histogramPlotActor.AddDataSetInputConnection(self.roiIA.GetOutputPort()) + self.histogramPlotActor.SetXRange(irange[0],irange[1]) + + self.histogramPlotActor.SetYRange( self.roiIA.GetOutput().GetScalarRange() ) + + \ No newline at end of file diff --git a/src/Python/ccpi/viewer/QVTKWidget.py b/src/Python/ccpi/viewer/QVTKWidget.py new file mode 100644 index 0000000..906786b --- /dev/null +++ b/src/Python/ccpi/viewer/QVTKWidget.py @@ -0,0 +1,340 @@ +################################################################################ +# File: QVTKWidget.py +# Author: Edoardo Pasca +# Description: PyVE Viewer Qt widget +# +# License: +# This file is part of PyVE. PyVE is an open-source image +# analysis and visualization environment focused on medical +# imaging. More info at http://pyve.sourceforge.net +# +# Copyright (c) 2011-2012 Edoardo Pasca, Lukas Batteau +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or +# without modification, are permitted provided that the following +# conditions are met: +# +# Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. Neither name of Edoardo Pasca or Lukas +# Batteau nor the names of any contributors may be used to endorse +# or promote products derived from this software without specific +# prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, +# INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR +# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT +# OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY +# OF SUCH DAMAGE. +# +# CHANGE HISTORY +# +# 20120118 Edoardo Pasca Initial version +# +############################################################################### + +import os +from PyQt5 import QtCore, QtGui, QtWidgets +#import itk +import vtk +#from viewer import PyveViewer +from ccpi.viewer.CILViewer2D import CILViewer2D , Converter + +class QVTKWidget(QtWidgets.QWidget): + + """ A QVTKWidget for Python and Qt.""" + + # Map between VTK and Qt cursors. + _CURSOR_MAP = { + 0: QtCore.Qt.ArrowCursor, # VTK_CURSOR_DEFAULT + 1: QtCore.Qt.ArrowCursor, # VTK_CURSOR_ARROW + 2: QtCore.Qt.SizeBDiagCursor, # VTK_CURSOR_SIZENE + 3: QtCore.Qt.SizeFDiagCursor, # VTK_CURSOR_SIZENWSE + 4: QtCore.Qt.SizeBDiagCursor, # VTK_CURSOR_SIZESW + 5: QtCore.Qt.SizeFDiagCursor, # VTK_CURSOR_SIZESE + 6: QtCore.Qt.SizeVerCursor, # VTK_CURSOR_SIZENS + 7: QtCore.Qt.SizeHorCursor, # VTK_CURSOR_SIZEWE + 8: QtCore.Qt.SizeAllCursor, # VTK_CURSOR_SIZEALL + 9: QtCore.Qt.PointingHandCursor, # VTK_CURSOR_HAND + 10: QtCore.Qt.CrossCursor, # VTK_CURSOR_CROSSHAIR + } + + def __init__(self, parent=None, wflags=QtCore.Qt.WindowFlags(), **kw): + # the current button + self._ActiveButton = QtCore.Qt.NoButton + + # private attributes + self.__oldFocus = None + self.__saveX = 0 + self.__saveY = 0 + self.__saveModifiers = QtCore.Qt.NoModifier + self.__saveButtons = QtCore.Qt.NoButton + self.__timeframe = 0 + + # create qt-level widget + QtWidgets.QWidget.__init__(self, parent, wflags|QtCore.Qt.MSWindowsOwnDC) + + # Link to PyVE Viewer + self._PyveViewer = CILViewer2D() + #self._Viewer = self._PyveViewer._vtkPyveViewer + + self._Iren = self._PyveViewer.GetInteractor() + #self._Iren = self._Viewer.GetRenderWindow().GetInteractor() + self._RenderWindow = self._PyveViewer.GetRenderWindow() + #self._RenderWindow = self._Viewer.GetRenderWindow() + + self._Iren.Register(self._RenderWindow) + self._Iren.SetRenderWindow(self._RenderWindow) + self._RenderWindow.SetWindowInfo(str(int(self.winId()))) + + # do all the necessary qt setup + self.setAttribute(QtCore.Qt.WA_OpaquePaintEvent) + self.setAttribute(QtCore.Qt.WA_PaintOnScreen) + self.setMouseTracking(True) # get all mouse events + self.setFocusPolicy(QtCore.Qt.WheelFocus) + self.setSizePolicy(QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)) + + self._Timer = QtCore.QTimer(self) + #self.connect(self._Timer, QtCore.pyqtSignal('timeout()'), self.TimerEvent) + + self._Iren.AddObserver('CreateTimerEvent', self.CreateTimer) + self._Iren.AddObserver('DestroyTimerEvent', self.DestroyTimer) + self._Iren.GetRenderWindow().AddObserver('CursorChangedEvent', + self.CursorChangedEvent) + + # Destructor + def __del__(self): + self._Iren.UnRegister(self._RenderWindow) + #QtWidgets.QWidget.__del__(self) + + # Display image data + def SetInput(self, imageData): + self._PyveViewer.setInput3DData(imageData) + + # GetInteractor + def GetInteractor(self): + return self._Iren + + # Display image data + def GetPyveViewer(self): + return self._PyveViewer + + def __getattr__(self, attr): + """Makes the object behave like a vtkGenericRenderWindowInteractor""" + print (attr) + if attr == '__vtk__': + return lambda t=self._Iren: t + elif hasattr(self._Iren, attr): + return getattr(self._Iren, attr) +# else: +# raise AttributeError( self.__class__.__name__ + \ +# " has no attribute named " + attr ) + + def CreateTimer(self, obj, evt): + self._Timer.start(10) + + def DestroyTimer(self, obj, evt): + self._Timer.stop() + return 1 + + def TimerEvent(self): + self._Iren.InvokeEvent("TimerEvent") + + def CursorChangedEvent(self, obj, evt): + """Called when the CursorChangedEvent fires on the render window.""" + # This indirection is needed since when the event fires, the current + # cursor is not yet set so we defer this by which time the current + # cursor should have been set. + QtCore.QTimer.singleShot(0, self.ShowCursor) + + def HideCursor(self): + """Hides the cursor.""" + self.setCursor(QtCore.Qt.BlankCursor) + + def ShowCursor(self): + """Shows the cursor.""" + vtk_cursor = self._Iren.GetRenderWindow().GetCurrentCursor() + qt_cursor = self._CURSOR_MAP.get(vtk_cursor, QtCore.Qt.ArrowCursor) + self.setCursor(qt_cursor) + + def sizeHint(self): + return QtCore.QSize(400, 400) + + def paintEngine(self): + return None + + def paintEvent(self, ev): + self._RenderWindow.Render() + + def resizeEvent(self, ev): + self._RenderWindow.Render() + w = self.width() + h = self.height() + + self._RenderWindow.SetSize(w, h) + self._Iren.SetSize(w, h) + + def _GetCtrlShiftAlt(self, ev): + ctrl = shift = alt = False + + if hasattr(ev, 'modifiers'): + if ev.modifiers() & QtCore.Qt.ShiftModifier: + shift = True + if ev.modifiers() & QtCore.Qt.ControlModifier: + ctrl = True + if ev.modifiers() & QtCore.Qt.AltModifier: + alt = True + else: + if self.__saveModifiers & QtCore.Qt.ShiftModifier: + shift = True + if self.__saveModifiers & QtCore.Qt.ControlModifier: + ctrl = True + if self.__saveModifiers & QtCore.Qt.AltModifier: + alt = True + + return ctrl, shift, alt + + def enterEvent(self, ev): + if not self.hasFocus(): + self.__oldFocus = self.focusWidget() + self.setFocus() + + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, + ctrl, shift, chr(0), 0, None) + self._Iren.SetAltKey(alt) + self._Iren.InvokeEvent("EnterEvent") + + def leaveEvent(self, ev): + if self.__saveButtons == QtCore.Qt.NoButton and self.__oldFocus: + self.__oldFocus.setFocus() + self.__oldFocus = None + + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, + ctrl, shift, chr(0), 0, None) + self._Iren.SetAltKey(alt) + self._Iren.InvokeEvent("LeaveEvent") + + def mousePressEvent(self, ev): + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + repeat = 0 + if ev.type() == QtCore.QEvent.MouseButtonDblClick: + repeat = 1 + self._Iren.SetEventInformationFlipY(ev.x(), ev.y(), + ctrl, shift, chr(0), repeat, None) + + self._Iren.SetAltKey(alt) + self._ActiveButton = ev.button() + + if self._ActiveButton == QtCore.Qt.LeftButton: + self._Iren.InvokeEvent("LeftButtonPressEvent") + elif self._ActiveButton == QtCore.Qt.RightButton: + self._Iren.InvokeEvent("RightButtonPressEvent") + elif self._ActiveButton == QtCore.Qt.MidButton: + self._Iren.InvokeEvent("MiddleButtonPressEvent") + + def mouseReleaseEvent(self, ev): + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + self._Iren.SetEventInformationFlipY(ev.x(), ev.y(), + ctrl, shift, chr(0), 0, None) + self._Iren.SetAltKey(alt) + + if self._ActiveButton == QtCore.Qt.LeftButton: + self._Iren.InvokeEvent("LeftButtonReleaseEvent") + elif self._ActiveButton == QtCore.Qt.RightButton: + self._Iren.InvokeEvent("RightButtonReleaseEvent") + elif self._ActiveButton == QtCore.Qt.MidButton: + self._Iren.InvokeEvent("MiddleButtonReleaseEvent") + + def mouseMoveEvent(self, ev): + self.__saveModifiers = ev.modifiers() + self.__saveButtons = ev.buttons() + self.__saveX = ev.x() + self.__saveY = ev.y() + + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + self._Iren.SetEventInformationFlipY(ev.x(), ev.y(), + ctrl, shift, chr(0), 0, None) + self._Iren.SetAltKey(alt) + self._Iren.InvokeEvent("MouseMoveEvent") + + def keyPressEvent(self, ev): + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + if ev.key() < 256: + key = str(ev.text()) + else: + key = chr(0) + + self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, + ctrl, shift, key, 0, None) + self._Iren.SetAltKey(alt) + self._Iren.InvokeEvent("KeyPressEvent") + self._Iren.InvokeEvent("CharEvent") + + def keyReleaseEvent(self, ev): + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + if ev.key() < 256: + key = chr(ev.key()) + else: + key = chr(0) + + self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, + ctrl, shift, key, 0, None) + self._Iren.SetAltKey(alt) + self._Iren.InvokeEvent("KeyReleaseEvent") + + def wheelEvent(self, ev): + print ("angleDeltaX %d" % ev.angleDelta().x()) + print ("angleDeltaY %d" % ev.angleDelta().y()) + if ev.angleDelta().y() >= 0: + self._Iren.InvokeEvent("MouseWheelForwardEvent") + else: + self._Iren.InvokeEvent("MouseWheelBackwardEvent") + + def GetRenderWindow(self): + return self._RenderWindow + + def Render(self): + self.update() + + +def QVTKExample(): + """A simple example that uses the QVTKWidget class.""" + + # every QT app needs an app + app = QtWidgets.QApplication(['PyVE QVTKWidget Example']) + page_VTK = QtWidgets.QWidget() + page_VTK.resize(500,500) + layout = QtWidgets.QVBoxLayout(page_VTK) + # create the widget + widget = QVTKWidget(parent=None) + layout.addWidget(widget) + + #reader = vtk.vtkPNGReader() + #reader.SetFileName("F:\Diagnostics\Images\PyVE\VTKData\Data\camscene.png") + reader = vtk.vtkMetaImageReader() + reader.SetFileName("C:\\Users\\ofn77899\\Documents\\GitHub\\CCPi-Simpleflex\\data\\head.mha") + reader.Update() + + widget.SetInput(reader.GetOutput()) + + # show the widget + page_VTK.show() + # start event processing + app.exec_() + +if __name__ == "__main__": + QVTKExample() diff --git a/src/Python/ccpi/viewer/QVTKWidget2.py b/src/Python/ccpi/viewer/QVTKWidget2.py new file mode 100644 index 0000000..e32e1c2 --- /dev/null +++ b/src/Python/ccpi/viewer/QVTKWidget2.py @@ -0,0 +1,84 @@ +################################################################################ +# File: QVTKWidget.py +# Author: Edoardo Pasca +# Description: PyVE Viewer Qt widget +# +# License: +# This file is part of PyVE. PyVE is an open-source image +# analysis and visualization environment focused on medical +# imaging. More info at http://pyve.sourceforge.net +# +# Copyright (c) 2011-2012 Edoardo Pasca, Lukas Batteau +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or +# without modification, are permitted provided that the following +# conditions are met: +# +# Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. Neither name of Edoardo Pasca or Lukas +# Batteau nor the names of any contributors may be used to endorse +# or promote products derived from this software without specific +# prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, +# INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR +# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT +# OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY +# OF SUCH DAMAGE. +# +# CHANGE HISTORY +# +# 20120118 Edoardo Pasca Initial version +# +############################################################################### + +import os +from PyQt5 import QtCore, QtGui, QtWidgets +#import itk +import vtk +#from viewer import PyveViewer +from ccpi.viewer.CILViewer2D import CILViewer2D , Converter +from vtk.qt.QVTKRenderWindowInteractor import QVTKRenderWindowInteractor + +class QVTKWidget(QVTKRenderWindowInteractor): + + + def __init__(self, parent=None, wflags=QtCore.Qt.WindowFlags(), **kw): + kw = dict() + super().__init__(parent, **kw) + + + # Link to PyVE Viewer + self._PyveViewer = CILViewer2D(400,400) + #self._Viewer = self._PyveViewer._vtkPyveViewer + + self._Iren = self._PyveViewer.GetInteractor() + kw['iren'] = self._Iren + #self._Iren = self._Viewer.GetRenderWindow().GetInteractor() + self._RenderWindow = self._PyveViewer.GetRenderWindow() + #self._RenderWindow = self._Viewer.GetRenderWindow() + kw['rw'] = self._RenderWindow + + + + + def GetInteractor(self): + return self._Iren + + # Display image data + def SetInput(self, imageData): + self._PyveViewer.setInput3DData(imageData) + \ No newline at end of file diff --git a/src/Python/ccpi/viewer/__init__.py b/src/Python/ccpi/viewer/__init__.py new file mode 100644 index 0000000..946188b --- /dev/null +++ b/src/Python/ccpi/viewer/__init__.py @@ -0,0 +1 @@ +from ccpi.viewer.CILViewer import CILViewer \ No newline at end of file diff --git a/src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc new file mode 100644 index 0000000..711f77a Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc differ diff --git a/src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc new file mode 100644 index 0000000..77c2ca8 Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc differ diff --git a/src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc new file mode 100644 index 0000000..3d11b87 Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc differ diff --git a/src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc new file mode 100644 index 0000000..2fa2eaf Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc differ diff --git a/src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc new file mode 100644 index 0000000..fcea537 Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc differ diff --git a/src/Python/ccpi/viewer/embedvtk.py b/src/Python/ccpi/viewer/embedvtk.py new file mode 100644 index 0000000..b5eb0a7 --- /dev/null +++ b/src/Python/ccpi/viewer/embedvtk.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +""" +Created on Thu Jul 27 12:18:58 2017 + +@author: ofn77899 +""" + +#!/usr/bin/env python + +import sys +import vtk +from PyQt5 import QtCore, QtWidgets +from vtk.qt.QVTKRenderWindowInteractor import QVTKRenderWindowInteractor +import QVTKWidget2 + +class MainWindow(QtWidgets.QMainWindow): + + def __init__(self, parent = None): + QtWidgets.QMainWindow.__init__(self, parent) + + self.frame = QtWidgets.QFrame() + + self.vl = QtWidgets.QVBoxLayout() +# self.vtkWidget = QVTKRenderWindowInteractor(self.frame) + + self.vtkWidget = QVTKWidget2.QVTKWidget(self.frame) + self.iren = self.vtkWidget.GetInteractor() + self.vl.addWidget(self.vtkWidget) + + + + + self.ren = vtk.vtkRenderer() + self.vtkWidget.GetRenderWindow().AddRenderer(self.ren) +# self.iren = self.vtkWidget.GetRenderWindow().GetInteractor() +# +# # Create source +# source = vtk.vtkSphereSource() +# source.SetCenter(0, 0, 0) +# source.SetRadius(5.0) +# +# # Create a mapper +# mapper = vtk.vtkPolyDataMapper() +# mapper.SetInputConnection(source.GetOutputPort()) +# +# # Create an actor +# actor = vtk.vtkActor() +# actor.SetMapper(mapper) +# +# self.ren.AddActor(actor) +# +# self.ren.ResetCamera() +# + self.frame.setLayout(self.vl) + self.setCentralWidget(self.frame) + reader = vtk.vtkMetaImageReader() + reader.SetFileName("C:\\Users\\ofn77899\\Documents\\GitHub\\CCPi-Simpleflex\\data\\head.mha") + reader.Update() + + self.vtkWidget.SetInput(reader.GetOutput()) + + #self.vktWidget.Initialize() + #self.vktWidget.Start() + + self.show() + #self.iren.Initialize() + + +if __name__ == "__main__": + + app = QtWidgets.QApplication(sys.argv) + + window = MainWindow() + + sys.exit(app.exec_()) \ No newline at end of file -- cgit v1.2.3 From 1a841b967e1db92a04e8e12c52b83489da27be1c Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 15:08:32 +0100 Subject: initial revision --- src/Python/ccpi/imaging/Regularizer.py | 322 +++++++++++++++++++++++++++++++++ 1 file changed, 322 insertions(+) create mode 100644 src/Python/ccpi/imaging/Regularizer.py diff --git a/src/Python/ccpi/imaging/Regularizer.py b/src/Python/ccpi/imaging/Regularizer.py new file mode 100644 index 0000000..fb9ae08 --- /dev/null +++ b/src/Python/ccpi/imaging/Regularizer.py @@ -0,0 +1,322 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Aug 8 14:26:00 2017 + +@author: ofn77899 +""" + +from ccpi.imaging import cpu_regularizers +import numpy as np +from enum import Enum +import timeit + +class Regularizer(): + '''Class to handle regularizer algorithms to be used during reconstruction + + Currently 5 CPU (OMP) regularization algorithms are available: + + 1) SplitBregman_TV + 2) FGP_TV + 3) LLT_model + 4) PatchBased_Regul + 5) TGV_PD + + Usage: + the regularizer can be invoked as object or as static method + Depending on the actual regularizer the input parameter may vary, and + a different default setting is defined. + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + + out = reg(input=u0, regularization_parameter=10., number_of_iterations=30, + tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., + number_of_iterations=30, tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + + A number of optional parameters can be passed or skipped + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) + + ''' + class Algorithm(Enum): + SplitBregman_TV = cpu_regularizers.SplitBregman_TV + FGP_TV = cpu_regularizers.FGP_TV + LLT_model = cpu_regularizers.LLT_model + PatchBased_Regul = cpu_regularizers.PatchBased_Regul + TGV_PD = cpu_regularizers.TGV_PD + # Algorithm + + class TotalVariationPenalty(Enum): + isotropic = 0 + l1 = 1 + # TotalVariationPenalty + + def __init__(self , algorithm, debug = True): + self.setAlgorithm ( algorithm ) + self.debug = debug + # __init__ + + def setAlgorithm(self, algorithm): + self.algorithm = algorithm + self.pars = self.getDefaultParsForAlgorithm(algorithm) + # setAlgorithm + + def getDefaultParsForAlgorithm(self, algorithm): + pars = dict() + + if algorithm == Regularizer.Algorithm.SplitBregman_TV : + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['number_of_iterations'] = 35 + pars['tolerance_constant'] = 0.0001 + pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + + elif algorithm == Regularizer.Algorithm.FGP_TV : + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['number_of_iterations'] = 50 + pars['tolerance_constant'] = 0.001 + pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + + elif algorithm == Regularizer.Algorithm.LLT_model: + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['time_step'] = None + pars['number_of_iterations'] = None + pars['tolerance_constant'] = None + pars['restrictive_Z_smoothing'] = 0 + + elif algorithm == Regularizer.Algorithm.PatchBased_Regul: + pars['algorithm'] = algorithm + pars['input'] = None + pars['searching_window_ratio'] = None + pars['similarity_window_ratio'] = None + pars['PB_filtering_parameter'] = None + pars['regularization_parameter'] = None + + elif algorithm == Regularizer.Algorithm.TGV_PD: + pars['algorithm'] = algorithm + pars['input'] = None + pars['first_order_term'] = None + pars['second_order_term'] = None + pars['number_of_iterations'] = None + pars['regularization_parameter'] = None + + else: + raise Exception('Unknown regularizer algorithm') + + return pars + # parsForAlgorithm + + def setParameter(self, **kwargs): + '''set named parameter for the regularization engine + + raises Exception if the named parameter is not recognized + Typical usage is: + + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + reg.setParameter(input=u0) + reg.setParameter(regularization_parameter=10.) + + it can be also used as + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + reg.setParameter(input=u0 , regularization_parameter=10.) + ''' + + for key , value in kwargs.items(): + if key in self.pars.keys(): + self.pars[key] = value + else: + raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) + # setParameter + + def getParameter(self, **kwargs): + ret = {} + for key , value in kwargs.items(): + if key in self.pars.keys(): + ret[key] = self.pars[key] + else: + raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) + # setParameter + + + def __call__(self, input = None, regularization_parameter = None, **kwargs): + '''Actual call for the regularizer. + + One can either set the regularization parameters first and then call the + algorithm or set the regularization parameter during the call (as + is done in the static methods). + ''' + + if kwargs is not None: + for key, value in kwargs.items(): + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + if input is not None: + self.pars['input'] = input + if regularization_parameter is not None: + self.pars['regularization_parameter'] = regularization_parameter + + if self.debug: + print ("--------------------------------------------------") + for key, value in self.pars.items(): + if key== 'algorithm' : + print("{0} = {1}".format(key, value.__name__)) + elif key == 'input': + print("{0} = {1}".format(key, np.shape(value))) + else: + print("{0} = {1}".format(key, value)) + + + if None in self.pars: + raise Exception("Not all parameters have been provided") + + input = self.pars['input'] + regularization_parameter = self.pars['regularization_parameter'] + if self.algorithm == Regularizer.Algorithm.SplitBregman_TV : + return self.algorithm(input, regularization_parameter, + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['TV_penalty'].value ) + elif self.algorithm == Regularizer.Algorithm.FGP_TV : + return self.algorithm(input, regularization_parameter, + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['TV_penalty'].value ) + elif self.algorithm == Regularizer.Algorithm.LLT_model : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + return self.algorithm(input, + regularization_parameter, + self.pars['time_step'] , + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['restrictive_Z_smoothing'] ) + elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + return self.algorithm(input, regularization_parameter, + self.pars['searching_window_ratio'] , + self.pars['similarity_window_ratio'] , + self.pars['PB_filtering_parameter']) + elif self.algorithm == Regularizer.Algorithm.TGV_PD : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + if len(np.shape(input)) == 2: + return self.algorithm(input, regularization_parameter, + self.pars['first_order_term'] , + self.pars['second_order_term'] , + self.pars['number_of_iterations']) + elif len(np.shape(input)) == 3: + #assuming it's 3D + # run independent calls on each slice + out3d = input.copy() + for i in range(np.shape(input)[2]): + out = self.algorithm(input, regularization_parameter, + self.pars['first_order_term'] , + self.pars['second_order_term'] , + self.pars['number_of_iterations']) + # copy the result in the 3D image + out3d.T[i] = out[0].copy() + # append the rest of the info that the algorithm returns + output = [out3d] + for i in range(1,len(out)): + output.append(out[i]) + return output + + + + + + # __call__ + + @staticmethod + def SplitBregman_TV(input, regularization_parameter , **kwargs): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + out = list( reg(input, regularization_parameter, **kwargs) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def FGP_TV(input, regularization_parameter , **kwargs): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.FGP_TV) + out = list( reg(input, regularization_parameter, **kwargs) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def LLT_model(input, regularization_parameter , time_step, number_of_iterations, + tolerance_constant, restrictive_Z_smoothing=0): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.LLT_model) + out = list( reg(input, regularization_parameter, time_step=time_step, + number_of_iterations=number_of_iterations, + tolerance_constant=tolerance_constant, + restrictive_Z_smoothing=restrictive_Z_smoothing) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def PatchBased_Regul(input, regularization_parameter, + searching_window_ratio, + similarity_window_ratio, + PB_filtering_parameter): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.PatchBased_Regul) + out = list( reg(input, + regularization_parameter, + searching_window_ratio=searching_window_ratio, + similarity_window_ratio=similarity_window_ratio, + PB_filtering_parameter=PB_filtering_parameter ) + ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def TGV_PD(input, regularization_parameter , first_order_term, + second_order_term, number_of_iterations): + start_time = timeit.default_timer() + + reg = Regularizer(Regularizer.Algorithm.TGV_PD) + out = list( reg(input, regularization_parameter, + first_order_term=first_order_term, + second_order_term=second_order_term, + number_of_iterations=number_of_iterations) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + + return out + + def printParametersToString(self): + txt = r'' + for key, value in self.pars.items(): + if key== 'algorithm' : + txt += "{0} = {1}".format(key, value.__name__) + elif key == 'input': + txt += "{0} = {1}".format(key, np.shape(value)) + else: + txt += "{0} = {1}".format(key, value) + txt += '\n' + return txt + -- cgit v1.2.3 From 9f8fb57e1e89c1ad200d9c7eada5c653be34db66 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 15:10:04 +0100 Subject: module rename --- src/Python/fista_module.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp index eacda3d..c36329e 100644 --- a/src/Python/fista_module.cpp +++ b/src/Python/fista_module.cpp @@ -1032,13 +1032,13 @@ bp::list TGV_PD(np::ndarray input, double d_lambda, double d_alpha1, double d_al return result; } -BOOST_PYTHON_MODULE(regularizers) +BOOST_PYTHON_MODULE(cpu_regularizers) { np::initialize(); //To specify that this module is a package bp::object package = bp::scope(); - package.attr("__path__") = "regularizers"; + package.attr("__path__") = "cpu_regularizers"; np::dtype dt1 = np::dtype::get_builtin(); np::dtype dt2 = np::dtype::get_builtin(); -- cgit v1.2.3 From a203949c84484fe2641e39451f033d20d445b1f3 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 16:51:18 +0100 Subject: export/import data from hdf5 Added file to export the data from DemoRD2.m to HDF5 to pass it to Python. Added file to import the data from DemoRD2.m from HDF5. --- demos/exportDemoRD2Data.m | 35 +++++++++++++++++++++++++++++++++++ src/Python/test/readhd5.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+) create mode 100644 demos/exportDemoRD2Data.m create mode 100644 src/Python/test/readhd5.py diff --git a/demos/exportDemoRD2Data.m b/demos/exportDemoRD2Data.m new file mode 100644 index 0000000..028353b --- /dev/null +++ b/demos/exportDemoRD2Data.m @@ -0,0 +1,35 @@ +clear all +close all +%% +% % adding paths +addpath('../data/'); +addpath('../main_func/'); addpath('../main_func/regularizers_CPU/'); +addpath('../supp/'); + +load('DendrRawData.mat') % load raw data of 3D dendritic set +angles_rad = angles*(pi/180); % conversion to radians +size_det = size(data_raw3D,1); % detectors dim +angSize = size(data_raw3D, 2); % angles dim +slices_tot = size(data_raw3D, 3); % no of slices +recon_size = 950; % reconstruction size + +Sino3D = zeros(size_det, angSize, slices_tot, 'single'); % log-corrected sino +% normalizing the data +for jj = 1:slices_tot + sino = data_raw3D(:,:,jj); + for ii = 1:angSize + Sino3D(:,ii,jj) = log((flats_ar(:,jj)-darks_ar(:,jj))./(single(sino(:,ii)) - darks_ar(:,jj))); + end +end + +Sino3D = Sino3D.*1000; +Weights3D = single(data_raw3D); % weights for PW model +clear data_raw3D + +hdf5write('DendrData.h5', '/Weights3D', Weights3D) +hdf5write('DendrData.h5', '/Sino3D', Sino3D, 'WriteMode', 'append') +hdf5write('DendrData.h5', '/angles_rad', angles_rad, 'WriteMode', 'append') +hdf5write('DendrData.h5', '/size_det', size_det, 'WriteMode', 'append') +hdf5write('DendrData.h5', '/angSize', angSize, 'WriteMode', 'append') +hdf5write('DendrData.h5', '/slices_tot', slices_tot, 'WriteMode', 'append') +hdf5write('DendrData.h5', '/recon_size', recon_size, 'WriteMode', 'append') \ No newline at end of file diff --git a/src/Python/test/readhd5.py b/src/Python/test/readhd5.py new file mode 100644 index 0000000..1e19e14 --- /dev/null +++ b/src/Python/test/readhd5.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +""" +Created on Wed Aug 23 16:34:49 2017 + +@author: ofn77899 +""" + +import h5py +import numpy + +def getEntry(nx, location): + for item in nx[location].keys(): + print (item) + +filename = r'C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\Demos\DendrData.h5' +nx = h5py.File(filename, "r") +#getEntry(nx, '/') +# I have exported the entries as children of / +entries = [entry for entry in nx['/'].keys()] +print (entries) + +Sino3D = numpy.asarray(nx.get('/Sino3D')) +Weights3D = numpy.asarray(nx.get('/Weights3D')) +angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0] +angles_rad = numpy.asarray(nx.get('/angles_rad')) +recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0] +size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0] +slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0] \ No newline at end of file -- cgit v1.2.3 From 8d53e078d3dabf7107982a8d25b4d66b1d0e73ce Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 16:54:59 +0100 Subject: initial revision for testing --- .../ccpi/reconstruction/FISTAReconstructor.py | 354 +++++++++++++++++++++ 1 file changed, 354 insertions(+) create mode 100644 src/Python/ccpi/reconstruction/FISTAReconstructor.py diff --git a/src/Python/ccpi/reconstruction/FISTAReconstructor.py b/src/Python/ccpi/reconstruction/FISTAReconstructor.py new file mode 100644 index 0000000..ea96b53 --- /dev/null +++ b/src/Python/ccpi/reconstruction/FISTAReconstructor.py @@ -0,0 +1,354 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +import h5py +#from ccpi.reconstruction.parallelbeam import alg + +from ccpi.imaging.Regularizer import Regularizer +from enum import Enum + +import astra + + + +class FISTAReconstructor(): + '''FISTA-based reconstruction algorithm using ASTRA-toolbox + + ''' + # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> + # ___Input___: + # params.[] file: + # - .proj_geom (geometry of the projector) [required] + # - .vol_geom (geometry of the reconstructed object) [required] + # - .sino (vectorized in 2D or 3D sinogram) [required] + # - .iterFISTA (iterations for the main loop, default 40) + # - .L_const (Lipschitz constant, default Power method) ) + # - .X_ideal (ideal image, if given) + # - .weights (statisitcal weights, size of the sinogram) + # - .ROI (Region-of-interest, only if X_ideal is given) + # - .initialize (a 'warm start' using SIRT method from ASTRA) + #----------------Regularization choices------------------------ + # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) + # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) + # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) + # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) + # - .Regul_Iterations (iterations for the selected penalty, default 25) + # - .Regul_tauLLT (time step parameter for LLT term) + # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) + # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) + #----------------Visualization parameters------------------------ + # - .show (visualize reconstruction 1/0, (0 default)) + # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) + # - .slice (for 3D volumes - slice number to imshow) + # ___Output___: + # 1. X - reconstructed image/volume + # 2. output - a structure with + # - .Resid_error - residual error (if X_ideal is given) + # - .objective: value of the objective function + # - .L_const: Lipshitz constant to avoid recalculations + + # References: + # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse + # Problems" by A. Beck and M Teboulle + # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo + # 3. "A novel tomographic reconstruction method based on the robust + # Student's t function for suppressing data outliers" D. Kazantsev et.al. + # D. Kazantsev, 2016-17 + def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + self.params = dict() + self.params['projector_geometry'] = projector_geometry + self.params['output_geometry'] = output_geometry + self.params['input_sinogram'] = input_sinogram + detectors, nangles, sliceZ = numpy.shape(input_sinogram) + self.params['detectors'] = detectors + self.params['number_og_angles'] = nangles + self.params['SlicesZ'] = sliceZ + + # Accepted input keywords + kw = ('number_of_iterations', + 'Lipschitz_constant' , + 'ideal_image' , + 'weights' , + 'region_of_interest' , + 'initialize' , + 'regularizer' , + 'ring_lambda_R_L1', + 'ring_alpha') + + # handle keyworded parameters + if kwargs is not None: + for key, value in kwargs.items(): + if key in kw: + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + # set the default values for the parameters if not set + if 'number_of_iterations' in kwargs.keys(): + self.pars['number_of_iterations'] = kwargs['number_of_iterations'] + else: + self.pars['number_of_iterations'] = 40 + if 'weights' in kwargs.keys(): + self.pars['weights'] = kwargs['weights'] + else: + self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) + if 'Lipschitz_constant' in kwargs.keys(): + self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] + else: + self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + + if not self.pars['ideal_image'] in kwargs.keys(): + self.pars['ideal_image'] = None + + if not self.pars['region_of_interest'] : + if self.pars['ideal_image'] == None: + pass + else: + self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + + if not self.pars['regularizer'] : + self.pars['regularizer'] = None + else: + # the regularizer must be a correctly instantiated object + if not self.pars['ring_lambda_R_L1']: + self.pars['ring_lambda_R_L1'] = 0 + if not self.pars['ring_alpha']: + self.pars['ring_alpha'] = 1 + + + + + def calculateLipschitzConstantWithPowerMethod(self): + ''' using Power method (PM) to establish L constant''' + + #N = params.vol_geom.GridColCount + N = self.pars['output_geometry'].GridColCount + proj_geom = self.params['projector_geometry'] + vol_geom = self.params['output_geometry'] + weights = self.pars['weights'] + SlicesZ = self.pars['SlicesZ'] + + if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); + niter = 15;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights.T[0]) + proj_geomT = proj_geom.copy(); + proj_geomT.DetectorRowCount = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + + for i in range(niter): + if i == 0: + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); + y = sqweight * y # element wise multiplication + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); + y = sqweight*y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + del proj_geomT + del vol_geomT + else: + #% divergen beam geometry + #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights.T[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 + + return s + + + def setRegularizer(self, regularizer): + if regularizer is not None: + self.pars['regularizer'] = regularizer + + + + + +def getEntry(location): + for item in nx[location].keys(): + print (item) + + +print ("Loading Data") + +##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" +####ind = [i * 1049 for i in range(360)] +#### use only 360 images +##images = 200 +##ind = [int(i * 1049 / images) for i in range(images)] +##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) + +#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" +#fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" +##fname = "/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/dendr.h5" +##nx = h5py.File(fname, "r") +## +### the data are stored in a particular location in the hdf5 +##for item in nx['entry1/tomo_entry/data'].keys(): +## print (item) +## +##data = nx.get('entry1/tomo_entry/data/rotation_angle') +##angles = numpy.zeros(data.shape) +##data.read_direct(angles) +##print (angles) +### angles should be in degrees +## +##data = nx.get('entry1/tomo_entry/data/data') +##stack = numpy.zeros(data.shape) +##data.read_direct(stack) +##print (data.shape) +## +##print ("Data Loaded") +## +## +### Normalize +##data = nx.get('entry1/tomo_entry/instrument/detector/image_key') +##itype = numpy.zeros(data.shape) +##data.read_direct(itype) +### 2 is dark field +##darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] +##dark = darks[0] +##for i in range(1, len(darks)): +## dark += darks[i] +##dark = dark / len(darks) +###dark[0][0] = dark[0][1] +## +### 1 is flat field +##flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] +##flat = flats[0] +##for i in range(1, len(flats)): +## flat += flats[i] +##flat = flat / len(flats) +###flat[0][0] = dark[0][1] +## +## +### 0 is projection data +##proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] +##angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] +##angle_proj = numpy.asarray (angle_proj) +##angle_proj = angle_proj.astype(numpy.float32) +## +### normalized data are +### norm = (projection - dark)/(flat-dark) +## +##def normalize(projection, dark, flat, def_val=0.1): +## a = (projection - dark) +## b = (flat-dark) +## with numpy.errstate(divide='ignore', invalid='ignore'): +## c = numpy.true_divide( a, b ) +## c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 +## return c +## +## +##norm = [normalize(projection, dark, flat) for projection in proj] +##norm = numpy.asarray (norm) +##norm = norm.astype(numpy.float32) + + +##niterations = 15 +##threads = 3 +## +##img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +##img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +##img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +## +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## iteration_values, False) +##print ("iteration values %s" % str(iteration_values)) +## +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## numpy.double(1e-5), iteration_values , False) +##print ("iteration values %s" % str(iteration_values)) +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## numpy.double(1e-5), iteration_values , False) +##print ("iteration values %s" % str(iteration_values)) +## +## +####numpy.save("cgls_recon.npy", img_data) +##import matplotlib.pyplot as plt +##fig, ax = plt.subplots(1,6,sharey=True) +##ax[0].imshow(img_cgls[80]) +##ax[0].axis('off') # clear x- and y-axes +##ax[1].imshow(img_sirt[80]) +##ax[1].axis('off') # clear x- and y-axes +##ax[2].imshow(img_mlem[80]) +##ax[2].axis('off') # clear x- and y-axesplt.show() +##ax[3].imshow(img_cgls_conv[80]) +##ax[3].axis('off') # clear x- and y-axesplt.show() +##ax[4].imshow(img_cgls_tikhonov[80]) +##ax[4].axis('off') # clear x- and y-axesplt.show() +##ax[5].imshow(img_cgls_TVreg[80]) +##ax[5].axis('off') # clear x- and y-axesplt.show() +## +## +##plt.show() +## + -- cgit v1.2.3 From 05bd227b56ec43c97c81630f50c3b741ef86ddcd Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 24 Aug 2017 16:39:37 +0100 Subject: bugfix --- src/Python/Matlab2Python_utils.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Python/Matlab2Python_utils.cpp b/src/Python/Matlab2Python_utils.cpp index e15d738..ee76bc7 100644 --- a/src/Python/Matlab2Python_utils.cpp +++ b/src/Python/Matlab2Python_utils.cpp @@ -123,7 +123,7 @@ T * mxGetData(const np::ndarray pm) { probably this would work. A = reinterpret_cast(prhs[0]); */ - return reinterpret_cast(prhs[0]); + //return reinterpret_cast(prhs[0]); } template @@ -273,4 +273,4 @@ BOOST_PYTHON_MODULE(prova) //numpy_boost_python_register_type(); def("mexFunction", mexFunction); def("doSomething", doSomething); -} \ No newline at end of file +} -- cgit v1.2.3 From 879c6723969eaea8e00f97291612fe22443c69f3 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 24 Aug 2017 16:41:10 +0100 Subject: initial facility to test the FISTA --- src/Python/test_reconstructor.py | 179 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 179 insertions(+) create mode 100644 src/Python/test_reconstructor.py diff --git a/src/Python/test_reconstructor.py b/src/Python/test_reconstructor.py new file mode 100644 index 0000000..0fd08f5 --- /dev/null +++ b/src/Python/test_reconstructor.py @@ -0,0 +1,179 @@ +# -*- coding: utf-8 -*- +""" +Created on Wed Aug 23 16:34:49 2017 + +@author: ofn77899 +Based on DemoRD2.m +""" + +import h5py +import numpy + +from ccpi.reconstruction_dev.FISTAReconstructor import FISTAReconstructor +import astra + +##def getEntry(nx, location): +## for item in nx[location].keys(): +## print (item) + +filename = r'/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/demos/DendrData.h5' +nx = h5py.File(filename, "r") +#getEntry(nx, '/') +# I have exported the entries as children of / +entries = [entry for entry in nx['/'].keys()] +print (entries) + +Sino3D = numpy.asarray(nx.get('/Sino3D')) +Weights3D = numpy.asarray(nx.get('/Weights3D')) +angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0] +angles_rad = numpy.asarray(nx.get('/angles_rad')) +recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0] +size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0] +slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0] + +Z_slices = 3 +det_row_count = Z_slices +# next definition is just for consistency of naming +det_col_count = size_det + +detectorSpacingX = 1.0 +detectorSpacingY = detectorSpacingX + + +proj_geom = astra.creators.create_proj_geom('parallel3d', + detectorSpacingX, + detectorSpacingY, + det_row_count, + det_col_count, + angles_rad) + +#vol_geom = astra_create_vol_geom(recon_size,recon_size,Z_slices); +image_size_x = recon_size +image_size_y = recon_size +image_size_z = Z_slices +vol_geom = astra.creators.create_vol_geom( image_size_x, + image_size_y, + image_size_z) + +## First pass the arguments to the FISTAReconstructor and test the +## Lipschitz constant + +#fistaRecon = FISTAReconstructor(proj_geom, vol_geom, Sino3D ) + #N = params.vol_geom.GridColCount + +pars = dict() +pars['projector_geometry'] = proj_geom +pars['output_geometry'] = vol_geom +pars['input_sinogram'] = Sino3D +sliceZ , nangles , detectors = numpy.shape(Sino3D) +pars['detectors'] = detectors +pars['number_of_angles'] = nangles +pars['SlicesZ'] = sliceZ + + +pars['weights'] = numpy.ones(numpy.shape(pars['input_sinogram'])) + +N = pars['output_geometry']['GridColCount'] +proj_geom = pars['projector_geometry'] +vol_geom = pars['output_geometry'] +weights = pars['weights'] +SlicesZ = pars['SlicesZ'] + +if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + print('Calculating Lipshitz constant for parallel beam geometry...') + niter = 15;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights[0]) + proj_geomT = proj_geom.copy(); + proj_geomT['DetectorRowCount'] = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + + import matplotlib.pyplot as plt + fig = plt.figure() + + #a.set_title('Lipschitz') + for i in range(niter): +# [id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geomT, vol_geomT); +# s = norm(x1(:)); +# x1 = x1/s; +# [sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); +# y = sqweight.*y; +# astra_mex_data3d('delete', sino_id); +# astra_mex_data3d('delete', id); + print ("iteration {0}".format(i)) + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geomT, + vol_geomT) + #a=fig.add_subplot(2,1,1) + #imgplot = plt.imshow(y[0]) + + y = sqweight * y # element wise multiplication + + #b=fig.add_subplot(2,1,2) + #imgplot = plt.imshow(x1[0]) + #plt.show() + + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geomT, + vol_geomT); + print ("shape {1} x1 {0}".format(x1.T[:4].T, numpy.shape(x1))) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + print ("x1 {0}".format(x1.T[:4].T)) + +# ### this line? +# sino_id, y = astra.creators.create_sino3d_gpu(x1, +# proj_geomT, +# vol_geomT); +# y = sqweight * y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + del proj_geomT + del vol_geomT +else: + #% divergen beam geometry + print('Calculating Lipshitz constant for divergen beam geometry...') + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 -- cgit v1.2.3 From e58f774938edd3664dfb1f3905964b3add050bc9 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 24 Aug 2017 16:42:05 +0100 Subject: initial revision --- src/Python/ccpi/imaging/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/Python/ccpi/imaging/__init__.py diff --git a/src/Python/ccpi/imaging/__init__.py b/src/Python/ccpi/imaging/__init__.py new file mode 100644 index 0000000..e69de29 -- cgit v1.2.3 From 9c974a26c0fc8060008745796fbe9f7ef5c250eb Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 24 Aug 2017 16:42:27 +0100 Subject: initial revision --- src/Python/ccpi/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/Python/ccpi/__init__.py diff --git a/src/Python/ccpi/__init__.py b/src/Python/ccpi/__init__.py new file mode 100644 index 0000000..e69de29 -- cgit v1.2.3 From c8693f530e95e140a3fba85fc65d879b51b79e6d Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 11 Oct 2017 15:11:00 +0100 Subject: table with regularizers output --- src/Python/test_regularizers.py | 66 ++++++++++++++++++++--------------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py index 755804a..5804897 100644 --- a/src/Python/test_regularizers.py +++ b/src/Python/test_regularizers.py @@ -163,52 +163,52 @@ imgplot = plt.imshow(reg_output[-1][0]) # # u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise # # ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); -# out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, - # searching_window_ratio=3, - # similarity_window_ratio=1, - # PB_filtering_parameter=0.08) -# pars = out2[-2] -# reg_output.append(out2) +out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, + searching_window_ratio=3, + similarity_window_ratio=1, + PB_filtering_parameter=0.08) +pars = out2[-2] +reg_output.append(out2) -# a=fig.add_subplot(2,3,5) +a=fig.add_subplot(2,3,5) -# textstr = out2[-1] +textstr = out2[-1] -# # these are matplotlib.patch.Patch properties -# props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) -# # place a text box in upper left in axes coords -# a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, - # verticalalignment='top', bbox=props) -# imgplot = plt.imshow(reg_output[-1][0]) +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0]) -# ###################### TGV_PD ######################################### -# # Quick 2D denoising example in Matlab: -# # Im = double(imread('lena_gray_256.tif'))/255; % loading image -# # u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise -# # u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); +###################### TGV_PD ######################################### +# Quick 2D denoising example in Matlab: +# Im = double(imread('lena_gray_256.tif'))/255; % loading image +# u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +# u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); -# out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, - # first_order_term=1.3, - # second_order_term=1, - # number_of_iterations=550) -# pars = out2[-2] -# reg_output.append(out2) +out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, + first_order_term=1.3, + second_order_term=1, + number_of_iterations=550) +pars = out2[-2] +reg_output.append(out2) -# a=fig.add_subplot(2,3,6) +a=fig.add_subplot(2,3,6) -# textstr = out2[-1] +textstr = out2[-1] -# # these are matplotlib.patch.Patch properties -# props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) -# # place a text box in upper left in axes coords -# a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, - # verticalalignment='top', bbox=props) -# imgplot = plt.imshow(reg_output[-1][0]) +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0]) plt.show() -- cgit v1.2.3 From a32469e9d96c6039b0a5e8e1b22f577e87484c40 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 11 Oct 2017 15:11:56 +0100 Subject: comment out clear of data_raw3D --- demos/exportDemoRD2Data.m | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/demos/exportDemoRD2Data.m b/demos/exportDemoRD2Data.m index 028353b..0ca036b 100644 --- a/demos/exportDemoRD2Data.m +++ b/demos/exportDemoRD2Data.m @@ -24,7 +24,7 @@ end Sino3D = Sino3D.*1000; Weights3D = single(data_raw3D); % weights for PW model -clear data_raw3D +%clear data_raw3D hdf5write('DendrData.h5', '/Weights3D', Weights3D) hdf5write('DendrData.h5', '/Sino3D', Sino3D, 'WriteMode', 'append') -- cgit v1.2.3 From 776070e22bf95491275a023f3a5ac00cea356714 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 11 Oct 2017 15:12:42 +0100 Subject: read and plot the hdf5 --- src/Python/test/readhd5.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/Python/test/readhd5.py b/src/Python/test/readhd5.py index 1e19e14..b042341 100644 --- a/src/Python/test/readhd5.py +++ b/src/Python/test/readhd5.py @@ -25,4 +25,17 @@ angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0] angles_rad = numpy.asarray(nx.get('/angles_rad')) recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0] size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0] -slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0] \ No newline at end of file +slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0] + +#from ccpi.viewer.CILViewer2D import CILViewer2D +#v = CILViewer2D() +#v.setInputAsNumpy(Weights3D) +#v.startRenderLoop() + +import matplotlib.pyplot as plt +fig = plt.figure() + +a=fig.add_subplot(1,1,1) +a.set_title('noise') +imgplot = plt.imshow(Weights3D[0].T) +plt.show() -- cgit v1.2.3 From 5c978b706192bc5885c7e5001a4bc4626f63d29f Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 11 Oct 2017 15:49:18 +0100 Subject: initial revision --- src/Python/test/simple_astra_test.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 src/Python/test/simple_astra_test.py diff --git a/src/Python/test/simple_astra_test.py b/src/Python/test/simple_astra_test.py new file mode 100644 index 0000000..905eeea --- /dev/null +++ b/src/Python/test/simple_astra_test.py @@ -0,0 +1,25 @@ +import astra +import numpy + +detectorSpacingX = 1.0 +detectorSpacingY = 1.0 +det_row_count = 128 +det_col_count = 128 + +angles_rad = numpy.asarray([i for i in range(360)], dtype=float) / 180. * numpy.pi + +proj_geom = astra.creators.create_proj_geom('parallel3d', + detectorSpacingX, + detectorSpacingY, + det_row_count, + det_col_count, + angles_rad) + +image_size_x = 64 +image_size_y = 64 +image_size_z = 32 + +vol_geom = astra.creators.create_vol_geom(image_size_x,image_size_y,image_size_z) + +x1 = numpy.random.rand(image_size_z,image_size_y,image_size_x) +sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom) -- cgit v1.2.3