summaryrefslogtreecommitdiffstats
path: root/Wrappers
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2019-04-01 16:49:59 +0100
committerEdoardo Pasca <edo.paskino@gmail.com>2019-04-01 16:49:59 +0100
commitcf36fb59af5806506a6b7b75edb7a5f7bebb8070 (patch)
tree90471752421f27834048144d2c07d360552fed13 /Wrappers
parent12ccc249a722a64c02d97e8e1513c065d4a7bf48 (diff)
downloadframework-cf36fb59af5806506a6b7b75edb7a5f7bebb8070.tar.gz
framework-cf36fb59af5806506a6b7b75edb7a5f7bebb8070.tar.bz2
framework-cf36fb59af5806506a6b7b75edb7a5f7bebb8070.tar.xz
framework-cf36fb59af5806506a6b7b75edb7a5f7bebb8070.zip
jenkins errors
Diffstat (limited to 'Wrappers')
-rw-r--r--Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py4
-rwxr-xr-xWrappers/Python/ccpi/optimisation/algs.py4
-rwxr-xr-xWrappers/Python/ccpi/optimisation/funcs.py124
3 files changed, 127 insertions, 5 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
index fb2bfd8..043fe38 100644
--- a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
@@ -6,11 +6,9 @@ Created on Mon Feb 4 16:18:06 2019
@author: evangelos
"""
from ccpi.optimisation.algorithms import Algorithm
-
-
from ccpi.framework import ImageData
import numpy as np
-import matplotlib.pyplot as plt
+#import matplotlib.pyplot as plt
import time
from ccpi.optimisation.operators import BlockOperator
from ccpi.framework import BlockDataContainer
diff --git a/Wrappers/Python/ccpi/optimisation/algs.py b/Wrappers/Python/ccpi/optimisation/algs.py
index 15638a9..6b6ae2c 100755
--- a/Wrappers/Python/ccpi/optimisation/algs.py
+++ b/Wrappers/Python/ccpi/optimisation/algs.py
@@ -20,8 +20,8 @@
import numpy
import time
-from ccpi.optimisation.funcs import Function
-from ccpi.optimisation.funcs import ZeroFun
+from ccpi.optimisation.functions import Function
+from ccpi.optimisation.functions import ZeroFun
from ccpi.framework import ImageData
from ccpi.framework import AcquisitionData
from ccpi.optimisation.spdhg import spdhg
diff --git a/Wrappers/Python/ccpi/optimisation/funcs.py b/Wrappers/Python/ccpi/optimisation/funcs.py
index 6741020..efc465c 100755
--- a/Wrappers/Python/ccpi/optimisation/funcs.py
+++ b/Wrappers/Python/ccpi/optimisation/funcs.py
@@ -22,7 +22,90 @@ import numpy
from ccpi.framework import DataContainer
import warnings
from ccpi.optimisation.functions import Function
+def isSizeCorrect(data1 ,data2):
+ if issubclass(type(data1), DataContainer) and \
+ issubclass(type(data2), DataContainer):
+ # check dimensionality
+ if data1.check_dimensions(data2):
+ return True
+ elif issubclass(type(data1) , numpy.ndarray) and \
+ issubclass(type(data2) , numpy.ndarray):
+ return data1.shape == data2.shape
+ else:
+ raise ValueError("{0}: getting two incompatible types: {1} {2}"\
+ .format('Function', type(data1), type(data2)))
+ return False
+class Norm2(Function):
+
+ def __init__(self,
+ gamma=1.0,
+ direction=None):
+ super(Norm2, self).__init__()
+ self.gamma = gamma;
+ self.direction = direction;
+
+ def __call__(self, x, out=None):
+
+ if out is None:
+ xx = numpy.sqrt(numpy.sum(numpy.square(x.as_array()), self.direction,
+ keepdims=True))
+ else:
+ if isSizeCorrect(out, x):
+ # check dimensionality
+ if issubclass(type(out), DataContainer):
+ arr = out.as_array()
+ numpy.square(x.as_array(), out=arr)
+ xx = numpy.sqrt(numpy.sum(arr, self.direction, keepdims=True))
+
+ elif issubclass(type(out) , numpy.ndarray):
+ numpy.square(x.as_array(), out=out)
+ xx = numpy.sqrt(numpy.sum(out, self.direction, keepdims=True))
+ else:
+ raise ValueError ('Wrong size: x{0} out{1}'.format(x.shape,out.shape) )
+
+ p = numpy.sum(self.gamma*xx)
+
+ return p
+
+ def prox(self, x, tau):
+
+ xx = numpy.sqrt(numpy.sum( numpy.square(x.as_array()), self.direction,
+ keepdims=True ))
+ xx = numpy.maximum(0, 1 - tau*self.gamma / xx)
+ p = x.as_array() * xx
+
+ return type(x)(p,geometry=x.geometry)
+ def proximal(self, x, tau, out=None):
+ if out is None:
+ return self.prox(x,tau)
+ else:
+ if isSizeCorrect(out, x):
+ # check dimensionality
+ if issubclass(type(out), DataContainer):
+ numpy.square(x.as_array(), out = out.as_array())
+ xx = numpy.sqrt(numpy.sum( out.as_array() , self.direction,
+ keepdims=True ))
+ xx = numpy.maximum(0, 1 - tau*self.gamma / xx)
+ x.multiply(xx, out= out.as_array())
+
+
+ elif issubclass(type(out) , numpy.ndarray):
+ numpy.square(x.as_array(), out=out)
+ xx = numpy.sqrt(numpy.sum(out, self.direction, keepdims=True))
+
+ xx = numpy.maximum(0, 1 - tau*self.gamma / xx)
+ x.multiply(xx, out= out)
+ else:
+ raise ValueError ('Wrong size: x{0} out{1}'.format(x.shape,out.shape) )
+
+class TV2D(Norm2):
+
+ def __init__(self, gamma):
+ super(TV2D,self).__init__(gamma, 0)
+ self.op = FiniteDiff2D()
+ self.L = self.op.get_max_sing_val()
+
# Define a class for squared 2-norm
class Norm2sq(Function):
@@ -146,3 +229,44 @@ class IndicatorBox(Function):
out.__imul__( self.sign_x )
+# A more interesting example, least squares plus 1-norm minimization.
+# Define class to represent 1-norm including prox function
+class Norm1(Function):
+
+ def __init__(self,gamma):
+ super(Norm1, self).__init__()
+ self.gamma = gamma
+ self.L = 1
+ self.sign_x = None
+
+ def __call__(self,x,out=None):
+ if out is None:
+ return self.gamma*(x.abs().sum())
+ else:
+ if not x.shape == out.shape:
+ raise ValueError('Norm1 Incompatible size:',
+ x.shape, out.shape)
+ x.abs(out=out)
+ return out.sum() * self.gamma
+
+ def prox(self,x,tau):
+ return (x.abs() - tau*self.gamma).maximum(0) * x.sign()
+
+ def proximal(self, x, tau, out=None):
+ if out is None:
+ return self.prox(x, tau)
+ else:
+ if isSizeCorrect(x,out):
+ # check dimensionality
+ if issubclass(type(out), DataContainer):
+ v = (x.abs() - tau*self.gamma).maximum(0)
+ x.sign(out=out)
+ out *= v
+ #out.fill(self.prox(x,tau))
+ elif issubclass(type(out) , numpy.ndarray):
+ v = (x.abs() - tau*self.gamma).maximum(0)
+ out[:] = x.sign()
+ out *= v
+ #out[:] = self.prox(x,tau)
+ else:
+ raise ValueError ('Wrong size: x{0} out{1}'.format(x.shape,out.shape) )