summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2019-04-11 15:43:31 +0100
committerEdoardo Pasca <edo.paskino@gmail.com>2019-04-11 15:43:31 +0100
commit1dec48d390df5cbc3436832cedf559b64f4651bc (patch)
tree58ebafbd3e505b0c5201aa335f3c5b2cb898ee04
parenta7bb88da8e8d4e94a3dbeb04f95928cb7d1fbd48 (diff)
downloadframework-1dec48d390df5cbc3436832cedf559b64f4651bc.tar.gz
framework-1dec48d390df5cbc3436832cedf559b64f4651bc.tar.bz2
framework-1dec48d390df5cbc3436832cedf559b64f4651bc.tar.xz
framework-1dec48d390df5cbc3436832cedf559b64f4651bc.zip
fix out call
-rwxr-xr-xWrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py58
1 files changed, 34 insertions, 24 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py b/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py
index ed1d5e5..c6b6e95 100755
--- a/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py
+++ b/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py
@@ -21,6 +21,7 @@ import numpy as np
from ccpi.optimisation.functions import Function, ScaledFunction
from ccpi.framework import DataContainer, ImageData, \
ImageGeometry, BlockDataContainer
+import functools
############################ mixed_L1,2NORM FUNCTIONS #####################
class MixedL21Norm(Function):
@@ -36,7 +37,9 @@ class MixedL21Norm(Function):
:param: x is a BlockDataContainer
- '''
+ '''
+ if not isinstance(x, BlockDataContainer):
+ raise ValueError('__call__ expected BlockDataContainer, got {}'.format(type(x)))
if self.SymTensor:
param = [1]*x.shape[0]
@@ -73,7 +76,6 @@ class MixedL21Norm(Function):
different form L2NormSquared which acts on DC
'''
-
pass
def proximal_conjugate(self, x, tau, out=None):
@@ -88,29 +90,37 @@ class MixedL21Norm(Function):
return res
else:
-# pass
+ if out is None:
+ tmp = [ el*el for el in x.containers]
+ res = sum(tmp).sqrt().maximum(1.0)
+ frac = [el/res for el in x.containers]
+ res = BlockDataContainer(*frac)
+ return res
+ else:
+ res1 = functools.reduce(lambda a,b: a + b*b, x.containers, x.get_item(0) * 0 )
+ res = res1.sqrt().maximum(1.0)
+
+ if False:
+ # works but not memory efficient as allocating a new BlockDataContainer
+ a = x / res
+ out.fill(a)
+ elif False:
+ # this leads to error
+# File "ccpi\framework\BlockDataContainer.py", line 142, in divide
+# return type(self)(*[ el.divide(other, *args, **kwargs) for el in self.containers], shape=self.shape)
+# File "ccpi\framework\BlockDataContainer.py", line 142, in <listcomp>
+# return type(self)(*[ el.divide(other, *args, **kwargs) for el in self.containers], shape=self.shape)
+# File "ccpi\framework\framework.py", line 814, in divide
+# return self.pixel_wise_binary(numpy.divide, other, *args, **kwargs)
+# File "ccpi\framework\framework.py", line 802, in pixel_wise_binary
+# raise ValueError (message(type(self), "incompatible class:" , pwop.__name__, type(out)))
+# ValueError: ImageData: incompatible class: true_divide <class 'ccpi.framework.BlockDataContainer.BlockDataContainer'>
+ x.divide(res, out=out)
+ else:
+ for i,el in enumerate(x.containers):
+ #a = out.get_item(i)
+ el.divide(res, out=out.get_item(i))
-
-# # tmp2 = np.sqrt(x.as_array()[0]**2 + x.as_array()[1]**2 + 2*x.as_array()[2]**2)/self.alpha
-# # res = x.divide(ImageData(tmp2).maximum(1.0))
-# if out is None:
-
- tmp = [ el*el for el in x]
- res = (sum(tmp).sqrt()).maximum(1.0)
- frac = [x[i]/res for i in range(x.shape[0])]
- res = BlockDataContainer(*frac)
-
- return res
- # else:
- # tmp = [ el*el for el in x]
- # res = (sum(tmp).sqrt()).maximum(1.0)
- # #frac = [x[i]/res for i in range(x.shape[0])]
- # for i in range(x.shape[0]):
- # a = out.get_item(i)
- # b = x.get_item(i)
- # b /= res
- # a.fill( b )
-
def __rmul__(self, scalar):
return ScaledFunction(self, scalar)