diff options
Diffstat (limited to 'matlab/mex')
| -rw-r--r-- | matlab/mex/astra_mex_c.cpp | 45 | ||||
| -rw-r--r--[-rwxr-xr-x] | matlab/mex/astra_mex_direct_c.cpp | 0 | ||||
| -rw-r--r-- | matlab/mex/astra_mex_plugin_c.cpp | 86 | ||||
| -rw-r--r-- | matlab/mex/mexInitFunctions.cpp | 8 | 
4 files changed, 125 insertions, 14 deletions
| diff --git a/matlab/mex/astra_mex_c.cpp b/matlab/mex/astra_mex_c.cpp index d34334c..fdf4f33 100644 --- a/matlab/mex/astra_mex_c.cpp +++ b/matlab/mex/astra_mex_c.cpp @@ -38,6 +38,7 @@ $Id$  #include "astra/Globals.h"  #ifdef ASTRA_CUDA  #include "../cuda/2d/darthelper.h" +#include "astra/CompositeGeometryManager.h"  #endif  using namespace std;  using namespace astra; @@ -83,12 +84,46 @@ void astra_mex_use_cuda(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs   * Set active GPU   */  void astra_mex_set_gpu_index(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) -{  +{  #ifdef ASTRA_CUDA -	if (nrhs >= 2) { -		bool ret = astraCUDA::setGPUIndex((int)mxGetScalar(prhs[1])); -		if (!ret) -			mexPrintf("Failed to set GPU %d\n", (int)mxGetScalar(prhs[1])); +	bool usage = false; +	if (nrhs != 2 && nrhs != 4) { +		usage = true; +	} + +	astra::SGPUParams params; +	params.memory = 0; + +	if (!usage && nrhs >= 4) { +		std::string s = mexToString(prhs[2]); +		if (s != "memory") { +			usage = true; +		} else { +			params.memory = (size_t)mxGetScalar(prhs[3]); +		} +	} + +	if (!usage && nrhs >= 2) { +		int n = mxGetN(prhs[1]) * mxGetM(prhs[1]); +		params.GPUIndices.resize(n); +		double* pdMatlabData = mxGetPr(prhs[1]); +		for (int i = 0; i < n; ++i) +			params.GPUIndices[i] = (int)pdMatlabData[i]; + + +		astra::CCompositeGeometryManager::setGlobalGPUParams(params); + + +		// Set first GPU +		if (n >= 1) { +			bool ret = astraCUDA::setGPUIndex((int)pdMatlabData[0]); +			if (!ret) +				mexPrintf("Failed to set GPU %d\n", (int)pdMatlabData[0]); +		} +	} + +	if (usage) { +		mexPrintf("Usage: astra_mex('set_gpu_index', index/indices [, 'memory', memory])");  	}  #endif  } diff --git a/matlab/mex/astra_mex_direct_c.cpp b/matlab/mex/astra_mex_direct_c.cpp index 38b3f59..38b3f59 100755..100644 --- a/matlab/mex/astra_mex_direct_c.cpp +++ b/matlab/mex/astra_mex_direct_c.cpp diff --git a/matlab/mex/astra_mex_plugin_c.cpp b/matlab/mex/astra_mex_plugin_c.cpp index 177fcf4..4ed534e 100644 --- a/matlab/mex/astra_mex_plugin_c.cpp +++ b/matlab/mex/astra_mex_plugin_c.cpp @@ -37,9 +37,63 @@ $Id$  #include "astra/PluginAlgorithm.h" +#include <Python.h> +  using namespace std;  using namespace astra; +static void fixLapackLoading() +{ +    // When running in Matlab, we need to force numpy +    // to use its internal lapack library instead of +    // Matlab's MKL library to avoid errors. To do this, +    // we set Python's dlopen flags to RTLD_NOW|RTLD_DEEPBIND +    // and import 'numpy.linalg.lapack_lite' here. We reset +    // Python's dlopen flags afterwards. +    PyObject *sys = PyImport_ImportModule("sys"); +    if (sys != NULL) { +        PyObject *curFlags = PyObject_CallMethod(sys, "getdlopenflags", NULL); +        if (curFlags != NULL) { +            PyObject *retVal = PyObject_CallMethod(sys, "setdlopenflags", "i", 10); // RTLD_NOW|RTLD_DEEPBIND +            if (retVal != NULL) { +                PyObject *lapack = PyImport_ImportModule("numpy.linalg.lapack_lite"); +                if (lapack != NULL) { +                    Py_DECREF(lapack); +                } +                PyObject *retVal2 = PyObject_CallMethod(sys, "setdlopenflags", "O",curFlags); +                if (retVal2 != NULL) { +                    Py_DECREF(retVal2); +                } +                Py_DECREF(retVal); +            } +            Py_DECREF(curFlags); +        } +        Py_DECREF(sys); +    } +} + +//----------------------------------------------------------------------------------------- +/** astra_mex_plugin('init'); + * + * Initialize plugin support by initializing python and importing astra + */ +void astra_mex_plugin_init() +{ +    if(!Py_IsInitialized()){ +        Py_Initialize(); +        PyEval_InitThreads(); +    } + +#ifndef _MSC_VER +    fixLapackLoading(); +#endif + +    // Importing astra may be overkill, since we only need to initialize +    // PythonPluginAlgorithmFactory from astra.plugin_c. +    PyObject *mod = PyImport_ImportModule("astra"); +    Py_XDECREF(mod); +} +  //-----------------------------------------------------------------------------------------  /** astra_mex_plugin('get_registered'); @@ -48,7 +102,11 @@ using namespace astra;   */  void astra_mex_plugin_get_registered(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])  { -    astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr(); +    astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getFactory(); +    if (!fact) { +        mexPrintf("Plugin support not initialized."); +        return; +    }      std::map<std::string, std::string> mp = fact->getRegisteredMap();      for(std::map<std::string,std::string>::iterator it=mp.begin();it!=mp.end();it++){          mexPrintf("%s: %s\n",it->first.c_str(), it->second.c_str()); @@ -62,9 +120,13 @@ void astra_mex_plugin_get_registered(int nlhs, mxArray* plhs[], int nrhs, const   */  void astra_mex_plugin_register(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])  { +    astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getFactory(); +    if (!fact) { +        mexPrintf("Plugin support not initialized."); +        return; +    }      if (2 <= nrhs) {          string class_name = mexToString(prhs[1]); -        astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr();          fact->registerPlugin(class_name);      }else{          mexPrintf("astra_mex_plugin('register', class_name);\n"); @@ -78,9 +140,13 @@ void astra_mex_plugin_register(int nlhs, mxArray* plhs[], int nrhs, const mxArra   */  void astra_mex_plugin_get_help(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])  { +    astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getFactory(); +    if (!fact) { +        mexPrintf("Plugin support not initialized."); +        return; +    }      if (2 <= nrhs) {          string name = mexToString(prhs[1]); -        astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr();          mexPrintf((fact->getHelp(name)+"\n").c_str());      }else{          mexPrintf("astra_mex_plugin('get_help', name);\n"); @@ -116,12 +182,14 @@ void mexFunction(int nlhs, mxArray* plhs[],  	initASTRAMex();  	// SWITCH (MODE) -	if (sMode ==  std::string("get_registered")) {  -		astra_mex_plugin_get_registered(nlhs, plhs, nrhs, prhs);  -    }else if (sMode ==  std::string("get_help")) {  -        astra_mex_plugin_get_help(nlhs, plhs, nrhs, prhs);  -    }else if (sMode ==  std::string("register")) {  -		astra_mex_plugin_register(nlhs, plhs, nrhs, prhs);  +	if (sMode == "init") { +		astra_mex_plugin_init(); +	} else if (sMode ==  std::string("get_registered")) { +		astra_mex_plugin_get_registered(nlhs, plhs, nrhs, prhs); +	}else if (sMode ==  std::string("get_help")) { +		astra_mex_plugin_get_help(nlhs, plhs, nrhs, prhs); +	}else if (sMode ==  std::string("register")) { +		astra_mex_plugin_register(nlhs, plhs, nrhs, prhs);  	} else {  		printHelp();  	} diff --git a/matlab/mex/mexInitFunctions.cpp b/matlab/mex/mexInitFunctions.cpp index bd3df2c..7245af2 100644 --- a/matlab/mex/mexInitFunctions.cpp +++ b/matlab/mex/mexInitFunctions.cpp @@ -23,5 +23,13 @@ void initASTRAMex(){      if(!astra::CLogger::setCallbackScreen(&logCallBack)){          mexErrMsgTxt("Error initializing mex functions.");      } +      mexIsInitialized=true; + + +    // If we have support for plugins, initialize them. +    // (NB: Call this after setting mexIsInitialized, to avoid recursively +    //      calling initASTRAMex) +    mexEvalString("if exist('astra_mex_plugin_c') == 3; astra_mex_plugin_c('init'); end"); +  } | 
