summaryrefslogtreecommitdiffstats
path: root/Wrappers
diff options
context:
space:
mode:
Diffstat (limited to 'Wrappers')
-rw-r--r--Wrappers/Python/demos/PDHG_examples/Tomo/PDHG_TGV_Tomo2D.py82
1 files changed, 51 insertions, 31 deletions
diff --git a/Wrappers/Python/demos/PDHG_examples/Tomo/PDHG_TGV_Tomo2D.py b/Wrappers/Python/demos/PDHG_examples/Tomo/PDHG_TGV_Tomo2D.py
index 422deb9..b2af753 100644
--- a/Wrappers/Python/demos/PDHG_examples/Tomo/PDHG_TGV_Tomo2D.py
+++ b/Wrappers/Python/demos/PDHG_examples/Tomo/PDHG_TGV_Tomo2D.py
@@ -23,18 +23,20 @@
Total Generalised Variation (TGV) Tomography 2D using PDHG algorithm:
-Problem: min_{x>0} \alpha * ||\nabla x - w||_{2,1} +
- \beta * || E w ||_{2,1} +
- int A x - g log(Ax + \eta)
+Problem: min_{x>0} \alpha * ||\nabla x - w||_{2,1} + \beta * || E w ||_{2,1} +
+ \frac{1}{2}||Au - g||^{2}
+
+ min_{u>0} \alpha * ||\nabla u - w||_{2,1} + \beta * || E w ||_{2,1} +
+ int A u - g log(Au + \eta)
\alpha: Regularization parameter
\beta: Regularization parameter
\nabla: Gradient operator
E: Symmetrized Gradient operator
- A: Projection Matrix
+ A: System Matrix
- g: Noisy Data with Poisson Noise
+ g: Noisy Sinogram
K = [ \nabla, - Identity
ZeroOperator, E
@@ -58,9 +60,12 @@ from ccpi.optimisation.functions import IndicatorBox, KullbackLeibler, ZeroFunct
from ccpi.astra.ops import AstraProjectorSimple
from ccpi.framework import TestData
import os, sys
-from skimage.util import random_noise
+if int(numpy.version.version.split('.')[1]) > 12:
+ from skimage.util import random_noise
+else:
+ from demoutil import random_noise
-loader = TestData(data_dir=os.path.join(sys.prefix, 'share','ccpi'))
+#loader = TestData(data_dir=os.path.join(sys.prefix, 'share','ccpi'))
# Load Data
#N = 50
@@ -102,28 +107,31 @@ else:
Aop = AstraProjectorSimple(ig, ag, 'cpu')
sin = Aop.direct(data)
-# Create noisy data. Apply Poisson noise
-scale = 0.5
-eta = 0
-n1 = scale * np.random.poisson(eta + sin.as_array()/scale)
-
-noisy_data = AcquisitionData(n1, ag)
+# Create noisy data.
+noises = ['gaussian', 'poisson']
+noise = noises[which_noise]
+
+if noise == 'poisson':
+ scale = 5
+ eta = 0
+ noisy_data = AcquisitionData(np.random.poisson( scale * (eta + sin.as_array()))/scale, ag)
+elif noise == 'gaussian':
+ n1 = np.random.normal(0, 1, size = ag.shape)
+ noisy_data = AcquisitionData(n1 + sin.as_array(), ag)
+else:
+ raise ValueError('Unsupported Noise ', noise)
# Show Ground Truth and Noisy Data
plt.figure(figsize=(10,10))
-plt.subplot(2,1,1)
+plt.subplot(1,2,2)
plt.imshow(data.as_array())
plt.title('Ground Truth')
plt.colorbar()
-plt.subplot(2,1,2)
+plt.subplot(1,2,1)
plt.imshow(noisy_data.as_array())
plt.title('Noisy Data')
plt.colorbar()
plt.show()
-#%%
-# Regularisation Parameters
-alpha = 1
-beta = 5
# Create Operators
op11 = Gradient(ig)
@@ -136,28 +144,41 @@ op31 = Aop
op32 = ZeroOperator(op22.domain_geometry(), ag)
operator = BlockOperator(op11, -1*op12, op21, op22, op31, op32, shape=(3,2) )
+
+# Create functions
+if noise == 'poisson':
+ alpha = 1
+ beta = 5
+ f3 = KullbackLeibler(noisy_data)
+ g = BlockFunction(IndicatorBox(lower=0), ZeroFunction())
+
+ # Primal & dual stepsizes
+ sigma = 1
+ tau = 1/(sigma*normK**2)
+
+elif noise == 'gaussian':
+ alpha = 20
+ f3 = 0.5 * L2NormSquared(b=noisy_data)
+ g = BlockFunction(ZeroFunction(), ZeroFunction())
+
+ # Primal & dual stepsizes
+ sigma = 10
+ tau = 1/(sigma*normK**2)
f1 = alpha * MixedL21Norm()
-f2 = beta * MixedL21Norm()
-f3 = KullbackLeibler(noisy_data)
+f2 = beta * MixedL21Norm()
f = BlockFunction(f1, f2, f3)
-
-g = BlockFunction(IndicatorBox(lower=0), ZeroFunction())
# Compute operator Norm
normK = operator.norm()
-# Primal & dual stepsizes
-sigma = 1
-tau = 1/(sigma*normK**2)
-
-
# Setup and run the PDHG algorithm
pdhg = PDHG(f=f,g=g,operator=operator, tau=tau, sigma=sigma)
pdhg.max_iteration = 3000
pdhg.update_objective_interval = 500
pdhg.run(3000)
+#%%
plt.figure(figsize=(15,15))
plt.subplot(3,1,1)
plt.imshow(data.as_array())
@@ -172,9 +193,8 @@ plt.imshow(pdhg.get_output()[0].as_array())
plt.title('TGV Reconstruction')
plt.colorbar()
plt.show()
-##
-plt.plot(np.linspace(0,N,N), data.as_array()[int(N/2),:], label = 'GTruth')
-plt.plot(np.linspace(0,N,N), pdhg.get_output()[0].as_array()[int(N/2),:], label = 'TGV reconstruction')
+plt.plot(np.linspace(0,ig.shape[1],ig.shape[1]), data.as_array()[int(N/2),:], label = 'GTruth')
+plt.plot(np.linspace(0,ig.shape[1],ig.shape[1]), pdhg.get_output()[0].as_array()[int(N/2),:], label = 'TGV reconstruction')
plt.legend()
plt.title('Middle Line Profiles')
plt.show()