diff options
Diffstat (limited to 'python')
| -rw-r--r-- | python/astra/PyXMLDocument.pxd | 2 | ||||
| -rw-r--r-- | python/astra/utils.pyx | 35 | 
2 files changed, 11 insertions, 26 deletions
| diff --git a/python/astra/PyXMLDocument.pxd b/python/astra/PyXMLDocument.pxd index 57c447e..033b8ef 100644 --- a/python/astra/PyXMLDocument.pxd +++ b/python/astra/PyXMLDocument.pxd @@ -53,6 +53,8 @@ cdef extern from "astra/XMLNode.h" namespace "astra":          string getAttribute(string)          list[XMLNode] getNodes()          vector[float32] getContentNumericalArray() +        void setContent(double*, int, int, bool) +        void setContent(double*, int)          string getContent()          bool hasAttribute(string) diff --git a/python/astra/utils.pyx b/python/astra/utils.pyx index 8f1e0b7..ddb37aa 100644 --- a/python/astra/utils.pyx +++ b/python/astra/utils.pyx @@ -26,6 +26,7 @@  # distutils: language = c++  # distutils: libraries = astra +cimport numpy as np  import numpy as np  import six  from libcpp.string cimport string @@ -85,6 +86,7 @@ cdef void readDict(XMLNode root, _dc):      cdef XMLNode itm      cdef int i      cdef int j +    cdef double* data      dc = convert_item(_dc)      for item in dc: @@ -93,21 +95,11 @@ cdef void readDict(XMLNode root, _dc):              if val.size == 0:                  break              listbase = root.addChildNode(item) -            listbase.addAttribute(< string > six.b('listsize'), < float32 > val.size) -            index = 0 +            data = <double*>np.PyArray_DATA(np.ascontiguousarray(val,dtype=np.float64))               if val.ndim == 2: -                for i in range(val.shape[0]): -                    for j in range(val.shape[1]): -                        itm = listbase.addChildNode(six.b('ListItem')) -                        itm.addAttribute(< string > six.b('index'), < float32 > index) -                        itm.addAttribute( < string > six.b('value'), < float32 > val[i, j]) -                        index += 1 +                listbase.setContent(data, val.shape[1], val.shape[0], False)              elif val.ndim == 1: -                for i in range(val.shape[0]): -                    itm = listbase.addChildNode(six.b('ListItem')) -                    itm.addAttribute(< string > six.b('index'), < float32 > index) -                    itm.addAttribute(< string > six.b('value'), < float32 > val[i]) -                    index += 1 +                listbase.setContent(data, val.shape[0])              else:                  raise Exception("Only 1 or 2 dimensions are allowed")          elif isinstance(val, dict): @@ -127,6 +119,7 @@ cdef void readOptions(XMLNode node, dc):      cdef XMLNode itm      cdef int i      cdef int j +    cdef double* data      for item in dc:          val = dc[item]          if node.hasOption(item): @@ -136,21 +129,11 @@ cdef void readOptions(XMLNode node, dc):                  break              listbase = node.addChildNode(six.b('Option'))              listbase.addAttribute(< string > six.b('key'), < string > item) -            listbase.addAttribute(< string > six.b('listsize'), < float32 > val.size) -            index = 0 +            data = <double*>np.PyArray_DATA(np.ascontiguousarray(val,dtype=np.float64))               if val.ndim == 2: -                for i in range(val.shape[0]): -                    for j in range(val.shape[1]): -                        itm = listbase.addChildNode(six.b('ListItem')) -                        itm.addAttribute(< string > six.b('index'), < float32 > index) -                        itm.addAttribute( < string > six.b('value'), < float32 > val[i, j]) -                        index += 1 +                listbase.setContent(data, val.shape[1], val.shape[0], False)              elif val.ndim == 1: -                for i in range(val.shape[0]): -                    itm = listbase.addChildNode(six.b('ListItem')) -                    itm.addAttribute(< string > six.b('index'), < float32 > index) -                    itm.addAttribute(< string > six.b('value'), < float32 > val[i]) -                    index += 1 +                listbase.setContent(data, val.shape[0])              else:                  raise Exception("Only 1 or 2 dimensions are allowed")          else: | 
