summaryrefslogtreecommitdiffstats
path: root/src/Matlab
diff options
context:
space:
mode:
Diffstat (limited to 'src/Matlab')
-rw-r--r--src/Matlab/mex_compile/regularisers_GPU/TGV_GPU.cpp32
1 files changed, 17 insertions, 15 deletions
diff --git a/src/Matlab/mex_compile/regularisers_GPU/TGV_GPU.cpp b/src/Matlab/mex_compile/regularisers_GPU/TGV_GPU.cpp
index edb551d..1173282 100644
--- a/src/Matlab/mex_compile/regularisers_GPU/TGV_GPU.cpp
+++ b/src/Matlab/mex_compile/regularisers_GPU/TGV_GPU.cpp
@@ -21,18 +21,18 @@ limitations under the License.
#include "TGV_GPU_core.h"
/* CUDA implementation of Primal-Dual denoising method for
- * Total Generilized Variation (TGV)-L2 model [1] (2D case only)
+ * Total Generilized Variation (TGV)-L2 model [1] (2D/3D)
*
* Input Parameters:
- * 1. Noisy image (2D) (required)
- * 2. lambda - regularisation parameter (required)
- * 3. parameter to control the first-order term (alpha1) (default - 1)
- * 4. parameter to control the second-order term (alpha0) (default - 0.5)
- * 5. Number of Chambolle-Pock (Primal-Dual) iterations (default is 300)
+ * 1. Noisy image/volume (2D/3D)
+ * 2. lambda - regularisation parameter
+ * 3. parameter to control the first-order term (alpha1)
+ * 4. parameter to control the second-order term (alpha0)
+ * 5. Number of Chambolle-Pock (Primal-Dual) iterations
* 6. Lipshitz constant (default is 12)
*
* Output:
- * Filtered/regulariaed image
+ * Filtered/regularised image
*
* References:
* [1] K. Bredies "Total Generalized Variation"
@@ -44,7 +44,7 @@ void mexFunction(
{
int number_of_dims, iter;
- mwSize dimX, dimY;
+ mwSize dimX, dimY, dimZ;
const mwSize *dim_array;
float *Input, *Output=NULL, lambda, alpha0, alpha1, L2;
@@ -57,8 +57,8 @@ void mexFunction(
Input = (float *) mxGetData(prhs[0]); /*noisy image (2D) */
lambda = (float) mxGetScalar(prhs[1]); /* regularisation parameter */
alpha1 = 1.0f; /* parameter to control the first-order term */
- alpha0 = 0.5f; /* parameter to control the second-order term */
- iter = 300; /* Iterations number */
+ alpha0 = 2.0f; /* parameter to control the second-order term */
+ iter = 500; /* Iterations number */
L2 = 12.0f; /* Lipshitz constant */
if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) {mexErrMsgTxt("The input image must be in a single precision"); }
@@ -68,12 +68,14 @@ void mexFunction(
if (nrhs == 6) L2 = (float) mxGetScalar(prhs[5]); /* Lipshitz constant */
/*Handling Matlab output data*/
- dimX = dim_array[0]; dimY = dim_array[1];
+ dimX = dim_array[0]; dimY = dim_array[1]; dimZ = dim_array[2];
if (number_of_dims == 2) {
- Output = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
- /* running the function */
- TGV_GPU_main(Input, Output, lambda, alpha1, alpha0, iter, L2, dimX, dimY);
+ dimZ = 1; /*2D case*/
+ Output = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
}
- if (number_of_dims == 3) {mexErrMsgTxt("Only 2D images accepted");}
+ if (number_of_dims == 3) Output = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+
+ /* running the function */
+ TGV_GPU_main(Input, Output, lambda, alpha1, alpha0, iter, L2, dimX, dimY, dimZ);
}