diff options
author | Edoardo Pasca <edo.paskino@gmail.com> | 2019-04-11 15:45:56 +0100 |
---|---|---|
committer | Edoardo Pasca <edo.paskino@gmail.com> | 2019-04-11 15:45:56 +0100 |
commit | 350a889c38805dcda98a299315af1ab64510fa5b (patch) | |
tree | c7aadfcc77e18b95b130087083db71332dc69a33 | |
parent | e80f8c108871245f06dc3e570502e95a4acba64b (diff) | |
download | framework-350a889c38805dcda98a299315af1ab64510fa5b.tar.gz framework-350a889c38805dcda98a299315af1ab64510fa5b.tar.bz2 framework-350a889c38805dcda98a299315af1ab64510fa5b.tar.xz framework-350a889c38805dcda98a299315af1ab64510fa5b.zip |
add test for mixed L21 Norm
-rw-r--r-- | Wrappers/Python/test/test_functions.py | 70 |
1 files changed, 66 insertions, 4 deletions
diff --git a/Wrappers/Python/test/test_functions.py b/Wrappers/Python/test/test_functions.py index 1891afd..bc1f034 100644 --- a/Wrappers/Python/test/test_functions.py +++ b/Wrappers/Python/test/test_functions.py @@ -20,7 +20,7 @@ from ccpi.optimisation.operators import Gradient #from ccpi.optimisation.functions import SimpleL2NormSq from ccpi.optimisation.functions import L2NormSquared #from ccpi.optimisation.functions import SimpleL1Norm -from ccpi.optimisation.functions import L1Norm +from ccpi.optimisation.functions import L1Norm, MixedL21Norm from ccpi.optimisation.funcs import Norm2sq # from ccpi.optimisation.functions.L2NormSquared import SimpleL2NormSq, L2NormSq @@ -29,13 +29,46 @@ from ccpi.optimisation.funcs import Norm2sq from ccpi.optimisation.functions import ZeroFun from ccpi.optimisation.functions import FunctionOperatorComposition -import unittest -import numpy +import unittest +import numpy # class TestFunction(unittest.TestCase): + def assertBlockDataContainerEqual(self, container1, container2): + print ("assert Block Data Container Equal") + self.assertTrue(issubclass(container1.__class__, container2.__class__)) + for col in range(container1.shape[0]): + if issubclass(container1.get_item(col).__class__, DataContainer): + print ("Checking col ", col) + self.assertNumpyArrayEqual( + container1.get_item(col).as_array(), + container2.get_item(col).as_array() + ) + else: + self.assertBlockDataContainerEqual(container1.get_item(col),container2.get_item(col)) + + def assertNumpyArrayEqual(self, first, second): + res = True + try: + numpy.testing.assert_array_equal(first, second) + except AssertionError as err: + res = False + print(err) + self.assertTrue(res) + + def assertNumpyArrayAlmostEqual(self, first, second, decimal=6): + res = True + try: + numpy.testing.assert_array_almost_equal(first, second, decimal) + except AssertionError as err: + res = False + print(err) + print("expected " , second) + print("actual " , first) + + self.assertTrue(res) def test_Function(self): @@ -280,8 +313,37 @@ class TestFunction(unittest.TestCase): ynew = new_chisq.gradient(u) numpy.testing.assert_array_equal(yold.as_array(), ynew.as_array()) + def test_mixedL12Norm(self): + M, N, K = 2,3,5 + ig = ImageGeometry(voxel_num_x=M, voxel_num_y = N) + u1 = ig.allocate('random_int') + u2 = ig.allocate('random_int') + + U = BlockDataContainer(u1, u2, shape=(2,1)) + + # Define no scale and scaled + f_no_scaled = MixedL21Norm() + #f_scaled = 0.5 * MixedL21Norm() + + # call + + # a1 = f_no_scaled(U) + # a2 = f_scaled(U) + # self.assertBlockDataContainerEqual(a1,a2) + tmp = [ el**2 for el in U.containers ] + self.assertBlockDataContainerEqual(BlockDataContainer(*tmp), + U.power(2)) + + z1 = f_no_scaled.proximal_conjugate(U, 1) + u3 = ig.allocate('random_int') + u4 = ig.allocate('random_int') + + z3 = BlockDataContainer(u3, u4, shape=(2,1)) + + + f_no_scaled.proximal_conjugate(U, 1, out=z3) + self.assertBlockDataContainerEqual(z3,z1) - # # f1 = L2NormSq(alpha=1, b=noisy_data) # print(f1(noisy_data)) |