diff options
author | Edoardo Pasca <edo.paskino@gmail.com> | 2019-04-11 15:43:31 +0100 |
---|---|---|
committer | Edoardo Pasca <edo.paskino@gmail.com> | 2019-04-11 15:43:31 +0100 |
commit | 1dec48d390df5cbc3436832cedf559b64f4651bc (patch) | |
tree | 58ebafbd3e505b0c5201aa335f3c5b2cb898ee04 | |
parent | a7bb88da8e8d4e94a3dbeb04f95928cb7d1fbd48 (diff) | |
download | framework-1dec48d390df5cbc3436832cedf559b64f4651bc.tar.gz framework-1dec48d390df5cbc3436832cedf559b64f4651bc.tar.bz2 framework-1dec48d390df5cbc3436832cedf559b64f4651bc.tar.xz framework-1dec48d390df5cbc3436832cedf559b64f4651bc.zip |
fix out call
-rwxr-xr-x | Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py | 58 |
1 files changed, 34 insertions, 24 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py b/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py index ed1d5e5..c6b6e95 100755 --- a/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py +++ b/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py @@ -21,6 +21,7 @@ import numpy as np from ccpi.optimisation.functions import Function, ScaledFunction from ccpi.framework import DataContainer, ImageData, \ ImageGeometry, BlockDataContainer +import functools ############################ mixed_L1,2NORM FUNCTIONS ##################### class MixedL21Norm(Function): @@ -36,7 +37,9 @@ class MixedL21Norm(Function): :param: x is a BlockDataContainer - ''' + ''' + if not isinstance(x, BlockDataContainer): + raise ValueError('__call__ expected BlockDataContainer, got {}'.format(type(x))) if self.SymTensor: param = [1]*x.shape[0] @@ -73,7 +76,6 @@ class MixedL21Norm(Function): different form L2NormSquared which acts on DC ''' - pass def proximal_conjugate(self, x, tau, out=None): @@ -88,29 +90,37 @@ class MixedL21Norm(Function): return res else: -# pass + if out is None: + tmp = [ el*el for el in x.containers] + res = sum(tmp).sqrt().maximum(1.0) + frac = [el/res for el in x.containers] + res = BlockDataContainer(*frac) + return res + else: + res1 = functools.reduce(lambda a,b: a + b*b, x.containers, x.get_item(0) * 0 ) + res = res1.sqrt().maximum(1.0) + + if False: + # works but not memory efficient as allocating a new BlockDataContainer + a = x / res + out.fill(a) + elif False: + # this leads to error +# File "ccpi\framework\BlockDataContainer.py", line 142, in divide +# return type(self)(*[ el.divide(other, *args, **kwargs) for el in self.containers], shape=self.shape) +# File "ccpi\framework\BlockDataContainer.py", line 142, in <listcomp> +# return type(self)(*[ el.divide(other, *args, **kwargs) for el in self.containers], shape=self.shape) +# File "ccpi\framework\framework.py", line 814, in divide +# return self.pixel_wise_binary(numpy.divide, other, *args, **kwargs) +# File "ccpi\framework\framework.py", line 802, in pixel_wise_binary +# raise ValueError (message(type(self), "incompatible class:" , pwop.__name__, type(out))) +# ValueError: ImageData: incompatible class: true_divide <class 'ccpi.framework.BlockDataContainer.BlockDataContainer'> + x.divide(res, out=out) + else: + for i,el in enumerate(x.containers): + #a = out.get_item(i) + el.divide(res, out=out.get_item(i)) - -# # tmp2 = np.sqrt(x.as_array()[0]**2 + x.as_array()[1]**2 + 2*x.as_array()[2]**2)/self.alpha -# # res = x.divide(ImageData(tmp2).maximum(1.0)) -# if out is None: - - tmp = [ el*el for el in x] - res = (sum(tmp).sqrt()).maximum(1.0) - frac = [x[i]/res for i in range(x.shape[0])] - res = BlockDataContainer(*frac) - - return res - # else: - # tmp = [ el*el for el in x] - # res = (sum(tmp).sqrt()).maximum(1.0) - # #frac = [x[i]/res for i in range(x.shape[0])] - # for i in range(x.shape[0]): - # a = out.get_item(i) - # b = x.get_item(i) - # b /= res - # a.fill( b ) - def __rmul__(self, scalar): return ScaledFunction(self, scalar) |