summaryrefslogtreecommitdiffstats
path: root/Wrappers/Python
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2019-03-11 13:31:54 -0400
committerEdoardo Pasca <edo.paskino@gmail.com>2019-03-11 13:31:54 -0400
commit78d97a226ede52ccd7386a8bf4097c9f83f6c4a6 (patch)
treedec941534dc2303d702efcd624607b707efcc394 /Wrappers/Python
parentcbb6e2ce2baa3a9c18f1d8ad537f1498348f827d (diff)
downloadframework-78d97a226ede52ccd7386a8bf4097c9f83f6c4a6.tar.gz
framework-78d97a226ede52ccd7386a8bf4097c9f83f6c4a6.tar.bz2
framework-78d97a226ede52ccd7386a8bf4097c9f83f6c4a6.tar.xz
framework-78d97a226ede52ccd7386a8bf4097c9f83f6c4a6.zip
deprecate grad and prox
Norm2sq fixes for memopt
Diffstat (limited to 'Wrappers/Python')
-rwxr-xr-xWrappers/Python/ccpi/optimisation/funcs.py45
1 files changed, 29 insertions, 16 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/funcs.py b/Wrappers/Python/ccpi/optimisation/funcs.py
index 99af275..4f84889 100755
--- a/Wrappers/Python/ccpi/optimisation/funcs.py
+++ b/Wrappers/Python/ccpi/optimisation/funcs.py
@@ -20,6 +20,7 @@
from ccpi.optimisation.ops import Identity, FiniteDiff2D
import numpy
from ccpi.framework import DataContainer
+import warnings
def isSizeCorrect(data1 ,data2):
@@ -40,8 +41,12 @@ class Function(object):
def __init__(self):
self.L = None
def __call__(self,x, out=None): raise NotImplementedError
- def grad(self, x): raise NotImplementedError
- def prox(self, x, tau): raise NotImplementedError
+ def grad(self, x):
+ warnings.warn("grad method is deprecated. use gradient instead", DeprecationWarning)
+ return self.gradient(x, out=None)
+ def prox(self, x, tau):
+ warnings.warn("prox method is deprecated. use proximal instead", DeprecationWarning)
+ return self.proximal(x,tau,out=None)
def gradient(self, x, out=None): raise NotImplementedError
def proximal(self, x, tau, out=None): raise NotImplementedError
@@ -141,12 +146,20 @@ class Norm2sq(Function):
self.A = A # Should be an operator, default identity
self.b = b # Default zero DataSet?
self.c = c # Default 1.
- self.memopt = memopt
if memopt:
- #self.direct_placehold = A.adjoint(b)
- self.direct_placehold = A.allocate_direct()
- self.adjoint_placehold = A.allocate_adjoint()
-
+ try:
+ self.adjoint_placehold = A.range_geometry().allocate()
+ self.direct_placehold = A.domain_geometry().allocate()
+ self.memopt = True
+ except NameError as ne:
+ warnings.warn(str(ne))
+ self.memopt = False
+ except NotImplementedError as nie:
+ print (nie)
+ warnings.warn(str(nie))
+ self.memopt = False
+ else:
+ self.memopt = False
# Compute the Lipschitz parameter from the operator if possible
# Leave it initialised to None otherwise
@@ -157,10 +170,9 @@ class Norm2sq(Function):
except NotImplementedError as noe:
pass
- def grad(self,x):
- #return 2*self.c*self.A.adjoint( self.A.direct(x) - self.b )
- return (2.0*self.c)*self.A.adjoint( self.A.direct(x) - self.b )
-
+ #def grad(self,x):
+ # return self.gradient(x, out=None)
+
def __call__(self,x):
#return self.c* np.sum(np.square((self.A.direct(x) - self.b).ravel()))
#if out is None:
@@ -178,12 +190,13 @@ class Norm2sq(Function):
self.A.direct(x, out=self.adjoint_placehold)
self.adjoint_placehold.__isub__( self.b )
self.A.adjoint(self.adjoint_placehold, out=self.direct_placehold)
- self.direct_placehold.__imul__(2.0 * self.c)
- # can this be avoided?
- out.fill(self.direct_placehold)
+ #self.direct_placehold.__imul__(2.0 * self.c)
+ ## can this be avoided?
+ #out.fill(self.direct_placehold)
+ self.direct_placehold.multiply(2.0*self.c, out=out)
else:
- return self.grad(x)
-
+ return (2.0*self.c)*self.A.adjoint( self.A.direct(x) - self.b )
+
class ZeroFun(Function):