diff options
author | Willem Jan Palenstijn <Willem.Jan.Palenstijn@cwi.nl> | 2016-10-06 12:30:18 +0200 |
---|---|---|
committer | Willem Jan Palenstijn <Willem.Jan.Palenstijn@cwi.nl> | 2016-10-06 12:30:18 +0200 |
commit | 0cec258c5079cc065fa75f82ae8d785986ebdf18 (patch) | |
tree | e7ca39da75ad5c9d728698295ac9c8ec32e4e499 /samples | |
parent | c2cdbc312196481edd202baa3bd668396e78534c (diff) | |
parent | 7bb42ddd9e26fc7c01734d26bc114b5a935d9110 (diff) | |
download | astra-0cec258c5079cc065fa75f82ae8d785986ebdf18.tar.gz astra-0cec258c5079cc065fa75f82ae8d785986ebdf18.tar.bz2 astra-0cec258c5079cc065fa75f82ae8d785986ebdf18.tar.xz astra-0cec258c5079cc065fa75f82ae8d785986ebdf18.zip |
Merge branch 'master' into FDK
Diffstat (limited to 'samples')
-rw-r--r-- | samples/python/s009_projection_matrix.py | 2 | ||||
-rw-r--r-- | samples/python/s015_fp_bp.py | 6 | ||||
-rw-r--r-- | samples/python/s017_OpTomo.py | 2 | ||||
-rw-r--r-- | samples/python/s018_plugin.py | 34 |
4 files changed, 26 insertions, 18 deletions
diff --git a/samples/python/s009_projection_matrix.py b/samples/python/s009_projection_matrix.py index c4c4557..e20d58c 100644 --- a/samples/python/s009_projection_matrix.py +++ b/samples/python/s009_projection_matrix.py @@ -46,7 +46,7 @@ W = astra.matrix.get(matrix_id) # Manually use this projection matrix to do a projection: import scipy.io P = scipy.io.loadmat('phantom.mat')['phantom256'] -s = W.dot(P.flatten()) +s = W.dot(P.ravel()) s = np.reshape(s, (len(proj_geom['ProjectionAngles']),proj_geom['DetectorCount'])) import pylab diff --git a/samples/python/s015_fp_bp.py b/samples/python/s015_fp_bp.py index fa0bf86..ff0b30a 100644 --- a/samples/python/s015_fp_bp.py +++ b/samples/python/s015_fp_bp.py @@ -46,12 +46,12 @@ class astra_wrap(object): def matvec(self,v): sid, s = astra.create_sino(np.reshape(v,(vol_geom['GridRowCount'],vol_geom['GridColCount'])),self.proj_id) astra.data2d.delete(sid) - return s.flatten() + return s.ravel() def rmatvec(self,v): bid, b = astra.create_backprojection(np.reshape(v,(len(proj_geom['ProjectionAngles']),proj_geom['DetectorCount'],)),self.proj_id) astra.data2d.delete(bid) - return b.flatten() + return b.ravel() vol_geom = astra.create_vol_geom(256, 256) proj_geom = astra.create_proj_geom('parallel', 1.0, 384, np.linspace(0,np.pi,180,False)) @@ -65,7 +65,7 @@ proj_id = astra.create_projector('cuda',proj_geom,vol_geom) sinogram_id, sinogram = astra.create_sino(P, proj_id) # Reshape the sinogram into a vector -b = sinogram.flatten() +b = sinogram.ravel() # Call lsqr with ASTRA FP and BP import scipy.sparse.linalg diff --git a/samples/python/s017_OpTomo.py b/samples/python/s017_OpTomo.py index 967fa64..214e9a7 100644 --- a/samples/python/s017_OpTomo.py +++ b/samples/python/s017_OpTomo.py @@ -50,7 +50,7 @@ pylab.figure(2) pylab.imshow(sinogram) # Run the lsqr linear solver -output = scipy.sparse.linalg.lsqr(W, sinogram.flatten(), iter_lim=150) +output = scipy.sparse.linalg.lsqr(W, sinogram.ravel(), iter_lim=150) rec = output[0].reshape([256, 256]) pylab.figure(3) diff --git a/samples/python/s018_plugin.py b/samples/python/s018_plugin.py index 31cca95..85b5486 100644 --- a/samples/python/s018_plugin.py +++ b/samples/python/s018_plugin.py @@ -30,30 +30,38 @@ import six # Define the plugin class (has to subclass astra.plugin.base) # Note that usually, these will be defined in a separate package/module -class SIRTPlugin(astra.plugin.base): - """Example of an ASTRA plugin class, implementing a simple 2D SIRT algorithm. +class LandweberPlugin(astra.plugin.base): + """Example of an ASTRA plugin class, implementing a simple 2D Landweber algorithm. Options: - 'rel_factor': relaxation factor (optional) + 'Relaxation': relaxation factor (optional) """ # The astra_name variable defines the name to use to # call the plugin from ASTRA - astra_name = "SIRT-PLUGIN" + astra_name = "LANDWEBER-PLUGIN" - def initialize(self,cfg, rel_factor = 1): + def initialize(self,cfg, Relaxation = 1): self.W = astra.OpTomo(cfg['ProjectorId']) self.vid = cfg['ReconstructionDataId'] self.sid = cfg['ProjectionDataId'] - self.rel = rel_factor + self.rel = Relaxation def run(self, its): v = astra.data2d.get_shared(self.vid) s = astra.data2d.get_shared(self.sid) + tv = np.zeros(v.shape, dtype=np.float32) + ts = np.zeros(s.shape, dtype=np.float32) W = self.W for i in range(its): - v[:] += self.rel*(W.T*(s - (W*v).reshape(s.shape))).reshape(v.shape)/s.size + W.FP(v,out=ts) + ts -= s # ts = W*v - s + + W.BP(ts,out=tv) + tv *= self.rel / s.size + + v -= tv # v = v - rel * W'*(W*v-s) / s.size if __name__=='__main__': @@ -75,20 +83,20 @@ if __name__=='__main__': # First we import the package that contains the plugin import s018_plugin # Then, we register the plugin class with ASTRA - astra.plugin.register(s018_plugin.SIRTPlugin) + astra.plugin.register(s018_plugin.LandweberPlugin) # Get a list of registered plugins six.print_(astra.plugin.get_registered()) # To get help on a registered plugin, use get_help - six.print_(astra.plugin.get_help('SIRT-PLUGIN')) + six.print_(astra.plugin.get_help('LANDWEBER-PLUGIN')) # Create data structures sid = astra.data2d.create('-sino', proj_geom, sinogram) vid = astra.data2d.create('-vol', vol_geom) # Create config using plugin name - cfg = astra.astra_dict('SIRT-PLUGIN') + cfg = astra.astra_dict('LANDWEBER-PLUGIN') cfg['ProjectorId'] = proj_id cfg['ProjectionDataId'] = sid cfg['ReconstructionDataId'] = vid @@ -103,18 +111,18 @@ if __name__=='__main__': rec = astra.data2d.get(vid) # Options for the plugin go in cfg['option'] - cfg = astra.astra_dict('SIRT-PLUGIN') + cfg = astra.astra_dict('LANDWEBER-PLUGIN') cfg['ProjectorId'] = proj_id cfg['ProjectionDataId'] = sid cfg['ReconstructionDataId'] = vid cfg['option'] = {} - cfg['option']['rel_factor'] = 1.5 + cfg['option']['Relaxation'] = 1.5 alg_id_rel = astra.algorithm.create(cfg) astra.algorithm.run(alg_id_rel, 100) rec_rel = astra.data2d.get(vid) # We can also use OpTomo to call the plugin - rec_op = W.reconstruct('SIRT-PLUGIN', sinogram, 100, extraOptions={'rel_factor':1.5}) + rec_op = W.reconstruct('LANDWEBER-PLUGIN', sinogram, 100, extraOptions={'Relaxation':1.5}) import pylab as pl pl.gray() |