diff options
Diffstat (limited to 'src/Python')
-rw-r--r-- | src/Python/fista_module.cpp | 123 | ||||
-rw-r--r-- | src/Python/setup.py | 1 |
2 files changed, 123 insertions, 1 deletions
diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp index d890b10..c2d9352 100644 --- a/src/Python/fista_module.cpp +++ b/src/Python/fista_module.cpp @@ -29,6 +29,7 @@ limitations under the License. #include "SplitBregman_TV_core.h" #include "FGP_TV_core.h" #include "LLT_model_core.h" +#include "PatchBased_Regul_core.h" #include "utils.h" @@ -793,7 +794,7 @@ bp::list LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, d if (switcher != 0) { Map = (unsigned short*)mxGetPr(plhs[1] = mxCreateNumericArray(3, dim_array, mxUINT16_CLASS, mxREAL)); }*/ - bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); np::dtype dtype = np::dtype::get_builtin<float>(); @@ -865,6 +866,126 @@ bp::list LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, d } +bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, double d_h, double d_lambda) { + // the result is in the following list + bp::list result; + + int N, M, Z, numdims, SearchW, /*SimilW, SearchW_real,*/ padXY, newsizeX, newsizeY, newsizeZ, switchpad_crop; + //const int *dims; + float *A, *B = NULL, *Ap = NULL, *Bp = NULL, h, lambda; + + numdims = input.get_nd(); + int dims[3]; + + dims[0] = input.shape(0); + dims[1] = input.shape(1); + if (numdims == 2) { + dims[2] = -1; + } + else { + dims[2] = input.shape(2); + } + /*numdims = mxGetNumberOfDimensions(prhs[0]); + dims = mxGetDimensions(prhs[0]);*/ + + N = dims[0]; + M = dims[1]; + Z = dims[2]; + + //if ((numdims < 2) || (numdims > 3)) { mexErrMsgTxt("The input should be 2D image or 3D volume"); } + //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input in single precision is required"); } + + //if (nrhs != 5) mexErrMsgTxt("Five inputs reqired: Image(2D,3D), SearchW, SimilW, Threshold, Regularization parameter"); + + ///*Handling inputs*/ + //A = (float *)mxGetData(prhs[0]); /* the image to regularize/filter */ + //SearchW_real = (int)mxGetScalar(prhs[1]); /* the searching window ratio */ + //SimilW = (int)mxGetScalar(prhs[2]); /* the similarity window ratio */ + //h = (float)mxGetScalar(prhs[3]); /* parameter for the PB filtering function */ + //lambda = (float)mxGetScalar(prhs[4]); /* regularization parameter */ + + //if (h <= 0) mexErrMsgTxt("Parmeter for the PB penalty function should be > 0"); + //if (lambda <= 0) mexErrMsgTxt(" Regularization parmeter should be > 0"); + + SearchW = SearchW_real + 2 * SimilW; + + /* SearchW_full = 2*SearchW + 1; */ /* the full searching window size */ + /* SimilW_full = 2*SimilW + 1; */ /* the full similarity window size */ + + + padXY = SearchW + 2 * SimilW; /* padding sizes */ + newsizeX = N + 2 * (padXY); /* the X size of the padded array */ + newsizeY = M + 2 * (padXY); /* the Y size of the padded array */ + newsizeZ = Z + 2 * (padXY); /* the Z size of the padded array */ + int N_dims[] = { newsizeX, newsizeY, newsizeZ }; + + /******************************2D case ****************************/ + if (numdims == 2) { + ///*Handling output*/ + //B = (float*)mxGetData(plhs[0] = mxCreateNumericMatrix(N, M, mxSINGLE_CLASS, mxREAL)); + ///*allocating memory for the padded arrays */ + //Ap = (float*)mxGetData(mxCreateNumericMatrix(newsizeX, newsizeY, mxSINGLE_CLASS, mxREAL)); + //Bp = (float*)mxGetData(mxCreateNumericMatrix(newsizeX, newsizeY, mxSINGLE_CLASS, mxREAL)); + ///**************************************************************************/ + + bp::tuple shape = bp::make_tuple(N, M); + np::dtype dtype = np::dtype::get_builtin<float>(); + + np::ndarray npB = np::zeros(shape, dtype); + + shape = bp::make_tuple(newsizeX, newsizeY); + np::ndarray npAp = np::zeros(shape, dtype); + np::ndarray npBp = np::zeros(shape, dtype); + B = reinterpret_cast<float *>(npB.get_data()); + Ap = reinterpret_cast<float *>(npAp.get_data()); + Bp = reinterpret_cast<float *>(npBp.get_data()); + + /*Perform padding of image A to the size of [newsizeX * newsizeY] */ + switchpad_crop = 0; /*padding*/ + pad_crop(A, Ap, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop); + + /* Do PB regularization with the padded array */ + PB_FUNC2D(Ap, Bp, newsizeY, newsizeX, padXY, SearchW, SimilW, (float)h, (float)lambda); + + switchpad_crop = 1; /*cropping*/ + pad_crop(Bp, B, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop); + result.append<np::ndarray>(npB); + } + else + { + /******************************3D case ****************************/ + ///*Handling output*/ + //B = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dims, mxSINGLE_CLASS, mxREAL)); + ///*allocating memory for the padded arrays */ + //Ap = (float*)mxGetPr(mxCreateNumericArray(3, N_dims, mxSINGLE_CLASS, mxREAL)); + //Bp = (float*)mxGetPr(mxCreateNumericArray(3, N_dims, mxSINGLE_CLASS, mxREAL)); + /**************************************************************************/ + bp::tuple shape = bp::make_tuple(dims[0], dims[1], dims[2]); + bp::tuple shape_AB = bp::make_tuple(N_dims[0], N_dims[1], N_dims[2]); + np::dtype dtype = np::dtype::get_builtin<float>(); + + np::ndarray npB = np::zeros(shape, dtype); + np::ndarray npAp = np::zeros(shape_AB, dtype); + np::ndarray npBp = np::zeros(shape_AB, dtype); + B = reinterpret_cast<float *>(npB.get_data()); + Ap = reinterpret_cast<float *>(npAp.get_data()); + Bp = reinterpret_cast<float *>(npBp.get_data()); + /*Perform padding of image A to the size of [newsizeX * newsizeY * newsizeZ] */ + switchpad_crop = 0; /*padding*/ + pad_crop(A, Ap, M, N, Z, newsizeY, newsizeX, newsizeZ, padXY, switchpad_crop); + + /* Do PB regularization with the padded array */ + PB_FUNC3D(Ap, Bp, newsizeY, newsizeX, newsizeZ, padXY, SearchW, SimilW, (float)h, (float)lambda); + + switchpad_crop = 1; /*cropping*/ + pad_crop(Bp, B, M, N, Z, newsizeY, newsizeX, newsizeZ, padXY, switchpad_crop); + + result.append<np::ndarray>(npB); + } /*end else ndims*/ + + return result; +} + BOOST_PYTHON_MODULE(regularizers) { np::initialize(); diff --git a/src/Python/setup.py b/src/Python/setup.py index a8feb1c..a4eed14 100644 --- a/src/Python/setup.py +++ b/src/Python/setup.py @@ -52,6 +52,7 @@ setup( "..\\..\\main_func\\regularizers_CPU\\FGP_TV_core.c", "..\\..\\main_func\\regularizers_CPU\\SplitBregman_TV_core.c", "..\\..\\main_func\\regularizers_CPU\\LLT_model_core.c", + "..\\..\\main_func\\regularizers_CPU\\PatchBased_Regul_core.c", "..\\..\\main_func\\regularizers_CPU\\utils.c" ], include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), |