diff options
author | Edoardo Pasca <edo.paskino@gmail.com> | 2019-04-12 17:28:17 +0100 |
---|---|---|
committer | Edoardo Pasca <edo.paskino@gmail.com> | 2019-04-12 17:28:17 +0100 |
commit | d1e26ec31df5a2b269e021e4a2c039e0e265a353 (patch) | |
tree | dbd8f0bcc582e126c8e355a58ed0ec2b6089290f /Wrappers | |
parent | 18ec4ecfe286d04868ddcd2febd94affd77e3a57 (diff) | |
download | framework-d1e26ec31df5a2b269e021e4a2c039e0e265a353.tar.gz framework-d1e26ec31df5a2b269e021e4a2c039e0e265a353.tar.bz2 framework-d1e26ec31df5a2b269e021e4a2c039e0e265a353.tar.xz framework-d1e26ec31df5a2b269e021e4a2c039e0e265a353.zip |
adds exception for incompatible geometry/array
Diffstat (limited to 'Wrappers')
-rwxr-xr-x | Wrappers/Python/ccpi/framework/framework.py | 227 | ||||
-rwxr-xr-x | Wrappers/Python/test/test_DataContainer.py | 16 |
2 files changed, 143 insertions, 100 deletions
diff --git a/Wrappers/Python/ccpi/framework/framework.py b/Wrappers/Python/ccpi/framework/framework.py index af4139b..7874813 100755 --- a/Wrappers/Python/ccpi/framework/framework.py +++ b/Wrappers/Python/ccpi/framework/framework.py @@ -772,61 +772,18 @@ class DataContainer(object): class ImageData(DataContainer): '''DataContainer for holding 2D or 3D DataContainer''' __container_priority__ = 1 + + def __init__(self, array = None, deep_copy=False, dimension_labels=None, **kwargs): - self.geometry = None + self.geometry = kwargs.get('geometry', None) if array is None: - if 'geometry' in kwargs.keys(): - geometry = kwargs['geometry'] - self.geometry = geometry - channels = geometry.channels - horiz_x = geometry.voxel_num_x - horiz_y = geometry.voxel_num_y - vert = 1 if geometry.voxel_num_z is None\ - else geometry.voxel_num_z # this should be 1 for 2D - if dimension_labels is None: - if channels > 1: - if vert > 1: - shape = (channels, vert, horiz_y, horiz_x) - dim_labels = [ImageGeometry.CHANNEL, - ImageGeometry.VERTICAL, - ImageGeometry.HORIZONTAL_Y, - ImageGeometry.HORIZONTAL_X] - else: - shape = (channels , horiz_y, horiz_x) - dim_labels = [ImageGeometry.CHANNEL, - ImageGeometry.HORIZONTAL_Y, - ImageGeometry.HORIZONTAL_X] - else: - if vert > 1: - shape = (vert, horiz_y, horiz_x) - dim_labels = [ImageGeometry.VERTICAL, - ImageGeometry.HORIZONTAL_Y, - ImageGeometry.HORIZONTAL_X] - else: - shape = (horiz_y, horiz_x) - dim_labels = [ImageGeometry.HORIZONTAL_Y, - ImageGeometry.HORIZONTAL_X] - dimension_labels = dim_labels - else: - shape = [] - for dim in dimension_labels: - if dim == ImageGeometry.CHANNEL: - shape.append(channels) - elif dim == ImageGeometry.HORIZONTAL_Y: - shape.append(horiz_y) - elif dim == ImageGeometry.VERTICAL: - shape.append(vert) - elif dim == ImageGeometry.HORIZONTAL_X: - shape.append(horiz_x) - if len(shape) != len(dimension_labels): - raise ValueError('Missing {0} axes'.format( - len(dimension_labels) - len(shape))) - shape = tuple(shape) + if self.geometry is not None: + shape, dimension_labels = self.get_shape_labels(self.geometry) array = numpy.zeros( shape , dtype=numpy.float32) super(ImageData, self).__init__(array, deep_copy, @@ -836,6 +793,11 @@ class ImageData(DataContainer): raise ValueError('Please pass either a DataContainer, ' +\ 'a numpy array or a geometry') else: + if self.geometry is not None: + shape, labels = self.get_shape_labels(self.geometry, dimension_labels) + if array.shape != shape: + raise ValueError('Shape mismatch {} {}'.format(shape, array.shape)) + if issubclass(type(array) , DataContainer): # if the array is a DataContainer get the info from there if not ( array.number_of_dimensions == 2 or \ @@ -890,11 +852,62 @@ class ImageData(DataContainer): #out.geometry = self.recalculate_geometry(dimensions , **kw) out.geometry = self.geometry return out - + + def get_shape_labels(self, geometry, dimension_labels=None): + channels = geometry.channels + horiz_x = geometry.voxel_num_x + horiz_y = geometry.voxel_num_y + vert = 1 if geometry.voxel_num_z is None\ + else geometry.voxel_num_z # this should be 1 for 2D + if dimension_labels is None: + if channels > 1: + if vert > 1: + shape = (channels, vert, horiz_y, horiz_x) + dim_labels = [ImageGeometry.CHANNEL, + ImageGeometry.VERTICAL, + ImageGeometry.HORIZONTAL_Y, + ImageGeometry.HORIZONTAL_X] + else: + shape = (channels , horiz_y, horiz_x) + dim_labels = [ImageGeometry.CHANNEL, + ImageGeometry.HORIZONTAL_Y, + ImageGeometry.HORIZONTAL_X] + else: + if vert > 1: + shape = (vert, horiz_y, horiz_x) + dim_labels = [ImageGeometry.VERTICAL, + ImageGeometry.HORIZONTAL_Y, + ImageGeometry.HORIZONTAL_X] + else: + shape = (horiz_y, horiz_x) + dim_labels = [ImageGeometry.HORIZONTAL_Y, + ImageGeometry.HORIZONTAL_X] + dimension_labels = dim_labels + else: + shape = [] + for i in range(len(dimension_labels)): + dim = dimension_labels[i] + if dim == ImageGeometry.CHANNEL: + shape.append(channels) + elif dim == ImageGeometry.HORIZONTAL_Y: + shape.append(horiz_y) + elif dim == ImageGeometry.VERTICAL: + shape.append(vert) + elif dim == ImageGeometry.HORIZONTAL_X: + shape.append(horiz_x) + if len(shape) != len(dimension_labels): + raise ValueError('Missing {0} axes {1} shape {2}'.format( + len(dimension_labels) - len(shape), dimension_labels, shape)) + shape = tuple(shape) + + return (shape, dimension_labels) + class AcquisitionData(DataContainer): '''DataContainer for holding 2D or 3D sinogram''' __container_priority__ = 1 + + def __init__(self, array = None, deep_copy=True, @@ -905,63 +918,20 @@ class AcquisitionData(DataContainer): if 'geometry' in kwargs.keys(): geometry = kwargs['geometry'] self.geometry = geometry - channels = geometry.channels - horiz = geometry.pixel_num_h - vert = geometry.pixel_num_v - angles = geometry.angles - num_of_angles = numpy.shape(angles)[0] - if dimension_labels is None: - if channels > 1: - if vert > 1: - shape = (channels, num_of_angles , vert, horiz) - dim_labels = [AcquisitionGeometry.CHANNEL, - AcquisitionGeometry.ANGLE, - AcquisitionGeometry.VERTICAL, - AcquisitionGeometry.HORIZONTAL] - else: - shape = (channels , num_of_angles, horiz) - dim_labels = [AcquisitionGeometry.CHANNEL, - AcquisitionGeometry.ANGLE, - AcquisitionGeometry.HORIZONTAL] - else: - if vert > 1: - shape = (num_of_angles, vert, horiz) - dim_labels = [AcquisitionGeometry.ANGLE, - AcquisitionGeometry.VERTICAL, - AcquisitionGeometry.HORIZONTAL - ] - else: - shape = (num_of_angles, horiz) - dim_labels = [AcquisitionGeometry.ANGLE, - AcquisitionGeometry.HORIZONTAL - ] - - dimension_labels = dim_labels - else: - shape = [] - for dim in dimension_labels: - if dim == AcquisitionGeometry.CHANNEL: - shape.append(channels) - elif dim == AcquisitionGeometry.ANGLE: - shape.append(num_of_angles) - elif dim == AcquisitionGeometry.VERTICAL: - shape.append(vert) - elif dim == AcquisitionGeometry.HORIZONTAL: - shape.append(horiz) - if len(shape) != len(dimension_labels): - raise ValueError('Missing {0} axes.\nExpected{1} got {2}'\ - .format( - len(dimension_labels) - len(shape), - dimension_labels, shape) - ) - shape = tuple(shape) + shape, dimension_labels = self.get_shape_labels(geometry, dimension_labels) + array = numpy.zeros( shape , dtype=numpy.float32) super(AcquisitionData, self).__init__(array, deep_copy, dimension_labels, **kwargs) else: - + if self.geometry is not None: + shape, labels = self.get_shape_labels(self.geometry, dimension_labels) + print('Shape mismatch {} {}'.format(shape, array.shape)) + if array.shape != shape: + raise ValueError('Shape mismatch {} {}'.format(shape, array.shape)) + if issubclass(type(array) ,DataContainer): # if the array is a DataContainer get the info from there if not ( array.number_of_dimensions == 2 or \ @@ -995,6 +965,63 @@ class AcquisitionData(DataContainer): super(AcquisitionData, self).__init__(array, deep_copy, dimension_labels, **kwargs) + def get_shape_labels(self, geometry, dimension_labels=None): + channels = geometry.channels + horiz = geometry.pixel_num_h + vert = geometry.pixel_num_v + angles = geometry.angles + num_of_angles = numpy.shape(angles)[0] + + if dimension_labels is None: + if channels > 1: + if vert > 1: + shape = (channels, num_of_angles , vert, horiz) + dim_labels = [AcquisitionGeometry.CHANNEL, + AcquisitionGeometry.ANGLE, + AcquisitionGeometry.VERTICAL, + AcquisitionGeometry.HORIZONTAL] + else: + shape = (channels , num_of_angles, horiz) + dim_labels = [AcquisitionGeometry.CHANNEL, + AcquisitionGeometry.ANGLE, + AcquisitionGeometry.HORIZONTAL] + else: + if vert > 1: + shape = (num_of_angles, vert, horiz) + dim_labels = [AcquisitionGeometry.ANGLE, + AcquisitionGeometry.VERTICAL, + AcquisitionGeometry.HORIZONTAL + ] + else: + shape = (num_of_angles, horiz) + dim_labels = [AcquisitionGeometry.ANGLE, + AcquisitionGeometry.HORIZONTAL + ] + + dimension_labels = dim_labels + else: + shape = [] + for i in range(len(dimension_labels)): + dim = dimension_labels[i] + + if dim == AcquisitionGeometry.CHANNEL: + shape.append(channels) + elif dim == AcquisitionGeometry.ANGLE: + shape.append(num_of_angles) + elif dim == AcquisitionGeometry.VERTICAL: + shape.append(vert) + elif dim == AcquisitionGeometry.HORIZONTAL: + shape.append(horiz) + if len(shape) != len(dimension_labels): + raise ValueError('Missing {0} axes.\nExpected{1} got {2}'\ + .format( + len(dimension_labels) - len(shape), + dimension_labels, shape) + ) + shape = tuple(shape) + return (shape, dimension_labels) + + class DataProcessor(object): diff --git a/Wrappers/Python/test/test_DataContainer.py b/Wrappers/Python/test/test_DataContainer.py index 8edfd8b..40cd244 100755 --- a/Wrappers/Python/test/test_DataContainer.py +++ b/Wrappers/Python/test/test_DataContainer.py @@ -494,6 +494,14 @@ class TestDataContainer(unittest.TestCase): self.assertEqual(order[0], image.dimension_labels[0]) self.assertEqual(order[1], image.dimension_labels[1]) self.assertEqual(order[2], image.dimension_labels[2]) + + ig = ImageGeometry(2,3,2) + try: + z = ImageData(numpy.random.randint(10, size=(2,3)), geometry=ig) + self.assertTrue(False) + except ValueError as ve: + print (ve) + self.assertTrue(True) #vgeometry.allocate('') def test_AcquisitionGeometry_allocate(self): @@ -525,6 +533,14 @@ class TestDataContainer(unittest.TestCase): self.assertEqual(order[1], sino.dimension_labels[1]) self.assertEqual(order[2], sino.dimension_labels[2]) self.assertEqual(order[2], sino.dimension_labels[2]) + + + try: + z = AcquisitionData(numpy.random.randint(10, size=(2,3)), geometry=ageometry) + self.assertTrue(False) + except ValueError as ve: + print (ve) + self.assertTrue(True) def assertNumpyArrayEqual(self, first, second): res = True |