diff options
| -rw-r--r-- | python/astra/operator.py | 33 | 
1 files changed, 11 insertions, 22 deletions
| diff --git a/python/astra/operator.py b/python/astra/operator.py index a3abd5a..0c37353 100644 --- a/python/astra/operator.py +++ b/python/astra/operator.py @@ -91,7 +91,7 @@ class OpTomo(scipy.sparse.linalg.LinearOperator):              arr = np.ascontiguousarray(arr)          return arr -    def matvec(self,v): +    def _matvec(self,v):          """Implements the forward operator.          :param v: Volume to forward project. @@ -135,24 +135,16 @@ class OpTomo(scipy.sparse.linalg.LinearOperator):          self.data_mod.delete([vid,sid])          return v.flatten() -    def matmat(self,m): -        """Implements the forward operator with a matrix. - -        :param m: Volumes to forward project, arranged in columns. -        :type m: :class:`numpy.ndarray` -        """ -        out = np.zeros((self.ssize,m.shape[1]),dtype=np.float32) -        for i in range(m.shape[1]): -            out[:,i] = self.matvec(m[:,i].flatten()) -        return out -      def __mul__(self,v):          """Provides easy forward operator by *.          :param v: Volume to forward project.          :type v: :class:`numpy.ndarray`          """ -        return self.matvec(v) +        # Catch the case of a forward projection of a 2D/3D image +        if isinstance(v, np.ndarray) and v.shape==self.vshape: +            return self._matvec(v) +        return scipy.sparse.linalg.LinearOperator.__mul__(self, v)      def reconstruct(self, method, s, iterations=1, extraOptions = {}):          """Reconstruct an object. @@ -192,17 +184,14 @@ class OpTomoTranspose(scipy.sparse.linalg.LinearOperator):          self.dtype = np.float32          self.shape = (parent.shape[1], parent.shape[0]) -    def matvec(self, s): +    def _matvec(self, s):          return self.parent.rmatvec(s)      def rmatvec(self, v):          return self.parent.matvec(v) -    def matmat(self, m): -        out = np.zeros((self.vsize,m.shape[1]),dtype=np.float32) -        for i in range(m.shape[1]): -            out[:,i] = self.matvec(m[:,i].flatten()) -        return out - -    def __mul__(self,v): -        return self.matvec(v) +    def __mul__(self,s): +        # Catch the case of a backprojection of 2D/3D data +        if isinstance(s, np.ndarray) and s.shape==self.parent.sshape: +            return self._matvec(s) +        return scipy.sparse.linalg.LinearOperator.__mul__(self, s) | 
