summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2019-04-11 15:45:56 +0100
committerEdoardo Pasca <edo.paskino@gmail.com>2019-04-11 15:45:56 +0100
commit350a889c38805dcda98a299315af1ab64510fa5b (patch)
treec7aadfcc77e18b95b130087083db71332dc69a33
parente80f8c108871245f06dc3e570502e95a4acba64b (diff)
downloadframework-350a889c38805dcda98a299315af1ab64510fa5b.tar.gz
framework-350a889c38805dcda98a299315af1ab64510fa5b.tar.bz2
framework-350a889c38805dcda98a299315af1ab64510fa5b.tar.xz
framework-350a889c38805dcda98a299315af1ab64510fa5b.zip
add test for mixed L21 Norm
-rw-r--r--Wrappers/Python/test/test_functions.py70
1 files changed, 66 insertions, 4 deletions
diff --git a/Wrappers/Python/test/test_functions.py b/Wrappers/Python/test/test_functions.py
index 1891afd..bc1f034 100644
--- a/Wrappers/Python/test/test_functions.py
+++ b/Wrappers/Python/test/test_functions.py
@@ -20,7 +20,7 @@ from ccpi.optimisation.operators import Gradient
#from ccpi.optimisation.functions import SimpleL2NormSq
from ccpi.optimisation.functions import L2NormSquared
#from ccpi.optimisation.functions import SimpleL1Norm
-from ccpi.optimisation.functions import L1Norm
+from ccpi.optimisation.functions import L1Norm, MixedL21Norm
from ccpi.optimisation.funcs import Norm2sq
# from ccpi.optimisation.functions.L2NormSquared import SimpleL2NormSq, L2NormSq
@@ -29,13 +29,46 @@ from ccpi.optimisation.funcs import Norm2sq
from ccpi.optimisation.functions import ZeroFun
from ccpi.optimisation.functions import FunctionOperatorComposition
-import unittest
-import numpy
+import unittest
+import numpy
#
class TestFunction(unittest.TestCase):
+ def assertBlockDataContainerEqual(self, container1, container2):
+ print ("assert Block Data Container Equal")
+ self.assertTrue(issubclass(container1.__class__, container2.__class__))
+ for col in range(container1.shape[0]):
+ if issubclass(container1.get_item(col).__class__, DataContainer):
+ print ("Checking col ", col)
+ self.assertNumpyArrayEqual(
+ container1.get_item(col).as_array(),
+ container2.get_item(col).as_array()
+ )
+ else:
+ self.assertBlockDataContainerEqual(container1.get_item(col),container2.get_item(col))
+
+ def assertNumpyArrayEqual(self, first, second):
+ res = True
+ try:
+ numpy.testing.assert_array_equal(first, second)
+ except AssertionError as err:
+ res = False
+ print(err)
+ self.assertTrue(res)
+
+ def assertNumpyArrayAlmostEqual(self, first, second, decimal=6):
+ res = True
+ try:
+ numpy.testing.assert_array_almost_equal(first, second, decimal)
+ except AssertionError as err:
+ res = False
+ print(err)
+ print("expected " , second)
+ print("actual " , first)
+
+ self.assertTrue(res)
def test_Function(self):
@@ -280,8 +313,37 @@ class TestFunction(unittest.TestCase):
ynew = new_chisq.gradient(u)
numpy.testing.assert_array_equal(yold.as_array(), ynew.as_array())
+ def test_mixedL12Norm(self):
+ M, N, K = 2,3,5
+ ig = ImageGeometry(voxel_num_x=M, voxel_num_y = N)
+ u1 = ig.allocate('random_int')
+ u2 = ig.allocate('random_int')
+
+ U = BlockDataContainer(u1, u2, shape=(2,1))
+
+ # Define no scale and scaled
+ f_no_scaled = MixedL21Norm()
+ #f_scaled = 0.5 * MixedL21Norm()
+
+ # call
+
+ # a1 = f_no_scaled(U)
+ # a2 = f_scaled(U)
+ # self.assertBlockDataContainerEqual(a1,a2)
+ tmp = [ el**2 for el in U.containers ]
+ self.assertBlockDataContainerEqual(BlockDataContainer(*tmp),
+ U.power(2))
+
+ z1 = f_no_scaled.proximal_conjugate(U, 1)
+ u3 = ig.allocate('random_int')
+ u4 = ig.allocate('random_int')
+
+ z3 = BlockDataContainer(u3, u4, shape=(2,1))
+
+
+ f_no_scaled.proximal_conjugate(U, 1, out=z3)
+ self.assertBlockDataContainerEqual(z3,z1)
-
#
# f1 = L2NormSq(alpha=1, b=noisy_data)
# print(f1(noisy_data))