/*
-----------------------------------------------------------------------
Copyright: 2010-2021, imec Vision Lab, University of Antwerp
2014-2021, CWI, Amsterdam
Contact: astra@astra-toolbox.com
Website: http://www.astra-toolbox.com/
This file is part of the ASTRA Toolbox.
The ASTRA Toolbox is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
The ASTRA Toolbox is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with the ASTRA Toolbox. If not, see .
-----------------------------------------------------------------------
*/
#include "astra/cuda/2d/fft.h"
#include "astra/cuda/2d/util.h"
#include "astra/Logging.h"
#include "astra/Fourier.h"
#include
#include
#include
#include
using namespace astra;
// TODO: evaluate what we want to do in these situations:
#define CHECK_ERROR(errorMessage) do { \
cudaError_t err = cudaThreadSynchronize(); \
if( cudaSuccess != err) { \
ASTRA_ERROR("Cuda error %s : %s", \
errorMessage,cudaGetErrorString( err)); \
exit(EXIT_FAILURE); \
} } while (0)
#define SAFE_CALL( call) do { \
cudaError err = call; \
if( cudaSuccess != err) { \
ASTRA_ERROR("Cuda error: %s ", \
cudaGetErrorString( err)); \
exit(EXIT_FAILURE); \
} \
err = cudaThreadSynchronize(); \
if( cudaSuccess != err) { \
ASTRA_ERROR("Cuda error: %s : ", \
cudaGetErrorString( err)); \
exit(EXIT_FAILURE); \
} } while (0)
namespace astraCUDA {
bool checkCufft(cufftResult err, const char *msg)
{
if (err != CUFFT_SUCCESS) {
ASTRA_ERROR("%s: CUFFT error %d.", msg, err);
return false;
} else {
return true;
}
}
__global__ static void applyFilter_kernel(int _iProjectionCount,
int _iFreqBinCount,
cufftComplex * _pSinogram,
cufftComplex * _pFilter)
{
int iIndex = threadIdx.x + blockIdx.x * blockDim.x;
int iProjectionIndex = iIndex / _iFreqBinCount;
if(iProjectionIndex >= _iProjectionCount)
{
return;
}
float fA = _pSinogram[iIndex].x;
float fB = _pSinogram[iIndex].y;
float fC = _pFilter[iIndex].x;
float fD = _pFilter[iIndex].y;
_pSinogram[iIndex].x = fA * fC - fB * fD;
_pSinogram[iIndex].y = fA * fD + fC * fB;
}
__global__ static void rescaleInverseFourier_kernel(int _iProjectionCount,
int _iDetectorCount,
float* _pfInFourierOutput)
{
int iIndex = threadIdx.x + blockIdx.x * blockDim.x;
int iProjectionIndex = iIndex / _iDetectorCount;
int iDetectorIndex = iIndex % _iDetectorCount;
if(iProjectionIndex >= _iProjectionCount)
{
return;
}
_pfInFourierOutput[iProjectionIndex * _iDetectorCount + iDetectorIndex] /= (float)_iDetectorCount;
}
static void rescaleInverseFourier(int _iProjectionCount, int _iDetectorCount,
float * _pfInFourierOutput)
{
const int iBlockSize = 256;
int iElementCount = _iProjectionCount * _iDetectorCount;
int iBlockCount = (iElementCount + iBlockSize - 1) / iBlockSize;
rescaleInverseFourier_kernel<<< iBlockCount, iBlockSize >>>(_iProjectionCount,
_iDetectorCount,
_pfInFourierOutput);
CHECK_ERROR("rescaleInverseFourier_kernel failed");
}
void applyFilter(int _iProjectionCount, int _iFreqBinCount,
cufftComplex * _pSinogram, cufftComplex * _pFilter)
{
const int iBlockSize = 256;
int iElementCount = _iProjectionCount * _iFreqBinCount;
int iBlockCount = (iElementCount + iBlockSize - 1) / iBlockSize;
applyFilter_kernel<<< iBlockCount, iBlockSize >>>(_iProjectionCount,
_iFreqBinCount,
_pSinogram, _pFilter);
CHECK_ERROR("applyFilter_kernel failed");
}
static bool invokeCudaFFT(int _iProjectionCount, int _iDetectorCount,
const float * _pfDevSource,
cufftComplex * _pDevTargetComplex)
{
cufftHandle plan;
if (!checkCufft(cufftPlan1d(&plan, _iDetectorCount, CUFFT_R2C, _iProjectionCount), "invokeCudaFFT plan")) {
return false;
}
if (!checkCufft(cufftExecR2C(plan, (cufftReal *)_pfDevSource, _pDevTargetComplex), "invokeCudaFFT exec")) {
cufftDestroy(plan);
return false;
}
cufftDestroy(plan);
return true;
}
static bool invokeCudaIFFT(int _iProjectionCount, int _iDetectorCount,
const cufftComplex * _pDevSourceComplex,
float * _pfDevTarget)
{
cufftHandle plan;
if (!checkCufft(cufftPlan1d(&plan, _iDetectorCount, CUFFT_C2R, _iProjectionCount), "invokeCudaIFFT plan")) {
return false;
}
// Getting rid of the const qualifier is due to cufft API issue?
if (!checkCufft(cufftExecC2R(plan, (cufftComplex *)_pDevSourceComplex,
(cufftReal *)_pfDevTarget), "invokeCudaIFFT exec"))
{
cufftDestroy(plan);
return false;
}
cufftDestroy(plan);
return true;
}
bool allocateComplexOnDevice(int _iProjectionCount, int _iDetectorCount,
cufftComplex ** _ppDevComplex)
{
size_t bufferSize = sizeof(cufftComplex) * _iProjectionCount * _iDetectorCount;
SAFE_CALL(cudaMalloc((void **)_ppDevComplex, bufferSize));
return true;
}
bool freeComplexOnDevice(cufftComplex * _pDevComplex)
{
SAFE_CALL(cudaFree(_pDevComplex));
return true;
}
bool uploadComplexArrayToDevice(int _iProjectionCount, int _iDetectorCount,
cufftComplex * _pHostComplexSource,
cufftComplex * _pDevComplexTarget)
{
size_t memSize = sizeof(cufftComplex) * _iProjectionCount * _iDetectorCount;
SAFE_CALL(cudaMemcpy(_pDevComplexTarget, _pHostComplexSource, memSize, cudaMemcpyHostToDevice));
return true;
}
bool runCudaFFT(int _iProjectionCount, const float * _pfDevRealSource,
int _iSourcePitch, int _iProjDets,
int _iFFTRealDetectorCount, int _iFFTFourierDetectorCount,
cufftComplex * _pDevTargetComplex)
{
float * pfDevRealFFTSource = NULL;
size_t bufferMemSize = sizeof(float) * _iProjectionCount * _iFFTRealDetectorCount;
SAFE_CALL(cudaMalloc((void **)&pfDevRealFFTSource, bufferMemSize));
SAFE_CALL(cudaMemset(pfDevRealFFTSource, 0, bufferMemSize));
for(int iProjectionIndex = 0; iProjectionIndex < _iProjectionCount; iProjectionIndex++)
{
const float * pfSourceLocation = _pfDevRealSource + iProjectionIndex * _iSourcePitch;
float * pfTargetLocation = pfDevRealFFTSource + iProjectionIndex * _iFFTRealDetectorCount;
SAFE_CALL(cudaMemcpy(pfTargetLocation, pfSourceLocation, sizeof(float) * _iProjDets, cudaMemcpyDeviceToDevice));
}
bool bResult = invokeCudaFFT(_iProjectionCount, _iFFTRealDetectorCount,
pfDevRealFFTSource, _pDevTargetComplex);
if(!bResult)
{
return false;
}
SAFE_CALL(cudaFree(pfDevRealFFTSource));
return true;
}
bool runCudaIFFT(int _iProjectionCount, const cufftComplex* _pDevSourceComplex,
float * _pfRealTarget,
int _iTargetPitch, int _iProjDets,
int _iFFTRealDetectorCount, int _iFFTFourierDetectorCount)
{
float * pfDevRealFFTTarget = NULL;
size_t bufferMemSize = sizeof(float) * _iProjectionCount * _iFFTRealDetectorCount;
SAFE_CALL(cudaMalloc((void **)&pfDevRealFFTTarget, bufferMemSize));
bool bResult = invokeCudaIFFT(_iProjectionCount, _iFFTRealDetectorCount,
_pDevSourceComplex, pfDevRealFFTTarget);
if(!bResult)
{
return false;
}
rescaleInverseFourier(_iProjectionCount, _iFFTRealDetectorCount,
pfDevRealFFTTarget);
SAFE_CALL(cudaMemset(_pfRealTarget, 0, sizeof(float) * _iProjectionCount * _iTargetPitch));
for(int iProjectionIndex = 0; iProjectionIndex < _iProjectionCount; iProjectionIndex++)
{
const float * pfSourceLocation = pfDevRealFFTTarget + iProjectionIndex * _iFFTRealDetectorCount;
float* pfTargetLocation = _pfRealTarget + iProjectionIndex * _iTargetPitch;
SAFE_CALL(cudaMemcpy(pfTargetLocation, pfSourceLocation, sizeof(float) * _iProjDets, cudaMemcpyDeviceToDevice));
}
SAFE_CALL(cudaFree(pfDevRealFFTTarget));
return true;
}
void genIdenFilter(int _iProjectionCount, cufftComplex * _pFilter,
int _iFFTRealDetectorCount, int _iFFTFourierDetectorCount)
{
for(int iProjectionIndex = 0; iProjectionIndex < _iProjectionCount; iProjectionIndex++)
{
for(int iDetectorIndex = 0; iDetectorIndex < _iFFTFourierDetectorCount; iDetectorIndex++)
{
int iIndex = iDetectorIndex + iProjectionIndex * _iFFTFourierDetectorCount;
_pFilter[iIndex].x = 1.0f;
_pFilter[iIndex].y = 0.0f;
}
}
}
void genCuFFTFilter(const SFilterConfig &_cfg, int _iProjectionCount,
cufftComplex * _pFilter, int _iFFTRealDetectorCount,
int _iFFTFourierDetectorCount)
{
float * pfFilt = astra::genFilter(_cfg,
_iFFTRealDetectorCount,
_iFFTFourierDetectorCount);
for(int iDetectorIndex = 0; iDetectorIndex < _iFFTFourierDetectorCount; iDetectorIndex++)
{
float fFilterValue = pfFilt[iDetectorIndex];
for(int iProjectionIndex = 0; iProjectionIndex < _iProjectionCount; iProjectionIndex++)
{
int iIndex = iDetectorIndex + iProjectionIndex * _iFFTFourierDetectorCount;
_pFilter[iIndex].x = fFilterValue;
_pFilter[iIndex].y = 0.0f;
}
}
delete[] pfFilt;
}
}