diff options
| author | Willem Jan Palenstijn <wjp@usecode.org> | 2015-05-08 14:03:01 +0200 | 
|---|---|---|
| committer | Willem Jan Palenstijn <wjp@usecode.org> | 2015-05-08 14:03:01 +0200 | 
| commit | 63307fca7a82bfea7592d9c8d3a359333e622495 (patch) | |
| tree | 5c2478fdeace55cad57bedec88017e948f09f56e | |
| parent | 5e25feb74f0d810af581db32fc5f9ed0560fa841 (diff) | |
| parent | 89da933904262d6b7e80e8adf85ca9d1273881b3 (diff) | |
| download | astra-63307fca7a82bfea7592d9c8d3a359333e622495.tar.gz astra-63307fca7a82bfea7592d9c8d3a359333e622495.tar.bz2 astra-63307fca7a82bfea7592d9c8d3a359333e622495.tar.xz astra-63307fca7a82bfea7592d9c8d3a359333e622495.zip | |
Merge pull request #46 from dmpelt/spot-like-python
Add SPOT-like object for Python (overrides `__mul__` and works with scipy.sparse.linalg)
| -rw-r--r-- | python/astra/ASTRAProjector.py | 135 | ||||
| -rw-r--r-- | python/astra/__init__.py | 2 | ||||
| -rw-r--r-- | python/astra/optomo.py | 197 | ||||
| -rw-r--r-- | python/docSRC/index.rst | 2 | ||||
| -rw-r--r-- | python/docSRC/operator.rst (renamed from python/docSRC/ASTRAProjector.rst) | 4 | 
5 files changed, 201 insertions, 139 deletions
| diff --git a/python/astra/ASTRAProjector.py b/python/astra/ASTRAProjector.py deleted file mode 100644 index f282618..0000000 --- a/python/astra/ASTRAProjector.py +++ /dev/null @@ -1,135 +0,0 @@ -#----------------------------------------------------------------------- -#Copyright 2013 Centrum Wiskunde & Informatica, Amsterdam -# -#Author: Daniel M. Pelt -#Contact: D.M.Pelt@cwi.nl -#Website: http://dmpelt.github.io/pyastratoolbox/ -# -# -#This file is part of the Python interface to the -#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox"). -# -#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify -#it under the terms of the GNU General Public License as published by -#the Free Software Foundation, either version 3 of the License, or -#(at your option) any later version. -# -#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful, -#but WITHOUT ANY WARRANTY; without even the implied warranty of -#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -#GNU General Public License for more details. -# -#You should have received a copy of the GNU General Public License -#along with the Python interface to the ASTRA Toolbox. If not, see <http://www.gnu.org/licenses/>. -# -#----------------------------------------------------------------------- - -import math -from . import creators as ac -from . import data2d - - -class ASTRAProjector2DTranspose(): -    """Implements the ``proj.T`` functionality. - -    Do not use directly, since it can be accessed as member ``.T`` of -    an :class:`ASTRAProjector2D` object. - -    """ -    def __init__(self, parentProj): -        self.parentProj = parentProj - -    def __mul__(self, data): -        return self.parentProj.backProject(data) - - -class ASTRAProjector2D(object): -    """Helps with various common ASTRA Toolbox 2D operations. - -    This class can perform several often used toolbox operations, such as: - -    * Forward projecting -    * Back projecting -    * Reconstructing - -    Note that this class has a some computational overhead, because it -    copies a lot of data. If you use many repeated operations, directly -    using the PyAstraToolbox methods directly is faster. - -    You can use this class as an abstracted weight matrix :math:`W`: multiplying an instance -    ``proj`` of this class by an image results in a forward projection of the image, and multiplying -    ``proj.T`` by a sinogram results in a backprojection of the sinogram:: - -        proj = ASTRAProjector2D(...) -        fp = proj*image -        bp = proj.T*sinogram - -    :param proj_geom: The projection geometry. -    :type proj_geom: :class:`dict` -    :param vol_geom: The volume geometry. -    :type vol_geom: :class:`dict` -    :param proj_type: Projector type, such as ``'line'``, ``'linear'``, ... -    :type proj_type: :class:`string` -    """ - -    def __init__(self, proj_geom, vol_geom, proj_type): -        self.vol_geom = vol_geom -        self.recSize = vol_geom['GridColCount'] -        self.angles = proj_geom['ProjectionAngles'] -        self.nDet = proj_geom['DetectorCount'] -        nexpow = int(pow(2, math.ceil(math.log(2 * self.nDet, 2)))) -        self.filterSize = nexpow / 2 + 1 -        self.nProj = self.angles.shape[0] -        self.proj_geom = proj_geom -        self.proj_id = ac.create_projector(proj_type, proj_geom, vol_geom) -        self.T = ASTRAProjector2DTranspose(self) - -    def backProject(self, data): -        """Backproject a sinogram. - -        :param data: The sinogram data or ID. -        :type data: :class:`numpy.ndarray` or :class:`int` -        :returns: :class:`numpy.ndarray` -- The backprojection. - -        """ -        vol_id, vol = ac.create_backprojection( -            data, self.proj_id, returnData=True) -        data2d.delete(vol_id) -        return vol - -    def forwardProject(self, data): -        """Forward project an image. - -        :param data: The image data or ID. -        :type data: :class:`numpy.ndarray` or :class:`int` -        :returns: :class:`numpy.ndarray` -- The forward projection. - -        """ -        sin_id, sino = ac.create_sino(data, self.proj_id, returnData=True) -        data2d.delete(sin_id) -        return sino - -    def reconstruct(self, data, method, **kwargs): -        """Reconstruct an image from a sinogram. - -        :param data: The sinogram data or ID. -        :type data: :class:`numpy.ndarray` or :class:`int` -        :param method: Name of the reconstruction algorithm. -        :type method: :class:`string` -        :param kwargs: Additional named parameters to pass to :func:`astra.creators.create_reconstruction`. -        :returns: :class:`numpy.ndarray` -- The reconstruction. - -        Example of a SIRT reconstruction using CUDA:: - -            proj = ASTRAProjector2D(...) -            rec = proj.reconstruct(sinogram,'SIRT_CUDA',iterations=1000) - -        """ -        kwargs['returnData'] = True -        rec_id, rec = ac.create_reconstruction( -            method, self.proj_id, data, **kwargs) -        data2d.delete(rec_id) -        return rec - -    def __mul__(self, data): -        return self.forwardProject(data) diff --git a/python/astra/__init__.py b/python/astra/__init__.py index 063dc16..6c15d30 100644 --- a/python/astra/__init__.py +++ b/python/astra/__init__.py @@ -27,7 +27,6 @@ from . import matlab as m  from .creators import astra_dict,create_vol_geom, create_proj_geom, create_backprojection, create_sino, create_reconstruction, create_projector,create_sino3d_gpu, create_backprojection3d_gpu  from .functions import data_op, add_noise_to_sino, clear, move_vol_geom  from .extrautils import clipCircle -from .ASTRAProjector import ASTRAProjector2D  from . import data2d  from . import astra  from . import data3d @@ -36,6 +35,7 @@ from . import projector  from . import projector3d  from . import matrix  from . import log +from .optomo import OpTomo  import os  try: diff --git a/python/astra/optomo.py b/python/astra/optomo.py new file mode 100644 index 0000000..0c37353 --- /dev/null +++ b/python/astra/optomo.py @@ -0,0 +1,197 @@ +#----------------------------------------------------------------------- +#Copyright 2013 Centrum Wiskunde & Informatica, Amsterdam +# +#Author: Daniel M. Pelt +#Contact: D.M.Pelt@cwi.nl +#Website: http://dmpelt.github.io/pyastratoolbox/ +# +# +#This file is part of the Python interface to the +#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox"). +# +#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify +#it under the terms of the GNU General Public License as published by +#the Free Software Foundation, either version 3 of the License, or +#(at your option) any later version. +# +#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful, +#but WITHOUT ANY WARRANTY; without even the implied warranty of +#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +#GNU General Public License for more details. +# +#You should have received a copy of the GNU General Public License +#along with the Python interface to the ASTRA Toolbox. If not, see <http://www.gnu.org/licenses/>. +# +#----------------------------------------------------------------------- + +from . import data2d +from . import data3d +from . import projector +from . import projector3d +from . import creators +from . import algorithm +from . import functions +import numpy as np +from six.moves import range, reduce +import operator +import scipy.sparse.linalg + +class OpTomo(scipy.sparse.linalg.LinearOperator): +    """Object that imitates a projection matrix with a given projector. + +    This object can do forward projection by using the ``*`` operator:: + +        W = astra.OpTomo(proj_id) +        fp = W*image +        bp = W.T*sinogram + +    It can also be used in minimization methods of the :mod:`scipy.sparse.linalg` module:: + +        W = astra.OpTomo(proj_id) +        output = scipy.sparse.linalg.lsqr(W,sinogram) + +    :param proj_id: ID to a projector. +    :type proj_id: :class:`int` +    """ + +    def __init__(self,proj_id): +        self.dtype = np.float32 +        try: +            self.vg = projector.volume_geometry(proj_id) +            self.pg = projector.projection_geometry(proj_id) +            self.data_mod = data2d +            self.appendString = "" +            if projector.is_cuda(proj_id): +                self.appendString += "_CUDA" +        except Exception: +            self.vg = projector3d.volume_geometry(proj_id) +            self.pg = projector3d.projection_geometry(proj_id) +            self.data_mod = data3d +            self.appendString = "3D" +            if projector3d.is_cuda(proj_id): +                self.appendString += "_CUDA" + +        self.vshape = functions.geom_size(self.vg) +        self.vsize = reduce(operator.mul,self.vshape) +        self.sshape = functions.geom_size(self.pg) +        self.ssize = reduce(operator.mul,self.sshape) + +        self.shape = (self.ssize, self.vsize) + +        self.proj_id = proj_id + +        self.T = OpTomoTranspose(self) + +    def __checkArray(self, arr, shp): +        if len(arr.shape)==1: +            arr = arr.reshape(shp) +        if arr.dtype != np.float32: +            arr = arr.astype(np.float32) +        if arr.flags['C_CONTIGUOUS']==False: +            arr = np.ascontiguousarray(arr) +        return arr + +    def _matvec(self,v): +        """Implements the forward operator. + +        :param v: Volume to forward project. +        :type v: :class:`numpy.ndarray` +        """ +        v = self.__checkArray(v, self.vshape) +        vid = self.data_mod.link('-vol',self.vg,v) +        s = np.zeros(self.sshape,dtype=np.float32) +        sid = self.data_mod.link('-sino',self.pg,s) + +        cfg = creators.astra_dict('FP'+self.appendString) +        cfg['ProjectionDataId'] = sid +        cfg['VolumeDataId'] = vid +        cfg['ProjectorId'] = self.proj_id +        fp_id = algorithm.create(cfg) +        algorithm.run(fp_id) + +        algorithm.delete(fp_id) +        self.data_mod.delete([vid,sid]) +        return s.flatten() + +    def rmatvec(self,s): +        """Implements the transpose operator. + +        :param s: The projection data. +        :type s: :class:`numpy.ndarray` +        """ +        s = self.__checkArray(s, self.sshape) +        sid = self.data_mod.link('-sino',self.pg,s) +        v = np.zeros(self.vshape,dtype=np.float32) +        vid = self.data_mod.link('-vol',self.vg,v) + +        cfg = creators.astra_dict('BP'+self.appendString) +        cfg['ProjectionDataId'] = sid +        cfg['ReconstructionDataId'] = vid +        cfg['ProjectorId'] = self.proj_id +        bp_id = algorithm.create(cfg) +        algorithm.run(bp_id) + +        algorithm.delete(bp_id) +        self.data_mod.delete([vid,sid]) +        return v.flatten() + +    def __mul__(self,v): +        """Provides easy forward operator by *. + +        :param v: Volume to forward project. +        :type v: :class:`numpy.ndarray` +        """ +        # Catch the case of a forward projection of a 2D/3D image +        if isinstance(v, np.ndarray) and v.shape==self.vshape: +            return self._matvec(v) +        return scipy.sparse.linalg.LinearOperator.__mul__(self, v) + +    def reconstruct(self, method, s, iterations=1, extraOptions = {}): +        """Reconstruct an object. + +        :param method: Method to use for reconstruction. +        :type method: :class:`string` +        :param s: The projection data. +        :type s: :class:`numpy.ndarray` +        :param iterations: Number of iterations to use. +        :type iterations: :class:`int` +        :param extraOptions: Extra options to use during reconstruction (i.e. for cfg['option']). +        :type extraOptions: :class:`dict` +        """ +        self.__checkArray(s, self.sshape) +        sid = self.data_mod.link('-sino',self.pg,s) +        v = np.zeros(self.vshape,dtype=np.float32) +        vid = self.data_mod.link('-vol',self.vg,v) +        cfg = creators.astra_dict(method) +        cfg['ProjectionDataId'] = sid +        cfg['ReconstructionDataId'] = vid +        cfg['ProjectorId'] = self.proj_id +        cfg['option'] = extraOptions +        alg_id = algorithm.create(cfg) +        algorithm.run(alg_id,iterations) +        algorithm.delete(alg_id) +        self.data_mod.delete([vid,sid]) +        return v + +class OpTomoTranspose(scipy.sparse.linalg.LinearOperator): +    """This object provides the transpose operation (``.T``) of the OpTomo object. + +    Do not use directly, since it can be accessed as member ``.T`` of +    an :class:`OpTomo` object. +    """ +    def __init__(self,parent): +        self.parent = parent +        self.dtype = np.float32 +        self.shape = (parent.shape[1], parent.shape[0]) + +    def _matvec(self, s): +        return self.parent.rmatvec(s) + +    def rmatvec(self, v): +        return self.parent.matvec(v) + +    def __mul__(self,s): +        # Catch the case of a backprojection of 2D/3D data +        if isinstance(s, np.ndarray) and s.shape==self.parent.sshape: +            return self._matvec(s) +        return scipy.sparse.linalg.LinearOperator.__mul__(self, s) diff --git a/python/docSRC/index.rst b/python/docSRC/index.rst index 8d17a4a..b7cc6d6 100644 --- a/python/docSRC/index.rst +++ b/python/docSRC/index.rst @@ -18,7 +18,7 @@ Contents:     matrix     creators     functions -   ASTRAProjector +   operator     matlab     astra  .. astra diff --git a/python/docSRC/ASTRAProjector.rst b/python/docSRC/operator.rst index 1c267e3..f5369fa 100644 --- a/python/docSRC/ASTRAProjector.rst +++ b/python/docSRC/operator.rst @@ -1,7 +1,7 @@ -Helper class: the :mod:`ASTRAProjector` module +OpTomo class: the :mod:`operator` module  ============================================== -.. automodule:: astra.ASTRAProjector +.. automodule:: astra.operator      :members:      :undoc-members:      :show-inheritance: | 
