summaryrefslogtreecommitdiffstats
path: root/Wrappers/Python
diff options
context:
space:
mode:
Diffstat (limited to 'Wrappers/Python')
-rw-r--r--Wrappers/Python/CMakeLists.txt183
-rw-r--r--Wrappers/Python/FindAnacondaEnvironment.cmake154
-rw-r--r--Wrappers/Python/ccpi/reconstruction/AstraDevice.py95
-rw-r--r--Wrappers/Python/ccpi/reconstruction/DeviceModel.py63
-rw-r--r--Wrappers/Python/ccpi/reconstruction/FISTAReconstructor.py882
-rw-r--r--Wrappers/Python/ccpi/reconstruction/Reconstructor.py598
-rw-r--r--Wrappers/Python/compile-fista.bat.in7
-rw-r--r--Wrappers/Python/compile-fista.sh.in9
-rw-r--r--Wrappers/Python/compile.bat.in7
-rw-r--r--Wrappers/Python/compile.sh.in9
-rw-r--r--Wrappers/Python/conda-recipe/bld.bat14
-rw-r--r--Wrappers/Python/conda-recipe/build.sh14
-rw-r--r--Wrappers/Python/conda-recipe/meta.yaml30
-rw-r--r--Wrappers/Python/fista-recipe/bld.bat11
-rw-r--r--Wrappers/Python/fista-recipe/build.sh10
-rw-r--r--Wrappers/Python/fista-recipe/meta.yaml29
-rw-r--r--Wrappers/Python/fista_module.cpp1047
-rw-r--r--Wrappers/Python/setup-fista.py.in27
-rw-r--r--Wrappers/Python/setup.py.in69
-rw-r--r--Wrappers/Python/test/astra_test.py85
-rw-r--r--Wrappers/Python/test/create_phantom_projections.py49
-rw-r--r--Wrappers/Python/test/readhd5.py42
-rw-r--r--Wrappers/Python/test/simple_astra_test.py25
-rw-r--r--Wrappers/Python/test/test_reconstructor-os_phantom.py480
-rw-r--r--Wrappers/Python/test/test_reconstructor.py359
-rw-r--r--Wrappers/Python/test/test_regularizers.py412
-rw-r--r--Wrappers/Python/test/test_regularizers_3d.py425
-rw-r--r--Wrappers/Python/test/view_result.py12
28 files changed, 5147 insertions, 0 deletions
diff --git a/Wrappers/Python/CMakeLists.txt b/Wrappers/Python/CMakeLists.txt
new file mode 100644
index 0000000..506159a
--- /dev/null
+++ b/Wrappers/Python/CMakeLists.txt
@@ -0,0 +1,183 @@
+# Copyright 2017 Edoardo Pasca
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# variables that must be set for conda compilation
+
+#PREFIX=C:\Apps\Miniconda2\envs\cil\Library
+#LIBRARY_INC=C:\\Apps\\Miniconda2\\envs\\cil\\Library\\include
+set (NUMPY_VERSION 1.12)
+
+## Tries to parse the output of conda env list to determine the current
+## active conda environment
+message ("Trying to determine your active conda environment...")
+execute_process(COMMAND "conda" "env" "list"
+ OUTPUT_VARIABLE _CONDA_ENVS
+ RESULT_VARIABLE _CONDA_RESULT
+ ERROR_VARIABLE _CONDA_ERR)
+ if(NOT _CONDA_RESULT)
+ string(REPLACE "\n" ";" ENV_LIST ${_CONDA_ENVS})
+ foreach(line ${ENV_LIST})
+ string(REGEX MATCHALL "(.+)[*](.+)" match ${line})
+ if (NOT ${match} EQUAL "")
+ #message("MATCHED " ${CMAKE_MATCH_0})
+ #message("MATCHED " ${CMAKE_MATCH_1})
+ #message("MATCHED " ${CMAKE_MATCH_2})
+ string(STRIP ${CMAKE_MATCH_1} CONDA_ENVIRONMENT)
+ string(STRIP ${CMAKE_MATCH_2} CONDA_ENVIRONMENT_PATH)
+ endif()
+ endforeach()
+ else()
+ message(FATAL_ERROR "ERROR with conda command " ${_CONDA_ERR})
+ endif()
+
+if (${CONDA_ENVIRONMENT} AND ${CONDA_ENVIRONMENT_PATH})
+ message (FATAL_ERROR "CONDA NOT FOUND")
+else()
+ message("**********************************************************")
+ message("Using current conda environmnet " ${CONDA_ENVIRONMENT})
+ message("Using current conda environmnet path " ${CONDA_ENVIRONMENT_PATH})
+endif()
+
+message("CIL VERSION " ${CIL_VERSION})
+
+# set the Python variables for the Conda environment
+include(FindAnacondaEnvironment.cmake)
+findPythonForAnacondaEnvironment(${CONDA_ENVIRONMENT_PATH})
+
+message("Python found " ${PYTHON_VERSION_STRING})
+message("Python found Major " ${PYTHON_VERSION_MAJOR})
+message("Python found Minor " ${PYTHON_VERSION_MINOR})
+
+findPythonPackagesPath()
+message("PYTHON_PACKAGES_FOUND " ${PYTHON_PACKAGES_PATH})
+
+## CACHE SOME VARIABLES ##
+set (CONDA_ENVIRONMENT ${CONDA_ENVIRONMENT} CACHE INTERNAL "active conda environment" FORCE)
+set (CONDA_ENVIRONMENT_PATH ${CONDA_ENVIRONMENT_PATH} CACHE INTERNAL "active conda environment" FORCE)
+
+set (PYTHON_VERSION_STRING ${PYTHON_VERSION_STRING} CACHE INTERNAL "conda environment Python version string" FORCE)
+set (PYTHON_VERSION_MAJOR ${PYTHON_VERSION_MAJOR} CACHE INTERNAL "conda environment Python version major" FORCE)
+set (PYTHON_VERSION_MINOR ${PYTHON_VERSION_MINOR} CACHE INTERNAL "conda environment Python version minor" FORCE)
+set (PYTHON_VERSION_PATCH ${PYTHON_VERSION_PATCH} CACHE INTERNAL "conda environment Python version patch" FORCE)
+set (PYTHON_PACKAGES_PATH ${PYTHON_PACKAGES_PATH} CACHE INTERNAL "conda environment Python packages path" FORCE)
+
+if (WIN32)
+ #set (CONDA_ENVIRONMENT_PATH "C:\\Apps\\Miniconda2\\envs\\${CONDA_ENVIRONMENT}" CACHE PATH "Main environment directory")
+ set (CONDA_ENVIRONMENT_PREFIX "${CONDA_ENVIRONMENT_PATH}\\Library" CACHE PATH "env dir")
+ set (CONDA_ENVIRONMENT_LIBRARY_INC "${CONDA_ENVIRONMENT_PREFIX}\\include" CACHE PATH "env dir")
+elseif (UNIX)
+ #set (CONDA_ENVIRONMENT_PATH "/apps/anaconda/2.4/envs/${CONDA_ENVIRONMENT}" CACHE PATH "Main environment directory")
+ set (CONDA_ENVIRONMENT_PREFIX "${CONDA_ENVIRONMENT_PATH}/lib/python${PYTHON_VERSION_MAJOR}.${PYTHON_VERSION_MINOR}" CACHE PATH "env dir")
+ set (CONDA_ENVIRONMENT_LIBRARY_INC "${CONDA_ENVIRONMENT_PREFIX}/include" CACHE PATH "env dir")
+endif()
+
+######### CONFIGURE REGULARIZER PACKAGE #############
+
+# copy the Pyhon files of the package regularizer
+file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging/)
+file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/__init__.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi)
+# regularizers
+file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/imaging/__init__.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging)
+file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/imaging/Regularizer.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging)
+
+# Copy and configure the relative conda build and recipes
+configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in ${CMAKE_CURRENT_BINARY_DIR}/setup.py)
+file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe)
+file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/conda-recipe/meta.yaml DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe)
+
+if (WIN32)
+
+ file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/conda-recipe/bld.bat DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe/)
+ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile.bat.in ${CMAKE_CURRENT_BINARY_DIR}/compile.bat)
+
+elseif(UNIX)
+
+ message ("We are on UNIX")
+ file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/conda-recipe/build.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe/)
+ # assumes we will use bash
+ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile.sh.in ${CMAKE_CURRENT_BINARY_DIR}/compile.sh)
+
+endif()
+
+########## CONFIGURE FISTA RECONSTRUCTOR PACKAGE
+# fista reconstructor
+file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/reconstruction/FISTAReconstructor.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/reconstruction)
+file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/reconstruction/__init__.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/reconstruction)
+file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/reconstruction/DeviceModel.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/reconstruction)
+file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/reconstruction/AstraDevice.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/reconstruction)
+
+configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup-fista.py.in ${CMAKE_CURRENT_BINARY_DIR}/setup-fista.py)
+file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/fista-recipe)
+file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/fista-recipe/meta.yaml DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/fista-recipe)
+
+if (WIN32)
+
+ file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/fista-recipe/bld.bat DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/fista-recipe/)
+ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile-fista.bat.in ${CMAKE_CURRENT_BINARY_DIR}/compile-fista.bat)
+
+elseif(UNIX)
+ message ("We are on UNIX")
+ file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/fista-recipe/build.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/fista-recipe/)
+ # assumes we will use bash
+ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile-fista.sh.in ${CMAKE_CURRENT_BINARY_DIR}/compile-fista.sh)
+endif()
+
+############################# TARGETS
+
+########################## REGULARIZER PACKAGE ###############################
+
+# runs cmake on the build tree to update the code from source
+add_custom_target(update_code
+ COMMAND ${CMAKE_COMMAND}
+ ARGS ${CMAKE_SOURCE_DIR}
+ WORKING_DIRECTORY ${CMAKE_BINARY_DIR}
+ )
+
+
+add_custom_target(fista
+ COMMAND bash
+ compile-fista.sh
+ WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
+ DEPENDS ${update_code}
+ )
+
+add_custom_target(regularizers
+ COMMAND bash
+ compile.sh
+ WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
+ DEPENDS update_code
+ )
+
+add_custom_target(install-fista
+ COMMAND ${CONDA_EXECUTABLE}
+ install --force --use-local ccpi-fista=${CIL_VERSION} -c ccpi -c conda-forge
+ WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
+ )
+
+add_custom_target(install-regularizers
+ COMMAND ${CONDA_EXECUTABLE}
+ install --force --use-local ccpi-regularizers=${CIL_VERSION} -c ccpi -c conda-forge
+ WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
+ )
+### add tests
+
+#add_executable(RegularizersTest )
+#find_package(tiff)
+#if (TIFF_FOUND)
+# message("LibTIFF Found")
+# message("TIFF_INCLUDE_DIR " ${TIFF_INCLUDE_DIR})
+# message("TIFF_LIBRARIES" ${TIFF_LIBRARIES})
+#else()
+# message("LibTIFF not found")
+#endif()
diff --git a/Wrappers/Python/FindAnacondaEnvironment.cmake b/Wrappers/Python/FindAnacondaEnvironment.cmake
new file mode 100644
index 0000000..6475128
--- /dev/null
+++ b/Wrappers/Python/FindAnacondaEnvironment.cmake
@@ -0,0 +1,154 @@
+# Copyright 2017 Edoardo Pasca
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# #.rst:
+# FindAnacondaEnvironment
+# --------------
+#
+# Find Python executable and library for a specific Anaconda environment
+#
+# This module finds the Python interpreter for a specific Anaconda enviroment,
+# if installed and determines where the include files and libraries are.
+# This code sets the following variables:
+#
+# ::
+# PYTHONINTERP_FOUND - if the Python interpret has been found
+# PYTHON_EXECUTABLE - the Python interpret found
+# PYTHON_LIBRARY - path to the python library
+# PYTHON_INCLUDE_PATH - path to where Python.h is found (deprecated)
+# PYTHON_INCLUDE_DIRS - path to where Python.h is found
+# PYTHONLIBS_VERSION_STRING - version of the Python libs found (since CMake 2.8.8)
+# PYTHON_VERSION_MAJOR - major Python version
+# PYTHON_VERSION_MINOR - minor Python version
+# PYTHON_VERSION_PATCH - patch Python version
+
+
+
+function (findPythonForAnacondaEnvironment env)
+ if (WIN32)
+ file(TO_CMAKE_PATH ${env}/python.exe PYTHON_EXECUTABLE)
+ elseif (UNIX)
+ file(TO_CMAKE_PATH ${env}/bin/python PYTHON_EXECUTABLE)
+ endif()
+
+
+ message("findPythonForAnacondaEnvironment Found Python Executable" ${PYTHON_EXECUTABLE})
+ ####### FROM FindPythonInterpr ########
+ # determine python version string
+ if(PYTHON_EXECUTABLE)
+ execute_process(COMMAND "${PYTHON_EXECUTABLE}" -c
+ "import sys; sys.stdout.write(';'.join([str(x) for x in sys.version_info[:3]]))"
+ OUTPUT_VARIABLE _VERSION
+ RESULT_VARIABLE _PYTHON_VERSION_RESULT
+ ERROR_QUIET)
+ if(NOT _PYTHON_VERSION_RESULT)
+ string(REPLACE ";" "." _PYTHON_VERSION_STRING "${_VERSION}")
+ list(GET _VERSION 0 _PYTHON_VERSION_MAJOR)
+ list(GET _VERSION 1 _PYTHON_VERSION_MINOR)
+ list(GET _VERSION 2 _PYTHON_VERSION_PATCH)
+ if(PYTHON_VERSION_PATCH EQUAL 0)
+ # it's called "Python 2.7", not "2.7.0"
+ string(REGEX REPLACE "\\.0$" "" _PYTHON_VERSION_STRING "${PYTHON_VERSION_STRING}")
+ endif()
+ else()
+ # sys.version predates sys.version_info, so use that
+ execute_process(COMMAND "${PYTHON_EXECUTABLE}" -c "import sys; sys.stdout.write(sys.version)"
+ OUTPUT_VARIABLE _VERSION
+ RESULT_VARIABLE _PYTHON_VERSION_RESULT
+ ERROR_QUIET)
+ if(NOT _PYTHON_VERSION_RESULT)
+ string(REGEX REPLACE " .*" "" _PYTHON_VERSION_STRING "${_VERSION}")
+ string(REGEX REPLACE "^([0-9]+)\\.[0-9]+.*" "\\1" _PYTHON_VERSION_MAJOR "${PYTHON_VERSION_STRING}")
+ string(REGEX REPLACE "^[0-9]+\\.([0-9])+.*" "\\1" _PYTHON_VERSION_MINOR "${PYTHON_VERSION_STRING}")
+ if(PYTHON_VERSION_STRING MATCHES "^[0-9]+\\.[0-9]+\\.([0-9]+)")
+ set(PYTHON_VERSION_PATCH "${CMAKE_MATCH_1}")
+ else()
+ set(PYTHON_VERSION_PATCH "0")
+ endif()
+ else()
+ # sys.version was first documented for Python 1.5, so assume
+ # this is older.
+ set(PYTHON_VERSION_STRING "1.4" PARENT_SCOPE)
+ set(PYTHON_VERSION_MAJOR "1" PARENT_SCOPE)
+ set(PYTHON_VERSION_MINOR "4" PARENT_SCOPE)
+ set(PYTHON_VERSION_PATCH "0" PARENT_SCOPE)
+ endif()
+ endif()
+ unset(_PYTHON_VERSION_RESULT)
+ unset(_VERSION)
+ endif()
+ ###############################################
+
+ set (PYTHON_EXECUTABLE ${PYTHON_EXECUTABLE} PARENT_SCOPE)
+ set (PYTHONINTERP_FOUND "ON" PARENT_SCOPE)
+ set (PYTHON_VERSION_STRING ${_PYTHON_VERSION_STRING} PARENT_SCOPE)
+ set (PYTHON_VERSION_MAJOR ${_PYTHON_VERSION_MAJOR} PARENT_SCOPE)
+ set (PYTHON_VERSION_MINOR ${_PYTHON_VERSION_MINOR} PARENT_SCOPE)
+ set (PYTHON_VERSION_PATCH ${_PYTHON_VERSION_PATCH} PARENT_SCOPE)
+ message("My version found " ${PYTHON_VERSION_STRING})
+ ## find conda executable
+ if (WIN32)
+ set (CONDA_EXECUTABLE ${env}/Script/conda PARENT_SCOPE)
+ elseif(UNIX)
+ set (CONDA_EXECUTABLE ${env}/bin/conda PARENT_SCOPE)
+ endif()
+endfunction()
+
+
+
+set(Python_ADDITIONAL_VERSIONS 3.5)
+
+find_package(PythonInterp)
+if (PYTHONINTERP_FOUND)
+
+ message("Found interpret " ${PYTHON_EXECUTABLE})
+ message("Python Library " ${PYTHON_LIBRARY})
+ message("Python Include Dir " ${PYTHON_INCLUDE_DIR})
+ message("Python Include Path " ${PYTHON_INCLUDE_PATH})
+
+ foreach(pv ${PYTHON_VERSION_STRING})
+ message("Found interpret " ${pv})
+ endforeach()
+endif()
+
+
+
+find_package(PythonLibs)
+if (PYTHONLIB_FOUND)
+ message("Found PythonLibs PYTHON_LIBRARIES " ${PYTHON_LIBRARIES})
+ message("Found PythonLibs PYTHON_INCLUDE_PATH " ${PYTHON_INCLUDE_PATH})
+ message("Found PythonLibs PYTHON_INCLUDE_DIRS " ${PYTHON_INCLUDE_DIRS})
+ message("Found PythonLibs PYTHONLIBS_VERSION_STRING " ${PYTHONLIBS_VERSION_STRING} )
+else()
+ message("No PythonLibs Found")
+endif()
+
+
+
+
+function(findPythonPackagesPath)
+ execute_process(COMMAND ${PYTHON_EXECUTABLE} -c "from distutils.sysconfig import *; print (get_python_lib())"
+ RESULT_VARIABLE PYTHON_CVPY_PROCESS
+ OUTPUT_VARIABLE PYTHON_STD_PACKAGES_PATH
+ OUTPUT_STRIP_TRAILING_WHITESPACE)
+ #message("STD_PACKAGES " ${PYTHON_STD_PACKAGES_PATH})
+ if("${PYTHON_STD_PACKAGES_PATH}" MATCHES "site-packages")
+ set(_PYTHON_PACKAGES_PATH "python${PYTHON_VERSION_MAJOR_MINOR}/site-packages")
+ endif()
+
+ SET(PYTHON_PACKAGES_PATH "${PYTHON_STD_PACKAGES_PATH}" PARENT_SCOPE)
+
+endfunction()
+
+
diff --git a/Wrappers/Python/ccpi/reconstruction/AstraDevice.py b/Wrappers/Python/ccpi/reconstruction/AstraDevice.py
new file mode 100644
index 0000000..57435f8
--- /dev/null
+++ b/Wrappers/Python/ccpi/reconstruction/AstraDevice.py
@@ -0,0 +1,95 @@
+import astra
+from ccpi.reconstruction.DeviceModel import DeviceModel
+import numpy
+
+class AstraDevice(DeviceModel):
+ '''Concrete class for Astra Device'''
+
+ def __init__(self,
+ device_type,
+ data_aquisition_geometry,
+ reconstructed_volume_geometry):
+
+ super(AstraDevice, self).__init__(device_type,
+ data_aquisition_geometry,
+ reconstructed_volume_geometry)
+
+ self.type = device_type
+ self.proj_geom = astra.creators.create_proj_geom(
+ device_type,
+ self.acquisition_data_geometry['detectorSpacingX'],
+ self.acquisition_data_geometry['detectorSpacingY'],
+ self.acquisition_data_geometry['cameraX'],
+ self.acquisition_data_geometry['cameraY'],
+ self.acquisition_data_geometry['angles'],
+ )
+
+ self.vol_geom = astra.creators.create_vol_geom(
+ self.reconstructed_volume_geometry['X'],
+ self.reconstructed_volume_geometry['Y'],
+ self.reconstructed_volume_geometry['Z']
+ )
+
+ def doForwardProject(self, volume):
+ '''Forward projects the volume according to the device geometry
+
+Uses Astra-toolbox
+'''
+
+ try:
+ sino_id, y = astra.creators.create_sino3d_gpu(
+ volume, self.proj_geom, self.vol_geom)
+ astra.matlab.data3d('delete', sino_id)
+ return y
+ except Exception as e:
+ print(e)
+ print("Value Error: ", self.proj_geom, self.vol_geom)
+
+ def doBackwardProject(self, projections):
+ '''Backward projects the projections according to the device geometry
+
+Uses Astra-toolbox
+'''
+ idx, volume = \
+ astra.creators.create_backprojection3d_gpu(
+ projections,
+ self.proj_geom,
+ self.vol_geom)
+
+ astra.matlab.data3d('delete', idx)
+ return volume
+
+ def createReducedDevice(self, proj_par={'cameraY' : 1} , vol_par={'Z':1}):
+ '''Create a new device based on the current device by changing some parameter
+
+VERY RISKY'''
+ acquisition_data_geometry = self.acquisition_data_geometry.copy()
+ for k,v in proj_par.items():
+ if k in acquisition_data_geometry.keys():
+ acquisition_data_geometry[k] = v
+ proj_geom = [
+ acquisition_data_geometry['cameraX'],
+ acquisition_data_geometry['cameraY'],
+ acquisition_data_geometry['detectorSpacingX'],
+ acquisition_data_geometry['detectorSpacingY'],
+ acquisition_data_geometry['angles']
+ ]
+
+ reconstructed_volume_geometry = self.reconstructed_volume_geometry.copy()
+ for k,v in vol_par.items():
+ if k in reconstructed_volume_geometry.keys():
+ reconstructed_volume_geometry[k] = v
+
+ vol_geom = [
+ reconstructed_volume_geometry['X'],
+ reconstructed_volume_geometry['Y'],
+ reconstructed_volume_geometry['Z']
+ ]
+ return AstraDevice(self.type, proj_geom, vol_geom)
+
+
+
+if __name__=="main":
+ a = AstraDevice()
+
+
diff --git a/Wrappers/Python/ccpi/reconstruction/DeviceModel.py b/Wrappers/Python/ccpi/reconstruction/DeviceModel.py
new file mode 100644
index 0000000..eeb9a34
--- /dev/null
+++ b/Wrappers/Python/ccpi/reconstruction/DeviceModel.py
@@ -0,0 +1,63 @@
+from abc import ABCMeta, abstractmethod
+from enum import Enum
+
+class DeviceModel(metaclass=ABCMeta):
+ '''Abstract class that defines the device for projection and backprojection
+
+This class defines the methods that must be implemented by concrete classes.
+
+ '''
+
+ class DeviceType(Enum):
+ '''Type of device
+PARALLEL BEAM
+PARALLEL BEAM 3D
+CONE BEAM
+HELICAL'''
+
+ PARALLEL = 'parallel'
+ PARALLEL3D = 'parallel3d'
+ CONE_BEAM = 'cone-beam'
+ HELICAL = 'helical'
+
+ def __init__(self,
+ device_type,
+ data_aquisition_geometry,
+ reconstructed_volume_geometry):
+ '''Initializes the class
+
+Mandatory parameters are:
+device_type from DeviceType Enum
+data_acquisition_geometry: tuple (camera_X, camera_Y, detectorSpacingX,
+ detectorSpacingY, angles)
+reconstructed_volume_geometry: tuple (dimX,dimY,dimZ)
+'''
+ self.device_geometry = device_type
+ self.acquisition_data_geometry = {
+ 'cameraX': data_aquisition_geometry[0],
+ 'cameraY': data_aquisition_geometry[1],
+ 'detectorSpacingX' : data_aquisition_geometry[2],
+ 'detectorSpacingY' : data_aquisition_geometry[3],
+ 'angles' : data_aquisition_geometry[4],}
+ self.reconstructed_volume_geometry = {
+ 'X': reconstructed_volume_geometry[0] ,
+ 'Y': reconstructed_volume_geometry[1] ,
+ 'Z': reconstructed_volume_geometry[2] }
+
+ @abstractmethod
+ def doForwardProject(self, volume):
+ '''Forward projects the volume according to the device geometry'''
+ return NotImplemented
+
+
+ @abstractmethod
+ def doBackwardProject(self, projections):
+ '''Backward projects the projections according to the device geometry'''
+ return NotImplemented
+
+ @abstractmethod
+ def createReducedDevice(self):
+ '''Create a Device to do forward/backward projections on 2D slices'''
+ return NotImplemented
+
+
diff --git a/Wrappers/Python/ccpi/reconstruction/FISTAReconstructor.py b/Wrappers/Python/ccpi/reconstruction/FISTAReconstructor.py
new file mode 100644
index 0000000..e40ad24
--- /dev/null
+++ b/Wrappers/Python/ccpi/reconstruction/FISTAReconstructor.py
@@ -0,0 +1,882 @@
+# -*- coding: utf-8 -*-
+###############################################################################
+#This work is part of the Core Imaging Library developed by
+#Visual Analytics and Imaging System Group of the Science Technology
+#Facilities Council, STFC
+#
+#Copyright 2017 Edoardo Pasca, Srikanth Nagella
+#Copyright 2017 Daniil Kazantsev
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#http://www.apache.org/licenses/LICENSE-2.0
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+###############################################################################
+
+
+
+import numpy
+#from ccpi.reconstruction.parallelbeam import alg
+
+#from ccpi.imaging.Regularizer import Regularizer
+from enum import Enum
+
+import astra
+from ccpi.reconstruction.AstraDevice import AstraDevice
+
+
+
+class FISTAReconstructor():
+ '''FISTA-based reconstruction algorithm using ASTRA-toolbox
+
+ '''
+ # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>>
+ # ___Input___:
+ # params.[] file:
+ # - .proj_geom (geometry of the projector) [required]
+ # - .vol_geom (geometry of the reconstructed object) [required]
+ # - .sino (vectorized in 2D or 3D sinogram) [required]
+ # - .iterFISTA (iterations for the main loop, default 40)
+ # - .L_const (Lipschitz constant, default Power method) )
+ # - .X_ideal (ideal image, if given)
+ # - .weights (statisitcal weights, size of the sinogram)
+ # - .ROI (Region-of-interest, only if X_ideal is given)
+ # - .initialize (a 'warm start' using SIRT method from ASTRA)
+ #----------------Regularization choices------------------------
+ # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter)
+ # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter)
+ # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter)
+ # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04)
+ # - .Regul_Iterations (iterations for the selected penalty, default 25)
+ # - .Regul_tauLLT (time step parameter for LLT term)
+ # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal)
+ # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1)
+ #----------------Visualization parameters------------------------
+ # - .show (visualize reconstruction 1/0, (0 default))
+ # - .maxvalplot (maximum value to use for imshow[0 maxvalplot])
+ # - .slice (for 3D volumes - slice number to imshow)
+ # ___Output___:
+ # 1. X - reconstructed image/volume
+ # 2. output - a structure with
+ # - .Resid_error - residual error (if X_ideal is given)
+ # - .objective: value of the objective function
+ # - .L_const: Lipshitz constant to avoid recalculations
+
+ # References:
+ # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse
+ # Problems" by A. Beck and M Teboulle
+ # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo
+ # 3. "A novel tomographic reconstruction method based on the robust
+ # Student's t function for suppressing data outliers" D. Kazantsev et.al.
+ # D. Kazantsev, 2016-17
+ def __init__(self, projector_geometry,
+ output_geometry,
+ input_sinogram,
+ device,
+ **kwargs):
+ # handle parmeters:
+ # obligatory parameters
+ self.pars = dict()
+ self.pars['projector_geometry'] = projector_geometry # proj_geom
+ self.pars['output_geometry'] = output_geometry # vol_geom
+ self.pars['input_sinogram'] = input_sinogram # sino
+ sliceZ, nangles, detectors = numpy.shape(input_sinogram)
+ self.pars['detectors'] = detectors
+ self.pars['number_of_angles'] = nangles
+ self.pars['SlicesZ'] = sliceZ
+ self.pars['output_volume'] = None
+ self.pars['device_model'] = device
+
+ self.use_device = True
+
+ print (self.pars)
+ # handle optional input parameters (at instantiation)
+
+ # Accepted input keywords
+ kw = (
+ # mandatory fields
+ 'projector_geometry',
+ 'output_geometry',
+ 'input_sinogram',
+ 'detectors',
+ 'number_of_angles',
+ 'SlicesZ',
+ # optional fields
+ 'number_of_iterations',
+ 'Lipschitz_constant' ,
+ 'ideal_image' ,
+ 'weights' ,
+ 'region_of_interest' ,
+ 'initialize' ,
+ 'regularizer' ,
+ 'ring_lambda_R_L1',
+ 'ring_alpha',
+ 'subsets',
+ 'output_volume',
+ 'os_subsets',
+ 'os_indices',
+ 'os_bins',
+ 'device_model',
+ 'reduced_device_model')
+ self.acceptedInputKeywords = list(kw)
+
+ # handle keyworded parameters
+ if kwargs is not None:
+ for key, value in kwargs.items():
+ if key in kw:
+ #print("{0} = {1}".format(key, value))
+ self.pars[key] = value
+
+ # set the default values for the parameters if not set
+ if 'number_of_iterations' in kwargs.keys():
+ self.pars['number_of_iterations'] = kwargs['number_of_iterations']
+ else:
+ self.pars['number_of_iterations'] = 40
+ if 'weights' in kwargs.keys():
+ self.pars['weights'] = kwargs['weights']
+ else:
+ self.pars['weights'] = \
+ numpy.ones(numpy.shape(
+ self.pars['input_sinogram']))
+ if 'Lipschitz_constant' in kwargs.keys():
+ self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant']
+ else:
+ self.pars['Lipschitz_constant'] = None
+
+ if not 'ideal_image' in kwargs.keys():
+ self.pars['ideal_image'] = None
+
+ if not 'region_of_interest'in kwargs.keys() :
+ if self.pars['ideal_image'] == None:
+ self.pars['region_of_interest'] = None
+ else:
+ ## nonzero if the image is larger than m
+ fsm = numpy.frompyfunc(lambda x,m: 1 if x>m else 0, 2,1)
+
+ self.pars['region_of_interest'] = fsm(self.pars['ideal_image'], 0)
+
+ # the regularizer must be a correctly instantiated object
+ if not 'regularizer' in kwargs.keys() :
+ self.pars['regularizer'] = None
+
+ #RING REMOVAL
+ if not 'ring_lambda_R_L1' in kwargs.keys():
+ self.pars['ring_lambda_R_L1'] = 0
+ if not 'ring_alpha' in kwargs.keys():
+ self.pars['ring_alpha'] = 1
+
+ # ORDERED SUBSET
+ if not 'subsets' in kwargs.keys():
+ self.pars['subsets'] = 0
+ else:
+ self.createOrderedSubsets()
+
+ if not 'initialize' in kwargs.keys():
+ self.pars['initialize'] = False
+
+ reduced_device = device.createReducedDevice()
+ self.setParameter(reduced_device_model=reduced_device)
+
+
+
+ def setParameter(self, **kwargs):
+ '''set named parameter for the reconstructor engine
+
+ raises Exception if the named parameter is not recognized
+
+ '''
+ for key , value in kwargs.items():
+ if key in self.acceptedInputKeywords:
+ self.pars[key] = value
+ else:
+ raise Exception('Wrong parameter {0} for '.format(key) +
+ 'reconstructor')
+ # setParameter
+
+ def getParameter(self, key):
+ if type(key) is str:
+ if key in self.acceptedInputKeywords:
+ return self.pars[key]
+ else:
+ raise Exception('Unrecongnised parameter: {0} '.format(key) )
+ elif type(key) is list:
+ outpars = []
+ for k in key:
+ outpars.append(self.getParameter(k))
+ return outpars
+ else:
+ raise Exception('Unhandled input {0}' .format(str(type(key))))
+
+
+ def calculateLipschitzConstantWithPowerMethod(self):
+ ''' using Power method (PM) to establish L constant'''
+
+ N = self.pars['output_geometry']['GridColCount']
+ proj_geom = self.pars['projector_geometry']
+ vol_geom = self.pars['output_geometry']
+ weights = self.pars['weights']
+ SlicesZ = self.pars['SlicesZ']
+
+
+
+ if (proj_geom['type'] == 'parallel') or \
+ (proj_geom['type'] == 'parallel3d'):
+ #% for parallel geometry we can do just one slice
+ #print('Calculating Lipshitz constant for parallel beam geometry...')
+ niter = 5;# % number of iteration for the PM
+ #N = params.vol_geom.GridColCount;
+ #x1 = rand(N,N,1);
+ x1 = numpy.random.rand(1,N,N)
+ #sqweight = sqrt(weights(:,:,1));
+ sqweight = numpy.sqrt(weights[0:1,:,:])
+ proj_geomT = proj_geom.copy();
+ proj_geomT['DetectorRowCount'] = 1;
+ vol_geomT = vol_geom.copy();
+ vol_geomT['GridSliceCount'] = 1;
+
+ #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
+
+
+ for i in range(niter):
+ # [id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geomT, vol_geomT);
+ # s = norm(x1(:));
+ # x1 = x1/s;
+ # [sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
+ # y = sqweight.*y;
+ # astra_mex_data3d('delete', sino_id);
+ # astra_mex_data3d('delete', id);
+ #print ("iteration {0}".format(i))
+
+ sino_id, y = astra.creators.create_sino3d_gpu(x1,
+ proj_geomT,
+ vol_geomT)
+
+ y = (sqweight * y).copy() # element wise multiplication
+
+ #b=fig.add_subplot(2,1,2)
+ #imgplot = plt.imshow(x1[0])
+ #plt.show()
+
+ #astra_mex_data3d('delete', sino_id);
+ astra.matlab.data3d('delete', sino_id)
+ del x1
+
+ idx,x1 = astra.creators.create_backprojection3d_gpu((sqweight*y).copy(),
+ proj_geomT,
+ vol_geomT)
+ del y
+
+
+ s = numpy.linalg.norm(x1)
+ ### this line?
+ x1 = (x1/s).copy();
+
+ # ### this line?
+ # sino_id, y = astra.creators.create_sino3d_gpu(x1,
+ # proj_geomT,
+ # vol_geomT);
+ # y = sqweight * y;
+ astra.matlab.data3d('delete', sino_id);
+ astra.matlab.data3d('delete', idx)
+ print ("iteration {0} s= {1}".format(i,s))
+
+ #end
+ del proj_geomT
+ del vol_geomT
+ #plt.show()
+ else:
+ #% divergen beam geometry
+ print('Calculating Lipshitz constant for divergen beam geometry...')
+ niter = 8; #% number of iteration for PM
+ x1 = numpy.random.rand(SlicesZ , N , N);
+ #sqweight = sqrt(weights);
+ sqweight = numpy.sqrt(weights[0])
+
+ sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom);
+ y = sqweight*y;
+ #astra_mex_data3d('delete', sino_id);
+ astra.matlab.data3d('delete', sino_id);
+
+ for i in range(niter):
+ #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom);
+ idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y,
+ proj_geom,
+ vol_geom)
+ s = numpy.linalg.norm(x1)
+ ### this line?
+ x1 = x1/s;
+ ### this line?
+ #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom);
+ sino_id, y = astra.creators.create_sino3d_gpu(x1,
+ proj_geom,
+ vol_geom);
+
+ y = sqweight*y;
+ #astra_mex_data3d('delete', sino_id);
+ #astra_mex_data3d('delete', id);
+ astra.matlab.data3d('delete', sino_id);
+ astra.matlab.data3d('delete', idx);
+ #end
+ #clear x1
+ del x1
+
+
+ return s
+
+
+ def setRegularizer(self, regularizer):
+ if regularizer is not None:
+ self.pars['regularizer'] = regularizer
+
+
+ def initialize(self):
+ # convenience variable storage
+ proj_geom = self.pars['projector_geometry']
+ vol_geom = self.pars['output_geometry']
+ sino = self.pars['input_sinogram']
+
+ # a 'warm start' with SIRT method
+ # Create a data object for the reconstruction
+ rec_id = astra.matlab.data3d('create', '-vol',
+ vol_geom);
+
+ #sinogram_id = astra_mex_data3d('create', '-proj3d', proj_geom, sino);
+ sinogram_id = astra.matlab.data3d('create', '-proj3d',
+ proj_geom,
+ sino)
+
+ sirt_config = astra.astra_dict('SIRT3D_CUDA')
+ sirt_config['ReconstructionDataId' ] = rec_id
+ sirt_config['ProjectionDataId'] = sinogram_id
+
+ sirt = astra.algorithm.create(sirt_config)
+ astra.algorithm.run(sirt, iterations=35)
+ X = astra.matlab.data3d('get', rec_id)
+
+ # clean up memory
+ astra.matlab.data3d('delete', rec_id)
+ astra.matlab.data3d('delete', sinogram_id)
+ astra.algorithm.delete(sirt)
+
+
+
+ return X
+
+ def createOrderedSubsets(self, subsets=None):
+ if subsets is None:
+ try:
+ subsets = self.getParameter('subsets')
+ except Exception():
+ subsets = 0
+ #return subsets
+ else:
+ self.setParameter(subsets=subsets)
+
+
+ angles = self.getParameter('projector_geometry')['ProjectionAngles']
+
+ #binEdges = numpy.linspace(angles.min(),
+ # angles.max(),
+ # subsets + 1)
+ binsDiscr, binEdges = numpy.histogram(angles, bins=subsets)
+ # get rearranged subset indices
+ IndicesReorg = numpy.zeros((numpy.shape(angles)), dtype=numpy.int32)
+ counterM = 0
+ for ii in range(binsDiscr.max()):
+ counter = 0
+ for jj in range(subsets):
+ curr_index = ii + jj + counter
+ #print ("{0} {1} {2}".format(binsDiscr[jj] , ii, counterM))
+ if binsDiscr[jj] > ii:
+ if (counterM < numpy.size(IndicesReorg)):
+ IndicesReorg[counterM] = curr_index
+ counterM = counterM + 1
+
+ counter = counter + binsDiscr[jj] - 1
+
+ # store the OS in parameters
+ self.setParameter(os_subsets=subsets,
+ os_bins=binsDiscr,
+ os_indices=IndicesReorg)
+
+
+ def prepareForIteration(self):
+ print ("FISTA Reconstructor: prepare for iteration")
+
+ self.residual_error = numpy.zeros((self.pars['number_of_iterations']))
+ self.objective = numpy.zeros((self.pars['number_of_iterations']))
+
+ #2D array (for 3D data) of sparse "ring"
+ detectors, nangles, sliceZ = numpy.shape(self.pars['input_sinogram'])
+ self.r = numpy.zeros((detectors, sliceZ), dtype=numpy.float)
+ # another ring variable
+ self.r_x = self.r.copy()
+
+ self.residual = numpy.zeros(numpy.shape(self.pars['input_sinogram']))
+
+ if self.getParameter('Lipschitz_constant') is None:
+ self.pars['Lipschitz_constant'] = \
+ self.calculateLipschitzConstantWithPowerMethod()
+ # errors vector (if the ground truth is given)
+ self.Resid_error = numpy.zeros((self.getParameter('number_of_iterations')));
+ # objective function values vector
+ self.objective = numpy.zeros((self.getParameter('number_of_iterations')));
+
+
+ # prepareForIteration
+
+ def iterate (self, Xin=None):
+ if self.getParameter('subsets') == 0:
+ return self.iterateStandard(Xin)
+ else:
+ return self.iterateOrderedSubsets(Xin)
+
+ def iterateStandard(self, Xin=None):
+ print ("FISTA Reconstructor: iterate")
+
+ if Xin is None:
+ if self.getParameter('initialize'):
+ X = self.initialize()
+ else:
+ N = vol_geom['GridColCount']
+ X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float)
+ else:
+ # copy by reference
+ X = Xin
+ # store the output volume in the parameters
+ self.setParameter(output_volume=X)
+ X_t = X.copy()
+ # convenience variable storage
+ proj_geom , vol_geom, sino , \
+ SlicesZ , ring_lambda_R_L1 , weights = \
+ self.getParameter([ 'projector_geometry' ,
+ 'output_geometry',
+ 'input_sinogram',
+ 'SlicesZ' ,
+ 'ring_lambda_R_L1',
+ 'weights'])
+
+ t = 1
+
+ device = self.getParameter('device_model')
+ reduced_device = self.getParameter('reduced_device_model')
+
+ for i in range(self.getParameter('number_of_iterations')):
+ print("iteration", i)
+ X_old = X.copy()
+ t_old = t
+ r_old = self.r.copy()
+ pg = self.getParameter('projector_geometry')['type']
+ if pg == 'parallel' or \
+ pg == 'fanflat' or \
+ pg == 'fanflat_vec':
+ # if the geometry is parallel use slice-by-slice
+ # projection-backprojection routine
+ #sino_updt = zeros(size(sino),'single');
+
+ if self.use_device :
+ self.sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float)
+
+ for kkk in range(SlicesZ):
+ self.sino_updt[kkk] = \
+ reduced_device.doForwardProject( X_t[kkk:kkk+1] )
+ else:
+ proj_geomT = proj_geom.copy()
+ proj_geomT['DetectorRowCount'] = 1
+ vol_geomT = vol_geom.copy()
+ vol_geomT['GridSliceCount'] = 1;
+ self.sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float)
+ for kkk in range(SlicesZ):
+ sino_id, self.sino_updt[kkk] = \
+ astra.creators.create_sino3d_gpu(
+ X_t[kkk:kkk+1], proj_geomT, vol_geomT)
+ astra.matlab.data3d('delete', sino_id)
+ else:
+ # for divergent 3D geometry (watch the GPU memory overflow in
+ # ASTRA versions < 1.8)
+ #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom);
+
+ if self.use_device:
+ self.sino_updt = device.doForwardProject(X_t)
+ else:
+ sino_id, self.sino_updt = astra.creators.create_sino3d_gpu(
+ X_t, proj_geom, vol_geom)
+ astra.matlab.data3d('delete', sino_id)
+
+
+ ## RING REMOVAL
+ if ring_lambda_R_L1 != 0:
+ self.ringRemoval(i)
+ else:
+ self.residual = weights * (self.sino_updt - sino)
+ self.objective[i] = 0.5 * numpy.linalg.norm(self.residual)
+ #objective(i) = 0.5*norm(residual(:)); % for the objective function output
+ ## Projection/Backprojection Routine
+ X, X_t = self.projectionBackprojection(X, X_t)
+
+ ## REGULARIZATION
+ Y = self.regularize(X)
+ X = Y.copy()
+ ## Update Loop
+ X , X_t, t = self.updateLoop(i, X, X_old, r_old, t, t_old)
+
+ print ("t" , t)
+ print ("X min {0} max {1}".format(X_t.min(),X_t.max()))
+ self.setParameter(output_volume=X)
+ return X
+ ## iterate
+
+ def ringRemoval(self, i):
+ print ("FISTA Reconstructor: ring removal")
+ residual = self.residual
+ lambdaR_L1 , alpha_ring , weights , L_const , sino= \
+ self.getParameter(['ring_lambda_R_L1',
+ 'ring_alpha' , 'weights',
+ 'Lipschitz_constant',
+ 'input_sinogram'])
+ r_x = self.r_x
+ sino_updt = self.sino_updt
+
+ SlicesZ, anglesNumb, Detectors = \
+ numpy.shape(self.getParameter('input_sinogram'))
+ if lambdaR_L1 > 0 :
+ for kkk in range(anglesNumb):
+
+ residual[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \
+ ((sino_updt[:,kkk,:]).squeeze() - \
+ (sino[:,kkk,:]).squeeze() -\
+ (alpha_ring * r_x)
+ )
+ vec = residual.sum(axis = 1)
+ #if SlicesZ > 1:
+ # vec = vec[:,1,:].squeeze()
+ self.r = (r_x - (1./L_const) * vec).copy()
+ self.objective[i] = (0.5 * (residual ** 2).sum())
+
+ def projectionBackprojection(self, X, X_t):
+ print ("FISTA Reconstructor: projection-backprojection routine")
+
+ # a few useful variables
+ SlicesZ, anglesNumb, Detectors = \
+ numpy.shape(self.getParameter('input_sinogram'))
+ residual = self.residual
+ proj_geom , vol_geom , L_const = \
+ self.getParameter(['projector_geometry' ,
+ 'output_geometry',
+ 'Lipschitz_constant'])
+
+ device, reduced_device = self.getParameter(['device_model',
+ 'reduced_device_model'])
+
+ if self.getParameter('projector_geometry')['type'] == 'parallel' or \
+ self.getParameter('projector_geometry')['type'] == 'fanflat' or \
+ self.getParameter('projector_geometry')['type'] == 'fanflat_vec':
+ # if the geometry is parallel use slice-by-slice
+ # projection-backprojection routine
+ #sino_updt = zeros(size(sino),'single');
+ x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32)
+
+ if self.use_device:
+ proj_geomT = proj_geom.copy()
+ proj_geomT['DetectorRowCount'] = 1
+ vol_geomT = vol_geom.copy()
+ vol_geomT['GridSliceCount'] = 1;
+
+ for kkk in range(SlicesZ):
+
+ x_id, x_temp[kkk] = \
+ astra.creators.create_backprojection3d_gpu(
+ residual[kkk:kkk+1],
+ proj_geomT, vol_geomT)
+ astra.matlab.data3d('delete', x_id)
+ else:
+ for kkk in range(SliceZ):
+ x_temp[kkk] = \
+ reduced_device.doBackwardProject(residual[kkk:kkk+1])
+ else:
+ if self.use_device:
+ x_id, x_temp = \
+ astra.creators.create_backprojection3d_gpu(
+ residual, proj_geom, vol_geom)
+ astra.matlab.data3d('delete', x_id)
+ else:
+ x_temp = \
+ device.doBackwardProject(residual)
+
+
+ X = X_t - (1/L_const) * x_temp
+ #astra.matlab.data3d('delete', sino_id)
+ return (X , X_t)
+
+
+ def regularize(self, X , output_all=False):
+ #print ("FISTA Reconstructor: regularize")
+
+ regularizer = self.getParameter('regularizer')
+ if regularizer is not None:
+ return regularizer(input=X,
+ output_all=output_all)
+ else:
+ return X
+
+ def updateLoop(self, i, X, X_old, r_old, t, t_old):
+ print ("FISTA Reconstructor: update loop")
+ lambdaR_L1 = self.getParameter('ring_lambda_R_L1')
+
+ t = (1 + numpy.sqrt(1 + 4 * t**2))/2
+ X_t = X + (((t_old -1)/t) * (X - X_old))
+
+ if lambdaR_L1 > 0:
+ self.r = numpy.max(
+ numpy.abs(self.r) - lambdaR_L1 , 0) * \
+ numpy.sign(self.r)
+ self.r_x = self.r + \
+ (((t_old-1)/t) * (self.r - r_old))
+
+ if self.getParameter('region_of_interest') is None:
+ string = 'Iteration Number {0} | Objective {1} \n'
+ print (string.format( i, self.objective[i]))
+ else:
+ ROI , X_ideal = fistaRecon.getParameter('region_of_interest',
+ 'ideal_image')
+
+ Resid_error[i] = RMSE(X*ROI, X_ideal*ROI)
+ string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n'
+ print (string.format(i,Resid_error[i], self.objective[i]))
+ return (X , X_t, t)
+
+ def iterateOrderedSubsets(self, Xin=None):
+ print ("FISTA Reconstructor: Ordered Subsets iterate")
+
+ if Xin is None:
+ if self.getParameter('initialize'):
+ X = self.initialize()
+ else:
+ N = vol_geom['GridColCount']
+ X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float)
+ else:
+ # copy by reference
+ X = Xin
+ # store the output volume in the parameters
+ self.setParameter(output_volume=X)
+ X_t = X.copy()
+
+ # some useful constants
+ proj_geom , vol_geom, sino , \
+ SlicesZ, weights , alpha_ring ,\
+ lambdaR_L1 , L_const , iterFISTA = self.getParameter(
+ ['projector_geometry' , 'output_geometry', 'input_sinogram',
+ 'SlicesZ' , 'weights', 'ring_alpha' ,
+ 'ring_lambda_R_L1', 'Lipschitz_constant',
+ 'number_of_iterations'])
+
+
+ # errors vector (if the ground truth is given)
+ Resid_error = numpy.zeros((iterFISTA));
+ # objective function values vector
+ #objective = numpy.zeros((iterFISTA));
+ objective = self.objective
+
+
+ t = 1
+
+ ## additional for
+ proj_geomSUB = proj_geom.copy()
+ self.residual2 = numpy.zeros(numpy.shape(sino))
+ residual2 = self.residual2
+ sino_updt_FULL = self.residual.copy()
+ r_x = self.r.copy()
+
+ print ("starting iterations")
+ ## % Outer FISTA iterations loop
+ for i in range(self.getParameter('number_of_iterations')):
+ # With OS approach it becomes trickier to correlate independent
+ # subsets, hence additional work is required one solution is to
+ # work with a full sinogram at times
+
+ r_old = self.r.copy()
+ t_old = t
+ SlicesZ, anglesNumb, Detectors = \
+ numpy.shape(self.getParameter('input_sinogram')) ## https://github.com/vais-ral/CCPi-FISTA_Reconstruction/issues/4
+ if (i > 1 and lambdaR_L1 > 0) :
+ for kkk in range(anglesNumb):
+
+ residual2[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \
+ ((sino_updt_FULL[:,kkk,:]).squeeze() - \
+ (sino[:,kkk,:]).squeeze() -\
+ (alpha_ring * r_x)
+ )
+
+ vec = self.residual.sum(axis = 1)
+ #if SlicesZ > 1:
+ # vec = vec[:,1,:] # 1 or 0?
+ r_x = self.r_x
+ # update ring variable
+ self.r = (r_x - (1./L_const) * vec).copy()
+
+ # subset loop
+ counterInd = 1
+ geometry_type = self.getParameter('projector_geometry')['type']
+ angles = self.getParameter('projector_geometry')['ProjectionAngles']
+
+ for ss in range(self.getParameter('subsets')):
+ #print ("Subset {0}".format(ss))
+ X_old = X.copy()
+ t_old = t
+
+ # the number of projections per subset
+ numProjSub = self.getParameter('os_bins')[ss]
+ CurrSubIndices = self.getParameter('os_indices')\
+ [counterInd:counterInd+numProjSub]
+ #print ("Len CurrSubIndices {0}".format(numProjSub))
+ mask = numpy.zeros(numpy.shape(angles), dtype=bool)
+ #cc = 0
+ for j in range(len(CurrSubIndices)):
+ mask[int(CurrSubIndices[j])] = True
+ proj_geomSUB['ProjectionAngles'] = angles[mask]
+
+ if self.use_device:
+ device = self.getParameter('device_model')\
+ .createReducedDevice(
+ proj_par={'angles':angles[mask]},
+ vol_par={})
+
+ shape = list(numpy.shape(self.getParameter('input_sinogram')))
+ shape[1] = numProjSub
+ sino_updt_Sub = numpy.zeros(shape)
+ if geometry_type == 'parallel' or \
+ geometry_type == 'fanflat' or \
+ geometry_type == 'fanflat_vec' :
+
+ for kkk in range(SlicesZ):
+ if self.use_device:
+ sinoT = device.doForwardProject(X_t[kkk:kkk+1])
+ else:
+ sino_id, sinoT = astra.creators.create_sino3d_gpu (
+ X_t[kkk:kkk+1] , proj_geomSUB, vol_geom)
+ astra.matlab.data3d('delete', sino_id)
+ sino_updt_Sub[kkk] = sinoT.T.copy()
+
+ else:
+ # for 3D geometry (watch the GPU memory overflow in
+ # ASTRA < 1.8)
+ if self.use_device:
+ sino_updt_Sub = device.doForwardProject(X_t)
+
+ else:
+ sino_id, sino_updt_Sub = \
+ astra.creators.create_sino3d_gpu (X_t, proj_geomSUB, vol_geom)
+
+ astra.matlab.data3d('delete', sino_id)
+
+ #print ("shape(sino_updt_Sub)",numpy.shape(sino_updt_Sub))
+ if lambdaR_L1 > 0 :
+ ## RING REMOVAL
+ #print ("ring removal")
+ residualSub , sino_updt_Sub, sino_updt_FULL = \
+ self.ringRemovalOrderedSubsets(ss,
+ counterInd,
+ sino_updt_Sub,
+ sino_updt_FULL)
+ else:
+ #PWLS model
+ #print ("PWLS model")
+ residualSub = weights[:,CurrSubIndices,:] * \
+ ( sino_updt_Sub - \
+ sino[:,CurrSubIndices,:].squeeze() )
+ objective[i] = 0.5 * numpy.linalg.norm(residualSub)
+
+ # projection/backprojection routine
+ if geometry_type == 'parallel' or \
+ geometry_type == 'fanflat' or \
+ geometry_type == 'fanflat_vec' :
+ # if geometry is 2D use slice-by-slice projection-backprojection
+ # routine
+ x_temp = numpy.zeros(numpy.shape(X), dtype=numpy.float32)
+ for kkk in range(SlicesZ):
+ if self.use_device:
+ x_temp[kkk] = device.doBackwardProject(
+ residualSub[kkk:kkk+1])
+ else:
+ x_id, x_temp[kkk] = \
+ astra.creators.create_backprojection3d_gpu(
+ residualSub[kkk:kkk+1],
+ proj_geomSUB, vol_geom)
+ astra.matlab.data3d('delete', x_id)
+
+ else:
+ if self.use_device:
+ x_temp = device.doBackwardProject(
+ residualSub)
+ else:
+ x_id, x_temp = \
+ astra.creators.create_backprojection3d_gpu(
+ residualSub, proj_geomSUB, vol_geom)
+
+ astra.matlab.data3d('delete', x_id)
+
+ X = X_t - (1/L_const) * x_temp
+
+ ## REGULARIZATION
+ X = self.regularize(X)
+
+ ## Update subset Loop
+ t = (1 + numpy.sqrt(1 + 4 * t**2))/2
+ X_t = X + (((t_old -1)/t) * (X - X_old))
+ # FINAL
+ ## update iteration loop
+ if lambdaR_L1 > 0:
+ self.r = numpy.max(
+ numpy.abs(self.r) - lambdaR_L1 , 0) * \
+ numpy.sign(self.r)
+ self.r_x = self.r + \
+ (((t_old-1)/t) * (self.r - r_old))
+
+ if self.getParameter('region_of_interest') is None:
+ string = 'Iteration Number {0} | Objective {1} \n'
+ print (string.format( i, self.objective[i]))
+ else:
+ ROI , X_ideal = fistaRecon.getParameter('region_of_interest',
+ 'ideal_image')
+
+ Resid_error[i] = RMSE(X*ROI, X_ideal*ROI)
+ string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n'
+ print (string.format(i,Resid_error[i], self.objective[i]))
+ print("X min {0} max {1}".format(X.min(),X.max()))
+ self.setParameter(output_volume=X)
+ counterInd = counterInd + numProjSub
+
+ return X
+
+ def ringRemovalOrderedSubsets(self, ss,counterInd,
+ sino_updt_Sub, sino_updt_FULL):
+ residual = self.residual
+ r_x = self.r_x
+ weights , alpha_ring , sino = \
+ self.getParameter( ['weights', 'ring_alpha', 'input_sinogram'])
+ numProjSub = self.getParameter('os_bins')[ss]
+ CurrSubIndices = self.getParameter('os_indices')\
+ [counterInd:counterInd+numProjSub]
+
+ shape = list(numpy.shape(self.getParameter('input_sinogram')))
+ shape[1] = numProjSub
+
+ residualSub = numpy.zeros(shape)
+
+ for kkk in range(numProjSub):
+ #print ("ring removal indC ... {0}".format(kkk))
+ indC = int(CurrSubIndices[kkk])
+ residualSub[:,kkk,:] = weights[:,indC,:].squeeze() * \
+ (sino_updt_Sub[:,kkk,:].squeeze() - \
+ sino[:,indC,:].squeeze() - alpha_ring * r_x)
+ # filling the full sinogram
+ sino_updt_FULL[:,indC,:] = sino_updt_Sub[:,kkk,:].squeeze()
+
+ return (residualSub , sino_updt_Sub, sino_updt_FULL)
+
+
diff --git a/Wrappers/Python/ccpi/reconstruction/Reconstructor.py b/Wrappers/Python/ccpi/reconstruction/Reconstructor.py
new file mode 100644
index 0000000..ba67327
--- /dev/null
+++ b/Wrappers/Python/ccpi/reconstruction/Reconstructor.py
@@ -0,0 +1,598 @@
+# -*- coding: utf-8 -*-
+###############################################################################
+#This work is part of the Core Imaging Library developed by
+#Visual Analytics and Imaging System Group of the Science Technology
+#Facilities Council, STFC
+#
+#Copyright 2017 Edoardo Pasca, Srikanth Nagella
+#Copyright 2017 Daniil Kazantsev
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#http://www.apache.org/licenses/LICENSE-2.0
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+###############################################################################
+
+
+
+import numpy
+import h5py
+from ccpi.reconstruction.parallelbeam import alg
+
+from Regularizer import Regularizer
+from enum import Enum
+
+import astra
+
+
+class Reconstructor:
+
+ class Algorithm(Enum):
+ CGLS = alg.cgls
+ CGLS_CONV = alg.cgls_conv
+ SIRT = alg.sirt
+ MLEM = alg.mlem
+ CGLS_TICHONOV = alg.cgls_tikhonov
+ CGLS_TVREG = alg.cgls_TVreg
+ FISTA = 'fista'
+
+ def __init__(self, algorithm = None, projection_data = None,
+ angles = None, center_of_rotation = None ,
+ flat_field = None, dark_field = None,
+ iterations = None, resolution = None, isLogScale = False, threads = None,
+ normalized_projection = None):
+
+ self.pars = dict()
+ self.pars['algorithm'] = algorithm
+ self.pars['projection_data'] = projection_data
+ self.pars['normalized_projection'] = normalized_projection
+ self.pars['angles'] = angles
+ self.pars['center_of_rotation'] = numpy.double(center_of_rotation)
+ self.pars['flat_field'] = flat_field
+ self.pars['iterations'] = iterations
+ self.pars['dark_field'] = dark_field
+ self.pars['resolution'] = resolution
+ self.pars['isLogScale'] = isLogScale
+ self.pars['threads'] = threads
+ if (iterations != None):
+ self.pars['iterationValues'] = numpy.zeros((iterations))
+
+ if projection_data != None and dark_field != None and flat_field != None:
+ norm = self.normalize(projection_data, dark_field, flat_field, 0.1)
+ self.pars['normalized_projection'] = norm
+
+
+ def setPars(self, parameters):
+ keys = ['algorithm','projection_data' ,'normalized_projection', \
+ 'angles' , 'center_of_rotation' , 'flat_field', \
+ 'iterations','dark_field' , 'resolution', 'isLogScale' , \
+ 'threads' , 'iterationValues', 'regularize']
+
+ for k in keys:
+ if k not in parameters.keys():
+ self.pars[k] = None
+ else:
+ self.pars[k] = parameters[k]
+
+
+ def sanityCheck(self):
+ projection_data = self.pars['projection_data']
+ dark_field = self.pars['dark_field']
+ flat_field = self.pars['flat_field']
+ angles = self.pars['angles']
+
+ if projection_data != None and dark_field != None and \
+ angles != None and flat_field != None:
+ data_shape = numpy.shape(projection_data)
+ angle_shape = numpy.shape(angles)
+
+ if angle_shape[0] != data_shape[0]:
+ #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \
+ # (angle_shape[0] , data_shape[0]) )
+ return (False , 'Projections and angles dimensions do not match: %d vs %d' % \
+ (angle_shape[0] , data_shape[0]) )
+
+ if data_shape[1:] != numpy.shape(flat_field):
+ #raise Exception('Projection and flat field dimensions do not match')
+ return (False , 'Projection and flat field dimensions do not match')
+ if data_shape[1:] != numpy.shape(dark_field):
+ #raise Exception('Projection and dark field dimensions do not match')
+ return (False , 'Projection and dark field dimensions do not match')
+
+ return (True , '' )
+ elif self.pars['normalized_projection'] != None:
+ data_shape = numpy.shape(self.pars['normalized_projection'])
+ angle_shape = numpy.shape(angles)
+
+ if angle_shape[0] != data_shape[0]:
+ #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \
+ # (angle_shape[0] , data_shape[0]) )
+ return (False , 'Projections and angles dimensions do not match: %d vs %d' % \
+ (angle_shape[0] , data_shape[0]) )
+ else:
+ return (True , '' )
+ else:
+ return (False , 'Not enough data')
+
+ def reconstruct(self, parameters = None):
+ if parameters != None:
+ self.setPars(parameters)
+
+ go , reason = self.sanityCheck()
+ if go:
+ return self._reconstruct()
+ else:
+ raise Exception(reason)
+
+
+ def _reconstruct(self, parameters=None):
+ if parameters!=None:
+ self.setPars(parameters)
+ parameters = self.pars
+
+ if parameters['algorithm'] != None and \
+ parameters['normalized_projection'] != None and \
+ parameters['angles'] != None and \
+ parameters['center_of_rotation'] != None and \
+ parameters['iterations'] != None and \
+ parameters['resolution'] != None and\
+ parameters['threads'] != None and\
+ parameters['isLogScale'] != None:
+
+
+ if parameters['algorithm'] in (Reconstructor.Algorithm.CGLS,
+ Reconstructor.Algorithm.MLEM, Reconstructor.Algorithm.SIRT):
+ #store parameters
+ self.pars = parameters
+ result = parameters['algorithm'](
+ parameters['normalized_projection'] ,
+ parameters['angles'],
+ parameters['center_of_rotation'],
+ parameters['resolution'],
+ parameters['iterations'],
+ parameters['threads'] ,
+ parameters['isLogScale']
+ )
+ return result
+ elif parameters['algorithm'] in (Reconstructor.Algorithm.CGLS_CONV,
+ Reconstructor.Algorithm.CGLS_TICHONOV,
+ Reconstructor.Algorithm.CGLS_TVREG) :
+ self.pars = parameters
+ result = parameters['algorithm'](
+ parameters['normalized_projection'] ,
+ parameters['angles'],
+ parameters['center_of_rotation'],
+ parameters['resolution'],
+ parameters['iterations'],
+ parameters['threads'] ,
+ parameters['regularize'],
+ numpy.zeros((parameters['iterations'])),
+ parameters['isLogScale']
+ )
+
+ elif parameters['algorithm'] == Reconstructor.Algorithm.FISTA:
+ pass
+
+ else:
+ if parameters['projection_data'] != None and \
+ parameters['dark_field'] != None and \
+ parameters['flat_field'] != None:
+ norm = self.normalize(parameters['projection_data'],
+ parameters['dark_field'],
+ parameters['flat_field'], 0.1)
+ self.pars['normalized_projection'] = norm
+ return self._reconstruct(parameters)
+
+
+
+ def _normalize(self, projection, dark, flat, def_val=0):
+ a = (projection - dark)
+ b = (flat-dark)
+ with numpy.errstate(divide='ignore', invalid='ignore'):
+ c = numpy.true_divide( a, b )
+ c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0
+ return c
+
+ def normalize(self, projections, dark, flat, def_val=0):
+ norm = [self._normalize(projection, dark, flat, def_val) for projection in projections]
+ return numpy.asarray (norm, dtype=numpy.float32)
+
+
+
+class FISTA():
+ '''FISTA-based reconstruction algorithm using ASTRA-toolbox
+
+ '''
+ # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>>
+ # ___Input___:
+ # params.[] file:
+ # - .proj_geom (geometry of the projector) [required]
+ # - .vol_geom (geometry of the reconstructed object) [required]
+ # - .sino (vectorized in 2D or 3D sinogram) [required]
+ # - .iterFISTA (iterations for the main loop, default 40)
+ # - .L_const (Lipschitz constant, default Power method) )
+ # - .X_ideal (ideal image, if given)
+ # - .weights (statisitcal weights, size of the sinogram)
+ # - .ROI (Region-of-interest, only if X_ideal is given)
+ # - .initialize (a 'warm start' using SIRT method from ASTRA)
+ #----------------Regularization choices------------------------
+ # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter)
+ # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter)
+ # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter)
+ # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04)
+ # - .Regul_Iterations (iterations for the selected penalty, default 25)
+ # - .Regul_tauLLT (time step parameter for LLT term)
+ # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal)
+ # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1)
+ #----------------Visualization parameters------------------------
+ # - .show (visualize reconstruction 1/0, (0 default))
+ # - .maxvalplot (maximum value to use for imshow[0 maxvalplot])
+ # - .slice (for 3D volumes - slice number to imshow)
+ # ___Output___:
+ # 1. X - reconstructed image/volume
+ # 2. output - a structure with
+ # - .Resid_error - residual error (if X_ideal is given)
+ # - .objective: value of the objective function
+ # - .L_const: Lipshitz constant to avoid recalculations
+
+ # References:
+ # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse
+ # Problems" by A. Beck and M Teboulle
+ # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo
+ # 3. "A novel tomographic reconstruction method based on the robust
+ # Student's t function for suppressing data outliers" D. Kazantsev et.al.
+ # D. Kazantsev, 2016-17
+ def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs):
+ self.params = dict()
+ self.params['projector_geometry'] = projector_geometry
+ self.params['output_geometry'] = output_geometry
+ self.params['input_sinogram'] = input_sinogram
+ detectors, nangles, sliceZ = numpy.shape(input_sinogram)
+ self.params['detectors'] = detectors
+ self.params['number_og_angles'] = nangles
+ self.params['SlicesZ'] = sliceZ
+
+ # Accepted input keywords
+ kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' ,
+ 'weights' , 'region_of_interest' , 'initialize' ,
+ 'regularizer' ,
+ 'ring_lambda_R_L1',
+ 'ring_alpha')
+
+ # handle keyworded parameters
+ if kwargs is not None:
+ for key, value in kwargs.items():
+ if key in kw:
+ #print("{0} = {1}".format(key, value))
+ self.pars[key] = value
+
+ # set the default values for the parameters if not set
+ if 'number_of_iterations' in kwargs.keys():
+ self.pars['number_of_iterations'] = kwargs['number_of_iterations']
+ else:
+ self.pars['number_of_iterations'] = 40
+ if 'weights' in kwargs.keys():
+ self.pars['weights'] = kwargs['weights']
+ else:
+ self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram']))
+ if 'Lipschitz_constant' in kwargs.keys():
+ self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant']
+ else:
+ self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod()
+
+ if not self.pars['ideal_image'] in kwargs.keys():
+ self.pars['ideal_image'] = None
+
+ if not self.pars['region_of_interest'] :
+ if self.pars['ideal_image'] == None:
+ pass
+ else:
+ self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0)
+
+ if not self.pars['regularizer'] :
+ self.pars['regularizer'] = None
+ else:
+ # the regularizer must be a correctly instantiated object
+ if not self.pars['ring_lambda_R_L1']:
+ self.pars['ring_lambda_R_L1'] = 0
+ if not self.pars['ring_alpha']:
+ self.pars['ring_alpha'] = 1
+
+
+
+
+ def calculateLipschitzConstantWithPowerMethod(self):
+ ''' using Power method (PM) to establish L constant'''
+
+ #N = params.vol_geom.GridColCount
+ N = self.pars['output_geometry'].GridColCount
+ proj_geom = self.params['projector_geometry']
+ vol_geom = self.params['output_geometry']
+ weights = self.pars['weights']
+ SlicesZ = self.pars['SlicesZ']
+
+ if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'):
+ #% for parallel geometry we can do just one slice
+ #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...');
+ niter = 15;# % number of iteration for the PM
+ #N = params.vol_geom.GridColCount;
+ #x1 = rand(N,N,1);
+ x1 = numpy.random.rand(1,N,N)
+ #sqweight = sqrt(weights(:,:,1));
+ sqweight = numpy.sqrt(weights.T[0])
+ proj_geomT = proj_geom.copy();
+ proj_geomT.DetectorRowCount = 1;
+ vol_geomT = vol_geom.copy();
+ vol_geomT['GridSliceCount'] = 1;
+
+
+ for i in range(niter):
+ if i == 0:
+ #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
+ sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT);
+ y = sqweight * y # element wise multiplication
+ #astra_mex_data3d('delete', sino_id);
+ astra.matlab.data3d('delete', sino_id)
+
+ idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT);
+ s = numpy.linalg.norm(x1)
+ ### this line?
+ x1 = x1/s;
+ ### this line?
+ sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
+ y = sqweight*y;
+ astra.matlab.data3d('delete', sino_id);
+ astra.matlab.data3d('delete', idx);
+ #end
+ del proj_geomT
+ del vol_geomT
+ else
+ #% divergen beam geometry
+ #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...');
+ niter = 8; #% number of iteration for PM
+ x1 = numpy.random.rand(SlicesZ , N , N);
+ #sqweight = sqrt(weights);
+ sqweight = numpy.sqrt(weights.T[0])
+
+ sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom);
+ y = sqweight*y;
+ #astra_mex_data3d('delete', sino_id);
+ astra.matlab.data3d('delete', sino_id);
+
+ for i in range(niter):
+ #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom);
+ idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y,
+ proj_geom,
+ vol_geom)
+ s = numpy.linalg.norm(x1)
+ ### this line?
+ x1 = x1/s;
+ ### this line?
+ #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom);
+ sino_id, y = astra.creators.create_sino3d_gpu(x1,
+ proj_geom,
+ vol_geom);
+
+ y = sqweight*y;
+ #astra_mex_data3d('delete', sino_id);
+ #astra_mex_data3d('delete', id);
+ astra.matlab.data3d('delete', sino_id);
+ astra.matlab.data3d('delete', idx);
+ #end
+ #clear x1
+ del x1
+
+ return s
+
+
+ def setRegularizer(self, regularizer):
+ if regularizer
+ self.pars['regularizer'] = regularizer
+
+
+
+
+
+def getEntry(location):
+ for item in nx[location].keys():
+ print (item)
+
+
+print ("Loading Data")
+
+##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif"
+####ind = [i * 1049 for i in range(360)]
+#### use only 360 images
+##images = 200
+##ind = [int(i * 1049 / images) for i in range(images)]
+##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None)
+
+#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs"
+fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs"
+nx = h5py.File(fname, "r")
+
+# the data are stored in a particular location in the hdf5
+for item in nx['entry1/tomo_entry/data'].keys():
+ print (item)
+
+data = nx.get('entry1/tomo_entry/data/rotation_angle')
+angles = numpy.zeros(data.shape)
+data.read_direct(angles)
+print (angles)
+# angles should be in degrees
+
+data = nx.get('entry1/tomo_entry/data/data')
+stack = numpy.zeros(data.shape)
+data.read_direct(stack)
+print (data.shape)
+
+print ("Data Loaded")
+
+
+# Normalize
+data = nx.get('entry1/tomo_entry/instrument/detector/image_key')
+itype = numpy.zeros(data.shape)
+data.read_direct(itype)
+# 2 is dark field
+darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ]
+dark = darks[0]
+for i in range(1, len(darks)):
+ dark += darks[i]
+dark = dark / len(darks)
+#dark[0][0] = dark[0][1]
+
+# 1 is flat field
+flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ]
+flat = flats[0]
+for i in range(1, len(flats)):
+ flat += flats[i]
+flat = flat / len(flats)
+#flat[0][0] = dark[0][1]
+
+
+# 0 is projection data
+proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ]
+angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ]
+angle_proj = numpy.asarray (angle_proj)
+angle_proj = angle_proj.astype(numpy.float32)
+
+# normalized data are
+# norm = (projection - dark)/(flat-dark)
+
+def normalize(projection, dark, flat, def_val=0.1):
+ a = (projection - dark)
+ b = (flat-dark)
+ with numpy.errstate(divide='ignore', invalid='ignore'):
+ c = numpy.true_divide( a, b )
+ c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0
+ return c
+
+
+norm = [normalize(projection, dark, flat) for projection in proj]
+norm = numpy.asarray (norm)
+norm = norm.astype(numpy.float32)
+
+#recon = Reconstructor(algorithm = Algorithm.CGLS, normalized_projection = norm,
+# angles = angle_proj, center_of_rotation = 86.2 ,
+# flat_field = flat, dark_field = dark,
+# iterations = 15, resolution = 1, isLogScale = False, threads = 3)
+
+#recon = Reconstructor(algorithm = Reconstructor.Algorithm.CGLS, projection_data = proj,
+# angles = angle_proj, center_of_rotation = 86.2 ,
+# flat_field = flat, dark_field = dark,
+# iterations = 15, resolution = 1, isLogScale = False, threads = 3)
+#img_cgls = recon.reconstruct()
+#
+#pars = dict()
+#pars['algorithm'] = Reconstructor.Algorithm.SIRT
+#pars['projection_data'] = proj
+#pars['angles'] = angle_proj
+#pars['center_of_rotation'] = numpy.double(86.2)
+#pars['flat_field'] = flat
+#pars['iterations'] = 15
+#pars['dark_field'] = dark
+#pars['resolution'] = 1
+#pars['isLogScale'] = False
+#pars['threads'] = 3
+#
+#img_sirt = recon.reconstruct(pars)
+#
+#recon.pars['algorithm'] = Reconstructor.Algorithm.MLEM
+#img_mlem = recon.reconstruct()
+
+############################################################
+############################################################
+#recon.pars['algorithm'] = Reconstructor.Algorithm.CGLS_CONV
+#recon.pars['regularize'] = numpy.double(0.1)
+#img_cgls_conv = recon.reconstruct()
+
+niterations = 15
+threads = 3
+
+img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
+img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
+img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
+
+iteration_values = numpy.zeros((niterations,))
+img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
+ iteration_values, False)
+print ("iteration values %s" % str(iteration_values))
+
+iteration_values = numpy.zeros((niterations,))
+img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
+ numpy.double(1e-5), iteration_values , False)
+print ("iteration values %s" % str(iteration_values))
+iteration_values = numpy.zeros((niterations,))
+img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
+ numpy.double(1e-5), iteration_values , False)
+print ("iteration values %s" % str(iteration_values))
+
+
+##numpy.save("cgls_recon.npy", img_data)
+import matplotlib.pyplot as plt
+fig, ax = plt.subplots(1,6,sharey=True)
+ax[0].imshow(img_cgls[80])
+ax[0].axis('off') # clear x- and y-axes
+ax[1].imshow(img_sirt[80])
+ax[1].axis('off') # clear x- and y-axes
+ax[2].imshow(img_mlem[80])
+ax[2].axis('off') # clear x- and y-axesplt.show()
+ax[3].imshow(img_cgls_conv[80])
+ax[3].axis('off') # clear x- and y-axesplt.show()
+ax[4].imshow(img_cgls_tikhonov[80])
+ax[4].axis('off') # clear x- and y-axesplt.show()
+ax[5].imshow(img_cgls_TVreg[80])
+ax[5].axis('off') # clear x- and y-axesplt.show()
+
+
+plt.show()
+
+#viewer = edo.CILViewer()
+#viewer.setInputAsNumpy(img_cgls2)
+#viewer.displaySliceActor(0)
+#viewer.startRenderLoop()
+
+import vtk
+
+def NumpyToVTKImageData(numpyarray):
+ if (len(numpy.shape(numpyarray)) == 3):
+ doubleImg = vtk.vtkImageData()
+ shape = numpy.shape(numpyarray)
+ doubleImg.SetDimensions(shape[0], shape[1], shape[2])
+ doubleImg.SetOrigin(0,0,0)
+ doubleImg.SetSpacing(1,1,1)
+ doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1)
+ #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation())
+ doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1)
+
+ for i in range(shape[0]):
+ for j in range(shape[1]):
+ for k in range(shape[2]):
+ doubleImg.SetScalarComponentFromDouble(
+ i,j,k,0, numpyarray[i][j][k])
+ #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) )
+ # rescale to appropriate VTK_UNSIGNED_SHORT
+ stats = vtk.vtkImageAccumulate()
+ stats.SetInputData(doubleImg)
+ stats.Update()
+ iMin = stats.GetMin()[0]
+ iMax = stats.GetMax()[0]
+ scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin)
+
+ shiftScaler = vtk.vtkImageShiftScale ()
+ shiftScaler.SetInputData(doubleImg)
+ shiftScaler.SetScale(scale)
+ shiftScaler.SetShift(iMin)
+ shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT)
+ shiftScaler.Update()
+ return shiftScaler.GetOutput()
+
+#writer = vtk.vtkMetaImageWriter()
+#writer.SetFileName(alg + "_recon.mha")
+#writer.SetInputData(NumpyToVTKImageData(img_cgls2))
+#writer.Write()
diff --git a/Wrappers/Python/compile-fista.bat.in b/Wrappers/Python/compile-fista.bat.in
new file mode 100644
index 0000000..b1db686
--- /dev/null
+++ b/Wrappers/Python/compile-fista.bat.in
@@ -0,0 +1,7 @@
+set CIL_VERSION=@CIL_VERSION@
+
+set PREFIX=@CONDA_ENVIRONMENT_PREFIX@
+set LIBRARY_INC=@CONDA_ENVIRONMENT_LIBRARY_INC@
+
+REM activate @CONDA_ENVIRONMENT@
+conda build fista-recipe --python=@PYTHON_VERSION_MAJOR@.@PYTHON_VERSION_MINOR@ --numpy=@NUMPY_VERSION@ -c ccpi -c conda-forge
diff --git a/Wrappers/Python/compile-fista.sh.in b/Wrappers/Python/compile-fista.sh.in
new file mode 100644
index 0000000..267f014
--- /dev/null
+++ b/Wrappers/Python/compile-fista.sh.in
@@ -0,0 +1,9 @@
+#!/bin/sh
+# compile within the right conda environment
+#module load python/anaconda
+#source activate @CONDA_ENVIRONMENT@
+
+export CIL_VERSION=@CIL_VERSION@
+export LIBRARY_INC=@CONDA_ENVIRONMENT_LIBRARY_INC@
+
+conda build fista-recipe --python=@PYTHON_VERSION_MAJOR@.@PYTHON_VERSION_MINOR@ --numpy=@NUMPY_VERSION@ -c ccpi
diff --git a/Wrappers/Python/compile.bat.in b/Wrappers/Python/compile.bat.in
new file mode 100644
index 0000000..e5342ed
--- /dev/null
+++ b/Wrappers/Python/compile.bat.in
@@ -0,0 +1,7 @@
+set CIL_VERSION=@CIL_VERSION@
+
+set PREFIX=@CONDA_ENVIRONMENT_PREFIX@
+set LIBRARY_INC=@CONDA_ENVIRONMENT_LIBRARY_INC@
+
+REM activate @CONDA_ENVIRONMENT@
+conda build conda-recipe --python=@PYTHON_VERSION_MAJOR@.@PYTHON_VERSION_MINOR@ --numpy=@NUMPY_VERSION@ -c ccpi -c conda-forge \ No newline at end of file
diff --git a/Wrappers/Python/compile.sh.in b/Wrappers/Python/compile.sh.in
new file mode 100644
index 0000000..93fdba2
--- /dev/null
+++ b/Wrappers/Python/compile.sh.in
@@ -0,0 +1,9 @@
+#!/bin/sh
+# compile within the right conda environment
+#module load python/anaconda
+#source activate @CONDA_ENVIRONMENT@
+
+export CIL_VERSION=@CIL_VERSION@
+export LIBRARY_INC=@CONDA_ENVIRONMENT_LIBRARY_INC@
+
+conda build conda-recipe --python=@PYTHON_VERSION_MAJOR@.@PYTHON_VERSION_MINOR@ --numpy=@NUMPY_VERSION@ -c ccpi
diff --git a/Wrappers/Python/conda-recipe/bld.bat b/Wrappers/Python/conda-recipe/bld.bat
new file mode 100644
index 0000000..69491de
--- /dev/null
+++ b/Wrappers/Python/conda-recipe/bld.bat
@@ -0,0 +1,14 @@
+IF NOT DEFINED CIL_VERSION (
+ECHO CIL_VERSION Not Defined.
+exit 1
+)
+
+mkdir "%SRC_DIR%\ccpi"
+xcopy /e "%RECIPE_DIR%\..\.." "%SRC_DIR%\ccpi"
+
+cd %SRC_DIR%\ccpi\Python
+
+%PYTHON% setup.py build_ext
+if errorlevel 1 exit 1
+%PYTHON% setup.py install
+if errorlevel 1 exit 1
diff --git a/Wrappers/Python/conda-recipe/build.sh b/Wrappers/Python/conda-recipe/build.sh
new file mode 100644
index 0000000..855047f
--- /dev/null
+++ b/Wrappers/Python/conda-recipe/build.sh
@@ -0,0 +1,14 @@
+
+if [ -z "$CIL_VERSION" ]; then
+ echo "Need to set CIL_VERSION"
+ exit 1
+fi
+mkdir "$SRC_DIR/ccpi"
+cp -r "$RECIPE_DIR/../.." "$SRC_DIR/ccpi"
+
+cd $SRC_DIR/ccpi/Python
+
+$PYTHON setup.py build_ext
+$PYTHON setup.py install
+
+
diff --git a/Wrappers/Python/conda-recipe/meta.yaml b/Wrappers/Python/conda-recipe/meta.yaml
new file mode 100644
index 0000000..7068e9d
--- /dev/null
+++ b/Wrappers/Python/conda-recipe/meta.yaml
@@ -0,0 +1,30 @@
+package:
+ name: ccpi-regularizers
+ version: {{ environ['CIL_VERSION'] }}
+
+
+build:
+ preserve_egg_dir: False
+ script_env:
+ - CIL_VERSION
+# number: 0
+
+requirements:
+ build:
+ - python
+ - numpy
+ - setuptools
+ - boost ==1.64
+ - boost-cpp ==1.64
+ - cython
+
+ run:
+ - python
+ - numpy
+ - boost ==1.64
+
+
+about:
+ home: http://www.ccpi.ac.uk
+ license: BSD license
+ summary: 'CCPi Core Imaging Library Quantification Toolbox'
diff --git a/Wrappers/Python/fista-recipe/bld.bat b/Wrappers/Python/fista-recipe/bld.bat
new file mode 100644
index 0000000..69c2afe
--- /dev/null
+++ b/Wrappers/Python/fista-recipe/bld.bat
@@ -0,0 +1,11 @@
+IF NOT DEFINED CIL_VERSION (
+ECHO CIL_VERSION Not Defined.
+exit 1
+)
+
+xcopy /e "%RECIPE_DIR%\.." "%SRC_DIR%"
+
+%PYTHON% setup.py -q bdist_egg
+:: %PYTHON% setup.py install --single-version-externally-managed --record=record.txt
+%PYTHON% setup.py install
+if errorlevel 1 exit 1
diff --git a/Wrappers/Python/fista-recipe/build.sh b/Wrappers/Python/fista-recipe/build.sh
new file mode 100644
index 0000000..e3f3552
--- /dev/null
+++ b/Wrappers/Python/fista-recipe/build.sh
@@ -0,0 +1,10 @@
+if [ -z "$CIL_VERSION" ]; then
+ echo "Need to set CIL_VERSION"
+ exit 1
+fi
+mkdir "$SRC_DIR/ccpifista"
+cp -r "$RECIPE_DIR/.." "$SRC_DIR/ccpifista"
+
+cd $SRC_DIR/ccpifista
+
+$PYTHON setup-fista.py install
diff --git a/Wrappers/Python/fista-recipe/meta.yaml b/Wrappers/Python/fista-recipe/meta.yaml
new file mode 100644
index 0000000..265541f
--- /dev/null
+++ b/Wrappers/Python/fista-recipe/meta.yaml
@@ -0,0 +1,29 @@
+package:
+ name: ccpi-fista
+ version: {{ environ['CIL_VERSION'] }}
+
+
+build:
+ preserve_egg_dir: False
+ script_env:
+ - CIL_VERSION
+# number: 0
+
+requirements:
+ build:
+ - python
+ - numpy
+ - setuptools
+
+ run:
+ - python
+ - numpy
+ #- astra-toolbox
+ - ccpi-regularizers
+
+
+
+about:
+ home: http://www.ccpi.ac.uk
+ license: Apache v.2.0 license
+ summary: 'CCPi Core Imaging Library (Viewer)'
diff --git a/Wrappers/Python/fista_module.cpp b/Wrappers/Python/fista_module.cpp
new file mode 100644
index 0000000..f3add76
--- /dev/null
+++ b/Wrappers/Python/fista_module.cpp
@@ -0,0 +1,1047 @@
+/*
+This work is part of the Core Imaging Library developed by
+Visual Analytics and Imaging System Group of the Science Technology
+Facilities Council, STFC
+
+Copyright 2017 Daniil Kazantsev
+Copyright 2017 Srikanth Nagella, Edoardo Pasca
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
+
+#include <iostream>
+#include <cmath>
+
+#include <boost/python.hpp>
+#include <boost/python/numpy.hpp>
+#include "boost/tuple/tuple.hpp"
+
+#include "SplitBregman_TV_core.h"
+#include "FGP_TV_core.h"
+#include "LLT_model_core.h"
+#include "PatchBased_Regul_core.h"
+#include "TGV_PD_core.h"
+#include "utils.h"
+
+
+
+#if defined(_WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(_WIN64)
+#include <windows.h>
+// this trick only if compiler is MSVC
+__if_not_exists(uint8_t) { typedef __int8 uint8_t; }
+__if_not_exists(uint16_t) { typedef __int8 uint16_t; }
+#endif
+
+namespace bp = boost::python;
+namespace np = boost::python::numpy;
+
+/*! in the Matlab implementation this is called as
+void mexFunction(
+int nlhs, mxArray *plhs[],
+int nrhs, const mxArray *prhs[])
+where:
+prhs Array of pointers to the INPUT mxArrays
+nrhs int number of INPUT mxArrays
+
+nlhs Array of pointers to the OUTPUT mxArrays
+plhs int number of OUTPUT mxArrays
+
+***********************************************************
+
+***********************************************************
+double mxGetScalar(const mxArray *pm);
+args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray.
+Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray. In C, mxGetScalar returns a double.
+***********************************************************
+char *mxArrayToString(const mxArray *array_ptr);
+args: array_ptr Pointer to mxCHAR array.
+Returns: C-style string. Returns NULL on failure. Possible reasons for failure include out of memory and specifying an array that is not an mxCHAR array.
+Description: Call mxArrayToString to copy the character data of an mxCHAR array into a C-style string.
+***********************************************************
+mxClassID mxGetClassID(const mxArray *pm);
+args: pm Pointer to an mxArray
+Returns: Numeric identifier of the class (category) of the mxArray that pm points to.For user-defined types,
+mxGetClassId returns a unique value identifying the class of the array contents.
+Use mxIsClass to determine whether an array is of a specific user-defined type.
+
+mxClassID Value MATLAB Type MEX Type C Primitive Type
+mxINT8_CLASS int8 int8_T char, byte
+mxUINT8_CLASS uint8 uint8_T unsigned char, byte
+mxINT16_CLASS int16 int16_T short
+mxUINT16_CLASS uint16 uint16_T unsigned short
+mxINT32_CLASS int32 int32_T int
+mxUINT32_CLASS uint32 uint32_T unsigned int
+mxINT64_CLASS int64 int64_T long long
+mxUINT64_CLASS uint64 uint64_T unsigned long long
+mxSINGLE_CLASS single float float
+mxDOUBLE_CLASS double double double
+
+****************************************************************
+double *mxGetPr(const mxArray *pm);
+args: pm Pointer to an mxArray of type double
+Returns: Pointer to the first element of the real data. Returns NULL in C (0 in Fortran) if there is no real data.
+****************************************************************
+mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims,
+mxClassID classid, mxComplexity ComplexFlag);
+args: ndimNumber of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2.
+dims Dimensions array. Each element in the dimensions array contains the size of the array in that dimension.
+For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array.
+classid Identifier for the class of the array, which determines the way the numerical data is represented in memory.
+For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer.
+ComplexFlag If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran).
+Returns: Pointer to the created mxArray, if successful. If unsuccessful in a standalone (non-MEX file) application, returns NULL in C (0 in Fortran).
+If unsuccessful in a MEX file, the MEX file terminates and returns control to the MATLAB prompt. The function is unsuccessful when there is not
+enough free heap space to create the mxArray.
+*/
+
+
+
+bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int methTV) {
+
+ // the result is in the following list
+ bp::list result;
+
+ int number_of_dims, dimX, dimY, dimZ, ll, j, count;
+ //const int *dim_array;
+ float *A, *U = NULL, *U_old = NULL, *Dx = NULL, *Dy = NULL, *Dz = NULL, *Bx = NULL, *By = NULL, *Bz = NULL, lambda, mu, epsil, re, re1, re_old;
+
+ //number_of_dims = mxGetNumberOfDimensions(prhs[0]);
+ //dim_array = mxGetDimensions(prhs[0]);
+
+ number_of_dims = input.get_nd();
+ int dim_array[3];
+
+ dim_array[0] = input.shape(0);
+ dim_array[1] = input.shape(1);
+ if (number_of_dims == 2) {
+ dim_array[2] = -1;
+ }
+ else {
+ dim_array[2] = input.shape(2);
+ }
+
+ // Parameter handling is be done in Python
+ ///*Handling Matlab input data*/
+ //if ((nrhs < 2) || (nrhs > 5)) mexErrMsgTxt("At least 2 parameters is required: Image(2D/3D), Regularization parameter. The full list of parameters: Image(2D/3D), Regularization parameter, iterations number, tolerance, penalty type ('iso' or 'l1')");
+
+ ///*Handling Matlab input data*/
+ //A = (float *)mxGetData(prhs[0]); /*noisy image (2D/3D) */
+ A = reinterpret_cast<float *>(input.get_data());
+
+ //mu = (float)mxGetScalar(prhs[1]); /* regularization parameter */
+ mu = (float)d_mu;
+
+ //iter = 35; /* default iterations number */
+
+ //epsil = 0.0001; /* default tolerance constant */
+ epsil = (float)d_epsil;
+ //methTV = 0; /* default isotropic TV penalty */
+ //if ((nrhs == 3) || (nrhs == 4) || (nrhs == 5)) iter = (int)mxGetScalar(prhs[2]); /* iterations number */
+ //if ((nrhs == 4) || (nrhs == 5)) epsil = (float)mxGetScalar(prhs[3]); /* tolerance constant */
+ //if (nrhs == 5) {
+ // char *penalty_type;
+ // penalty_type = mxArrayToString(prhs[4]); /* choosing TV penalty: 'iso' or 'l1', 'iso' is the default */
+ // if ((strcmp(penalty_type, "l1") != 0) && (strcmp(penalty_type, "iso") != 0)) mexErrMsgTxt("Choose TV type: 'iso' or 'l1',");
+ // if (strcmp(penalty_type, "l1") == 0) methTV = 1; /* enable 'l1' penalty */
+ // mxFree(penalty_type);
+ //}
+ //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input image must be in a single precision"); }
+
+ lambda = 2.0f*mu;
+ count = 1;
+ re_old = 0.0f;
+ /*Handling Matlab output data*/
+ dimY = dim_array[0]; dimX = dim_array[1]; dimZ = dim_array[2];
+
+ if (number_of_dims == 2) {
+ dimZ = 1; /*2D case*/
+ //U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+ //U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+ //Dx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+ //Dy = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+ //Bx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+ //By = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+ bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]);
+ np::dtype dtype = np::dtype::get_builtin<float>();
+
+ np::ndarray npU = np::zeros(shape, dtype);
+ np::ndarray npU_old = np::zeros(shape, dtype);
+ np::ndarray npDx = np::zeros(shape, dtype);
+ np::ndarray npDy = np::zeros(shape, dtype);
+ np::ndarray npBx = np::zeros(shape, dtype);
+ np::ndarray npBy = np::zeros(shape, dtype);
+
+ U = reinterpret_cast<float *>(npU.get_data());
+ U_old = reinterpret_cast<float *>(npU_old.get_data());
+ Dx = reinterpret_cast<float *>(npDx.get_data());
+ Dy = reinterpret_cast<float *>(npDy.get_data());
+ Bx = reinterpret_cast<float *>(npBx.get_data());
+ By = reinterpret_cast<float *>(npBy.get_data());
+
+
+
+ copyIm(A, U, dimX, dimY, dimZ); /*initialize */
+
+ /* begin outer SB iterations */
+ for (ll = 0; ll < iter; ll++) {
+
+ /*storing old values*/
+ copyIm(U, U_old, dimX, dimY, dimZ);
+
+ /*GS iteration */
+ gauss_seidel2D(U, A, Dx, Dy, Bx, By, dimX, dimY, lambda, mu);
+
+ if (methTV == 1) updDxDy_shrinkAniso2D(U, Dx, Dy, Bx, By, dimX, dimY, lambda);
+ else updDxDy_shrinkIso2D(U, Dx, Dy, Bx, By, dimX, dimY, lambda);
+
+ updBxBy2D(U, Dx, Dy, Bx, By, dimX, dimY);
+
+ /* calculate norm to terminate earlier */
+ re = 0.0f; re1 = 0.0f;
+ for (j = 0; j < dimX*dimY*dimZ; j++)
+ {
+ re += pow(U_old[j] - U[j], 2);
+ re1 += pow(U_old[j], 2);
+ }
+ re = sqrt(re) / sqrt(re1);
+ if (re < epsil) count++;
+ if (count > 4) break;
+
+ /* check that the residual norm is decreasing */
+ if (ll > 2) {
+ if (re > re_old) break;
+ }
+ re_old = re;
+ /*printf("%f %i %i \n", re, ll, count); */
+
+ /*copyIm(U_old, U, dimX, dimY, dimZ); */
+
+ }
+ //printf("SB iterations stopped at iteration: %i\n", ll);
+ result.append<np::ndarray>(npU);
+ result.append<int>(ll);
+ }
+ if (number_of_dims == 3) {
+ /*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+ U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+ Dx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+ Dy = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+ Dz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+ Bx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+ By = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+ Bz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));*/
+ bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]);
+ np::dtype dtype = np::dtype::get_builtin<float>();
+
+ np::ndarray npU = np::zeros(shape, dtype);
+ np::ndarray npU_old = np::zeros(shape, dtype);
+ np::ndarray npDx = np::zeros(shape, dtype);
+ np::ndarray npDy = np::zeros(shape, dtype);
+ np::ndarray npDz = np::zeros(shape, dtype);
+ np::ndarray npBx = np::zeros(shape, dtype);
+ np::ndarray npBy = np::zeros(shape, dtype);
+ np::ndarray npBz = np::zeros(shape, dtype);
+
+ U = reinterpret_cast<float *>(npU.get_data());
+ U_old = reinterpret_cast<float *>(npU_old.get_data());
+ Dx = reinterpret_cast<float *>(npDx.get_data());
+ Dy = reinterpret_cast<float *>(npDy.get_data());
+ Dz = reinterpret_cast<float *>(npDz.get_data());
+ Bx = reinterpret_cast<float *>(npBx.get_data());
+ By = reinterpret_cast<float *>(npBy.get_data());
+ Bz = reinterpret_cast<float *>(npBz.get_data());
+
+ copyIm(A, U, dimX, dimY, dimZ); /*initialize */
+
+ /* begin outer SB iterations */
+ for (ll = 0; ll<iter; ll++) {
+
+ /*storing old values*/
+ copyIm(U, U_old, dimX, dimY, dimZ);
+
+ /*GS iteration */
+ gauss_seidel3D(U, A, Dx, Dy, Dz, Bx, By, Bz, dimX, dimY, dimZ, lambda, mu);
+
+ if (methTV == 1) updDxDyDz_shrinkAniso3D(U, Dx, Dy, Dz, Bx, By, Bz, dimX, dimY, dimZ, lambda);
+ else updDxDyDz_shrinkIso3D(U, Dx, Dy, Dz, Bx, By, Bz, dimX, dimY, dimZ, lambda);
+
+ updBxByBz3D(U, Dx, Dy, Dz, Bx, By, Bz, dimX, dimY, dimZ);
+
+ /* calculate norm to terminate earlier */
+ re = 0.0f; re1 = 0.0f;
+ for (j = 0; j<dimX*dimY*dimZ; j++)
+ {
+ re += pow(U[j] - U_old[j], 2);
+ re1 += pow(U[j], 2);
+ }
+ re = sqrt(re) / sqrt(re1);
+ if (re < epsil) count++;
+ if (count > 4) break;
+
+ /* check that the residual norm is decreasing */
+ if (ll > 2) {
+ if (re > re_old) break;
+ }
+ /*printf("%f %i %i \n", re, ll, count); */
+ re_old = re;
+ }
+ //printf("SB iterations stopped at iteration: %i\n", ll);
+ result.append<np::ndarray>(npU);
+ result.append<int>(ll);
+ }
+ return result;
+
+ }
+
+
+
+bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int methTV) {
+
+ // the result is in the following list
+ bp::list result;
+
+ int number_of_dims, dimX, dimY, dimZ, ll, j, count;
+ float *A, *D = NULL, *D_old = NULL, *P1 = NULL, *P2 = NULL, *P3 = NULL, *P1_old = NULL, *P2_old = NULL, *P3_old = NULL, *R1 = NULL, *R2 = NULL, *R3 = NULL;
+ float lambda, tk, tkp1, re, re1, re_old, epsil, funcval;
+
+ //number_of_dims = mxGetNumberOfDimensions(prhs[0]);
+ //dim_array = mxGetDimensions(prhs[0]);
+
+ number_of_dims = input.get_nd();
+ int dim_array[3];
+
+ dim_array[0] = input.shape(0);
+ dim_array[1] = input.shape(1);
+ if (number_of_dims == 2) {
+ dim_array[2] = -1;
+ }
+ else {
+ dim_array[2] = input.shape(2);
+ }
+ // Parameter handling is be done in Python
+ ///*Handling Matlab input data*/
+ //if ((nrhs < 2) || (nrhs > 5)) mexErrMsgTxt("At least 2 parameters is required: Image(2D/3D), Regularization parameter. The full list of parameters: Image(2D/3D), Regularization parameter, iterations number, tolerance, penalty type ('iso' or 'l1')");
+
+ ///*Handling Matlab input data*/
+ //A = (float *)mxGetData(prhs[0]); /*noisy image (2D/3D) */
+ A = reinterpret_cast<float *>(input.get_data());
+
+ //mu = (float)mxGetScalar(prhs[1]); /* regularization parameter */
+ lambda = (float)d_mu;
+
+ //iter = 35; /* default iterations number */
+
+ //epsil = 0.0001; /* default tolerance constant */
+ epsil = (float)d_epsil;
+ //methTV = 0; /* default isotropic TV penalty */
+ //if ((nrhs == 3) || (nrhs == 4) || (nrhs == 5)) iter = (int)mxGetScalar(prhs[2]); /* iterations number */
+ //if ((nrhs == 4) || (nrhs == 5)) epsil = (float)mxGetScalar(prhs[3]); /* tolerance constant */
+ //if (nrhs == 5) {
+ // char *penalty_type;
+ // penalty_type = mxArrayToString(prhs[4]); /* choosing TV penalty: 'iso' or 'l1', 'iso' is the default */
+ // if ((strcmp(penalty_type, "l1") != 0) && (strcmp(penalty_type, "iso") != 0)) mexErrMsgTxt("Choose TV type: 'iso' or 'l1',");
+ // if (strcmp(penalty_type, "l1") == 0) methTV = 1; /* enable 'l1' penalty */
+ // mxFree(penalty_type);
+ //}
+ //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input image must be in a single precision"); }
+
+ //plhs[1] = mxCreateNumericMatrix(1, 1, mxSINGLE_CLASS, mxREAL);
+ bp::tuple shape1 = bp::make_tuple(dim_array[0], dim_array[1]);
+ np::dtype dtype = np::dtype::get_builtin<float>();
+ np::ndarray out1 = np::zeros(shape1, dtype);
+
+ //float *funcvalA = (float *)mxGetData(plhs[1]);
+ float * funcvalA = reinterpret_cast<float *>(out1.get_data());
+ //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input image must be in a single precision"); }
+
+ /*Handling Matlab output data*/
+ dimX = dim_array[0]; dimY = dim_array[1]; dimZ = dim_array[2];
+
+ tk = 1.0f;
+ tkp1 = 1.0f;
+ count = 1;
+ re_old = 0.0f;
+
+ if (number_of_dims == 2) {
+ dimZ = 1; /*2D case*/
+ /*D = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+ D_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+ P1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+ P2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+ P1_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+ P2_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+ R1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+ R2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));*/
+
+ bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]);
+ np::dtype dtype = np::dtype::get_builtin<float>();
+
+
+ np::ndarray npD = np::zeros(shape, dtype);
+ np::ndarray npD_old = np::zeros(shape, dtype);
+ np::ndarray npP1 = np::zeros(shape, dtype);
+ np::ndarray npP2 = np::zeros(shape, dtype);
+ np::ndarray npP1_old = np::zeros(shape, dtype);
+ np::ndarray npP2_old = np::zeros(shape, dtype);
+ np::ndarray npR1 = np::zeros(shape, dtype);
+ np::ndarray npR2 = np::zeros(shape, dtype);
+
+ D = reinterpret_cast<float *>(npD.get_data());
+ D_old = reinterpret_cast<float *>(npD_old.get_data());
+ P1 = reinterpret_cast<float *>(npP1.get_data());
+ P2 = reinterpret_cast<float *>(npP2.get_data());
+ P1_old = reinterpret_cast<float *>(npP1_old.get_data());
+ P2_old = reinterpret_cast<float *>(npP2_old.get_data());
+ R1 = reinterpret_cast<float *>(npR1.get_data());
+ R2 = reinterpret_cast<float *>(npR2.get_data());
+
+ /* begin iterations */
+ for (ll = 0; ll<iter; ll++) {
+ /* computing the gradient of the objective function */
+ Obj_func2D(A, D, R1, R2, lambda, dimX, dimY);
+
+ /*Taking a step towards minus of the gradient*/
+ Grad_func2D(P1, P2, D, R1, R2, lambda, dimX, dimY);
+
+
+
+
+ /*updating R and t*/
+ tkp1 = (1.0f + sqrt(1.0f + 4.0f*tk*tk))*0.5f;
+ Rupd_func2D(P1, P1_old, P2, P2_old, R1, R2, tkp1, tk, dimX, dimY);
+
+ /* calculate norm */
+ re = 0.0f; re1 = 0.0f;
+ for (j = 0; j<dimX*dimY*dimZ; j++)
+ {
+ re += pow(D[j] - D_old[j], 2);
+ re1 += pow(D[j], 2);
+ }
+ re = sqrt(re) / sqrt(re1);
+ if (re < epsil) count++;
+ if (count > 3) {
+ Obj_func2D(A, D, P1, P2, lambda, dimX, dimY);
+ funcval = 0.0f;
+ for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2);
+ //funcvalA[0] = sqrt(funcval);
+ float fv = sqrt(funcval);
+ std::memcpy(funcvalA, &fv, sizeof(float));
+ break;
+ }
+
+ /* check that the residual norm is decreasing */
+ if (ll > 2) {
+ if (re > re_old) {
+ Obj_func2D(A, D, P1, P2, lambda, dimX, dimY);
+ funcval = 0.0f;
+ for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2);
+ //funcvalA[0] = sqrt(funcval);
+ float fv = sqrt(funcval);
+ std::memcpy(funcvalA, &fv, sizeof(float));
+ break;
+ }
+ }
+ re_old = re;
+ /*printf("%f %i %i \n", re, ll, count); */
+
+ /*storing old values*/
+ copyIm(D, D_old, dimX, dimY, dimZ);
+ copyIm(P1, P1_old, dimX, dimY, dimZ);
+ copyIm(P2, P2_old, dimX, dimY, dimZ);
+ tk = tkp1;
+
+ /* calculating the objective function value */
+ if (ll == (iter - 1)) {
+ Obj_func2D(A, D, P1, P2, lambda, dimX, dimY);
+ funcval = 0.0f;
+ for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2);
+ //funcvalA[0] = sqrt(funcval);
+ float fv = sqrt(funcval);
+ std::memcpy(funcvalA, &fv, sizeof(float));
+ }
+ }
+ //printf("FGP-TV iterations stopped at iteration %i with the function value %f \n", ll, funcvalA[0]);
+ result.append<np::ndarray>(npD);
+ result.append<np::ndarray>(out1);
+ result.append<int>(ll);
+ }
+ if (number_of_dims == 3) {
+ /*D = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+ D_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+ P1 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+ P2 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+ P3 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+ P1_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+ P2_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+ P3_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+ R1 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+ R2 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+ R3 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));*/
+ bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]);
+ np::dtype dtype = np::dtype::get_builtin<float>();
+
+ np::ndarray npD = np::zeros(shape, dtype);
+ np::ndarray npD_old = np::zeros(shape, dtype);
+ np::ndarray npP1 = np::zeros(shape, dtype);
+ np::ndarray npP2 = np::zeros(shape, dtype);
+ np::ndarray npP3 = np::zeros(shape, dtype);
+ np::ndarray npP1_old = np::zeros(shape, dtype);
+ np::ndarray npP2_old = np::zeros(shape, dtype);
+ np::ndarray npP3_old = np::zeros(shape, dtype);
+ np::ndarray npR1 = np::zeros(shape, dtype);
+ np::ndarray npR2 = np::zeros(shape, dtype);
+ np::ndarray npR3 = np::zeros(shape, dtype);
+
+ D = reinterpret_cast<float *>(npD.get_data());
+ D_old = reinterpret_cast<float *>(npD_old.get_data());
+ P1 = reinterpret_cast<float *>(npP1.get_data());
+ P2 = reinterpret_cast<float *>(npP2.get_data());
+ P3 = reinterpret_cast<float *>(npP3.get_data());
+ P1_old = reinterpret_cast<float *>(npP1_old.get_data());
+ P2_old = reinterpret_cast<float *>(npP2_old.get_data());
+ P3_old = reinterpret_cast<float *>(npP3_old.get_data());
+ R1 = reinterpret_cast<float *>(npR1.get_data());
+ R2 = reinterpret_cast<float *>(npR2.get_data());
+ R3 = reinterpret_cast<float *>(npR3.get_data());
+ /* begin iterations */
+ for (ll = 0; ll<iter; ll++) {
+ /* computing the gradient of the objective function */
+ Obj_func3D(A, D, R1, R2, R3, lambda, dimX, dimY, dimZ);
+ /*Taking a step towards minus of the gradient*/
+ Grad_func3D(P1, P2, P3, D, R1, R2, R3, lambda, dimX, dimY, dimZ);
+
+ /* projection step */
+ Proj_func3D(P1, P2, P3, dimX, dimY, dimZ);
+
+ /*updating R and t*/
+ tkp1 = (1.0f + sqrt(1.0f + 4.0f*tk*tk))*0.5f;
+ Rupd_func3D(P1, P1_old, P2, P2_old, P3, P3_old, R1, R2, R3, tkp1, tk, dimX, dimY, dimZ);
+
+ /* calculate norm - stopping rules*/
+ re = 0.0f; re1 = 0.0f;
+ for (j = 0; j<dimX*dimY*dimZ; j++)
+ {
+ re += pow(D[j] - D_old[j], 2);
+ re1 += pow(D[j], 2);
+ }
+ re = sqrt(re) / sqrt(re1);
+ /* stop if the norm residual is less than the tolerance EPS */
+ if (re < epsil) count++;
+ if (count > 3) {
+ Obj_func3D(A, D, P1, P2, P3, lambda, dimX, dimY, dimZ);
+ funcval = 0.0f;
+ for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2);
+ //funcvalA[0] = sqrt(funcval);
+ float fv = sqrt(funcval);
+ std::memcpy(funcvalA, &fv, sizeof(float));
+ break;
+ }
+
+ /* check that the residual norm is decreasing */
+ if (ll > 2) {
+ if (re > re_old) {
+ Obj_func3D(A, D, P1, P2, P3, lambda, dimX, dimY, dimZ);
+ funcval = 0.0f;
+ for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2);
+ //funcvalA[0] = sqrt(funcval);
+ float fv = sqrt(funcval);
+ std::memcpy(funcvalA, &fv, sizeof(float));
+ break;
+ }
+ }
+
+ re_old = re;
+ /*printf("%f %i %i \n", re, ll, count); */
+
+ /*storing old values*/
+ copyIm(D, D_old, dimX, dimY, dimZ);
+ copyIm(P1, P1_old, dimX, dimY, dimZ);
+ copyIm(P2, P2_old, dimX, dimY, dimZ);
+ copyIm(P3, P3_old, dimX, dimY, dimZ);
+ tk = tkp1;
+
+ if (ll == (iter - 1)) {
+ Obj_func3D(A, D, P1, P2, P3, lambda, dimX, dimY, dimZ);
+ funcval = 0.0f;
+ for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2);
+ //funcvalA[0] = sqrt(funcval);
+ float fv = sqrt(funcval);
+ std::memcpy(funcvalA, &fv, sizeof(float));
+ }
+
+ }
+ //printf("FGP-TV iterations stopped at iteration %i with the function value %f \n", ll, funcvalA[0]);
+ result.append<np::ndarray>(npD);
+ result.append<np::ndarray>(out1);
+ result.append<int>(ll);
+ }
+
+ return result;
+}
+
+bp::list LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) {
+ // the result is in the following list
+ bp::list result;
+
+ int number_of_dims, dimX, dimY, dimZ, ll, j, count;
+ //const int *dim_array;
+ float *U0, *U = NULL, *U_old = NULL, *D1 = NULL, *D2 = NULL, *D3 = NULL, lambda, tau, re, re1, epsil, re_old;
+ unsigned short *Map = NULL;
+
+ number_of_dims = input.get_nd();
+ int dim_array[3];
+
+ dim_array[0] = input.shape(0);
+ dim_array[1] = input.shape(1);
+ if (number_of_dims == 2) {
+ dim_array[2] = -1;
+ }
+ else {
+ dim_array[2] = input.shape(2);
+ }
+
+ ///*Handling Matlab input data*/
+ //U0 = (float *)mxGetData(prhs[0]); /*origanal noise image/volume*/
+ //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input in single precision is required"); }
+ //lambda = (float)mxGetScalar(prhs[1]); /*regularization parameter*/
+ //tau = (float)mxGetScalar(prhs[2]); /* time-step */
+ //iter = (int)mxGetScalar(prhs[3]); /*iterations number*/
+ //epsil = (float)mxGetScalar(prhs[4]); /* tolerance constant */
+ //switcher = (int)mxGetScalar(prhs[5]); /*switch on (1) restrictive smoothing in Z dimension*/
+
+ U0 = reinterpret_cast<float *>(input.get_data());
+ lambda = (float)d_lambda;
+ tau = (float)d_tau;
+ // iter is passed as parameter
+ epsil = (float)d_epsil;
+ // switcher is passed as parameter
+ /*Handling Matlab output data*/
+ dimX = dim_array[0]; dimY = dim_array[1]; dimZ = 1;
+
+ if (number_of_dims == 2) {
+ /*2D case*/
+ /*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+ U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+ D1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+ D2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));*/
+
+ bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]);
+ np::dtype dtype = np::dtype::get_builtin<float>();
+
+
+ np::ndarray npU = np::zeros(shape, dtype);
+ np::ndarray npU_old = np::zeros(shape, dtype);
+ np::ndarray npD1 = np::zeros(shape, dtype);
+ np::ndarray npD2 = np::zeros(shape, dtype);
+
+
+ U = reinterpret_cast<float *>(npU.get_data());
+ U_old = reinterpret_cast<float *>(npU_old.get_data());
+ D1 = reinterpret_cast<float *>(npD1.get_data());
+ D2 = reinterpret_cast<float *>(npD2.get_data());
+
+ /*Copy U0 to U*/
+ copyIm(U0, U, dimX, dimY, dimZ);
+
+ count = 1;
+ re_old = 0.0f;
+
+ for (ll = 0; ll < iter; ll++) {
+
+ copyIm(U, U_old, dimX, dimY, dimZ);
+
+ /*estimate inner derrivatives */
+ der2D(U, D1, D2, dimX, dimY, dimZ);
+ /* calculate div^2 and update */
+ div_upd2D(U0, U, D1, D2, dimX, dimY, dimZ, lambda, tau);
+
+ /* calculate norm to terminate earlier */
+ re = 0.0f; re1 = 0.0f;
+ for (j = 0; j<dimX*dimY*dimZ; j++)
+ {
+ re += pow(U_old[j] - U[j], 2);
+ re1 += pow(U_old[j], 2);
+ }
+ re = sqrt(re) / sqrt(re1);
+ if (re < epsil) count++;
+ if (count > 4) break;
+
+ /* check that the residual norm is decreasing */
+ if (ll > 2) {
+ if (re > re_old) break;
+ }
+ re_old = re;
+
+ } /*end of iterations*/
+ //printf("HO iterations stopped at iteration: %i\n", ll);
+
+ result.append<np::ndarray>(npU);
+ }
+ else if (number_of_dims == 3) {
+ /*3D case*/
+ dimZ = dim_array[2];
+ /*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+ U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+ D1 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+ D2 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+ D3 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+ if (switcher != 0) {
+ Map = (unsigned short*)mxGetPr(plhs[1] = mxCreateNumericArray(3, dim_array, mxUINT16_CLASS, mxREAL));
+ }*/
+ bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]);
+ np::dtype dtype = np::dtype::get_builtin<float>();
+
+
+ np::ndarray npU = np::zeros(shape, dtype);
+ np::ndarray npU_old = np::zeros(shape, dtype);
+ np::ndarray npD1 = np::zeros(shape, dtype);
+ np::ndarray npD2 = np::zeros(shape, dtype);
+ np::ndarray npD3 = np::zeros(shape, dtype);
+ np::ndarray npMap = np::zeros(shape, np::dtype::get_builtin<unsigned short>());
+ Map = reinterpret_cast<unsigned short *>(npMap.get_data());
+ if (switcher != 0) {
+ //Map = (unsigned short*)mxGetPr(plhs[1] = mxCreateNumericArray(3, dim_array, mxUINT16_CLASS, mxREAL));
+
+ Map = reinterpret_cast<unsigned short *>(npMap.get_data());
+ }
+
+ U = reinterpret_cast<float *>(npU.get_data());
+ U_old = reinterpret_cast<float *>(npU_old.get_data());
+ D1 = reinterpret_cast<float *>(npD1.get_data());
+ D2 = reinterpret_cast<float *>(npD2.get_data());
+ D3 = reinterpret_cast<float *>(npD2.get_data());
+
+ /*Copy U0 to U*/
+ copyIm(U0, U, dimX, dimY, dimZ);
+
+ count = 1;
+ re_old = 0.0f;
+
+
+ if (switcher == 1) {
+ /* apply restrictive smoothing */
+ calcMap(U, Map, dimX, dimY, dimZ);
+ /*clear outliers */
+ cleanMap(Map, dimX, dimY, dimZ);
+ }
+ for (ll = 0; ll < iter; ll++) {
+
+ copyIm(U, U_old, dimX, dimY, dimZ);
+
+ /*estimate inner derrivatives */
+ der3D(U, D1, D2, D3, dimX, dimY, dimZ);
+ /* calculate div^2 and update */
+ div_upd3D(U0, U, D1, D2, D3, Map, switcher, dimX, dimY, dimZ, lambda, tau);
+
+ /* calculate norm to terminate earlier */
+ re = 0.0f; re1 = 0.0f;
+ for (j = 0; j<dimX*dimY*dimZ; j++)
+ {
+ re += pow(U_old[j] - U[j], 2);
+ re1 += pow(U_old[j], 2);
+ }
+ re = sqrt(re) / sqrt(re1);
+ if (re < epsil) count++;
+ if (count > 4) break;
+
+ /* check that the residual norm is decreasing */
+ if (ll > 2) {
+ if (re > re_old) break;
+ }
+ re_old = re;
+
+ } /*end of iterations*/
+ //printf("HO iterations stopped at iteration: %i\n", ll);
+ result.append<np::ndarray>(npU);
+ if (switcher != 0) result.append<np::ndarray>(npMap);
+
+ }
+ return result;
+}
+
+
+bp::list PatchBased_Regul(np::ndarray input, double d_lambda, int SearchW_real, int SimilW, double d_h) {
+ // the result is in the following list
+ bp::list result;
+
+ int N, M, Z, numdims, SearchW, /*SimilW, SearchW_real,*/ padXY, newsizeX, newsizeY, newsizeZ, switchpad_crop;
+ //const int *dims;
+ float *A, *B = NULL, *Ap = NULL, *Bp = NULL, h, lambda;
+
+ numdims = input.get_nd();
+ int dims[3];
+
+ dims[0] = input.shape(0);
+ dims[1] = input.shape(1);
+ if (numdims == 2) {
+ dims[2] = -1;
+ }
+ else {
+ dims[2] = input.shape(2);
+ }
+ /*numdims = mxGetNumberOfDimensions(prhs[0]);
+ dims = mxGetDimensions(prhs[0]);*/
+
+ N = dims[0];
+ M = dims[1];
+ Z = dims[2];
+
+ //if ((numdims < 2) || (numdims > 3)) { mexErrMsgTxt("The input should be 2D image or 3D volume"); }
+ //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input in single precision is required"); }
+
+ //if (nrhs != 5) mexErrMsgTxt("Five inputs reqired: Image(2D,3D), SearchW, SimilW, Threshold, Regularization parameter");
+
+ ///*Handling inputs*/
+ //A = (float *)mxGetData(prhs[0]); /* the image to regularize/filter */
+ A = reinterpret_cast<float *>(input.get_data());
+ //SearchW_real = (int)mxGetScalar(prhs[1]); /* the searching window ratio */
+ //SimilW = (int)mxGetScalar(prhs[2]); /* the similarity window ratio */
+ //h = (float)mxGetScalar(prhs[3]); /* parameter for the PB filtering function */
+ //lambda = (float)mxGetScalar(prhs[4]); /* regularization parameter */
+
+ //if (h <= 0) mexErrMsgTxt("Parmeter for the PB penalty function should be > 0");
+ //if (lambda <= 0) mexErrMsgTxt(" Regularization parmeter should be > 0");
+
+ lambda = (float)d_lambda;
+ h = (float)d_h;
+ SearchW = SearchW_real + 2 * SimilW;
+
+ /* SearchW_full = 2*SearchW + 1; */ /* the full searching window size */
+ /* SimilW_full = 2*SimilW + 1; */ /* the full similarity window size */
+
+
+ padXY = SearchW + 2 * SimilW; /* padding sizes */
+ newsizeX = N + 2 * (padXY); /* the X size of the padded array */
+ newsizeY = M + 2 * (padXY); /* the Y size of the padded array */
+ newsizeZ = Z + 2 * (padXY); /* the Z size of the padded array */
+ int N_dims[] = { newsizeX, newsizeY, newsizeZ };
+ /******************************2D case ****************************/
+ if (numdims == 2) {
+ ///*Handling output*/
+ //B = (float*)mxGetData(plhs[0] = mxCreateNumericMatrix(N, M, mxSINGLE_CLASS, mxREAL));
+ ///*allocating memory for the padded arrays */
+ //Ap = (float*)mxGetData(mxCreateNumericMatrix(newsizeX, newsizeY, mxSINGLE_CLASS, mxREAL));
+ //Bp = (float*)mxGetData(mxCreateNumericMatrix(newsizeX, newsizeY, mxSINGLE_CLASS, mxREAL));
+ ///**************************************************************************/
+
+ bp::tuple shape = bp::make_tuple(N, M);
+ np::dtype dtype = np::dtype::get_builtin<float>();
+
+ np::ndarray npB = np::zeros(shape, dtype);
+
+ shape = bp::make_tuple(newsizeX, newsizeY);
+ np::ndarray npAp = np::zeros(shape, dtype);
+ np::ndarray npBp = np::zeros(shape, dtype);
+ B = reinterpret_cast<float *>(npB.get_data());
+ Ap = reinterpret_cast<float *>(npAp.get_data());
+ Bp = reinterpret_cast<float *>(npBp.get_data());
+
+ /*Perform padding of image A to the size of [newsizeX * newsizeY] */
+ switchpad_crop = 0; /*padding*/
+ pad_crop(A, Ap, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop);
+
+ /* Do PB regularization with the padded array */
+ PB_FUNC2D(Ap, Bp, newsizeY, newsizeX, padXY, SearchW, SimilW, (float)h, (float)lambda);
+
+ switchpad_crop = 1; /*cropping*/
+ pad_crop(Bp, B, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop);
+
+ result.append<np::ndarray>(npB);
+ }
+ else
+ {
+ /******************************3D case ****************************/
+ ///*Handling output*/
+ //B = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dims, mxSINGLE_CLASS, mxREAL));
+ ///*allocating memory for the padded arrays */
+ //Ap = (float*)mxGetPr(mxCreateNumericArray(3, N_dims, mxSINGLE_CLASS, mxREAL));
+ //Bp = (float*)mxGetPr(mxCreateNumericArray(3, N_dims, mxSINGLE_CLASS, mxREAL));
+ /**************************************************************************/
+ bp::tuple shape = bp::make_tuple(dims[0], dims[1], dims[2]);
+ bp::tuple shape_AB = bp::make_tuple(N_dims[0], N_dims[1], N_dims[2]);
+ np::dtype dtype = np::dtype::get_builtin<float>();
+
+ np::ndarray npB = np::zeros(shape, dtype);
+ np::ndarray npAp = np::zeros(shape_AB, dtype);
+ np::ndarray npBp = np::zeros(shape_AB, dtype);
+ B = reinterpret_cast<float *>(npB.get_data());
+ Ap = reinterpret_cast<float *>(npAp.get_data());
+ Bp = reinterpret_cast<float *>(npBp.get_data());
+ /*Perform padding of image A to the size of [newsizeX * newsizeY * newsizeZ] */
+ switchpad_crop = 0; /*padding*/
+ pad_crop(A, Ap, M, N, Z, newsizeY, newsizeX, newsizeZ, padXY, switchpad_crop);
+
+ /* Do PB regularization with the padded array */
+ PB_FUNC3D(Ap, Bp, newsizeY, newsizeX, newsizeZ, padXY, SearchW, SimilW, (float)h, (float)lambda);
+
+ switchpad_crop = 1; /*cropping*/
+ pad_crop(Bp, B, M, N, Z, newsizeY, newsizeX, newsizeZ, padXY, switchpad_crop);
+
+ result.append<np::ndarray>(npB);
+ } /*end else ndims*/
+
+ return result;
+}
+
+bp::list TGV_PD(np::ndarray input, double d_lambda, double d_alpha1, double d_alpha0, int iter) {
+ // the result is in the following list
+ bp::list result;
+ int number_of_dims, /*iter,*/ dimX, dimY, dimZ, ll;
+ //const int *dim_array;
+ float *A, *U, *U_old, *P1, *P2, *Q1, *Q2, *Q3, *V1, *V1_old, *V2, *V2_old, lambda, L2, tau, sigma, alpha1, alpha0;
+
+ //number_of_dims = mxGetNumberOfDimensions(prhs[0]);
+ //dim_array = mxGetDimensions(prhs[0]);
+ number_of_dims = input.get_nd();
+ int dim_array[3];
+
+ dim_array[0] = input.shape(0);
+ dim_array[1] = input.shape(1);
+ if (number_of_dims == 2) {
+ dim_array[2] = -1;
+ }
+ else {
+ dim_array[2] = input.shape(2);
+ }
+ /*Handling Matlab input data*/
+ //A = (float *)mxGetData(prhs[0]); /*origanal noise image/volume*/
+ //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input in single precision is required"); }
+
+ A = reinterpret_cast<float *>(input.get_data());
+
+ //lambda = (float)mxGetScalar(prhs[1]); /*regularization parameter*/
+ //alpha1 = (float)mxGetScalar(prhs[2]); /*first-order term*/
+ //alpha0 = (float)mxGetScalar(prhs[3]); /*second-order term*/
+ //iter = (int)mxGetScalar(prhs[4]); /*iterations number*/
+ //if (nrhs != 5) mexErrMsgTxt("Five input parameters is reqired: Image(2D/3D), Regularization parameter, alpha1, alpha0, Iterations");
+ lambda = (float)d_lambda;
+ alpha1 = (float)d_alpha1;
+ alpha0 = (float)d_alpha0;
+
+ /*Handling Matlab output data*/
+ dimX = dim_array[0]; dimY = dim_array[1];
+
+ if (number_of_dims == 2) {
+ /*2D case*/
+ dimZ = 1;
+ bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]);
+ np::dtype dtype = np::dtype::get_builtin<float>();
+
+ np::ndarray npU = np::zeros(shape, dtype);
+ np::ndarray npP1 = np::zeros(shape, dtype);
+ np::ndarray npP2 = np::zeros(shape, dtype);
+ np::ndarray npQ1 = np::zeros(shape, dtype);
+ np::ndarray npQ2 = np::zeros(shape, dtype);
+ np::ndarray npQ3 = np::zeros(shape, dtype);
+ np::ndarray npV1 = np::zeros(shape, dtype);
+ np::ndarray npV1_old = np::zeros(shape, dtype);
+ np::ndarray npV2 = np::zeros(shape, dtype);
+ np::ndarray npV2_old = np::zeros(shape, dtype);
+ np::ndarray npU_old = np::zeros(shape, dtype);
+
+ U = reinterpret_cast<float *>(npU.get_data());
+ U_old = reinterpret_cast<float *>(npU_old.get_data());
+ P1 = reinterpret_cast<float *>(npP1.get_data());
+ P2 = reinterpret_cast<float *>(npP2.get_data());
+ Q1 = reinterpret_cast<float *>(npQ1.get_data());
+ Q2 = reinterpret_cast<float *>(npQ2.get_data());
+ Q3 = reinterpret_cast<float *>(npQ3.get_data());
+ V1 = reinterpret_cast<float *>(npV1.get_data());
+ V1_old = reinterpret_cast<float *>(npV1_old.get_data());
+ V2 = reinterpret_cast<float *>(npV2.get_data());
+ V2_old = reinterpret_cast<float *>(npV2_old.get_data());
+ //U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+
+ /*dual variables*/
+ /*P1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+ P2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+
+ Q1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+ Q2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+ Q3 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+
+ U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+
+ V1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+ V1_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+ V2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+ V2_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));*/
+ /*printf("%i \n", i);*/
+ L2 = 12.0; /*Lipshitz constant*/
+ tau = 1.0 / pow(L2, 0.5);
+ sigma = 1.0 / pow(L2, 0.5);
+
+ /*Copy A to U*/
+ copyIm(A, U, dimX, dimY, dimZ);
+ /* Here primal-dual iterations begin for 2D */
+ for (ll = 0; ll < iter; ll++) {
+
+ /* Calculate Dual Variable P */
+ DualP_2D(U, V1, V2, P1, P2, dimX, dimY, dimZ, sigma);
+
+ /*Projection onto convex set for P*/
+ ProjP_2D(P1, P2, dimX, dimY, dimZ, alpha1);
+
+ /* Calculate Dual Variable Q */
+ DualQ_2D(V1, V2, Q1, Q2, Q3, dimX, dimY, dimZ, sigma);
+
+ /*Projection onto convex set for Q*/
+ ProjQ_2D(Q1, Q2, Q3, dimX, dimY, dimZ, alpha0);
+
+ /*saving U into U_old*/
+ copyIm(U, U_old, dimX, dimY, dimZ);
+
+ /*adjoint operation -> divergence and projection of P*/
+ DivProjP_2D(U, A, P1, P2, dimX, dimY, dimZ, lambda, tau);
+
+ /*get updated solution U*/
+ newU(U, U_old, dimX, dimY, dimZ);
+
+ /*saving V into V_old*/
+ copyIm(V1, V1_old, dimX, dimY, dimZ);
+ copyIm(V2, V2_old, dimX, dimY, dimZ);
+
+ /* upd V*/
+ UpdV_2D(V1, V2, P1, P2, Q1, Q2, Q3, dimX, dimY, dimZ, tau);
+
+ /*get new V*/
+ newU(V1, V1_old, dimX, dimY, dimZ);
+ newU(V2, V2_old, dimX, dimY, dimZ);
+ } /*end of iterations*/
+
+ result.append<np::ndarray>(npU);
+ }
+
+
+
+
+ return result;
+}
+
+BOOST_PYTHON_MODULE(cpu_regularizers)
+{
+ np::initialize();
+
+ //To specify that this module is a package
+ bp::object package = bp::scope();
+ package.attr("__path__") = "cpu_regularizers";
+
+ np::dtype dt1 = np::dtype::get_builtin<uint8_t>();
+ np::dtype dt2 = np::dtype::get_builtin<uint16_t>();
+
+ def("SplitBregman_TV", SplitBregman_TV);
+ def("FGP_TV", FGP_TV);
+ def("LLT_model", LLT_model);
+ def("PatchBased_Regul", PatchBased_Regul);
+ def("TGV_PD", TGV_PD);
+}
diff --git a/Wrappers/Python/setup-fista.py.in b/Wrappers/Python/setup-fista.py.in
new file mode 100644
index 0000000..c5c9f4d
--- /dev/null
+++ b/Wrappers/Python/setup-fista.py.in
@@ -0,0 +1,27 @@
+from distutils.core import setup
+#from setuptools import setup, find_packages
+import os
+
+cil_version=os.environ['CIL_VERSION']
+if cil_version == '':
+ print("Please set the environmental variable CIL_VERSION")
+ sys.exit(1)
+
+setup(
+ name="ccpi-fista",
+ version=cil_version,
+ packages=['ccpi','ccpi.reconstruction'],
+ install_requires=['numpy'],
+
+ zip_safe = False,
+
+ # metadata for upload to PyPI
+ author="Edoardo Pasca",
+ author_email="edo.paskino@gmail.com",
+ description='CCPi Core Imaging Library - FISTA Reconstructor module',
+ license="Apache v2.0",
+ keywords="tomography interative reconstruction",
+ url="http://www.ccpi.ac.uk", # project home page, if any
+
+ # could also include long_description, download_url, classifiers, etc.
+)
diff --git a/Wrappers/Python/setup.py.in b/Wrappers/Python/setup.py.in
new file mode 100644
index 0000000..12e8af1
--- /dev/null
+++ b/Wrappers/Python/setup.py.in
@@ -0,0 +1,69 @@
+#!/usr/bin/env python
+
+import setuptools
+from distutils.core import setup
+from distutils.extension import Extension
+from Cython.Distutils import build_ext
+
+import os
+import sys
+import numpy
+import platform
+
+cil_version=@CIL_VERSION@
+
+library_include_path = ""
+library_lib_path = ""
+try:
+ library_include_path = os.environ['LIBRARY_INC']
+ library_lib_path = os.environ['LIBRARY_LIB']
+except:
+ library_include_path = os.environ['PREFIX']+'/include'
+ pass
+
+extra_include_dirs = [numpy.get_include(), library_include_path]
+extra_library_dirs = [os.path.join(library_include_path, "..", "lib")]
+extra_compile_args = ['-fopenmp','-O2', '-funsigned-char', '-Wall', '-std=c++0x']
+extra_libraries = []
+extra_include_dirs += [os.path.join("@CMAKE_SOURCE_DIR@" , "main_func" , "regularizers_CPU"),
+ os.path.join("@CMAKE_SOURCE_DIR@" , "main_func" , "regularizers_GPU") ,
+ "@CMAKE_CURRENT_SOURCE_DIR@"]
+
+if platform.system() == 'Windows':
+ extra_compile_args[0:] = ['/DWIN32','/EHsc','/DBOOST_ALL_NO_LIB' , '/openmp' ]
+
+ if sys.version_info.major == 3 :
+ extra_libraries += ['boost_python3-vc140-mt-1_64', 'boost_numpy3-vc140-mt-1_64']
+ else:
+ extra_libraries += ['boost_python-vc90-mt-1_64', 'boost_numpy-vc90-mt-1_64']
+else:
+ if sys.version_info.major == 3:
+ extra_libraries += ['boost_python3', 'boost_numpy3','gomp']
+ else:
+ extra_libraries += ['boost_python', 'boost_numpy','gomp']
+
+setup(
+ name='ccpi',
+ description='CCPi Core Imaging Library - Image Regularizers',
+ version=cil_version,
+ cmdclass = {'build_ext': build_ext},
+ ext_modules = [Extension("ccpi.imaging.cpu_regularizers",
+ sources=[os.path.join("@CMAKE_CURRENT_SOURCE_DIR@" , "fista_module.cpp" ),
+ os.path.join("@CMAKE_SOURCE_DIR@" , "main_func" , "regularizers_CPU", "FGP_TV_core.c"),
+ os.path.join("@CMAKE_SOURCE_DIR@" , "main_func" , "regularizers_CPU", "SplitBregman_TV_core.c"),
+ os.path.join("@CMAKE_SOURCE_DIR@" , "main_func" , "regularizers_CPU", "LLT_model_core.c"),
+ os.path.join("@CMAKE_SOURCE_DIR@" , "main_func" , "regularizers_CPU", "PatchBased_Regul_core.c"),
+ os.path.join("@CMAKE_SOURCE_DIR@" , "main_func" , "regularizers_CPU", "TGV_PD_core.c"),
+ os.path.join("@CMAKE_SOURCE_DIR@" , "main_func" , "regularizers_CPU", "utils.c")
+ ],
+ include_dirs=extra_include_dirs,
+ library_dirs=extra_library_dirs,
+ extra_compile_args=extra_compile_args,
+ libraries=extra_libraries ),
+
+ ],
+ zip_safe = False,
+ packages = {'ccpi','ccpi.imaging'},
+)
+
+
diff --git a/Wrappers/Python/test/astra_test.py b/Wrappers/Python/test/astra_test.py
new file mode 100644
index 0000000..42c375a
--- /dev/null
+++ b/Wrappers/Python/test/astra_test.py
@@ -0,0 +1,85 @@
+import astra
+import numpy
+import filefun
+
+
+# read in the same data as the DemoRD2
+angles = filefun.dlmread("DemoRD2/angles.csv")
+darks_ar = filefun.dlmread("DemoRD2/darks_ar.csv", separator=",")
+flats_ar = filefun.dlmread("DemoRD2/flats_ar.csv", separator=",")
+
+if True:
+ Sino3D = numpy.load("DemoRD2/Sino3D.npy")
+else:
+ sino = filefun.dlmread("DemoRD2/sino_01.csv", separator=",")
+ a = map (lambda x:x, numpy.shape(sino))
+ a.append(20)
+
+ Sino3D = numpy.zeros(tuple(a), dtype="float")
+
+ for i in range(1,numpy.shape(Sino3D)[2]+1):
+ print("Read file DemoRD2/sino_%02d.csv" % i)
+ sino = filefun.dlmread("DemoRD2/sino_%02d.csv" % i, separator=",")
+ Sino3D.T[i-1] = sino.T
+
+Weights3D = numpy.asarray(Sino3D, dtype="float")
+
+##angles_rad = angles*(pi/180); % conversion to radians
+##size_det = size(data_raw3D,1); % detectors dim
+##angSize = size(data_raw3D, 2); % angles dim
+##slices_tot = size(data_raw3D, 3); % no of slices
+##recon_size = 950; % reconstruction size
+
+
+angles_rad = angles * numpy.pi /180.
+size_det, angSize, slices_tot = numpy.shape(Sino3D)
+size_det, angSize, slices_tot = [int(i) for i in numpy.shape(Sino3D)]
+recon_size = 950
+Z_slices = 3;
+det_row_count = Z_slices;
+
+#proj_geom = astra_create_proj_geom('parallel3d', 1, 1,
+# det_row_count, size_det, angles_rad);
+
+detectorSpacingX = 1.0
+detectorSpacingY = detectorSpacingX
+proj_geom = astra.create_proj_geom('parallel3d',
+ detectorSpacingX,
+ detectorSpacingY,
+ det_row_count,
+ size_det,
+ angles_rad)
+
+#vol_geom = astra_create_vol_geom(recon_size,recon_size,Z_slices);
+vol_geom = astra.create_vol_geom(recon_size,recon_size,Z_slices);
+
+sino = numpy.zeros((size_det, angSize, slices_tot), dtype="float")
+
+#weights = ones(size(sino));
+weights = numpy.ones(numpy.shape(sino))
+
+#####################################################################
+## PowerMethod for Lipschitz constant
+
+N = vol_geom['GridColCount']
+x1 = numpy.random.rand(1,N,N)
+#sqweight = sqrt(weights(:,:,1));
+sqweight = numpy.sqrt(weights.T[0]).T
+##proj_geomT = proj_geom;
+proj_geomT = proj_geom.copy()
+##proj_geomT.DetectorRowCount = 1;
+proj_geomT['DetectorRowCount'] = 1
+##vol_geomT = vol_geom;
+vol_geomT = vol_geom.copy()
+##vol_geomT.GridSliceCount = 1;
+vol_geomT['GridSliceCount'] = 1
+
+##[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
+
+#sino_id, y = astra.create_sino3d_gpu(x1, proj_geomT, vol_geomT);
+sino_id, y = astra.create_sino(x1, proj_geomT, vol_geomT);
+
+##y = sqweight.*y;
+##astra_mex_data3d('delete', sino_id);
+
+
diff --git a/Wrappers/Python/test/create_phantom_projections.py b/Wrappers/Python/test/create_phantom_projections.py
new file mode 100644
index 0000000..20a9278
--- /dev/null
+++ b/Wrappers/Python/test/create_phantom_projections.py
@@ -0,0 +1,49 @@
+from ccpi.reconstruction.AstraDevice import AstraDevice
+from ccpi.reconstruction.DeviceModel import DeviceModel
+import h5py
+import numpy
+import matplotlib.pyplot as plt
+
+nx = h5py.File('phant3D_256.h5', "r")
+phantom = numpy.asarray(nx.get('/dataset1'))
+pX,pY,pZ = numpy.shape(phantom)
+
+filename = r'/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/demos/DendrData.h5'
+nxa = h5py.File(filename, "r")
+#getEntry(nx, '/')
+# I have exported the entries as children of /
+entries = [entry for entry in nxa['/'].keys()]
+print (entries)
+
+angles_rad = numpy.asarray(nxa.get('/angles_rad'), dtype="float32")
+
+
+device = AstraDevice(
+ DeviceModel.DeviceType.PARALLEL3D.value,
+ [ pX , pY , 1., 1., angles_rad],
+ [ pX, pY, pZ ] )
+
+
+proj = device.doForwardProject(phantom)
+stack = [proj[:,i,:] for i in range(len(angles_rad))]
+stack = numpy.asarray(stack)
+
+
+fig = plt.figure()
+a=fig.add_subplot(1,2,1)
+a.set_title('proj')
+imgplot = plt.imshow(proj[:,100,:])
+a=fig.add_subplot(1,2,2)
+a.set_title('stack')
+imgplot = plt.imshow(stack[100])
+plt.show()
+
+pf = h5py.File("phantom3D256_projections.h5" , "w")
+pf.create_dataset("/projections", data=stack)
+pf.create_dataset("/sinogram", data=proj)
+pf.create_dataset("/angles", data=angles_rad)
+pf.create_dataset("/reconstruction_volume" , data=numpy.asarray([pX, pY, pZ]))
+pf.create_dataset("/camera/size" , data=numpy.asarray([pX , pY ]))
+pf.create_dataset("/camera/spacing" , data=numpy.asarray([1.,1.]))
+pf.flush()
+pf.close()
diff --git a/Wrappers/Python/test/readhd5.py b/Wrappers/Python/test/readhd5.py
new file mode 100644
index 0000000..eff6c43
--- /dev/null
+++ b/Wrappers/Python/test/readhd5.py
@@ -0,0 +1,42 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Wed Aug 23 16:34:49 2017
+
+@author: ofn77899
+"""
+
+import h5py
+import numpy
+
+def getEntry(nx, location):
+ for item in nx[location].keys():
+ print (item)
+
+filename = r'/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/demos/DendrData.h5'
+nx = h5py.File(filename, "r")
+#getEntry(nx, '/')
+# I have exported the entries as children of /
+entries = [entry for entry in nx['/'].keys()]
+print (entries)
+
+Sino3D = numpy.asarray(nx.get('/Sino3D'))
+Weights3D = numpy.asarray(nx.get('/Weights3D'))
+angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0]
+angles_rad = numpy.asarray(nx.get('/angles_rad'))
+recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0]
+size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0]
+
+slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0]
+
+#from ccpi.viewer.CILViewer2D import CILViewer2D
+#v = CILViewer2D()
+#v.setInputAsNumpy(Weights3D)
+#v.startRenderLoop()
+
+import matplotlib.pyplot as plt
+fig = plt.figure()
+
+a=fig.add_subplot(1,1,1)
+a.set_title('noise')
+imgplot = plt.imshow(Weights3D[0].T)
+plt.show()
diff --git a/Wrappers/Python/test/simple_astra_test.py b/Wrappers/Python/test/simple_astra_test.py
new file mode 100644
index 0000000..905eeea
--- /dev/null
+++ b/Wrappers/Python/test/simple_astra_test.py
@@ -0,0 +1,25 @@
+import astra
+import numpy
+
+detectorSpacingX = 1.0
+detectorSpacingY = 1.0
+det_row_count = 128
+det_col_count = 128
+
+angles_rad = numpy.asarray([i for i in range(360)], dtype=float) / 180. * numpy.pi
+
+proj_geom = astra.creators.create_proj_geom('parallel3d',
+ detectorSpacingX,
+ detectorSpacingY,
+ det_row_count,
+ det_col_count,
+ angles_rad)
+
+image_size_x = 64
+image_size_y = 64
+image_size_z = 32
+
+vol_geom = astra.creators.create_vol_geom(image_size_x,image_size_y,image_size_z)
+
+x1 = numpy.random.rand(image_size_z,image_size_y,image_size_x)
+sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom)
diff --git a/Wrappers/Python/test/test_reconstructor-os_phantom.py b/Wrappers/Python/test/test_reconstructor-os_phantom.py
new file mode 100644
index 0000000..01f1354
--- /dev/null
+++ b/Wrappers/Python/test/test_reconstructor-os_phantom.py
@@ -0,0 +1,480 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Wed Aug 23 16:34:49 2017
+
+@author: ofn77899
+Based on DemoRD2.m
+"""
+
+import h5py
+import numpy
+
+from ccpi.reconstruction.FISTAReconstructor import FISTAReconstructor
+import astra
+import matplotlib.pyplot as plt
+from ccpi.imaging.Regularizer import Regularizer
+from ccpi.reconstruction.AstraDevice import AstraDevice
+from ccpi.reconstruction.DeviceModel import DeviceModel
+
+#from ccpi.viewer.CILViewer2D import *
+
+
+def RMSE(signal1, signal2):
+ '''RMSE Root Mean Squared Error'''
+ if numpy.shape(signal1) == numpy.shape(signal2):
+ err = (signal1 - signal2)
+ err = numpy.sum( err * err )/numpy.size(signal1); # MSE
+ err = sqrt(err); # RMSE
+ return err
+ else:
+ raise Exception('Input signals must have the same shape')
+
+filename = r'/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/src/Python/test/phantom3D256_projections.h5'
+nx = h5py.File(filename, "r")
+#getEntry(nx, '/')
+# I have exported the entries as children of /
+entries = [entry for entry in nx['/'].keys()]
+print (entries)
+
+projections = numpy.asarray(nx.get('/projections'), dtype="float32")
+#Weights3D = numpy.asarray(nx.get('/Weights3D'), dtype="float32")
+#angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0]
+angles_rad = numpy.asarray(nx.get('/angles'), dtype="float32")
+angSize = numpy.size(angles_rad)
+image_size_x, image_size_y, image_size_z = \
+ numpy.asarray(nx.get('/reconstruction_volume'), dtype=int)
+det_col_count, det_row_count = \
+ numpy.asarray(nx.get('/camera/size'))
+#slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0]
+detectorSpacingX, detectorSpacingY = numpy.asarray(nx.get('/camera/spacing'), dtype=int)
+
+Z_slices = 20
+#det_row_count = image_size_y
+# next definition is just for consistency of naming
+#det_col_count = image_size_x
+
+detectorSpacingX = 1.0
+detectorSpacingY = detectorSpacingX
+
+
+proj_geom = astra.creators.create_proj_geom('parallel3d',
+ detectorSpacingX,
+ detectorSpacingY,
+ det_row_count,
+ det_col_count,
+ angles_rad)
+
+#vol_geom = astra_create_vol_geom(recon_size,recon_size,Z_slices);
+##image_size_x = recon_size
+##image_size_y = recon_size
+##image_size_z = Z_slices
+vol_geom = astra.creators.create_vol_geom( image_size_x,
+ image_size_y,
+ image_size_z)
+
+## First pass the arguments to the FISTAReconstructor and test the
+## Lipschitz constant
+astradevice = AstraDevice(DeviceModel.DeviceType.PARALLEL3D.value,
+ [proj_geom['DetectorRowCount'] ,
+ proj_geom['DetectorColCount'] ,
+ proj_geom['DetectorSpacingX'] ,
+ proj_geom['DetectorSpacingY'] ,
+ proj_geom['ProjectionAngles']
+ ],
+ [
+ vol_geom['GridColCount'],
+ vol_geom['GridRowCount'],
+ vol_geom['GridSliceCount'] ] )
+## create the sinogram
+Sino3D = numpy.transpose(projections, axes=[1,0,2])
+
+fistaRecon = FISTAReconstructor(proj_geom,
+ vol_geom,
+ Sino3D ,
+ #weights=Weights3D,
+ device=astradevice)
+
+print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant']))
+fistaRecon.setParameter(number_of_iterations = 4)
+#fistaRecon.setParameter(Lipschitz_constant = 767893952.0)
+fistaRecon.setParameter(ring_alpha = 21)
+fistaRecon.setParameter(ring_lambda_R_L1 = 0.002)
+#fistaRecon.setParameter(ring_lambda_R_L1 = 0)
+subsets = 8
+fistaRecon.setParameter(subsets=subsets)
+
+
+#reg = Regularizer(Regularizer.Algorithm.FGP_TV)
+#reg.setParameter(regularization_parameter=0.005,
+# number_of_iterations=50)
+reg = Regularizer(Regularizer.Algorithm.FGP_TV)
+reg.setParameter(regularization_parameter=5e6,
+ tolerance_constant=0.0001,
+ number_of_iterations=50)
+
+#fistaRecon.setParameter(regularizer=reg)
+#lc = fistaRecon.getParameter('Lipschitz_constant')
+#reg.setParameter(regularization_parameter=5e6/lc)
+
+## Ordered subset
+if True:
+ #subsets = 8
+ fistaRecon.setParameter(subsets=subsets)
+ fistaRecon.createOrderedSubsets()
+else:
+ angles = fistaRecon.getParameter('projector_geometry')['ProjectionAngles']
+ #binEdges = numpy.linspace(angles.min(),
+ # angles.max(),
+ # subsets + 1)
+ binsDiscr, binEdges = numpy.histogram(angles, bins=subsets)
+ # get rearranged subset indices
+ IndicesReorg = numpy.zeros((numpy.shape(angles)))
+ counterM = 0
+ for ii in range(binsDiscr.max()):
+ counter = 0
+ for jj in range(subsets):
+ curr_index = ii + jj + counter
+ #print ("{0} {1} {2}".format(binsDiscr[jj] , ii, counterM))
+ if binsDiscr[jj] > ii:
+ if (counterM < numpy.size(IndicesReorg)):
+ IndicesReorg[counterM] = curr_index
+ counterM = counterM + 1
+
+ counter = counter + binsDiscr[jj] - 1
+
+
+if True:
+ print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant']))
+ print ("prepare for iteration")
+ fistaRecon.prepareForIteration()
+
+
+
+ print("initializing ...")
+ if True:
+ # if X doesn't exist
+ #N = params.vol_geom.GridColCount
+ N = vol_geom['GridColCount']
+ print ("N " + str(N))
+ X = numpy.asarray(numpy.ones((image_size_x,image_size_y,image_size_z)),
+ dtype=numpy.float) * 0.001
+ X = numpy.asarray(numpy.zeros((image_size_x,image_size_y,image_size_z)),
+ dtype=numpy.float)
+ else:
+ #X = fistaRecon.initialize()
+ X = numpy.load("X.npy")
+
+ print (numpy.shape(X))
+ X_t = X.copy()
+ print ("initialized")
+ proj_geom , vol_geom, sino , \
+ SlicesZ, weights , alpha_ring = fistaRecon.getParameter(
+ ['projector_geometry' , 'output_geometry',
+ 'input_sinogram', 'SlicesZ' , 'weights', 'ring_alpha'])
+ lambdaR_L1 , alpha_ring , weights , L_const= \
+ fistaRecon.getParameter(['ring_lambda_R_L1',
+ 'ring_alpha' , 'weights',
+ 'Lipschitz_constant'])
+
+ #fistaRecon.setParameter(number_of_iterations = 3)
+ iterFISTA = fistaRecon.getParameter('number_of_iterations')
+ # errors vector (if the ground truth is given)
+ Resid_error = numpy.zeros((iterFISTA));
+ # objective function values vector
+ objective = numpy.zeros((iterFISTA));
+
+
+ t = 1
+
+
+ ## additional for
+ proj_geomSUB = proj_geom.copy()
+ fistaRecon.residual2 = numpy.zeros(numpy.shape(fistaRecon.pars['input_sinogram']))
+ residual2 = fistaRecon.residual2
+ sino_updt_FULL = fistaRecon.residual.copy()
+ r_x = fistaRecon.r.copy()
+
+ results = []
+ print ("starting iterations")
+## % Outer FISTA iterations loop
+ for i in range(fistaRecon.getParameter('number_of_iterations')):
+## % With OS approach it becomes trickier to correlate independent subsets, hence additional work is required
+## % one solution is to work with a full sinogram at times
+## if ((i >= 3) && (lambdaR_L1 > 0))
+## [sino_id2, sino_updt2] = astra_create_sino3d_cuda(X, proj_geom, vol_geom);
+## astra_mex_data3d('delete', sino_id2);
+## end
+ # With OS approach it becomes trickier to correlate independent subsets,
+ # hence additional work is required one solution is to work with a full
+ # sinogram at times
+
+
+ #t_old = t
+ SlicesZ, anglesNumb, Detectors = \
+ numpy.shape(fistaRecon.getParameter('input_sinogram'))
+ ## https://github.com/vais-ral/CCPi-FISTA_Reconstruction/issues/4
+ r_old = fistaRecon.r.copy()
+
+ if (i > 1 and lambdaR_L1 > 0) :
+ for kkk in range(anglesNumb):
+
+ residual2[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \
+ ((sino_updt_FULL[:,kkk,:]).squeeze() - \
+ (sino[:,kkk,:]).squeeze() -\
+ (alpha_ring * r_x)
+ )
+ #r_old = fistaRecon.r.copy()
+ vec = fistaRecon.residual.sum(axis = 1)
+ #if SlicesZ > 1:
+ # vec = vec[:,1,:] # 1 or 0?
+ r_x = fistaRecon.r_x
+ # update ring variable
+ fistaRecon.r = (r_x - (1./L_const) * vec)
+
+ # subset loop
+ counterInd = 1
+ geometry_type = fistaRecon.getParameter('projector_geometry')['type']
+ angles = fistaRecon.getParameter('projector_geometry')['ProjectionAngles']
+
+## if geometry_type == 'parallel' or \
+## geometry_type == 'fanflat' or \
+## geometry_type == 'fanflat_vec' :
+##
+## for kkk in range(SlicesZ):
+## sino_id, sinoT[kkk] = \
+## astra.creators.create_sino3d_gpu(
+## X_t[kkk:kkk+1], proj_geomSUB, vol_geom)
+## sino_updt_Sub[kkk] = sinoT.T.copy()
+##
+## else:
+## sino_id, sino_updt_Sub = \
+## astra.creators.create_sino3d_gpu(X_t, proj_geomSUB, vol_geom)
+##
+## astra.matlab.data3d('delete', sino_id)
+
+ for ss in range(fistaRecon.getParameter('subsets')):
+ print ("Subset {0}".format(ss))
+ X_old = X.copy()
+ t_old = t
+ print ("X[0][0][0] {0} t {1}".format(X[0][0][0], t))
+
+ # the number of projections per subset
+ numProjSub = fistaRecon.getParameter('os_bins')[ss]
+ CurrSubIndices = fistaRecon.getParameter('os_indices')\
+ [counterInd:counterInd+numProjSub]
+ shape = list(numpy.shape(fistaRecon.getParameter('input_sinogram')))
+ shape[1] = numProjSub
+ sino_updt_Sub = numpy.zeros(shape)
+
+ #print ("Len CurrSubIndices {0}".format(numProjSub))
+ mask = numpy.zeros(numpy.shape(angles), dtype=bool)
+ cc = 0
+ for j in range(len(CurrSubIndices)):
+ mask[int(CurrSubIndices[j])] = True
+
+ ## this is a reduced device
+ rdev = fistaRecon.getParameter('device_model')\
+ .createReducedDevice(proj_par={'angles' : angles[mask]},
+ vol_par={})
+ proj_geomSUB['ProjectionAngles'] = angles[mask]
+
+
+
+ if geometry_type == 'parallel' or \
+ geometry_type == 'fanflat' or \
+ geometry_type == 'fanflat_vec' :
+
+ for kkk in range(SlicesZ):
+ sino_id, sinoT = astra.creators.create_sino3d_gpu (
+ X_t[kkk:kkk+1] , proj_geomSUB, vol_geom)
+ sino_updt_Sub[kkk] = sinoT.T.copy()
+ astra.matlab.data3d('delete', sino_id)
+ else:
+ # for 3D geometry (watch the GPU memory overflow in ASTRA < 1.8)
+ sino_id, sino_updt_Sub = \
+ astra.creators.create_sino3d_gpu (X_t,
+ proj_geomSUB,
+ vol_geom)
+
+ astra.matlab.data3d('delete', sino_id)
+
+
+
+
+ ## RING REMOVAL
+ residual = fistaRecon.residual
+
+
+ if lambdaR_L1 > 0 :
+ print ("ring removal")
+ residualSub = numpy.zeros(shape)
+ ## for a chosen subset
+ ## for kkk = 1:numProjSub
+ ## indC = CurrSubIndeces(kkk);
+ ## residualSub(:,kkk,:) = squeeze(weights(:,indC,:)).*(squeeze(sino_updt_Sub(:,kkk,:)) - (squeeze(sino(:,indC,:)) - alpha_ring.*r_x));
+ ## sino_updt_FULL(:,indC,:) = squeeze(sino_updt_Sub(:,kkk,:)); % filling the full sinogram
+ ## end
+ for kkk in range(numProjSub):
+ #print ("ring removal indC ... {0}".format(kkk))
+ indC = int(CurrSubIndices[kkk])
+ residualSub[:,kkk,:] = weights[:,indC,:].squeeze() * \
+ (sino_updt_Sub[:,kkk,:].squeeze() - \
+ sino[:,indC,:].squeeze() - alpha_ring * r_x)
+ # filling the full sinogram
+ sino_updt_FULL[:,indC,:] = sino_updt_Sub[:,kkk,:].squeeze()
+
+ else:
+ #PWLS model
+ # I guess we need to use mask here instead
+ residualSub = weights[:,CurrSubIndices,:] * \
+ ( sino_updt_Sub - \
+ sino[:,CurrSubIndices,:].squeeze() )
+ # it seems that in the original code the following like is not
+ # calculated in the case of ring removal
+ objective[i] = 0.5 * numpy.linalg.norm(residualSub)
+
+ #backprojection
+ if geometry_type == 'parallel' or \
+ geometry_type == 'fanflat' or \
+ geometry_type == 'fanflat_vec' :
+ # if geometry is 2D use slice-by-slice projection-backprojection
+ # routine
+ x_temp = numpy.zeros(numpy.shape(X), dtype=numpy.float32)
+ for kkk in range(SlicesZ):
+
+ x_id, x_temp[kkk] = \
+ astra.creators.create_backprojection3d_gpu(
+ residualSub[kkk:kkk+1],
+ proj_geomSUB, vol_geom)
+ astra.matlab.data3d('delete', x_id)
+
+ else:
+ x_id, x_temp = \
+ astra.creators.create_backprojection3d_gpu(
+ residualSub, proj_geomSUB, vol_geom)
+
+ astra.matlab.data3d('delete', x_id)
+
+ X = X_t - (1/L_const) * x_temp
+
+
+
+ ## REGULARIZATION
+ ## SKIPPING FOR NOW
+ ## Should be simpli
+ # regularizer = fistaRecon.getParameter('regularizer')
+ # for slices:
+ # out = regularizer(input=X)
+ print ("regularizer")
+ reg = fistaRecon.getParameter('regularizer')
+
+ if reg is not None:
+ X = reg(input=X,
+ output_all=False)
+
+ t = (1 + numpy.sqrt(1 + 4 * t **2))/2
+ X_t = X + (((t_old -1)/t) * (X-X_old))
+ counterInd = counterInd + numProjSub - 1
+ if i == 1:
+ r_old = fistaRecon.r.copy()
+
+ ## FINAL
+ print ("final")
+ lambdaR_L1 = fistaRecon.getParameter('ring_lambda_R_L1')
+ if lambdaR_L1 > 0:
+ fistaRecon.r = numpy.max(
+ numpy.abs(fistaRecon.r) - lambdaR_L1 , 0) * \
+ numpy.sign(fistaRecon.r)
+ # updating r
+ r_x = fistaRecon.r + ((t_old-1)/t) * (fistaRecon.r - r_old)
+
+
+ if fistaRecon.getParameter('region_of_interest') is None:
+ string = 'Iteration Number {0} | Objective {1} \n'
+ print (string.format( i, objective[i]))
+ else:
+ ROI , X_ideal = fistaRecon.getParameter('region_of_interest',
+ 'ideal_image')
+
+ Resid_error[i] = RMSE(X*ROI, X_ideal*ROI)
+ string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n'
+ print (string.format(i,Resid_error[i], objective[i]))
+
+ results.append(X[10])
+ numpy.save("X_out_os.npy", X)
+
+else:
+
+
+
+ astradevice = AstraDevice(DeviceModel.DeviceType.PARALLEL3D.value,
+ [proj_geom['DetectorRowCount'] ,
+ proj_geom['DetectorColCount'] ,
+ proj_geom['DetectorSpacingX'] ,
+ proj_geom['DetectorSpacingY'] ,
+ proj_geom['ProjectionAngles']
+ ],
+ [
+ vol_geom['GridColCount'],
+ vol_geom['GridRowCount'],
+ vol_geom['GridSliceCount'] ] )
+ regul = Regularizer(Regularizer.Algorithm.FGP_TV)
+ regul.setParameter(regularization_parameter=5e6,
+ number_of_iterations=50,
+ tolerance_constant=1e-4,
+ TV_penalty=Regularizer.TotalVariationPenalty.isotropic)
+
+ fistaRecon = FISTAReconstructor(proj_geom,
+ vol_geom,
+ Sino3D ,
+ weights=Weights3D,
+ device=astradevice,
+ #regularizer = regul,
+ subsets=8)
+
+ print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant']))
+ fistaRecon.setParameter(number_of_iterations = 1)
+ fistaRecon.setParameter(Lipschitz_constant = 767893952.0)
+ fistaRecon.setParameter(ring_alpha = 21)
+ fistaRecon.setParameter(ring_lambda_R_L1 = 0.002)
+ #fistaRecon.setParameter(subsets=8)
+
+ #lc = fistaRecon.getParameter('Lipschitz_constant')
+ #fistaRecon.getParameter('regularizer').setParameter(regularization_parameter=5e6/lc)
+
+ fistaRecon.prepareForIteration()
+ X = fistaRecon.iterate(numpy.load("X.npy"))
+
+
+# plot
+fig = plt.figure()
+cols = 3
+
+## add the difference
+rd = []
+for i in range(1,len(results)):
+ rd.append(results[i-1])
+ rd.append(results[i])
+ rd.append(results[i] - results[i-1])
+
+rows = (lambda x: int(numpy.floor(x/cols) + 1) if x%cols != 0 else int(x/cols)) \
+ (len (rd))
+for i in range(len (results)):
+ a=fig.add_subplot(rows,cols,i+1)
+ imgplot = plt.imshow(results[i], vmin=0, vmax=1)
+ a.text(0.05, 0.95, "iteration {0}".format(i),
+ verticalalignment='top')
+## i = i + 1
+## a=fig.add_subplot(rows,cols,i+1)
+## imgplot = plt.imshow(results[i], vmin=0, vmax=10)
+## a.text(0.05, 0.95, "iteration {0}".format(i),
+## verticalalignment='top')
+
+## a=fig.add_subplot(rows,cols,i+2)
+## imgplot = plt.imshow(results[i]-results[i-1], vmin=0, vmax=10)
+## a.text(0.05, 0.95, "difference {0}-{1}".format(i, i-1),
+## verticalalignment='top')
+
+
+
+plt.show()
diff --git a/Wrappers/Python/test/test_reconstructor.py b/Wrappers/Python/test/test_reconstructor.py
new file mode 100644
index 0000000..40065e7
--- /dev/null
+++ b/Wrappers/Python/test/test_reconstructor.py
@@ -0,0 +1,359 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Wed Aug 23 16:34:49 2017
+
+@author: ofn77899
+Based on DemoRD2.m
+"""
+
+import h5py
+import numpy
+
+from ccpi.reconstruction.FISTAReconstructor import FISTAReconstructor
+import astra
+import matplotlib.pyplot as plt
+from ccpi.imaging.Regularizer import Regularizer
+from ccpi.reconstruction.AstraDevice import AstraDevice
+from ccpi.reconstruction.DeviceModel import DeviceModel
+
+def RMSE(signal1, signal2):
+ '''RMSE Root Mean Squared Error'''
+ if numpy.shape(signal1) == numpy.shape(signal2):
+ err = (signal1 - signal2)
+ err = numpy.sum( err * err )/numpy.size(signal1); # MSE
+ err = sqrt(err); # RMSE
+ return err
+ else:
+ raise Exception('Input signals must have the same shape')
+
+def createAstraDevice(projector_geometry, output_geometry):
+ '''TODO remove'''
+
+ device = AstraDevice(DeviceModel.DeviceType.PARALLEL3D.value,
+ [projector_geometry['DetectorRowCount'] ,
+ projector_geometry['DetectorColCount'] ,
+ projector_geometry['DetectorSpacingX'] ,
+ projector_geometry['DetectorSpacingY'] ,
+ projector_geometry['ProjectionAngles']
+ ],
+ [
+ output_geometry['GridColCount'],
+ output_geometry['GridRowCount'],
+ output_geometry['GridSliceCount'] ] )
+ return device
+
+filename = r'/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/demos/DendrData.h5'
+nx = h5py.File(filename, "r")
+#getEntry(nx, '/')
+# I have exported the entries as children of /
+entries = [entry for entry in nx['/'].keys()]
+print (entries)
+
+Sino3D = numpy.asarray(nx.get('/Sino3D'), dtype="float32")
+Weights3D = numpy.asarray(nx.get('/Weights3D'), dtype="float32")
+angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0]
+angles_rad = numpy.asarray(nx.get('/angles_rad'), dtype="float32")
+recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0]
+size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0]
+slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0]
+
+Z_slices = 20
+det_row_count = Z_slices
+# next definition is just for consistency of naming
+det_col_count = size_det
+
+detectorSpacingX = 1.0
+detectorSpacingY = detectorSpacingX
+
+
+proj_geom = astra.creators.create_proj_geom('parallel3d',
+ detectorSpacingX,
+ detectorSpacingY,
+ det_row_count,
+ det_col_count,
+ angles_rad)
+
+#vol_geom = astra_create_vol_geom(recon_size,recon_size,Z_slices);
+image_size_x = recon_size
+image_size_y = recon_size
+image_size_z = Z_slices
+vol_geom = astra.creators.create_vol_geom( image_size_x,
+ image_size_y,
+ image_size_z)
+
+## First pass the arguments to the FISTAReconstructor and test the
+## Lipschitz constant
+
+##fistaRecon = FISTAReconstructor(proj_geom,
+## vol_geom,
+## Sino3D ,
+## weights=Weights3D)
+##
+##print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant']))
+##fistaRecon.setParameter(number_of_iterations = 12)
+##fistaRecon.setParameter(Lipschitz_constant = 767893952.0)
+##fistaRecon.setParameter(ring_alpha = 21)
+##fistaRecon.setParameter(ring_lambda_R_L1 = 0.002)
+##
+##reg = Regularizer(Regularizer.Algorithm.LLT_model)
+##reg.setParameter(regularization_parameter=25,
+## time_step=0.0003,
+## tolerance_constant=0.0001,
+## number_of_iterations=300)
+##fistaRecon.setParameter(regularizer=reg)
+
+## Ordered subset
+if False:
+ subsets = 16
+ angles = fistaRecon.getParameter('projector_geometry')['ProjectionAngles']
+ #binEdges = numpy.linspace(angles.min(),
+ # angles.max(),
+ # subsets + 1)
+ binsDiscr, binEdges = numpy.histogram(angles, bins=subsets)
+ # get rearranged subset indices
+ IndicesReorg = numpy.zeros((numpy.shape(angles)))
+ counterM = 0
+ for ii in range(binsDiscr.max()):
+ counter = 0
+ for jj in range(subsets):
+ curr_index = ii + jj + counter
+ #print ("{0} {1} {2}".format(binsDiscr[jj] , ii, counterM))
+ if binsDiscr[jj] > ii:
+ if (counterM < numpy.size(IndicesReorg)):
+ IndicesReorg[counterM] = curr_index
+ counterM = counterM + 1
+
+ counter = counter + binsDiscr[jj] - 1
+
+
+if False:
+ print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant']))
+ print ("prepare for iteration")
+ fistaRecon.prepareForIteration()
+
+
+
+ print("initializing ...")
+ if False:
+ # if X doesn't exist
+ #N = params.vol_geom.GridColCount
+ N = vol_geom['GridColCount']
+ print ("N " + str(N))
+ X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float)
+ else:
+ #X = fistaRecon.initialize()
+ X = numpy.load("X.npy")
+
+ print (numpy.shape(X))
+ X_t = X.copy()
+ print ("initialized")
+ proj_geom , vol_geom, sino , \
+ SlicesZ = fistaRecon.getParameter(['projector_geometry' ,
+ 'output_geometry',
+ 'input_sinogram',
+ 'SlicesZ'])
+
+ #fistaRecon.setParameter(number_of_iterations = 3)
+ iterFISTA = fistaRecon.getParameter('number_of_iterations')
+ # errors vector (if the ground truth is given)
+ Resid_error = numpy.zeros((iterFISTA));
+ # objective function values vector
+ objective = numpy.zeros((iterFISTA));
+
+
+ t = 1
+
+
+ print ("starting iterations")
+## % Outer FISTA iterations loop
+ for i in range(fistaRecon.getParameter('number_of_iterations')):
+ X_old = X.copy()
+ t_old = t
+ r_old = fistaRecon.r.copy()
+ if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \
+ fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat' or \
+ fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat_vec' :
+ # if the geometry is parallel use slice-by-slice
+ # projection-backprojection routine
+ #sino_updt = zeros(size(sino),'single');
+ proj_geomT = proj_geom.copy()
+ proj_geomT['DetectorRowCount'] = 1
+ vol_geomT = vol_geom.copy()
+ vol_geomT['GridSliceCount'] = 1;
+ sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float)
+ for kkk in range(SlicesZ):
+ sino_id, sino_updt[kkk] = \
+ astra.creators.create_sino3d_gpu(
+ X_t[kkk:kkk+1], proj_geom, vol_geom)
+ astra.matlab.data3d('delete', sino_id)
+ else:
+ # for divergent 3D geometry (watch the GPU memory overflow in
+ # ASTRA versions < 1.8)
+ #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom);
+ sino_id, sino_updt = astra.creators.create_sino3d_gpu(
+ X_t, proj_geom, vol_geom)
+
+ ## RING REMOVAL
+ residual = fistaRecon.residual
+ lambdaR_L1 , alpha_ring , weights , L_const= \
+ fistaRecon.getParameter(['ring_lambda_R_L1',
+ 'ring_alpha' , 'weights',
+ 'Lipschitz_constant'])
+ r_x = fistaRecon.r_x
+ SlicesZ, anglesNumb, Detectors = \
+ numpy.shape(fistaRecon.getParameter('input_sinogram'))
+ if lambdaR_L1 > 0 :
+ print ("ring removal")
+ for kkk in range(anglesNumb):
+
+ residual[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \
+ ((sino_updt[:,kkk,:]).squeeze() - \
+ (sino[:,kkk,:]).squeeze() -\
+ (alpha_ring * r_x)
+ )
+ vec = residual.sum(axis = 1)
+ #if SlicesZ > 1:
+ # vec = vec[:,1,:].squeeze()
+ fistaRecon.r = (r_x - (1./L_const) * vec).copy()
+ objective[i] = (0.5 * (residual ** 2).sum())
+## % the ring removal part (Group-Huber fidelity)
+## for kkk = 1:anglesNumb
+## residual(:,kkk,:) = squeeze(weights(:,kkk,:)).*
+## (squeeze(sino_updt(:,kkk,:)) -
+## (squeeze(sino(:,kkk,:)) - alpha_ring.*r_x));
+## end
+## vec = sum(residual,2);
+## if (SlicesZ > 1)
+## vec = squeeze(vec(:,1,:));
+## end
+## r = r_x - (1./L_const).*vec;
+## objective(i) = (0.5*sum(residual(:).^2)); % for the objective function output
+
+
+
+ # Projection/Backprojection Routine
+ if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \
+ fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat' or\
+ fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat_vec':
+ x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32)
+ print ("Projection/Backprojection Routine")
+ for kkk in range(SlicesZ):
+
+ x_id, x_temp[kkk] = \
+ astra.creators.create_backprojection3d_gpu(
+ residual[kkk:kkk+1],
+ proj_geomT, vol_geomT)
+ astra.matlab.data3d('delete', x_id)
+ else:
+ x_id, x_temp = \
+ astra.creators.create_backprojection3d_gpu(
+ residual, proj_geom, vol_geom)
+
+ X = X_t - (1/L_const) * x_temp
+ astra.matlab.data3d('delete', sino_id)
+ astra.matlab.data3d('delete', x_id)
+
+
+ ## REGULARIZATION
+ ## SKIPPING FOR NOW
+ ## Should be simpli
+ # regularizer = fistaRecon.getParameter('regularizer')
+ # for slices:
+ # out = regularizer(input=X)
+ print ("skipping regularizer")
+
+
+ ## FINAL
+ print ("final")
+ lambdaR_L1 = fistaRecon.getParameter('ring_lambda_R_L1')
+ if lambdaR_L1 > 0:
+ fistaRecon.r = numpy.max(
+ numpy.abs(fistaRecon.r) - lambdaR_L1 , 0) * \
+ numpy.sign(fistaRecon.r)
+ t = (1 + numpy.sqrt(1 + 4 * t**2))/2
+ X_t = X + (((t_old -1)/t) * (X - X_old))
+
+ if lambdaR_L1 > 0:
+ fistaRecon.r_x = fistaRecon.r + \
+ (((t_old-1)/t) * (fistaRecon.r - r_old))
+
+ if fistaRecon.getParameter('region_of_interest') is None:
+ string = 'Iteration Number {0} | Objective {1} \n'
+ print (string.format( i, objective[i]))
+ else:
+ ROI , X_ideal = fistaRecon.getParameter('region_of_interest',
+ 'ideal_image')
+
+ Resid_error[i] = RMSE(X*ROI, X_ideal*ROI)
+ string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n'
+ print (string.format(i,Resid_error[i], objective[i]))
+
+## if (lambdaR_L1 > 0)
+## r = max(abs(r)-lambdaR_L1, 0).*sign(r); % soft-thresholding operator for ring vector
+## end
+##
+## t = (1 + sqrt(1 + 4*t^2))/2; % updating t
+## X_t = X + ((t_old-1)/t).*(X - X_old); % updating X
+##
+## if (lambdaR_L1 > 0)
+## r_x = r + ((t_old-1)/t).*(r - r_old); % updating r
+## end
+##
+## if (show == 1)
+## figure(10); imshow(X(:,:,slice), [0 maxvalplot]);
+## if (lambdaR_L1 > 0)
+## figure(11); plot(r); title('Rings offset vector')
+## end
+## pause(0.01);
+## end
+## if (strcmp(X_ideal, 'none' ) == 0)
+## Resid_error(i) = RMSE(X(ROI), X_ideal(ROI));
+## fprintf('%s %i %s %s %.4f %s %s %f \n', 'Iteration Number:', i, '|', 'Error RMSE:', Resid_error(i), '|', 'Objective:', objective(i));
+## else
+## fprintf('%s %i %s %s %f \n', 'Iteration Number:', i, '|', 'Objective:', objective(i));
+## end
+else:
+
+ # create a device for forward/backprojection
+ #astradevice = createAstraDevice(proj_geom, vol_geom)
+
+ astradevice = AstraDevice(DeviceModel.DeviceType.PARALLEL3D.value,
+ [proj_geom['DetectorRowCount'] ,
+ proj_geom['DetectorColCount'] ,
+ proj_geom['DetectorSpacingX'] ,
+ proj_geom['DetectorSpacingY'] ,
+ proj_geom['ProjectionAngles']
+ ],
+ [
+ vol_geom['GridColCount'],
+ vol_geom['GridRowCount'],
+ vol_geom['GridSliceCount'] ] )
+
+ regul = Regularizer(Regularizer.Algorithm.FGP_TV)
+ regul.setParameter(regularization_parameter=5e6,
+ number_of_iterations=50,
+ tolerance_constant=1e-4,
+ TV_penalty=Regularizer.TotalVariationPenalty.isotropic)
+
+ fistaRecon = FISTAReconstructor(proj_geom,
+ vol_geom,
+ Sino3D ,
+ device = astradevice,
+ weights=Weights3D,
+ regularizer = regul
+ )
+
+ print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant']))
+ fistaRecon.setParameter(number_of_iterations = 18)
+ fistaRecon.setParameter(Lipschitz_constant = 767893952.0)
+ fistaRecon.setParameter(ring_alpha = 21)
+ fistaRecon.setParameter(ring_lambda_R_L1 = 0.002)
+
+
+
+ fistaRecon.prepareForIteration()
+ X = numpy.load("X.npy")
+
+
+ X = fistaRecon.iterate(X)
+ #numpy.save("X_out.npy", X)
diff --git a/Wrappers/Python/test/test_regularizers.py b/Wrappers/Python/test/test_regularizers.py
new file mode 100644
index 0000000..27e4ed3
--- /dev/null
+++ b/Wrappers/Python/test/test_regularizers.py
@@ -0,0 +1,412 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Fri Aug 4 11:10:05 2017
+
+@author: ofn77899
+"""
+
+#from ccpi.viewer.CILViewer2D import Converter
+#import vtk
+
+import matplotlib.pyplot as plt
+import numpy as np
+import os
+from enum import Enum
+import timeit
+#from PIL import Image
+#from Regularizer import Regularizer
+from ccpi.imaging.Regularizer import Regularizer
+
+###############################################################################
+#https://stackoverflow.com/questions/13875989/comparing-image-in-url-to-image-in-filesystem-in-python/13884956#13884956
+#NRMSE a normalization of the root of the mean squared error
+#NRMSE is simply 1 - [RMSE / (maxval - minval)]. Where maxval is the maximum
+# intensity from the two images being compared, and respectively the same for
+# minval. RMSE is given by the square root of MSE:
+# sqrt[(sum(A - B) ** 2) / |A|],
+# where |A| means the number of elements in A. By doing this, the maximum value
+# given by RMSE is maxval.
+
+def nrmse(im1, im2):
+ a, b = im1.shape
+ rmse = np.sqrt(np.sum((im2 - im1) ** 2) / float(a * b))
+ max_val = max(np.max(im1), np.max(im2))
+ min_val = min(np.min(im1), np.min(im2))
+ return 1 - (rmse / (max_val - min_val))
+###############################################################################
+
+###############################################################################
+#
+# 2D Regularizers
+#
+###############################################################################
+#Example:
+# figure;
+# Im = double(imread('lena_gray_256.tif'))/255; % loading image
+# u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0;
+# u = SplitBregman_TV(single(u0), 10, 30, 1e-04);
+
+
+#filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\lena_gray_512.tif"
+filename = r"/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/lena_gray_512.tif"
+#filename = r'/home/algol/Documents/Python/STD_test_images/lena_gray_512.tif'
+
+#reader = vtk.vtkTIFFReader()
+#reader.SetFileName(os.path.normpath(filename))
+#reader.Update()
+Im = plt.imread(filename)
+#Im = Image.open('/home/algol/Documents/Python/STD_test_images/lena_gray_512.tif')/255
+#img.show()
+Im = np.asarray(Im, dtype='float32')
+
+
+
+
+#imgplot = plt.imshow(Im)
+perc = 0.05
+u0 = Im + (perc* np.random.normal(size=np.shape(Im)))
+# map the u0 u0->u0>0
+f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1)
+u0 = f(u0).astype('float32')
+
+## plot
+fig = plt.figure()
+#a=fig.add_subplot(3,3,1)
+#a.set_title('Original')
+#imgplot = plt.imshow(Im)
+
+a=fig.add_subplot(2,3,1)
+a.set_title('noise')
+imgplot = plt.imshow(u0,cmap="gray")
+
+reg_output = []
+##############################################################################
+# Call regularizer
+
+####################### SplitBregman_TV #####################################
+# u = SplitBregman_TV(single(u0), 10, 30, 1e-04);
+
+use_object = True
+if use_object:
+ reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV)
+ print (reg.pars)
+ reg.setParameter(input=u0)
+ reg.setParameter(regularization_parameter=10.)
+ # or
+ # reg.setParameter(input=u0, regularization_parameter=10., #number_of_iterations=30,
+ #tolerance_constant=1e-4,
+ #TV_Penalty=Regularizer.TotalVariationPenalty.l1)
+ plotme = reg() [0]
+ pars = reg.pars
+ textstr = reg.printParametersToString()
+
+ #out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30,
+ #tolerance_constant=1e-4,
+ # TV_Penalty=Regularizer.TotalVariationPenalty.l1)
+
+#out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30,
+# tolerance_constant=1e-4,
+# TV_Penalty=Regularizer.TotalVariationPenalty.l1)
+
+else:
+ out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. )
+ pars = out2[2]
+ reg_output.append(out2)
+ plotme = reg_output[-1][0]
+ textstr = out2[-1]
+
+a=fig.add_subplot(2,3,2)
+
+
+# these are matplotlib.patch.Patch properties
+props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+# place a text box in upper left in axes coords
+a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+ verticalalignment='top', bbox=props)
+imgplot = plt.imshow(plotme,cmap="gray")
+
+###################### FGP_TV #########################################
+# u = FGP_TV(single(u0), 0.05, 100, 1e-04);
+out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.0005,
+ number_of_iterations=50)
+pars = out2[-2]
+
+reg_output.append(out2)
+
+a=fig.add_subplot(2,3,3)
+
+textstr = out2[-1]
+
+# these are matplotlib.patch.Patch properties
+props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+# place a text box in upper left in axes coords
+a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+ verticalalignment='top', bbox=props)
+imgplot = plt.imshow(reg_output[-1][0])
+# place a text box in upper left in axes coords
+a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+ verticalalignment='top', bbox=props)
+imgplot = plt.imshow(reg_output[-1][0],cmap="gray")
+
+###################### LLT_model #########################################
+# * u0 = Im + .03*randn(size(Im)); % adding noise
+# [Den] = LLT_model(single(u0), 10, 0.1, 1);
+#Den = LLT_model(single(u0), 25, 0.0003, 300, 0.0001, 0);
+#input, regularization_parameter , time_step, number_of_iterations,
+# tolerance_constant, restrictive_Z_smoothing=0
+out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25,
+ time_step=0.0003,
+ tolerance_constant=0.0001,
+ number_of_iterations=300)
+pars = out2[-2]
+
+reg_output.append(out2)
+
+a=fig.add_subplot(2,3,4)
+
+textstr = out2[-1]
+
+# these are matplotlib.patch.Patch properties
+props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+# place a text box in upper left in axes coords
+a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+ verticalalignment='top', bbox=props)
+imgplot = plt.imshow(reg_output[-1][0],cmap="gray")
+
+
+# ###################### PatchBased_Regul #########################################
+# # Quick 2D denoising example in Matlab:
+# # Im = double(imread('lena_gray_256.tif'))/255; % loading image
+# # u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise
+# # ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05);
+
+out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05,
+ searching_window_ratio=3,
+ similarity_window_ratio=1,
+ PB_filtering_parameter=0.08)
+pars = out2[-2]
+reg_output.append(out2)
+
+a=fig.add_subplot(2,3,5)
+
+
+textstr = out2[-1]
+
+# these are matplotlib.patch.Patch properties
+props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+# place a text box in upper left in axes coords
+a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+ verticalalignment='top', bbox=props)
+imgplot = plt.imshow(reg_output[-1][0],cmap="gray")
+
+
+# ###################### TGV_PD #########################################
+# # Quick 2D denoising example in Matlab:
+# # Im = double(imread('lena_gray_256.tif'))/255; % loading image
+# # u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise
+# # u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550);
+
+
+out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05,
+ first_order_term=1.3,
+ second_order_term=1,
+ number_of_iterations=550)
+pars = out2[-2]
+reg_output.append(out2)
+
+a=fig.add_subplot(2,3,6)
+
+
+textstr = out2[-1]
+
+
+# these are matplotlib.patch.Patch properties
+props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+# place a text box in upper left in axes coords
+a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+ verticalalignment='top', bbox=props)
+imgplot = plt.imshow(reg_output[-1][0],cmap="gray")
+
+
+plt.show()
+
+################################################################################
+##
+## 3D Regularizers
+##
+################################################################################
+##Example:
+## figure;
+## Im = double(imread('lena_gray_256.tif'))/255; % loading image
+## u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0;
+## u = SplitBregman_TV(single(u0), 10, 30, 1e-04);
+#
+##filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-Reconstruction\python\test\reconstruction_example.mha"
+#filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-Simpleflex\data\head.mha"
+#
+#reader = vtk.vtkMetaImageReader()
+#reader.SetFileName(os.path.normpath(filename))
+#reader.Update()
+##vtk returns 3D images, let's take just the one slice there is as 2D
+#Im = Converter.vtk2numpy(reader.GetOutput())
+#Im = Im.astype('float32')
+##imgplot = plt.imshow(Im)
+#perc = 0.05
+#u0 = Im + (perc* np.random.normal(size=np.shape(Im)))
+## map the u0 u0->u0>0
+#f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1)
+#u0 = f(u0).astype('float32')
+#converter = Converter.numpy2vtkImporter(u0, reader.GetOutput().GetSpacing(),
+# reader.GetOutput().GetOrigin())
+#converter.Update()
+#writer = vtk.vtkMetaImageWriter()
+#writer.SetInputData(converter.GetOutput())
+#writer.SetFileName(r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\noisy_head.mha")
+##writer.Write()
+#
+#
+### plot
+#fig3D = plt.figure()
+##a=fig.add_subplot(3,3,1)
+##a.set_title('Original')
+##imgplot = plt.imshow(Im)
+#sliceNo = 32
+#
+#a=fig3D.add_subplot(2,3,1)
+#a.set_title('noise')
+#imgplot = plt.imshow(u0.T[sliceNo])
+#
+#reg_output3d = []
+#
+###############################################################################
+## Call regularizer
+#
+######################## SplitBregman_TV #####################################
+## u = SplitBregman_TV(single(u0), 10, 30, 1e-04);
+#
+##reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV)
+#
+##out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30,
+## #tolerance_constant=1e-4,
+## TV_Penalty=Regularizer.TotalVariationPenalty.l1)
+#
+#out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30,
+# tolerance_constant=1e-4,
+# TV_Penalty=Regularizer.TotalVariationPenalty.l1)
+#
+#
+#pars = out2[-2]
+#reg_output3d.append(out2)
+#
+#a=fig3D.add_subplot(2,3,2)
+#
+#
+#textstr = out2[-1]
+#
+#
+## these are matplotlib.patch.Patch properties
+#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+## place a text box in upper left in axes coords
+#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+# verticalalignment='top', bbox=props)
+#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo])
+#
+####################### FGP_TV #########################################
+## u = FGP_TV(single(u0), 0.05, 100, 1e-04);
+#out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.005,
+# number_of_iterations=200)
+#pars = out2[-2]
+#reg_output3d.append(out2)
+#
+#a=fig3D.add_subplot(2,3,2)
+#
+#
+#textstr = out2[-1]
+#
+#
+## these are matplotlib.patch.Patch properties
+#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+## place a text box in upper left in axes coords
+#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+# verticalalignment='top', bbox=props)
+#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo])
+#
+####################### LLT_model #########################################
+## * u0 = Im + .03*randn(size(Im)); % adding noise
+## [Den] = LLT_model(single(u0), 10, 0.1, 1);
+##Den = LLT_model(single(u0), 25, 0.0003, 300, 0.0001, 0);
+##input, regularization_parameter , time_step, number_of_iterations,
+## tolerance_constant, restrictive_Z_smoothing=0
+#out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25,
+# time_step=0.0003,
+# tolerance_constant=0.0001,
+# number_of_iterations=300)
+#pars = out2[-2]
+#reg_output3d.append(out2)
+#
+#a=fig3D.add_subplot(2,3,2)
+#
+#
+#textstr = out2[-1]
+#
+#
+## these are matplotlib.patch.Patch properties
+#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+## place a text box in upper left in axes coords
+#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+# verticalalignment='top', bbox=props)
+#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo])
+#
+####################### PatchBased_Regul #########################################
+## Quick 2D denoising example in Matlab:
+## Im = double(imread('lena_gray_256.tif'))/255; % loading image
+## u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise
+## ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05);
+#
+#out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05,
+# searching_window_ratio=3,
+# similarity_window_ratio=1,
+# PB_filtering_parameter=0.08)
+#pars = out2[-2]
+#reg_output3d.append(out2)
+#
+#a=fig3D.add_subplot(2,3,2)
+#
+#
+#textstr = out2[-1]
+#
+#
+## these are matplotlib.patch.Patch properties
+#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+## place a text box in upper left in axes coords
+#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+# verticalalignment='top', bbox=props)
+#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo])
+#
+
+###################### TGV_PD #########################################
+# Quick 2D denoising example in Matlab:
+# Im = double(imread('lena_gray_256.tif'))/255; % loading image
+# u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise
+# u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550);
+
+
+#out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05,
+# first_order_term=1.3,
+# second_order_term=1,
+# number_of_iterations=550)
+#pars = out2[-2]
+#reg_output3d.append(out2)
+#
+#a=fig3D.add_subplot(2,3,2)
+#
+#
+#textstr = out2[-1]
+#
+#
+## these are matplotlib.patch.Patch properties
+#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+## place a text box in upper left in axes coords
+#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+# verticalalignment='top', bbox=props)
+#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo])
diff --git a/Wrappers/Python/test/test_regularizers_3d.py b/Wrappers/Python/test/test_regularizers_3d.py
new file mode 100644
index 0000000..2d11a7e
--- /dev/null
+++ b/Wrappers/Python/test/test_regularizers_3d.py
@@ -0,0 +1,425 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Fri Aug 4 11:10:05 2017
+
+@author: ofn77899
+"""
+
+#from ccpi.viewer.CILViewer2D import Converter
+#import vtk
+
+import matplotlib.pyplot as plt
+import numpy as np
+import os
+from enum import Enum
+import timeit
+#from PIL import Image
+#from Regularizer import Regularizer
+from ccpi.imaging.Regularizer import Regularizer
+
+###############################################################################
+#https://stackoverflow.com/questions/13875989/comparing-image-in-url-to-image-in-filesystem-in-python/13884956#13884956
+#NRMSE a normalization of the root of the mean squared error
+#NRMSE is simply 1 - [RMSE / (maxval - minval)]. Where maxval is the maximum
+# intensity from the two images being compared, and respectively the same for
+# minval. RMSE is given by the square root of MSE:
+# sqrt[(sum(A - B) ** 2) / |A|],
+# where |A| means the number of elements in A. By doing this, the maximum value
+# given by RMSE is maxval.
+
+def nrmse(im1, im2):
+ a, b = im1.shape
+ rmse = np.sqrt(np.sum((im2 - im1) ** 2) / float(a * b))
+ max_val = max(np.max(im1), np.max(im2))
+ min_val = min(np.min(im1), np.min(im2))
+ return 1 - (rmse / (max_val - min_val))
+###############################################################################
+
+###############################################################################
+#
+# 2D Regularizers
+#
+###############################################################################
+#Example:
+# figure;
+# Im = double(imread('lena_gray_256.tif'))/255; % loading image
+# u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0;
+# u = SplitBregman_TV(single(u0), 10, 30, 1e-04);
+
+
+#filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\lena_gray_512.tif"
+filename = r"/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/lena_gray_512.tif"
+#filename = r'/home/algol/Documents/Python/STD_test_images/lena_gray_512.tif'
+
+#reader = vtk.vtkTIFFReader()
+#reader.SetFileName(os.path.normpath(filename))
+#reader.Update()
+Im = plt.imread(filename)
+#Im = Image.open('/home/algol/Documents/Python/STD_test_images/lena_gray_512.tif')/255
+#img.show()
+Im = np.asarray(Im, dtype='float32')
+
+# create a 3D image by stacking N of this images
+
+
+#imgplot = plt.imshow(Im)
+perc = 0.05
+u_n = Im + (perc* np.random.normal(size=np.shape(Im)))
+y,z = np.shape(u_n)
+u_n = np.reshape(u_n , (1,y,z))
+
+u0 = u_n.copy()
+for i in range (19):
+ u_n = Im + (perc* np.random.normal(size=np.shape(Im)))
+ u_n = np.reshape(u_n , (1,y,z))
+
+ u0 = np.vstack ( (u0, u_n) )
+
+# map the u0 u0->u0>0
+f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1)
+u0 = f(u0).astype('float32')
+
+print ("Passed image shape {0}".format(np.shape(u0)))
+
+## plot
+fig = plt.figure()
+#a=fig.add_subplot(3,3,1)
+#a.set_title('Original')
+#imgplot = plt.imshow(Im)
+sliceno = 10
+
+a=fig.add_subplot(2,3,1)
+a.set_title('noise')
+imgplot = plt.imshow(u0[sliceno],cmap="gray")
+
+reg_output = []
+##############################################################################
+# Call regularizer
+
+####################### SplitBregman_TV #####################################
+# u = SplitBregman_TV(single(u0), 10, 30, 1e-04);
+
+use_object = True
+if use_object:
+ reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV)
+ print (reg.pars)
+ reg.setParameter(input=u0)
+ reg.setParameter(regularization_parameter=10.)
+ # or
+ # reg.setParameter(input=u0, regularization_parameter=10., #number_of_iterations=30,
+ #tolerance_constant=1e-4,
+ #TV_Penalty=Regularizer.TotalVariationPenalty.l1)
+ plotme = reg() [0]
+ pars = reg.pars
+ textstr = reg.printParametersToString()
+
+ #out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30,
+ #tolerance_constant=1e-4,
+ # TV_Penalty=Regularizer.TotalVariationPenalty.l1)
+
+#out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30,
+# tolerance_constant=1e-4,
+# TV_Penalty=Regularizer.TotalVariationPenalty.l1)
+
+else:
+ out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. )
+ pars = out2[2]
+ reg_output.append(out2)
+ plotme = reg_output[-1][0]
+ textstr = out2[-1]
+
+a=fig.add_subplot(2,3,2)
+
+
+# these are matplotlib.patch.Patch properties
+props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+# place a text box in upper left in axes coords
+a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+ verticalalignment='top', bbox=props)
+imgplot = plt.imshow(plotme[sliceno],cmap="gray")
+
+###################### FGP_TV #########################################
+# u = FGP_TV(single(u0), 0.05, 100, 1e-04);
+out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.0005,
+ number_of_iterations=50)
+pars = out2[-2]
+
+reg_output.append(out2)
+
+a=fig.add_subplot(2,3,3)
+
+textstr = out2[-1]
+
+# these are matplotlib.patch.Patch properties
+props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+# place a text box in upper left in axes coords
+a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+ verticalalignment='top', bbox=props)
+imgplot = plt.imshow(reg_output[-1][0][sliceno])
+# place a text box in upper left in axes coords
+a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+ verticalalignment='top', bbox=props)
+imgplot = plt.imshow(reg_output[-1][0][sliceno],cmap="gray")
+
+###################### LLT_model #########################################
+# * u0 = Im + .03*randn(size(Im)); % adding noise
+# [Den] = LLT_model(single(u0), 10, 0.1, 1);
+#Den = LLT_model(single(u0), 25, 0.0003, 300, 0.0001, 0);
+#input, regularization_parameter , time_step, number_of_iterations,
+# tolerance_constant, restrictive_Z_smoothing=0
+out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25,
+ time_step=0.0003,
+ tolerance_constant=0.0001,
+ number_of_iterations=300)
+pars = out2[-2]
+
+reg_output.append(out2)
+
+a=fig.add_subplot(2,3,4)
+
+textstr = out2[-1]
+
+# these are matplotlib.patch.Patch properties
+props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+# place a text box in upper left in axes coords
+a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+ verticalalignment='top', bbox=props)
+imgplot = plt.imshow(reg_output[-1][0][sliceno],cmap="gray")
+
+
+# ###################### PatchBased_Regul #########################################
+# # Quick 2D denoising example in Matlab:
+# # Im = double(imread('lena_gray_256.tif'))/255; % loading image
+# # u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise
+# # ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05);
+
+out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05,
+ searching_window_ratio=3,
+ similarity_window_ratio=1,
+ PB_filtering_parameter=0.08)
+pars = out2[-2]
+reg_output.append(out2)
+
+a=fig.add_subplot(2,3,5)
+
+
+textstr = out2[-1]
+
+# these are matplotlib.patch.Patch properties
+props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+# place a text box in upper left in axes coords
+a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+ verticalalignment='top', bbox=props)
+imgplot = plt.imshow(reg_output[-1][0][sliceno],cmap="gray")
+
+
+# ###################### TGV_PD #########################################
+# # Quick 2D denoising example in Matlab:
+# # Im = double(imread('lena_gray_256.tif'))/255; % loading image
+# # u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise
+# # u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550);
+
+
+out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05,
+ first_order_term=1.3,
+ second_order_term=1,
+ number_of_iterations=550)
+pars = out2[-2]
+reg_output.append(out2)
+
+a=fig.add_subplot(2,3,6)
+
+
+textstr = out2[-1]
+
+
+# these are matplotlib.patch.Patch properties
+props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+# place a text box in upper left in axes coords
+a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+ verticalalignment='top', bbox=props)
+imgplot = plt.imshow(reg_output[-1][0][sliceno],cmap="gray")
+
+
+plt.show()
+
+################################################################################
+##
+## 3D Regularizers
+##
+################################################################################
+##Example:
+## figure;
+## Im = double(imread('lena_gray_256.tif'))/255; % loading image
+## u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0;
+## u = SplitBregman_TV(single(u0), 10, 30, 1e-04);
+#
+##filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-Reconstruction\python\test\reconstruction_example.mha"
+#filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-Simpleflex\data\head.mha"
+#
+#reader = vtk.vtkMetaImageReader()
+#reader.SetFileName(os.path.normpath(filename))
+#reader.Update()
+##vtk returns 3D images, let's take just the one slice there is as 2D
+#Im = Converter.vtk2numpy(reader.GetOutput())
+#Im = Im.astype('float32')
+##imgplot = plt.imshow(Im)
+#perc = 0.05
+#u0 = Im + (perc* np.random.normal(size=np.shape(Im)))
+## map the u0 u0->u0>0
+#f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1)
+#u0 = f(u0).astype('float32')
+#converter = Converter.numpy2vtkImporter(u0, reader.GetOutput().GetSpacing(),
+# reader.GetOutput().GetOrigin())
+#converter.Update()
+#writer = vtk.vtkMetaImageWriter()
+#writer.SetInputData(converter.GetOutput())
+#writer.SetFileName(r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\noisy_head.mha")
+##writer.Write()
+#
+#
+### plot
+#fig3D = plt.figure()
+##a=fig.add_subplot(3,3,1)
+##a.set_title('Original')
+##imgplot = plt.imshow(Im)
+#sliceNo = 32
+#
+#a=fig3D.add_subplot(2,3,1)
+#a.set_title('noise')
+#imgplot = plt.imshow(u0.T[sliceNo])
+#
+#reg_output3d = []
+#
+###############################################################################
+## Call regularizer
+#
+######################## SplitBregman_TV #####################################
+## u = SplitBregman_TV(single(u0), 10, 30, 1e-04);
+#
+##reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV)
+#
+##out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30,
+## #tolerance_constant=1e-4,
+## TV_Penalty=Regularizer.TotalVariationPenalty.l1)
+#
+#out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30,
+# tolerance_constant=1e-4,
+# TV_Penalty=Regularizer.TotalVariationPenalty.l1)
+#
+#
+#pars = out2[-2]
+#reg_output3d.append(out2)
+#
+#a=fig3D.add_subplot(2,3,2)
+#
+#
+#textstr = out2[-1]
+#
+#
+## these are matplotlib.patch.Patch properties
+#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+## place a text box in upper left in axes coords
+#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+# verticalalignment='top', bbox=props)
+#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo])
+#
+####################### FGP_TV #########################################
+## u = FGP_TV(single(u0), 0.05, 100, 1e-04);
+#out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.005,
+# number_of_iterations=200)
+#pars = out2[-2]
+#reg_output3d.append(out2)
+#
+#a=fig3D.add_subplot(2,3,2)
+#
+#
+#textstr = out2[-1]
+#
+#
+## these are matplotlib.patch.Patch properties
+#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+## place a text box in upper left in axes coords
+#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+# verticalalignment='top', bbox=props)
+#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo])
+#
+####################### LLT_model #########################################
+## * u0 = Im + .03*randn(size(Im)); % adding noise
+## [Den] = LLT_model(single(u0), 10, 0.1, 1);
+##Den = LLT_model(single(u0), 25, 0.0003, 300, 0.0001, 0);
+##input, regularization_parameter , time_step, number_of_iterations,
+## tolerance_constant, restrictive_Z_smoothing=0
+#out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25,
+# time_step=0.0003,
+# tolerance_constant=0.0001,
+# number_of_iterations=300)
+#pars = out2[-2]
+#reg_output3d.append(out2)
+#
+#a=fig3D.add_subplot(2,3,2)
+#
+#
+#textstr = out2[-1]
+#
+#
+## these are matplotlib.patch.Patch properties
+#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+## place a text box in upper left in axes coords
+#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+# verticalalignment='top', bbox=props)
+#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo])
+#
+####################### PatchBased_Regul #########################################
+## Quick 2D denoising example in Matlab:
+## Im = double(imread('lena_gray_256.tif'))/255; % loading image
+## u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise
+## ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05);
+#
+#out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05,
+# searching_window_ratio=3,
+# similarity_window_ratio=1,
+# PB_filtering_parameter=0.08)
+#pars = out2[-2]
+#reg_output3d.append(out2)
+#
+#a=fig3D.add_subplot(2,3,2)
+#
+#
+#textstr = out2[-1]
+#
+#
+## these are matplotlib.patch.Patch properties
+#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+## place a text box in upper left in axes coords
+#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+# verticalalignment='top', bbox=props)
+#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo])
+#
+
+###################### TGV_PD #########################################
+# Quick 2D denoising example in Matlab:
+# Im = double(imread('lena_gray_256.tif'))/255; % loading image
+# u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise
+# u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550);
+
+
+#out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05,
+# first_order_term=1.3,
+# second_order_term=1,
+# number_of_iterations=550)
+#pars = out2[-2]
+#reg_output3d.append(out2)
+#
+#a=fig3D.add_subplot(2,3,2)
+#
+#
+#textstr = out2[-1]
+#
+#
+## these are matplotlib.patch.Patch properties
+#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+## place a text box in upper left in axes coords
+#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+# verticalalignment='top', bbox=props)
+#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo])
diff --git a/Wrappers/Python/test/view_result.py b/Wrappers/Python/test/view_result.py
new file mode 100644
index 0000000..f89a90c
--- /dev/null
+++ b/Wrappers/Python/test/view_result.py
@@ -0,0 +1,12 @@
+import numpy
+from ccpi.viewer.CILViewer2D import *
+import sys
+#reader = vtk.vtkMetaImageReader()
+#reader.SetFileName("X_out_os_s.mhd")
+#reader.Update()
+
+X = numpy.load(sys.argv[1])
+
+v = CILViewer2D()
+v.setInputAsNumpy(X)
+v.startRenderLoop()