summaryrefslogtreecommitdiffstats
path: root/python/astra/utils.pyx
diff options
context:
space:
mode:
Diffstat (limited to 'python/astra/utils.pyx')
-rw-r--r--python/astra/utils.pyx82
1 files changed, 77 insertions, 5 deletions
diff --git a/python/astra/utils.pyx b/python/astra/utils.pyx
index 9871ac6..07727ce 100644
--- a/python/astra/utils.pyx
+++ b/python/astra/utils.pyx
@@ -29,8 +29,13 @@
cimport numpy as np
import numpy as np
import six
+if six.PY3:
+ import builtins
+else:
+ import __builtin__
from libcpp.string cimport string
from libcpp.vector cimport vector
+from libcpp.list cimport list
from cython.operator cimport dereference as deref, preincrement as inc
from cpython.version cimport PY_MAJOR_VERSION
@@ -39,9 +44,6 @@ from .PyXMLDocument cimport XMLDocument
from .PyXMLDocument cimport XMLNode
from .PyIncludes cimport *
-cdef extern from "astra/PluginAlgorithm.h" namespace "astra":
- object XMLNode2dict(XMLNode)
-
cdef Config * dictToConfig(string rootname, dc):
cdef Config * cfg = new Config()
@@ -93,7 +95,7 @@ cdef void readDict(XMLNode root, _dc):
dc = convert_item(_dc)
for item in dc:
val = dc[item]
- if isinstance(val, list) or isinstance(val, tuple):
+ if isinstance(val, __builtins__.list) or isinstance(val, tuple):
val = np.array(val,dtype=np.float64)
if isinstance(val, np.ndarray):
if val.size == 0:
@@ -129,7 +131,7 @@ cdef void readOptions(XMLNode node, dc):
val = dc[item]
if node.hasOption(item):
raise Exception('Duplicate Option: %s' % item)
- if isinstance(val, list) or isinstance(val, tuple):
+ if isinstance(val, __builtins__.list) or isinstance(val, tuple):
val = np.array(val,dtype=np.float64)
if isinstance(val, np.ndarray):
if val.size == 0:
@@ -149,3 +151,73 @@ cdef void readOptions(XMLNode node, dc):
cdef configToDict(Config *cfg):
return XMLNode2dict(cfg.self)
+
+def castString3(input):
+ return input.decode('utf-8')
+
+def castString2(input):
+ return input
+
+if six.PY3:
+ castString = castString3
+else:
+ castString = castString2
+
+def stringToPythonValue(inputIn):
+ input = castString(inputIn)
+ # matrix
+ if ';' in input:
+ row_strings = input.split(';')
+ col_strings = row_strings[0].split(',')
+ nRows = len(row_strings)
+ nCols = len(col_strings)
+
+ out = np.empty((nRows,nCols))
+ for ridx, row in enumerate(row_strings):
+ col_strings = row.split(',')
+ for cidx, col in enumerate(col_strings):
+ out[ridx,cidx] = float(col)
+ return out
+
+ # vector
+ if ',' in input:
+ items = input.split(',')
+ out = np.empty(len(items))
+ for idx,item in enumerate(items):
+ out[idx] = float(item)
+ return out
+
+ try:
+ # integer
+ return int(input)
+ except ValueError:
+ try:
+ #float
+ return float(input)
+ except ValueError:
+ # string
+ return str(input)
+
+
+cdef XMLNode2dict(XMLNode node):
+ cdef XMLNode subnode
+ cdef list[XMLNode] nodes
+ cdef list[XMLNode].iterator it
+ dct = {}
+ opts = {}
+ if node.hasAttribute(six.b('type')):
+ dct['type'] = castString(node.getAttribute(six.b('type')))
+ nodes = node.getNodes()
+ it = nodes.begin()
+ while it != nodes.end():
+ subnode = deref(it)
+ if castString(subnode.getName())=="Option":
+ if subnode.hasAttribute('value'):
+ opts[castString(subnode.getAttribute('key'))] = stringToPythonValue(subnode.getAttribute('value'))
+ else:
+ opts[castString(subnode.getAttribute('key'))] = stringToPythonValue(subnode.getContent())
+ else:
+ dct[castString(subnode.getName())] = stringToPythonValue(subnode.getContent())
+ inc(it)
+ if len(opts)>0: dct['options'] = opts
+ return dct