diff options
author | Edoardo Pasca <edo.paskino@gmail.com> | 2017-08-07 17:21:12 +0100 |
---|---|---|
committer | Edoardo Pasca <edo.paskino@gmail.com> | 2017-08-07 17:21:12 +0100 |
commit | 6589fa197d9f87f7a37f46943aa995d97f50bb46 (patch) | |
tree | 0a789dc6a41ea7d5d47e56c6adfa0ba46a1d1872 /src | |
parent | 0ce5b086d137761f2c31b9ae1c271c9aeb8d8fb8 (diff) | |
download | regularization-6589fa197d9f87f7a37f46943aa995d97f50bb46.tar.gz regularization-6589fa197d9f87f7a37f46943aa995d97f50bb46.tar.bz2 regularization-6589fa197d9f87f7a37f46943aa995d97f50bb46.tar.xz regularization-6589fa197d9f87f7a37f46943aa995d97f50bb46.zip |
added TGV_PD, removed useless code
Diffstat (limited to 'src')
-rw-r--r-- | src/Python/fista_module.cpp | 245 |
1 files changed, 146 insertions, 99 deletions
diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp index c2d9352..eacda3d 100644 --- a/src/Python/fista_module.cpp +++ b/src/Python/fista_module.cpp @@ -30,6 +30,7 @@ limitations under the License. #include "FGP_TV_core.h" #include "LLT_model_core.h" #include "PatchBased_Regul_core.h" +#include "TGV_PD_core.h" #include "utils.h" @@ -103,101 +104,8 @@ If unsuccessful in a MEX file, the MEX file terminates and returns control to th enough free heap space to create the mxArray. */ -void mexErrMessageText(char* text) { - std::cerr << text << std::endl; -} - -/* -double mxGetScalar(const mxArray *pm); -args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. -Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray. In C, mxGetScalar returns a double. -*/ - -template<typename T> -double mxGetScalar(const np::ndarray plh) { - return (double)bp::extract<T>(plh[0]); -} - - - -template<typename T> -T * mxGetData(const np::ndarray pm) { - //args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. - //Returns: Pointer to the value of the first real(nonimaginary) element of the mxArray.In C, mxGetScalar returns a double. - /*Access the numpy array pointer: - char * get_data() const; - Returns: Array’s raw data pointer as a char - Note: This returns char so stride math works properly on it.User will have to reinterpret_cast it. - probably this would work. - A = reinterpret_cast<float *>(prhs[0]); - */ - return reinterpret_cast<T *>(prhs[0]); -} - -template<typename T> -np::ndarray zeros(int dims, int * dim_array, T el) { - bp::tuple shape; - if (dims == 3) - shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); - else if (dims == 2) - shape = bp::make_tuple(dim_array[0], dim_array[1]); - np::dtype dtype = np::dtype::get_builtin<T>(); - np::ndarray zz = np::zeros(shape, dtype); - return zz; -} - - -bp::list mexFunction(np::ndarray input) { - int number_of_dims = input.get_nd(); - int dim_array[3]; - - dim_array[0] = input.shape(0); - dim_array[1] = input.shape(1); - if (number_of_dims == 2) { - dim_array[2] = -1; - } - else { - dim_array[2] = input.shape(2); - } - - /**************************************************************************/ - np::ndarray zz = zeros(3, dim_array, (int)0); - np::ndarray fzz = zeros(3, dim_array, (float)0); - /**************************************************************************/ - - int * A = reinterpret_cast<int *>(input.get_data()); - int * B = reinterpret_cast<int *>(zz.get_data()); - float * C = reinterpret_cast<float *>(fzz.get_data()); - - //Copy data and cast - for (int i = 0; i < dim_array[0]; i++) { - for (int j = 0; j < dim_array[1]; j++) { - for (int k = 0; k < dim_array[2]; k++) { - int index = k + dim_array[2] * j + dim_array[2] * dim_array[1] * i; - int val = (*(A + index)); - float fval = (float)val; - std::memcpy(B + index, &val, sizeof(int)); - std::memcpy(C + index, &fval, sizeof(float)); - } - } - } - - bp::list result; - - result.append<int>(number_of_dims); - result.append<int>(dim_array[0]); - result.append<int>(dim_array[1]); - result.append<int>(dim_array[2]); - result.append<np::ndarray>(zz); - result.append<np::ndarray>(fzz); - - //result.append<bp::tuple>(tup); - return result; - -} - bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int methTV) { // the result is in the following list @@ -487,7 +395,7 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me np::ndarray npP1_old = np::zeros(shape, dtype); np::ndarray npP2_old = np::zeros(shape, dtype); np::ndarray npR1 = np::zeros(shape, dtype); - np::ndarray npR2 = zeros(2, dim_array, (float)0); + np::ndarray npR2 = np::zeros(shape, dtype); D = reinterpret_cast<float *>(npD.get_data()); D_old = reinterpret_cast<float *>(npD_old.get_data()); @@ -866,7 +774,7 @@ bp::list LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, d } -bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, double d_h, double d_lambda) { +bp::list PatchBased_Regul(np::ndarray input, double d_lambda, int SearchW_real, int SimilW, double d_h) { // the result is in the following list bp::list result; @@ -899,6 +807,7 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, doub ///*Handling inputs*/ //A = (float *)mxGetData(prhs[0]); /* the image to regularize/filter */ + A = reinterpret_cast<float *>(input.get_data()); //SearchW_real = (int)mxGetScalar(prhs[1]); /* the searching window ratio */ //SimilW = (int)mxGetScalar(prhs[2]); /* the similarity window ratio */ //h = (float)mxGetScalar(prhs[3]); /* parameter for the PB filtering function */ @@ -907,6 +816,8 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, doub //if (h <= 0) mexErrMsgTxt("Parmeter for the PB penalty function should be > 0"); //if (lambda <= 0) mexErrMsgTxt(" Regularization parmeter should be > 0"); + lambda = (float)d_lambda; + h = (float)d_h; SearchW = SearchW_real + 2 * SimilW; /* SearchW_full = 2*SearchW + 1; */ /* the full searching window size */ @@ -918,7 +829,6 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, doub newsizeY = M + 2 * (padXY); /* the Y size of the padded array */ newsizeZ = Z + 2 * (padXY); /* the Z size of the padded array */ int N_dims[] = { newsizeX, newsizeY, newsizeZ }; - /******************************2D case ****************************/ if (numdims == 2) { ///*Handling output*/ @@ -943,12 +853,13 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, doub /*Perform padding of image A to the size of [newsizeX * newsizeY] */ switchpad_crop = 0; /*padding*/ pad_crop(A, Ap, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop); - + /* Do PB regularization with the padded array */ PB_FUNC2D(Ap, Bp, newsizeY, newsizeX, padXY, SearchW, SimilW, (float)h, (float)lambda); - + switchpad_crop = 1; /*cropping*/ pad_crop(Bp, B, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop); + result.append<np::ndarray>(npB); } else @@ -986,6 +897,141 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, doub return result; } +bp::list TGV_PD(np::ndarray input, double d_lambda, double d_alpha1, double d_alpha0, int iter) { + // the result is in the following list + bp::list result; + int number_of_dims, /*iter,*/ dimX, dimY, dimZ, ll; + //const int *dim_array; + float *A, *U, *U_old, *P1, *P2, *P3, *Q1, *Q2, *Q3, *Q4, *Q5, *Q6, *Q7, *Q8, *Q9, *V1, *V1_old, *V2, *V2_old, *V3, *V3_old, lambda, L2, tau, sigma, alpha1, alpha0; + + //number_of_dims = mxGetNumberOfDimensions(prhs[0]); + //dim_array = mxGetDimensions(prhs[0]); + number_of_dims = input.get_nd(); + int dim_array[3]; + + dim_array[0] = input.shape(0); + dim_array[1] = input.shape(1); + if (number_of_dims == 2) { + dim_array[2] = -1; + } + else { + dim_array[2] = input.shape(2); + } + /*Handling Matlab input data*/ + //A = (float *)mxGetData(prhs[0]); /*origanal noise image/volume*/ + //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input in single precision is required"); } + + A = reinterpret_cast<float *>(input.get_data()); + + //lambda = (float)mxGetScalar(prhs[1]); /*regularization parameter*/ + //alpha1 = (float)mxGetScalar(prhs[2]); /*first-order term*/ + //alpha0 = (float)mxGetScalar(prhs[3]); /*second-order term*/ + //iter = (int)mxGetScalar(prhs[4]); /*iterations number*/ + //if (nrhs != 5) mexErrMsgTxt("Five input parameters is reqired: Image(2D/3D), Regularization parameter, alpha1, alpha0, Iterations"); + lambda = (float)d_lambda; + alpha1 = (float)d_alpha1; + alpha0 = (float)d_alpha0; + + /*Handling Matlab output data*/ + dimX = dim_array[0]; dimY = dim_array[1]; + + if (number_of_dims == 2) { + /*2D case*/ + dimZ = 1; + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin<float>(); + + np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npP1 = np::zeros(shape, dtype); + np::ndarray npP2 = np::zeros(shape, dtype); + np::ndarray npQ1 = np::zeros(shape, dtype); + np::ndarray npQ2 = np::zeros(shape, dtype); + np::ndarray npQ3 = np::zeros(shape, dtype); + np::ndarray npV1 = np::zeros(shape, dtype); + np::ndarray npV1_old = np::zeros(shape, dtype); + np::ndarray npV2 = np::zeros(shape, dtype); + np::ndarray npV2_old = np::zeros(shape, dtype); + np::ndarray npU_old = np::zeros(shape, dtype); + + U = reinterpret_cast<float *>(npU.get_data()); + U_old = reinterpret_cast<float *>(npU_old.get_data()); + P1 = reinterpret_cast<float *>(npP1.get_data()); + P2 = reinterpret_cast<float *>(npP2.get_data()); + Q1 = reinterpret_cast<float *>(npQ1.get_data()); + Q2 = reinterpret_cast<float *>(npQ2.get_data()); + Q3 = reinterpret_cast<float *>(npQ3.get_data()); + V1 = reinterpret_cast<float *>(npV1.get_data()); + V1_old = reinterpret_cast<float *>(npV1_old.get_data()); + V2 = reinterpret_cast<float *>(npV2.get_data()); + V2_old = reinterpret_cast<float *>(npV2_old.get_data()); + //U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + + /*dual variables*/ + /*P1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + P2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + + Q1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Q2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Q3 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + + U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + + V1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + V1_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + V2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + V2_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));*/ + /*printf("%i \n", i);*/ + L2 = 12.0; /*Lipshitz constant*/ + tau = 1.0 / pow(L2, 0.5); + sigma = 1.0 / pow(L2, 0.5); + + /*Copy A to U*/ + copyIm(A, U, dimX, dimY, dimZ); + /* Here primal-dual iterations begin for 2D */ + for (ll = 0; ll < iter; ll++) { + + /* Calculate Dual Variable P */ + DualP_2D(U, V1, V2, P1, P2, dimX, dimY, dimZ, sigma); + + /*Projection onto convex set for P*/ + ProjP_2D(P1, P2, dimX, dimY, dimZ, alpha1); + + /* Calculate Dual Variable Q */ + DualQ_2D(V1, V2, Q1, Q2, Q3, dimX, dimY, dimZ, sigma); + + /*Projection onto convex set for Q*/ + ProjQ_2D(Q1, Q2, Q3, dimX, dimY, dimZ, alpha0); + + /*saving U into U_old*/ + copyIm(U, U_old, dimX, dimY, dimZ); + + /*adjoint operation -> divergence and projection of P*/ + DivProjP_2D(U, A, P1, P2, dimX, dimY, dimZ, lambda, tau); + + /*get updated solution U*/ + newU(U, U_old, dimX, dimY, dimZ); + + /*saving V into V_old*/ + copyIm(V1, V1_old, dimX, dimY, dimZ); + copyIm(V2, V2_old, dimX, dimY, dimZ); + + /* upd V*/ + UpdV_2D(V1, V2, P1, P2, Q1, Q2, Q3, dimX, dimY, dimZ, tau); + + /*get new V*/ + newU(V1, V1_old, dimX, dimY, dimZ); + newU(V2, V2_old, dimX, dimY, dimZ); + } /*end of iterations*/ + + result.append<np::ndarray>(npU); + } + + + + + return result; +} + BOOST_PYTHON_MODULE(regularizers) { np::initialize(); @@ -997,8 +1043,9 @@ BOOST_PYTHON_MODULE(regularizers) np::dtype dt1 = np::dtype::get_builtin<uint8_t>(); np::dtype dt2 = np::dtype::get_builtin<uint16_t>(); - def("mexFunction", mexFunction); def("SplitBregman_TV", SplitBregman_TV); def("FGP_TV", FGP_TV); def("LLT_model", LLT_model); + def("PatchBased_Regul", PatchBased_Regul); + def("TGV_PD", TGV_PD); }
\ No newline at end of file |