diff options
Diffstat (limited to 'python/astra')
| -rw-r--r-- | python/astra/astra.py | 4 | ||||
| -rw-r--r-- | python/astra/astra_c.pyx | 21 | 
2 files changed, 20 insertions, 5 deletions
| diff --git a/python/astra/astra.py b/python/astra/astra.py index 26b1ff0..9328b6b 100644 --- a/python/astra/astra.py +++ b/python/astra/astra.py @@ -49,10 +49,10 @@ def version(printToScreen=False):      """      return a.version(printToScreen) -def set_gpu_index(idx): +def set_gpu_index(idx, memory=0):      """Set default GPU index to use.      :param idx: GPU index      :type idx: :class:`int`      """ -    a.set_gpu_index(idx) +    a.set_gpu_index(idx, memory) diff --git a/python/astra/astra_c.pyx b/python/astra/astra_c.pyx index 6b246b6..2a9c816 100644 --- a/python/astra/astra_c.pyx +++ b/python/astra/astra_c.pyx @@ -31,6 +31,7 @@ import six  from .utils import wrap_from_bytes  from libcpp.string cimport string +from libcpp.vector cimport vector  from libcpp cimport bool  cdef extern from "astra/Globals.h" namespace "astra":      int getVersion() @@ -43,6 +44,12 @@ IF HAVE_CUDA==True:  ELSE:    def setGPUIndex():      pass +cdef extern from "astra/CompositeGeometryManager.h" namespace "astra": +    cdef cppclass SGPUParams: +        vector[int] GPUIndices +        size_t memory +cdef extern from "astra/CompositeGeometryManager.h" namespace "astra::CCompositeGeometryManager": +    void setGlobalGPUParams(SGPUParams&)  def credits():      six.print_("""The ASTRA Toolbox has been developed at the University of Antwerp and CWI, Amsterdam by @@ -70,8 +77,16 @@ def version(printToScreen=False):      else:          return getVersion() -def set_gpu_index(idx): +def set_gpu_index(idx, memory=0): +    import types +    import collections +    cdef SGPUParams params      if use_cuda()==True: -        ret = setGPUIndex(idx) +        if not isinstance(idx, collections.Iterable) or isinstance(idx, types.StringTypes): +            idx = (idx,) +        params.memory = memory +        params.GPUIndices = idx +        setGlobalGPUParams(params) +        ret = setGPUIndex(params.GPUIndices[0])          if not ret: -            six.print_("Failed to set GPU " + str(idx)) +            six.print_("Failed to set GPU " + str(params.GPUIndices[0])) | 
