diff options
| author | Edoardo Pasca <edo.paskino@gmail.com> | 2017-11-10 12:38:25 +0000 | 
|---|---|---|
| committer | Edoardo Pasca <edo.paskino@gmail.com> | 2018-01-19 14:26:06 +0000 | 
| commit | 2b11ca3f30580b814971fcad39110e0751161acb (patch) | |
| tree | 50e67644d2b18018825c17d0001c2ebf92b0b95b | |
| parent | 1062346b9f892b7a90ccfb8406bd61fdc44e8e82 (diff) | |
| download | regularization-2b11ca3f30580b814971fcad39110e0751161acb.tar.gz regularization-2b11ca3f30580b814971fcad39110e0751161acb.tar.bz2 regularization-2b11ca3f30580b814971fcad39110e0751161acb.tar.xz regularization-2b11ca3f30580b814971fcad39110e0751161acb.zip | |
some development
| -rw-r--r-- | src/Python/ccpi/reconstruction/FISTAReconstructor.py | 2 | ||||
| -rw-r--r-- | src/Python/test/create_phantom_projections.py | 49 | ||||
| -rw-r--r-- | src/Python/test/test_reconstructor-os_phantom.py | 480 | 
3 files changed, 530 insertions, 1 deletions
| diff --git a/src/Python/ccpi/reconstruction/FISTAReconstructor.py b/src/Python/ccpi/reconstruction/FISTAReconstructor.py index af6275f..e40ad24 100644 --- a/src/Python/ccpi/reconstruction/FISTAReconstructor.py +++ b/src/Python/ccpi/reconstruction/FISTAReconstructor.py @@ -233,7 +233,7 @@ class FISTAReconstructor():              #x1 = rand(N,N,1);              x1 = numpy.random.rand(1,N,N)              #sqweight = sqrt(weights(:,:,1)); -            sqweight = numpy.sqrt(weights[0]) +            sqweight = numpy.sqrt(weights[0:1,:,:])              proj_geomT = proj_geom.copy();              proj_geomT['DetectorRowCount'] = 1;              vol_geomT = vol_geom.copy(); diff --git a/src/Python/test/create_phantom_projections.py b/src/Python/test/create_phantom_projections.py new file mode 100644 index 0000000..20a9278 --- /dev/null +++ b/src/Python/test/create_phantom_projections.py @@ -0,0 +1,49 @@ +from ccpi.reconstruction.AstraDevice import AstraDevice +from ccpi.reconstruction.DeviceModel import DeviceModel +import h5py +import numpy +import matplotlib.pyplot as plt + +nx = h5py.File('phant3D_256.h5', "r") +phantom = numpy.asarray(nx.get('/dataset1')) +pX,pY,pZ = numpy.shape(phantom) + +filename = r'/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/demos/DendrData.h5' +nxa = h5py.File(filename, "r") +#getEntry(nx, '/') +# I have exported the entries as children of / +entries = [entry for entry in nxa['/'].keys()] +print (entries) + +angles_rad = numpy.asarray(nxa.get('/angles_rad'), dtype="float32") + + +device = AstraDevice( +    DeviceModel.DeviceType.PARALLEL3D.value, +    [ pX , pY , 1., 1., angles_rad], +    [ pX, pY, pZ ] ) + + +proj = device.doForwardProject(phantom) +stack = [proj[:,i,:] for i in range(len(angles_rad))] +stack = numpy.asarray(stack) + + +fig = plt.figure() +a=fig.add_subplot(1,2,1) +a.set_title('proj') +imgplot = plt.imshow(proj[:,100,:]) +a=fig.add_subplot(1,2,2) +a.set_title('stack') +imgplot = plt.imshow(stack[100]) +plt.show() + +pf = h5py.File("phantom3D256_projections.h5" , "w") +pf.create_dataset("/projections", data=stack) +pf.create_dataset("/sinogram", data=proj) +pf.create_dataset("/angles", data=angles_rad) +pf.create_dataset("/reconstruction_volume" , data=numpy.asarray([pX, pY, pZ])) +pf.create_dataset("/camera/size" , data=numpy.asarray([pX , pY ])) +pf.create_dataset("/camera/spacing" , data=numpy.asarray([1.,1.])) +pf.flush() +pf.close() diff --git a/src/Python/test/test_reconstructor-os_phantom.py b/src/Python/test/test_reconstructor-os_phantom.py new file mode 100644 index 0000000..01f1354 --- /dev/null +++ b/src/Python/test/test_reconstructor-os_phantom.py @@ -0,0 +1,480 @@ +# -*- coding: utf-8 -*- +""" +Created on Wed Aug 23 16:34:49 2017 + +@author: ofn77899 +Based on DemoRD2.m +""" + +import h5py +import numpy + +from ccpi.reconstruction.FISTAReconstructor import FISTAReconstructor +import astra +import matplotlib.pyplot as plt +from ccpi.imaging.Regularizer import Regularizer +from ccpi.reconstruction.AstraDevice import AstraDevice +from ccpi.reconstruction.DeviceModel import DeviceModel + +#from ccpi.viewer.CILViewer2D import * + + +def RMSE(signal1, signal2): +    '''RMSE Root Mean Squared Error''' +    if numpy.shape(signal1) == numpy.shape(signal2): +        err = (signal1 - signal2) +        err = numpy.sum( err * err )/numpy.size(signal1);  # MSE +        err = sqrt(err);                                   # RMSE +        return err +    else: +        raise Exception('Input signals must have the same shape') + +filename = r'/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/src/Python/test/phantom3D256_projections.h5' +nx = h5py.File(filename, "r") +#getEntry(nx, '/') +# I have exported the entries as children of / +entries = [entry for entry in nx['/'].keys()] +print (entries) + +projections = numpy.asarray(nx.get('/projections'), dtype="float32") +#Weights3D = numpy.asarray(nx.get('/Weights3D'), dtype="float32") +#angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0] +angles_rad = numpy.asarray(nx.get('/angles'), dtype="float32") +angSize = numpy.size(angles_rad) +image_size_x, image_size_y, image_size_z = \ +              numpy.asarray(nx.get('/reconstruction_volume'), dtype=int) +det_col_count, det_row_count = \ +              numpy.asarray(nx.get('/camera/size')) +#slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0] +detectorSpacingX, detectorSpacingY = numpy.asarray(nx.get('/camera/spacing'), dtype=int) + +Z_slices = 20 +#det_row_count = image_size_y +# next definition is just for consistency of naming +#det_col_count = image_size_x + +detectorSpacingX = 1.0 +detectorSpacingY = detectorSpacingX + + +proj_geom = astra.creators.create_proj_geom('parallel3d', +                                            detectorSpacingX, +                                            detectorSpacingY, +                                            det_row_count, +                                            det_col_count, +                                            angles_rad) + +#vol_geom = astra_create_vol_geom(recon_size,recon_size,Z_slices); +##image_size_x = recon_size +##image_size_y = recon_size +##image_size_z = Z_slices +vol_geom = astra.creators.create_vol_geom( image_size_x, +                                           image_size_y, +                                           image_size_z) + +## First pass the arguments to the FISTAReconstructor and test the +## Lipschitz constant +astradevice = AstraDevice(DeviceModel.DeviceType.PARALLEL3D.value, +                [proj_geom['DetectorRowCount'] , +                 proj_geom['DetectorColCount'] , +                 proj_geom['DetectorSpacingX'] , +                 proj_geom['DetectorSpacingY'] , +                 proj_geom['ProjectionAngles'] +                 ], +                [ +                    vol_geom['GridColCount'], +                    vol_geom['GridRowCount'],  +                    vol_geom['GridSliceCount'] ] ) +## create the sinogram  +Sino3D = numpy.transpose(projections, axes=[1,0,2]) + +fistaRecon = FISTAReconstructor(proj_geom, +                                vol_geom, +                                Sino3D , +                                #weights=Weights3D, +                                device=astradevice) + +print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) +fistaRecon.setParameter(number_of_iterations = 4) +#fistaRecon.setParameter(Lipschitz_constant = 767893952.0) +fistaRecon.setParameter(ring_alpha = 21) +fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) +#fistaRecon.setParameter(ring_lambda_R_L1 = 0) +subsets = 8 +fistaRecon.setParameter(subsets=subsets) + + +#reg = Regularizer(Regularizer.Algorithm.FGP_TV) +#reg.setParameter(regularization_parameter=0.005, +#                          number_of_iterations=50) +reg = Regularizer(Regularizer.Algorithm.FGP_TV) +reg.setParameter(regularization_parameter=5e6, +                          tolerance_constant=0.0001, +                          number_of_iterations=50) + +#fistaRecon.setParameter(regularizer=reg) +#lc = fistaRecon.getParameter('Lipschitz_constant') +#reg.setParameter(regularization_parameter=5e6/lc) + +## Ordered subset +if True: +    #subsets = 8 +    fistaRecon.setParameter(subsets=subsets) +    fistaRecon.createOrderedSubsets() +else: +    angles = fistaRecon.getParameter('projector_geometry')['ProjectionAngles'] +    #binEdges = numpy.linspace(angles.min(), +    #                          angles.max(), +    #                          subsets + 1) +    binsDiscr, binEdges = numpy.histogram(angles, bins=subsets) +    # get rearranged subset indices +    IndicesReorg = numpy.zeros((numpy.shape(angles))) +    counterM = 0 +    for ii in range(binsDiscr.max()): +        counter = 0 +        for jj in range(subsets): +            curr_index = ii + jj  + counter +            #print ("{0} {1} {2}".format(binsDiscr[jj] , ii, counterM)) +            if binsDiscr[jj] > ii: +                if (counterM < numpy.size(IndicesReorg)): +                    IndicesReorg[counterM] = curr_index +                counterM = counterM + 1 +                 +            counter = counter + binsDiscr[jj] - 1 + + +if True: +    print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) +    print ("prepare for iteration") +    fistaRecon.prepareForIteration() +     +     + +    print("initializing  ...") +    if True: +        # if X doesn't exist +        #N = params.vol_geom.GridColCount +        N = vol_geom['GridColCount'] +        print ("N " + str(N)) +        X = numpy.asarray(numpy.ones((image_size_x,image_size_y,image_size_z)), +                        dtype=numpy.float) * 0.001 +        X = numpy.asarray(numpy.zeros((image_size_x,image_size_y,image_size_z)), +                        dtype=numpy.float)  +    else: +        #X = fistaRecon.initialize() +        X = numpy.load("X.npy") +     +    print (numpy.shape(X)) +    X_t = X.copy() +    print ("initialized") +    proj_geom , vol_geom, sino , \ +        SlicesZ, weights , alpha_ring = fistaRecon.getParameter( +            ['projector_geometry' , 'output_geometry', +             'input_sinogram', 'SlicesZ' ,  'weights', 'ring_alpha']) +    lambdaR_L1 , alpha_ring , weights , L_const= \ +                       fistaRecon.getParameter(['ring_lambda_R_L1', +                                               'ring_alpha' , 'weights', +                                               'Lipschitz_constant']) + +    #fistaRecon.setParameter(number_of_iterations = 3) +    iterFISTA = fistaRecon.getParameter('number_of_iterations') +    # errors vector (if the ground truth is given) +    Resid_error = numpy.zeros((iterFISTA)); +    # objective function values vector +    objective = numpy.zeros((iterFISTA));  + +       +    t = 1 +     + +    ## additional for  +    proj_geomSUB = proj_geom.copy() +    fistaRecon.residual2 = numpy.zeros(numpy.shape(fistaRecon.pars['input_sinogram'])) +    residual2 = fistaRecon.residual2 +    sino_updt_FULL = fistaRecon.residual.copy() +    r_x = fistaRecon.r.copy() + +    results = [] +    print ("starting iterations") +##    % Outer FISTA iterations loop +    for i in range(fistaRecon.getParameter('number_of_iterations')): +##        % With OS approach it becomes trickier to correlate independent subsets, hence additional work is required +##        % one solution is to work with a full sinogram at times +##        if ((i >= 3) && (lambdaR_L1 > 0)) +##            [sino_id2, sino_updt2] = astra_create_sino3d_cuda(X, proj_geom, vol_geom); +##            astra_mex_data3d('delete', sino_id2); +##        end +        # With OS approach it becomes trickier to correlate independent subsets, +        # hence additional work is required one solution is to work with a full +        # sinogram at times + +         +        #t_old = t +        SlicesZ, anglesNumb, Detectors = \ +                    numpy.shape(fistaRecon.getParameter('input_sinogram')) +        ## https://github.com/vais-ral/CCPi-FISTA_Reconstruction/issues/4 +        r_old = fistaRecon.r.copy() +             +        if (i > 1 and lambdaR_L1 > 0) : +            for kkk in range(anglesNumb): +                  +                 residual2[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ +                                       ((sino_updt_FULL[:,kkk,:]).squeeze() - \ +                                        (sino[:,kkk,:]).squeeze() -\ +                                        (alpha_ring * r_x) +                                        ) +            #r_old = fistaRecon.r.copy() +            vec = fistaRecon.residual.sum(axis = 1) +            #if SlicesZ > 1: +            #    vec = vec[:,1,:] # 1 or 0? +            r_x = fistaRecon.r_x +            # update ring variable +            fistaRecon.r = (r_x - (1./L_const) * vec) + +        # subset loop +        counterInd = 1 +        geometry_type = fistaRecon.getParameter('projector_geometry')['type'] +        angles = fistaRecon.getParameter('projector_geometry')['ProjectionAngles'] +     +##        if geometry_type == 'parallel' or \ +##           geometry_type == 'fanflat' or \ +##           geometry_type == 'fanflat_vec' : +##             +##            for kkk in range(SlicesZ): +##                sino_id, sinoT[kkk] = \ +##                         astra.creators.create_sino3d_gpu( +##                             X_t[kkk:kkk+1], proj_geomSUB, vol_geom) +##                sino_updt_Sub[kkk] = sinoT.T.copy() +##                 +##        else: +##            sino_id, sino_updt_Sub = \ +##                astra.creators.create_sino3d_gpu(X_t, proj_geomSUB, vol_geom) +##         +##        astra.matlab.data3d('delete', sino_id) +   +        for ss in range(fistaRecon.getParameter('subsets')): +            print ("Subset {0}".format(ss)) +            X_old = X.copy() +            t_old = t +            print ("X[0][0][0] {0} t {1}".format(X[0][0][0], t)) +             +            # the number of projections per subset +            numProjSub = fistaRecon.getParameter('os_bins')[ss] +            CurrSubIndices = fistaRecon.getParameter('os_indices')\ +                             [counterInd:counterInd+numProjSub] +            shape = list(numpy.shape(fistaRecon.getParameter('input_sinogram'))) +            shape[1] = numProjSub +            sino_updt_Sub = numpy.zeros(shape) +             +            #print ("Len CurrSubIndices {0}".format(numProjSub)) +            mask = numpy.zeros(numpy.shape(angles), dtype=bool) +            cc = 0 +            for j in range(len(CurrSubIndices)): +                mask[int(CurrSubIndices[j])] = True + +            ## this is a reduced device +            rdev = fistaRecon.getParameter('device_model')\ +                   .createReducedDevice(proj_par={'angles' : angles[mask]}, +                                        vol_par={}) +            proj_geomSUB['ProjectionAngles'] = angles[mask] + +             + +            if geometry_type == 'parallel' or \ +               geometry_type == 'fanflat' or \ +               geometry_type == 'fanflat_vec' : + +                for kkk in range(SlicesZ): +                    sino_id, sinoT = astra.creators.create_sino3d_gpu ( +                        X_t[kkk:kkk+1] , proj_geomSUB, vol_geom) +                    sino_updt_Sub[kkk] = sinoT.T.copy() +                    astra.matlab.data3d('delete', sino_id) +            else: +                # for 3D geometry (watch the GPU memory overflow in ASTRA < 1.8) +                sino_id, sino_updt_Sub = \ +                     astra.creators.create_sino3d_gpu (X_t, +                                                       proj_geomSUB, +                                                       vol_geom) +                 +                astra.matlab.data3d('delete', sino_id) +                 +             + + +            ## RING REMOVAL +            residual = fistaRecon.residual +             +             +            if lambdaR_L1 > 0 : +                 print ("ring removal") +                 residualSub = numpy.zeros(shape) +    ##             for a chosen subset +    ##                for kkk = 1:numProjSub +    ##                    indC = CurrSubIndeces(kkk); +    ##                    residualSub(:,kkk,:) =  squeeze(weights(:,indC,:)).*(squeeze(sino_updt_Sub(:,kkk,:)) - (squeeze(sino(:,indC,:)) - alpha_ring.*r_x)); +    ##                    sino_updt_FULL(:,indC,:) = squeeze(sino_updt_Sub(:,kkk,:)); % filling the full sinogram +    ##                end +                 for kkk in range(numProjSub): +                     #print ("ring removal indC ... {0}".format(kkk)) +                     indC = int(CurrSubIndices[kkk]) +                     residualSub[:,kkk,:] = weights[:,indC,:].squeeze() * \ +                            (sino_updt_Sub[:,kkk,:].squeeze() - \ +                             sino[:,indC,:].squeeze() - alpha_ring * r_x) +                     # filling the full sinogram +                     sino_updt_FULL[:,indC,:] = sino_updt_Sub[:,kkk,:].squeeze() + +            else: +                #PWLS model +                # I guess we need to use mask here instead +                residualSub = weights[:,CurrSubIndices,:] * \ +                              ( sino_updt_Sub - \ +                                sino[:,CurrSubIndices,:].squeeze() ) +            # it seems that in the original code the following like is not +            # calculated in the case of ring removal +            objective[i] = 0.5 * numpy.linalg.norm(residualSub) + +            #backprojection +            if geometry_type == 'parallel' or \ +               geometry_type == 'fanflat' or \ +               geometry_type == 'fanflat_vec' : +                # if geometry is 2D use slice-by-slice projection-backprojection +                # routine +                x_temp = numpy.zeros(numpy.shape(X), dtype=numpy.float32) +                for kkk in range(SlicesZ): +                     +                    x_id, x_temp[kkk] = \ +                             astra.creators.create_backprojection3d_gpu( +                                 residualSub[kkk:kkk+1], +                                 proj_geomSUB, vol_geom) +                    astra.matlab.data3d('delete', x_id) +                     +            else: +                x_id, x_temp = \ +                      astra.creators.create_backprojection3d_gpu( +                          residualSub, proj_geomSUB, vol_geom) + +                astra.matlab.data3d('delete', x_id) +                 +            X = X_t - (1/L_const) * x_temp + +         + +            ## REGULARIZATION +            ## SKIPPING FOR NOW +            ## Should be simpli +            # regularizer = fistaRecon.getParameter('regularizer') +            # for slices: +            # out = regularizer(input=X) +            print ("regularizer") +            reg = fistaRecon.getParameter('regularizer') + +            if reg is not None: +                X = reg(input=X, +                        output_all=False) + +            t = (1 + numpy.sqrt(1 + 4 * t **2))/2 +            X_t = X + (((t_old -1)/t) * (X-X_old)) +            counterInd = counterInd + numProjSub - 1 +        if i == 1: +            r_old = fistaRecon.r.copy() +             +        ## FINAL +        print ("final") +        lambdaR_L1 = fistaRecon.getParameter('ring_lambda_R_L1') +        if lambdaR_L1 > 0: +            fistaRecon.r = numpy.max( +                numpy.abs(fistaRecon.r) - lambdaR_L1 , 0) * \ +                numpy.sign(fistaRecon.r) +            # updating r +            r_x = fistaRecon.r + ((t_old-1)/t) * (fistaRecon.r - r_old) +         + +        if fistaRecon.getParameter('region_of_interest') is None: +            string = 'Iteration Number {0} | Objective {1} \n' +            print (string.format( i, objective[i])) +        else: +            ROI , X_ideal = fistaRecon.getParameter('region_of_interest', +                                                    'ideal_image') +             +            Resid_error[i] = RMSE(X*ROI, X_ideal*ROI) +            string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' +            print (string.format(i,Resid_error[i], objective[i])) + +        results.append(X[10]) +    numpy.save("X_out_os.npy", X) + +else: +     +     +     +    astradevice = AstraDevice(DeviceModel.DeviceType.PARALLEL3D.value, +                [proj_geom['DetectorRowCount'] , +                 proj_geom['DetectorColCount'] , +                 proj_geom['DetectorSpacingX'] , +                 proj_geom['DetectorSpacingY'] , +                 proj_geom['ProjectionAngles'] +                 ], +                [ +                    vol_geom['GridColCount'], +                    vol_geom['GridRowCount'],  +                    vol_geom['GridSliceCount'] ] ) +    regul = Regularizer(Regularizer.Algorithm.FGP_TV) +    regul.setParameter(regularization_parameter=5e6, +                       number_of_iterations=50, +                       tolerance_constant=1e-4, +                       TV_penalty=Regularizer.TotalVariationPenalty.isotropic) + +    fistaRecon = FISTAReconstructor(proj_geom, +                                vol_geom, +                                Sino3D , +                                weights=Weights3D, +                                device=astradevice, +                                #regularizer = regul, +                                subsets=8) + +    print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) +    fistaRecon.setParameter(number_of_iterations = 1) +    fistaRecon.setParameter(Lipschitz_constant = 767893952.0) +    fistaRecon.setParameter(ring_alpha = 21) +    fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) +    #fistaRecon.setParameter(subsets=8) +     +    #lc = fistaRecon.getParameter('Lipschitz_constant') +    #fistaRecon.getParameter('regularizer').setParameter(regularization_parameter=5e6/lc) +     +    fistaRecon.prepareForIteration() +    X = fistaRecon.iterate(numpy.load("X.npy")) +     + +# plot +fig = plt.figure() +cols = 3 + +## add the difference +rd = [] +for i in range(1,len(results)): +    rd.append(results[i-1]) +    rd.append(results[i]) +    rd.append(results[i] - results[i-1]) + +rows = (lambda x: int(numpy.floor(x/cols) + 1) if x%cols != 0  else int(x/cols)) \ +       (len (rd)) +for i in range(len (results)): +    a=fig.add_subplot(rows,cols,i+1) +    imgplot = plt.imshow(results[i], vmin=0, vmax=1) +    a.text(0.05, 0.95, "iteration {0}".format(i), +               verticalalignment='top') +##    i = i + 1 +##    a=fig.add_subplot(rows,cols,i+1) +##    imgplot = plt.imshow(results[i], vmin=0, vmax=10) +##    a.text(0.05, 0.95, "iteration {0}".format(i), +##               verticalalignment='top') +     +##    a=fig.add_subplot(rows,cols,i+2) +##    imgplot = plt.imshow(results[i]-results[i-1], vmin=0, vmax=10) +##    a.text(0.05, 0.95, "difference {0}-{1}".format(i, i-1), +##               verticalalignment='top') +         +         + +plt.show() | 
