From c65291e6b987283e4767a8ad2bd2d2433ca3782e Mon Sep 17 00:00:00 2001 From: Daniil Kazantsev Date: Thu, 28 Nov 2019 23:01:03 +0000 Subject: all work completed on gpu version of pdtv --- test/test_CPU_regularisers.py | 11 +++++- test/test_run_test.py | 88 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 1 deletion(-) (limited to 'test') diff --git a/test/test_CPU_regularisers.py b/test/test_CPU_regularisers.py index 95e2a3f..266ca8a 100644 --- a/test/test_CPU_regularisers.py +++ b/test/test_CPU_regularisers.py @@ -3,7 +3,7 @@ import unittest import os #import timeit import numpy as np -from ccpi.filters.regularisers import FGP_TV, SB_TV, TGV, LLT_ROF, FGP_dTV, NDF, Diff4th, ROF_TV +from ccpi.filters.regularisers import FGP_TV, SB_TV, TGV, LLT_ROF, FGP_dTV, NDF, Diff4th, ROF_TV, PD_TV from testroutines import BinReader, rmse ############################################################################### @@ -39,6 +39,15 @@ class TestRegularisers(unittest.TestCase): self.assertAlmostEqual(rms,0.02,delta=0.01) + def test_PD_TV_CPU(self): + Im,input,ref = self.getPars() + + pd_cpu,info = PD_TV(input, 0.02, 300, 0.0, 0, 0, 8, 0.0025, 'cpu'); + + rms = rmse(Im, pd_cpu) + + self.assertAlmostEqual(rms,0.02,delta=0.01) + def test_TV_ROF_CPU(self): # set parameters Im, input,ref = self.getPars() diff --git a/test/test_run_test.py b/test/test_run_test.py index e693e03..1707aec 100755 --- a/test/test_run_test.py +++ b/test/test_run_test.py @@ -164,6 +164,94 @@ class TestRegularisers(unittest.TestCase): self.assertLessEqual(diff_im.sum() , 1) + def test_PD_TV_CPU_vs_GPU(self): + print(__name__) + #filename = os.path.join("test","lena_gray_512.tif") + #plt = TiffReader() + filename = os.path.join("test","test_imageLena.bin") + plt = BinReader() + # read image + Im = plt.imread(filename) + Im = np.asarray(Im, dtype='float32') + + Im = Im/255 + perc = 0.05 + u0 = Im + np.random.normal(loc = 0 , + scale = perc * Im , + size = np.shape(Im)) + u_ref = Im + np.random.normal(loc = 0 , + scale = 0.01 * Im , + size = np.shape(Im)) + + # map the u0 u0->u0>0 + # f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1) + u0 = u0.astype('float32') + u_ref = u_ref.astype('float32') + + print ("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") + print ("____________PD-TV bench___________________") + print ("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") + + + pars = {'algorithm' : PD_TV, \ + 'input' : u0,\ + 'regularisation_parameter':0.02, \ + 'number_of_iterations' :1500 ,\ + 'tolerance_constant':0.0,\ + 'methodTV': 0 ,\ + 'nonneg': 0, + 'lipschitz_const' : 8, + 'tau' : 0.0025} + + print ("#############PD TV CPU####################") + start_time = timeit.default_timer() + (pd_cpu,info_vec_cpu) = PD_TV(pars['input'], + pars['regularisation_parameter'], + pars['number_of_iterations'], + pars['tolerance_constant'], + pars['methodTV'], + pars['nonneg'], + pars['lipschitz_const'], + pars['tau'],'cpu') + + rms = rmse(Im, pd_cpu) + pars['rmse'] = rms + + txtstr = printParametersToString(pars) + txtstr += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + print (txtstr) + + print ("##############PD TV GPU##################") + start_time = timeit.default_timer() + try: + (pd_gpu,info_vec_gpu) = PD_TV(pars['input'], + pars['regularisation_parameter'], + pars['number_of_iterations'], + pars['tolerance_constant'], + pars['methodTV'], + pars['nonneg'], + pars['lipschitz_const'], + pars['tau'],'gpu') + + except ValueError as ve: + self.skipTest("Results not comparable. GPU computing error.") + + rms = rmse(Im, pd_gpu) + pars['rmse'] = rms + pars['algorithm'] = PD_TV + txtstr = printParametersToString(pars) + txtstr += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + print (txtstr) + + print ("--------Compare the results--------") + tolerance = 1e-05 + diff_im = np.zeros(np.shape(pd_cpu)) + diff_im = abs(pd_cpu - pd_gpu) + diff_im[diff_im > tolerance] = 1 + + self.assertLessEqual(diff_im.sum() , 1) + + def test_SB_TV_CPU_vs_GPU(self): print(__name__) #filename = os.path.join("test","lena_gray_512.tif") -- cgit v1.2.3