diff options
-rw-r--r-- | Wrappers/Python/src/cpu_regularisers.pyx | 15 |
1 files changed, 9 insertions, 6 deletions
diff --git a/Wrappers/Python/src/cpu_regularisers.pyx b/Wrappers/Python/src/cpu_regularisers.pyx index bb55df5..bdb1eff 100644 --- a/Wrappers/Python/src/cpu_regularisers.pyx +++ b/Wrappers/Python/src/cpu_regularisers.pyx @@ -449,15 +449,17 @@ def NVM_INP_2D(np.ndarray[np.float32_t, ndim=2, mode="c"] inputData, #****************************************************************# #***************Calculation of TV-energy functional**************# #****************************************************************# -def TV_ENERGY(inputData, regularisation_parameter, typeFunctional): +def TV_ENERGY(inputData, inputData0, regularisation_parameter, typeFunctional): if inputData.ndim == 2: - return TV_ENERGY_2D(inputData, regularisation_parameter, typeFunctional) + return TV_ENERGY_2D(inputData, inputData0, regularisation_parameter, typeFunctional) elif inputData.ndim == 3: - return TV_ENERGY_3D(inputData, regularisation_parameter, typeFunctional) + return TV_ENERGY_3D(inputData, inputData0, regularisation_parameter, typeFunctional) def TV_ENERGY_2D(np.ndarray[np.float32_t, ndim=2, mode="c"] inputData, + np.ndarray[np.float32_t, ndim=2, mode="c"] inputData0, float regularisation_parameter, int typeFunctional): + cdef long dims[2] dims[0] = inputData.shape[0] dims[1] = inputData.shape[1] @@ -466,11 +468,12 @@ def TV_ENERGY_2D(np.ndarray[np.float32_t, ndim=2, mode="c"] inputData, np.zeros([1], dtype='float32') # run function - TV_energy2D(&inputData[0,0], &outputData[0], regularisation_parameter, typeFunctional, dims[1], dims[0]) + TV_energy2D(&inputData[0,0], &inputData0[0,0], &outputData[0], regularisation_parameter, typeFunctional, dims[1], dims[0]) return outputData -def TV_ENERGY_3D(np.ndarray[np.float32_t, ndim=3, mode="c"] inputData, +def TV_ENERGY_3D(np.ndarray[np.float32_t, ndim=3, mode="c"] inputData, + np.ndarray[np.float32_t, ndim=3, mode="c"] inputData0, float regularisation_parameter, int typeFunctional): @@ -483,6 +486,6 @@ def TV_ENERGY_3D(np.ndarray[np.float32_t, ndim=3, mode="c"] inputData, np.zeros([1], dtype='float32') # Run function - TV_energy3D(&inputData[0,0,0], &outputData[0], regularisation_parameter, typeFunctional, dims[2], dims[1], dims[0]) + TV_energy3D(&inputData[0,0,0], &inputData0[0,0,0], &outputData[0], regularisation_parameter, typeFunctional, dims[2], dims[1], dims[0]) return outputData |