summaryrefslogtreecommitdiffstats
path: root/main_func
diff options
context:
space:
mode:
authorDaniil Kazantsev <dkazanc@hotmail.com>2017-10-18 21:46:29 +0100
committerDaniil Kazantsev <dkazanc@hotmail.com>2017-10-18 21:46:29 +0100
commitcb8ef11f00e897b6f5b3049126dd32baa3c50cf9 (patch)
tree10a16da90cd269bad73390fa91e73ebbdad26813 /main_func
parent0847a315ce744e52be3dade398fb16c58323084e (diff)
downloadregularization-cb8ef11f00e897b6f5b3049126dd32baa3c50cf9.tar.gz
regularization-cb8ef11f00e897b6f5b3049126dd32baa3c50cf9.tar.bz2
regularization-cb8ef11f00e897b6f5b3049126dd32baa3c50cf9.tar.xz
regularization-cb8ef11f00e897b6f5b3049126dd32baa3c50cf9.zip
ordered subsets fix for GH term
Diffstat (limited to 'main_func')
-rw-r--r--main_func/FISTA_REC.m179
1 files changed, 105 insertions, 74 deletions
diff --git a/main_func/FISTA_REC.m b/main_func/FISTA_REC.m
index dde0e73..bea1860 100644
--- a/main_func/FISTA_REC.m
+++ b/main_func/FISTA_REC.m
@@ -106,14 +106,14 @@ if (isfield(params,'L_const'))
else
% using Power method (PM) to establish L constant
fprintf('%s %s %s \n', 'Calculating Lipshitz constant for',proj_geom.type, 'beam geometry...');
- if (strcmp(proj_geom.type,'parallel') || strcmp(proj_geom.type,'fanflat') || strcmp(proj_geom.type,'fanflat_vec'))
- % for 2D geometry we can do just one selected slice
+ if (strcmp(proj_geom.type,'parallel') || strcmp(proj_geom.type,'fanflat') || strcmp(proj_geom.type,'fanflat_vec'))
+ % for 2D geometry we can do just one selected slice
niter = 15; % number of iteration for the PM
x1 = rand(N,N,1);
sqweight = sqrt(weights(:,:,1));
[sino_id, y] = astra_create_sino_cuda(x1, proj_geom, vol_geom);
y = sqweight.*y';
- astra_mex_data2d('delete', sino_id);
+ astra_mex_data2d('delete', sino_id);
for i = 1:niter
[x1] = astra_create_backprojection_cuda((sqweight.*y)', proj_geom, vol_geom);
s = norm(x1(:));
@@ -121,9 +121,9 @@ else
[sino_id, y] = astra_create_sino_cuda(x1, proj_geom, vol_geom);
y = sqweight.*y';
astra_mex_data2d('delete', sino_id);
- end
+ end
elseif (strcmp(proj_geom.type,'cone') || strcmp(proj_geom.type,'parallel3d') || strcmp(proj_geom.type,'parallel3d_vec') || strcmp(proj_geom.type,'cone_vec'))
- % 3D geometry
+ % 3D geometry
niter = 8; % number of iteration for PM
x1 = rand(N,N,SlicesZ);
sqweight = sqrt(weights);
@@ -268,7 +268,7 @@ if (isfield(params,'initialize'))
X = params.initialize;
if ((size(X,1) ~= N) || (size(X,2) ~= N) || (size(X,3) ~= SlicesZ))
error('%s \n', 'The initialized volume has different dimensions!');
- end
+ end
else
X = zeros(N,N,SlicesZ, 'single'); % storage for the solution
end
@@ -320,10 +320,11 @@ if (subsets == 0)
t_old = t;
r_old = r;
- % if the geometry is 2D use slice-by-slice projection-backprojection routine
- if (strcmp(proj_geom.type,'parallel') || strcmp(proj_geom.type,'fanflat') || strcmp(proj_geom.type,'fanflat_vec'))
+
+ if (strcmp(proj_geom.type,'parallel') || strcmp(proj_geom.type,'fanflat') || strcmp(proj_geom.type,'fanflat_vec'))
+ % if geometry is 2D use slice-by-slice projection-backprojection routine
sino_updt = zeros(size(sino),'single');
- for kkk = 1:SlicesZ
+ for kkk = 1:SlicesZ
[sino_id, sinoT] = astra_create_sino_cuda(X_t(:,:,kkk), proj_geom, vol_geom);
sino_updt(:,:,kkk) = sinoT';
astra_mex_data2d('delete', sino_id);
@@ -359,12 +360,11 @@ if (subsets == 0)
else
% no ring removal (LS model)
residual = weights.*(sino_updt - sino);
- objective(i) = (0.5*sum(residual(:).^2)); % for the objective function output
+ objective(i) = 0.5*norm(residual(:)); % for the objective function output
end
-
% if the geometry is 2D use slice-by-slice projection-backprojection routine
- if (strcmp(proj_geom.type,'parallel') || strcmp(proj_geom.type,'fanflat') || strcmp(proj_geom.type,'fanflat_vec'))
+ if (strcmp(proj_geom.type,'parallel') || strcmp(proj_geom.type,'fanflat') || strcmp(proj_geom.type,'fanflat_vec'))
x_temp = zeros(size(X),'single');
for kkk = 1:SlicesZ
[x_temp(:,:,kkk)] = astra_create_backprojection_cuda(squeeze(residual(:,:,kkk))', proj_geom, vol_geom);
@@ -373,9 +373,9 @@ if (subsets == 0)
[id, x_temp] = astra_create_backprojection3d_cuda(residual, proj_geom, vol_geom);
astra_mex_data3d('delete', id);
end
- X = X_t - (1/L_const).*x_temp;
+ X = X_t - (1/L_const).*x_temp;
- % ----------------Regularization part------------------------
+ % ----------------Regularization part------------------------%
if (lambdaFGP_TV > 0)
% FGP-TV regularization
if ((strcmp('2D', Dimension) == 1))
@@ -484,94 +484,128 @@ if (subsets == 0)
end
end
else
- % Ordered Subsets (OS) FISTA reconstruction routine (normally one order of magnitude faster than classical)
+ % Ordered Subsets (OS) FISTA reconstruction routine (normally one order of magnitude faster than the classical version)
t = 1;
X_t = X;
- proj_geomSUB = proj_geom;
-
+ proj_geomSUB = proj_geom;
r = zeros(Detectors,SlicesZ, 'single'); % 2D array (for 3D data) of sparse "ring" vectors
r_x = r; % another ring variable
residual2 = zeros(size(sino),'single');
+ sino_updt_FULL = zeros(size(sino),'single');
% Outer FISTA iterations loop
- for i = 1:iterFISTA
+ for i = 1:iterFISTA
- % With OS approach it becomes trickier to correlate independent subsets, hence additional work is required
- % one solution is to work with a full sinogram at times
- if ((i >= 3) && (lambdaR_L1 > 0))
- [sino_id2, sino_updt2] = astra_create_sino3d_cuda(X, proj_geom, vol_geom);
- astra_mex_data3d('delete', sino_id2);
+ if ((i > 1) && (lambdaR_L1 > 0))
+ % in order to make Group-Huber fidelity work with ordered subsets
+ % we still need to work with full sinogram
+
+ % the offset variable must be calculated for the whole
+ % updated sinogram - sino_updt_FULL
+ for kkk = 1:anglesNumb
+ residual2(:,kkk,:) = squeeze(weights(:,kkk,:)).*(squeeze(sino_updt_FULL(:,kkk,:)) - (squeeze(sino(:,kkk,:)) - alpha_ring.*r_x));
+ end
+
+ r_old = r;
+ vec = sum(residual2,2);
+ if (SlicesZ > 1)
+ vec = squeeze(vec(:,1,:));
+ end
+ r = r_x - (1./L_const).*vec; % update ring variable
end
% subsets loop
counterInd = 1;
for ss = 1:subsets
X_old = X;
- t_old = t;
- r_old = r;
+ t_old = t;
numProjSub = binsDiscr(ss); % the number of projections per subset
CurrSubIndeces = IndicesReorg(counterInd:(counterInd + numProjSub - 1)); % extract indeces attached to the subset
proj_geomSUB.ProjectionAngles = angles(CurrSubIndeces);
+ sino_updt_Sub = zeros(Detectors, numProjSub, SlicesZ,'single');
if (lambdaR_L1 > 0)
-
- % the ring removal part (Group-Huber fidelity)
- % first 2 iterations do additional work reconstructing whole dataset to ensure
- % the stablility
- if (i < 3)
- [sino_id2, sino_updt2] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom);
- astra_mex_data3d('delete', sino_id2);
+ % Group-Huber fidelity (ring removal)
+ if (strcmp(proj_geom.type,'parallel') || strcmp(proj_geom.type,'fanflat') || strcmp(proj_geom.type,'fanflat_vec'))
+ % if geometry is 2D use slice-by-slice projection-backprojection routine
+ for kkk = 1:SlicesZ
+ [sino_id, sinoT] = astra_create_sino_cuda(X_t(:,:,kkk), proj_geomSUB, vol_geom);
+ sino_updt_Sub(:,:,kkk) = sinoT';
+ astra_mex_data2d('delete', sino_id);
+ end
else
- [sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geomSUB, vol_geom);
- end
-
- for kkk = 1:anglesNumb
- residual2(:,kkk,:) = squeeze(weights(:,kkk,:)).*(squeeze(sino_updt2(:,kkk,:)) - (squeeze(sino(:,kkk,:)) - alpha_ring.*r_x));
+ % for 3D geometry (watch the GPU memory overflow in earlier ASTRA versions < 1.8)
+ [sino_id, sino_updt_Sub] = astra_create_sino3d_cuda(X_t, proj_geomSUB, vol_geom);
+ astra_mex_data3d('delete', sino_id);
end
- residual = zeros(Detectors, numProjSub, SlicesZ,'single');
+ residualSub = zeros(Detectors, numProjSub, SlicesZ,'single'); % residual for a chosen subset
for kkk = 1:numProjSub
indC = CurrSubIndeces(kkk);
- if (i < 3)
- residual(:,kkk,:) = squeeze(residual2(:,indC,:));
- else
- residual(:,kkk,:) = squeeze(weights(:,indC,:)).*(squeeze(sino_updt(:,kkk,:)) - (squeeze(sino(:,indC,:)) - alpha_ring.*r_x));
+ residualSub(:,kkk,:) = squeeze(weights(:,indC,:)).*(squeeze(sino_updt_Sub(:,kkk,:)) - (squeeze(sino(:,indC,:)) - alpha_ring.*r_x));
+ sino_updt_FULL(:,indC,:) = squeeze(sino_updt_Sub(:,kkk,:)); % filling the full sinogram
+ end
+
+ elseif (studentt > 0)
+ % student t data fidelity
+ if (strcmp(proj_geom.type,'parallel') || strcmp(proj_geom.type,'fanflat') || strcmp(proj_geom.type,'fanflat_vec'))
+ % if geometry is 2D use slice-by-slice projection-backprojection routine
+ for kkk = 1:SlicesZ
+ [sino_id, sinoT] = astra_create_sino_cuda(X_t(:,:,kkk), proj_geomSUB, vol_geom);
+ sino_updt_Sub(:,:,kkk) = sinoT';
+ astra_mex_data2d('delete', sino_id);
end
+ else
+ % for 3D geometry (watch the GPU memory overflow in earlier ASTRA versions < 1.8)
+ [sino_id, sino_updt_Sub] = astra_create_sino3d_cuda(X_t, proj_geomSUB, vol_geom);
+ astra_mex_data3d('delete', sino_id);
end
- vec = sum(residual2,2);
- if (SlicesZ > 1)
- vec = squeeze(vec(:,1,:));
+
+ % artifacts removal with Students t penalty
+ residualSub = squeeze(weights(:,CurrSubIndeces,:)).*(sino_updt_Sub - squeeze(sino(:,CurrSubIndeces,:)));
+
+ for kkk = 1:SlicesZ
+ res_vec = reshape(residualSub(:,:,kkk), Detectors*numProjSub, 1); % 1D vectorized sinogram
+ %s = 100;
+ %gr = (2)*res_vec./(s*2 + conj(res_vec).*res_vec);
+ [ff, gr] = studentst(res_vec, 1);
+ residualSub(:,:,kkk) = reshape(gr, Detectors, numProjSub);
end
- r = r_x - (1./L_const).*vec;
+ objective(i) = ff; % for the objective function output
else
- [sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geomSUB, vol_geom);
-
- if (studentt == 1)
- % artifacts removal with Students t penalty
- residual = squeeze(weights(:,CurrSubIndeces,:)).*(sino_updt - squeeze(sino(:,CurrSubIndeces,:)));
-
+ % PWLS model
+ if (strcmp(proj_geom.type,'parallel') || strcmp(proj_geom.type,'fanflat') || strcmp(proj_geom.type,'fanflat_vec'))
+ % if geometry is 2D use slice-by-slice projection-backprojection routine
for kkk = 1:SlicesZ
- res_vec = reshape(residual(:,:,kkk), Detectors*numProjSub, 1); % 1D vectorized sinogram
- %s = 100;
- %gr = (2)*res_vec./(s*2 + conj(res_vec).*res_vec);
- [ff, gr] = studentst(res_vec, 1);
- residual(:,:,kkk) = reshape(gr, Detectors, numProjSub);
+ [sino_id, sinoT] = astra_create_sino_cuda(X_t(:,:,kkk), proj_geomSUB, vol_geom);
+ sino_updt_Sub(:,:,kkk) = sinoT';
+ astra_mex_data2d('delete', sino_id);
end
- objective(i) = ff; % for the objective function output
else
- % no ring removal (LS model)
- residual = squeeze(weights(:,CurrSubIndeces,:)).*(sino_updt - squeeze(sino(:,CurrSubIndeces,:)));
+ % for 3D geometry (watch the GPU memory overflow in earlier ASTRA versions < 1.8)
+ [sino_id, sino_updt_Sub] = astra_create_sino3d_cuda(X_t, proj_geomSUB, vol_geom);
+ astra_mex_data3d('delete', sino_id);
end
+ residualSub = squeeze(weights(:,CurrSubIndeces,:)).*(sino_updt_Sub - squeeze(sino(:,CurrSubIndeces,:)));
+ objective(i) = 0.5*norm(residualSub(:)); % for the objective function output
end
- [id, x_temp] = astra_create_backprojection3d_cuda(residual, proj_geomSUB, vol_geom);
- X = X_t - (1/L_const).*x_temp;
- astra_mex_data3d('delete', sino_id);
- astra_mex_data3d('delete', id);
+ if (strcmp(proj_geom.type,'parallel') || strcmp(proj_geom.type,'fanflat') || strcmp(proj_geom.type,'fanflat_vec'))
+ % if geometry is 2D use slice-by-slice projection-backprojection routine
+ x_temp = zeros(size(X),'single');
+ for kkk = 1:SlicesZ
+ [x_temp(:,:,kkk)] = astra_create_backprojection_cuda(squeeze(residualSub(:,:,kkk))', proj_geomSUB, vol_geom);
+ end
+ else
+ [id, x_temp] = astra_create_backprojection3d_cuda(residualSub, proj_geomSUB, vol_geom);
+ astra_mex_data3d('delete', id);
+ end
+
+ X = X_t - (1/L_const).*x_temp;
- % regularization
+ % ----------------Regularization part------------------------%
if (lambdaFGP_TV > 0)
% FGP-TV regularization
if ((strcmp('2D', Dimension) == 1))
@@ -653,20 +687,17 @@ else
end
end
- if (lambdaR_L1 > 0)
- r = max(abs(r)-lambdaR_L1, 0).*sign(r); % soft-thresholding operator for ring vector
- end
-
t = (1 + sqrt(1 + 4*t^2))/2; % updating t
X_t = X + ((t_old-1)/t).*(X - X_old); % updating X
-
- if (lambdaR_L1 > 0)
- r_x = r + ((t_old-1)/t).*(r - r_old); % updating r
- end
-
counterInd = counterInd + numProjSub;
end
+ % working with a 'ring vector'
+ if (lambdaR_L1 > 0)
+ r = max(abs(r)-lambdaR_L1, 0).*sign(r); % soft-thresholding operator for ring vector
+ r_x = r + ((t_old-1)/t).*(r - r_old); % updating r
+ end
+
if (show == 1)
figure(10); imshow(X(:,:,slice), [0 maxvalplot]);
if (lambdaR_L1 > 0)