summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Wrappers/Python/src/cpu_regularisers.pyx15
1 files changed, 9 insertions, 6 deletions
diff --git a/Wrappers/Python/src/cpu_regularisers.pyx b/Wrappers/Python/src/cpu_regularisers.pyx
index bb55df5..bdb1eff 100644
--- a/Wrappers/Python/src/cpu_regularisers.pyx
+++ b/Wrappers/Python/src/cpu_regularisers.pyx
@@ -449,15 +449,17 @@ def NVM_INP_2D(np.ndarray[np.float32_t, ndim=2, mode="c"] inputData,
#****************************************************************#
#***************Calculation of TV-energy functional**************#
#****************************************************************#
-def TV_ENERGY(inputData, regularisation_parameter, typeFunctional):
+def TV_ENERGY(inputData, inputData0, regularisation_parameter, typeFunctional):
if inputData.ndim == 2:
- return TV_ENERGY_2D(inputData, regularisation_parameter, typeFunctional)
+ return TV_ENERGY_2D(inputData, inputData0, regularisation_parameter, typeFunctional)
elif inputData.ndim == 3:
- return TV_ENERGY_3D(inputData, regularisation_parameter, typeFunctional)
+ return TV_ENERGY_3D(inputData, inputData0, regularisation_parameter, typeFunctional)
def TV_ENERGY_2D(np.ndarray[np.float32_t, ndim=2, mode="c"] inputData,
+ np.ndarray[np.float32_t, ndim=2, mode="c"] inputData0,
float regularisation_parameter,
int typeFunctional):
+
cdef long dims[2]
dims[0] = inputData.shape[0]
dims[1] = inputData.shape[1]
@@ -466,11 +468,12 @@ def TV_ENERGY_2D(np.ndarray[np.float32_t, ndim=2, mode="c"] inputData,
np.zeros([1], dtype='float32')
# run function
- TV_energy2D(&inputData[0,0], &outputData[0], regularisation_parameter, typeFunctional, dims[1], dims[0])
+ TV_energy2D(&inputData[0,0], &inputData0[0,0], &outputData[0], regularisation_parameter, typeFunctional, dims[1], dims[0])
return outputData
-def TV_ENERGY_3D(np.ndarray[np.float32_t, ndim=3, mode="c"] inputData,
+def TV_ENERGY_3D(np.ndarray[np.float32_t, ndim=3, mode="c"] inputData,
+ np.ndarray[np.float32_t, ndim=3, mode="c"] inputData0,
float regularisation_parameter,
int typeFunctional):
@@ -483,6 +486,6 @@ def TV_ENERGY_3D(np.ndarray[np.float32_t, ndim=3, mode="c"] inputData,
np.zeros([1], dtype='float32')
# Run function
- TV_energy3D(&inputData[0,0,0], &outputData[0], regularisation_parameter, typeFunctional, dims[2], dims[1], dims[0])
+ TV_energy3D(&inputData[0,0,0], &inputData0[0,0,0], &outputData[0], regularisation_parameter, typeFunctional, dims[2], dims[1], dims[0])
return outputData