diff options
author | Edoardo Pasca <edo.paskino@gmail.com> | 2019-03-11 13:31:54 -0400 |
---|---|---|
committer | Edoardo Pasca <edo.paskino@gmail.com> | 2019-03-11 13:31:54 -0400 |
commit | 78d97a226ede52ccd7386a8bf4097c9f83f6c4a6 (patch) | |
tree | dec941534dc2303d702efcd624607b707efcc394 /Wrappers/Python | |
parent | cbb6e2ce2baa3a9c18f1d8ad537f1498348f827d (diff) | |
download | framework-78d97a226ede52ccd7386a8bf4097c9f83f6c4a6.tar.gz framework-78d97a226ede52ccd7386a8bf4097c9f83f6c4a6.tar.bz2 framework-78d97a226ede52ccd7386a8bf4097c9f83f6c4a6.tar.xz framework-78d97a226ede52ccd7386a8bf4097c9f83f6c4a6.zip |
deprecate grad and prox
Norm2sq fixes for memopt
Diffstat (limited to 'Wrappers/Python')
-rwxr-xr-x | Wrappers/Python/ccpi/optimisation/funcs.py | 45 |
1 files changed, 29 insertions, 16 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/funcs.py b/Wrappers/Python/ccpi/optimisation/funcs.py index 99af275..4f84889 100755 --- a/Wrappers/Python/ccpi/optimisation/funcs.py +++ b/Wrappers/Python/ccpi/optimisation/funcs.py @@ -20,6 +20,7 @@ from ccpi.optimisation.ops import Identity, FiniteDiff2D import numpy from ccpi.framework import DataContainer +import warnings def isSizeCorrect(data1 ,data2): @@ -40,8 +41,12 @@ class Function(object): def __init__(self): self.L = None def __call__(self,x, out=None): raise NotImplementedError - def grad(self, x): raise NotImplementedError - def prox(self, x, tau): raise NotImplementedError + def grad(self, x): + warnings.warn("grad method is deprecated. use gradient instead", DeprecationWarning) + return self.gradient(x, out=None) + def prox(self, x, tau): + warnings.warn("prox method is deprecated. use proximal instead", DeprecationWarning) + return self.proximal(x,tau,out=None) def gradient(self, x, out=None): raise NotImplementedError def proximal(self, x, tau, out=None): raise NotImplementedError @@ -141,12 +146,20 @@ class Norm2sq(Function): self.A = A # Should be an operator, default identity self.b = b # Default zero DataSet? self.c = c # Default 1. - self.memopt = memopt if memopt: - #self.direct_placehold = A.adjoint(b) - self.direct_placehold = A.allocate_direct() - self.adjoint_placehold = A.allocate_adjoint() - + try: + self.adjoint_placehold = A.range_geometry().allocate() + self.direct_placehold = A.domain_geometry().allocate() + self.memopt = True + except NameError as ne: + warnings.warn(str(ne)) + self.memopt = False + except NotImplementedError as nie: + print (nie) + warnings.warn(str(nie)) + self.memopt = False + else: + self.memopt = False # Compute the Lipschitz parameter from the operator if possible # Leave it initialised to None otherwise @@ -157,10 +170,9 @@ class Norm2sq(Function): except NotImplementedError as noe: pass - def grad(self,x): - #return 2*self.c*self.A.adjoint( self.A.direct(x) - self.b ) - return (2.0*self.c)*self.A.adjoint( self.A.direct(x) - self.b ) - + #def grad(self,x): + # return self.gradient(x, out=None) + def __call__(self,x): #return self.c* np.sum(np.square((self.A.direct(x) - self.b).ravel())) #if out is None: @@ -178,12 +190,13 @@ class Norm2sq(Function): self.A.direct(x, out=self.adjoint_placehold) self.adjoint_placehold.__isub__( self.b ) self.A.adjoint(self.adjoint_placehold, out=self.direct_placehold) - self.direct_placehold.__imul__(2.0 * self.c) - # can this be avoided? - out.fill(self.direct_placehold) + #self.direct_placehold.__imul__(2.0 * self.c) + ## can this be avoided? + #out.fill(self.direct_placehold) + self.direct_placehold.multiply(2.0*self.c, out=out) else: - return self.grad(x) - + return (2.0*self.c)*self.A.adjoint( self.A.direct(x) - self.b ) + class ZeroFun(Function): |