From fd496731c8e9d4975864d76dbb6574cbeee7cf98 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 4 Aug 2017 16:16:37 +0100 Subject: Added 3 regularizers SplitBregman_TV FGP_TV LLT_model --- src/Python/fista_module.cpp | 266 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 232 insertions(+), 34 deletions(-) (limited to 'src/Python') diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp index 2492884..d890b10 100644 --- a/src/Python/fista_module.cpp +++ b/src/Python/fista_module.cpp @@ -3,7 +3,7 @@ This work is part of the Core Imaging Library developed by Visual Analytics and Imaging System Group of the Science Technology Facilities Council, STFC -Copyright 2017 Daniil Kazanteev +Copyright 2017 Daniil Kazantsev Copyright 2017 Srikanth Nagella, Edoardo Pasca Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,6 +28,8 @@ limitations under the License. #include "SplitBregman_TV_core.h" #include "FGP_TV_core.h" +#include "LLT_model_core.h" +#include "utils.h" @@ -131,6 +133,18 @@ T * mxGetData(const np::ndarray pm) { return reinterpret_cast(prhs[0]); } +template +np::ndarray zeros(int dims, int * dim_array, T el) { + bp::tuple shape; + if (dims == 3) + shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + else if (dims == 2) + shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin(); + np::ndarray zz = np::zeros(shape, dtype); + return zz; +} + @@ -169,7 +183,6 @@ bp::list mexFunction(np::ndarray input) { } } - bp::list result; result.append(number_of_dims); @@ -189,14 +202,14 @@ bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsi // the result is in the following list bp::list result; - int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV; - const int *dim_array; + int number_of_dims, dimX, dimY, dimZ, ll, j, count; + //const int *dim_array; float *A, *U = NULL, *U_old = NULL, *Dx = NULL, *Dy = NULL, *Dz = NULL, *Bx = NULL, *By = NULL, *Bz = NULL, lambda, mu, epsil, re, re1, re_old; //number_of_dims = mxGetNumberOfDimensions(prhs[0]); //dim_array = mxGetDimensions(prhs[0]); - int number_of_dims = input.get_nd(); + number_of_dims = input.get_nd(); int dim_array[3]; dim_array[0] = input.shape(0); @@ -252,26 +265,26 @@ bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsi bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); np::dtype dtype = np::dtype::get_builtin(); - np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npU = np::zeros(shape, dtype); np::ndarray npU_old = np::zeros(shape, dtype); - np::ndarray npDx = np::zeros(shape, dtype); - np::ndarray npDy = np::zeros(shape, dtype); - np::ndarray npBx = np::zeros(shape, dtype); - np::ndarray npBy = np::zeros(shape, dtype); + np::ndarray npDx = np::zeros(shape, dtype); + np::ndarray npDy = np::zeros(shape, dtype); + np::ndarray npBx = np::zeros(shape, dtype); + np::ndarray npBy = np::zeros(shape, dtype); - U = reinterpret_cast(npU.get_data()); + U = reinterpret_cast(npU.get_data()); U_old = reinterpret_cast(npU_old.get_data()); - Dx = reinterpret_cast(npDx.get_data()); - Dy = reinterpret_cast(npDy.get_data()); - Bx = reinterpret_cast(npBx.get_data()); - By = reinterpret_cast(npBy.get_data()); + Dx = reinterpret_cast(npDx.get_data()); + Dy = reinterpret_cast(npDy.get_data()); + Bx = reinterpret_cast(npBx.get_data()); + By = reinterpret_cast(npBy.get_data()); + - copyIm(A, U, dimX, dimY, dimZ); /*initialize */ /* begin outer SB iterations */ - for (ll = 0; ll(npU); - result.append(ll); + } //printf("SB iterations stopped at iteration: %i\n", ll); - if (number_of_dims == 3) { + result.append(npU); + result.append(ll); + } + if (number_of_dims == 3) { /*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); Dx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); @@ -375,24 +390,25 @@ bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsi result.append(npU); result.append(ll); } - } return result; -} + } + + bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int methTV) { // the result is in the following list bp::list result; - int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV; + int number_of_dims, dimX, dimY, dimZ, ll, j, count; float *A, *D = NULL, *D_old = NULL, *P1 = NULL, *P2 = NULL, *P3 = NULL, *P1_old = NULL, *P2_old = NULL, *P3_old = NULL, *R1 = NULL, *R2 = NULL, *R3 = NULL; - float lambda, tk, tkp1, re, re1, re_old, epsil, funcval; + float lambda, tk, tkp1, re, re1, re_old, epsil, funcval, mu; //number_of_dims = mxGetNumberOfDimensions(prhs[0]); //dim_array = mxGetDimensions(prhs[0]); - int number_of_dims = input.get_nd(); + number_of_dims = input.get_nd(); int dim_array[3]; dim_array[0] = input.shape(0); @@ -512,7 +528,7 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me for (j = 0; j(input.get_data()); + lambda = (float)d_lambda; + tau = (float)d_tau; + // iter is passed as parameter + epsil = (float)d_epsil; + // switcher is passed as parameter + /*Handling Matlab output data*/ + dimX = dim_array[0]; dimY = dim_array[1]; dimZ = 1; + + if (number_of_dims == 2) { + /*2D case*/ + /*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + D1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + D2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));*/ + + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin(); + + + np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npU_old = np::zeros(shape, dtype); + np::ndarray npD1 = np::zeros(shape, dtype); + np::ndarray npD2 = np::zeros(shape, dtype); + + + U = reinterpret_cast(npU.get_data()); + U_old = reinterpret_cast(npU_old.get_data()); + D1 = reinterpret_cast(npD1.get_data()); + D2 = reinterpret_cast(npD2.get_data()); + + /*Copy U0 to U*/ + copyIm(U0, U, dimX, dimY, dimZ); + + count = 1; + re_old = 0.0f; + + for (ll = 0; ll < iter; ll++) { + + copyIm(U, U_old, dimX, dimY, dimZ); + + /*estimate inner derrivatives */ + der2D(U, D1, D2, dimX, dimY, dimZ); + /* calculate div^2 and update */ + div_upd2D(U0, U, D1, D2, dimX, dimY, dimZ, lambda, tau); + + /* calculate norm to terminate earlier */ + re = 0.0f; re1 = 0.0f; + for (j = 0; j 4) break; + + /* check that the residual norm is decreasing */ + if (ll > 2) { + if (re > re_old) break; + } + re_old = re; + + } /*end of iterations*/ + //printf("HO iterations stopped at iteration: %i\n", ll); + + result.append(npU); + } + else if (number_of_dims == 3) { + /*3D case*/ + dimZ = dim_array[2]; + /*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + D1 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + D2 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + D3 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + if (switcher != 0) { + Map = (unsigned short*)mxGetPr(plhs[1] = mxCreateNumericArray(3, dim_array, mxUINT16_CLASS, mxREAL)); + }*/ + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin(); + + + np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npU_old = np::zeros(shape, dtype); + np::ndarray npD1 = np::zeros(shape, dtype); + np::ndarray npD2 = np::zeros(shape, dtype); + np::ndarray npD3 = np::zeros(shape, dtype); + np::ndarray npMap = np::zeros(shape, np::dtype::get_builtin()); + Map = reinterpret_cast(npMap.get_data()); + if (switcher != 0) { + //Map = (unsigned short*)mxGetPr(plhs[1] = mxCreateNumericArray(3, dim_array, mxUINT16_CLASS, mxREAL)); + + Map = reinterpret_cast(npMap.get_data()); + } + + U = reinterpret_cast(npU.get_data()); + U_old = reinterpret_cast(npU_old.get_data()); + D1 = reinterpret_cast(npD1.get_data()); + D2 = reinterpret_cast(npD2.get_data()); + D3 = reinterpret_cast(npD2.get_data()); + + /*Copy U0 to U*/ + copyIm(U0, U, dimX, dimY, dimZ); + + count = 1; + re_old = 0.0f; + + + if (switcher == 1) { + /* apply restrictive smoothing */ + calcMap(U, Map, dimX, dimY, dimZ); + /*clear outliers */ + cleanMap(Map, dimX, dimY, dimZ); + } + for (ll = 0; ll < iter; ll++) { + + copyIm(U, U_old, dimX, dimY, dimZ); + + /*estimate inner derrivatives */ + der3D(U, D1, D2, D3, dimX, dimY, dimZ); + /* calculate div^2 and update */ + div_upd3D(U0, U, D1, D2, D3, Map, switcher, dimX, dimY, dimZ, lambda, tau); + + /* calculate norm to terminate earlier */ + re = 0.0f; re1 = 0.0f; + for (j = 0; j 4) break; + + /* check that the residual norm is decreasing */ + if (ll > 2) { + if (re > re_old) break; + } + re_old = re; + + } /*end of iterations*/ + //printf("HO iterations stopped at iteration: %i\n", ll); + result.append(npU); + if (switcher != 0) result.append(npMap); + + } + return result; +} + + +BOOST_PYTHON_MODULE(regularizers) { np::initialize(); //To specify that this module is a package bp::object package = bp::scope(); - package.attr("__path__") = "fista"; + package.attr("__path__") = "regularizers"; np::dtype dt1 = np::dtype::get_builtin(); np::dtype dt2 = np::dtype::get_builtin(); @@ -682,4 +879,5 @@ BOOST_PYTHON_MODULE(fista) def("mexFunction", mexFunction); def("SplitBregman_TV", SplitBregman_TV); def("FGP_TV", FGP_TV); + def("LLT_model", LLT_model); } \ No newline at end of file -- cgit v1.2.3