1616import base64
1717import sys
1818import warnings
19- from typing import Type
19+ from copy import copy
20+ from typing import Type , cast
2021
2122import numpy as np
2223
2728from ..nifti1 import data_type_codes , intent_codes , xform_codes
2829from .util import KIND2FMT , array_index_order_codes , gifti_encoding_codes , gifti_endian_codes
2930
31+ GIFTI_DTYPES = (
32+ data_type_codes ['NIFTI_TYPE_UINT8' ],
33+ data_type_codes ['NIFTI_TYPE_INT32' ],
34+ data_type_codes ['NIFTI_TYPE_FLOAT32' ],
35+ )
36+
3037
3138class _GiftiMDList (list ):
3239 """List view of GiftiMetaData object that will translate most operations"""
@@ -462,11 +469,7 @@ def __init__(
462469 if datatype is None :
463470 if self .data is None :
464471 datatype = 'none'
465- elif self .data .dtype in (
466- np .dtype ('uint8' ),
467- np .dtype ('int32' ),
468- np .dtype ('float32' ),
469- ):
472+ elif data_type_codes [self .data .dtype ] in GIFTI_DTYPES :
470473 datatype = self .data .dtype
471474 else :
472475 raise ValueError (
@@ -848,20 +851,45 @@ def _to_xml_element(self):
848851 GIFTI .append (dar ._to_xml_element ())
849852 return GIFTI
850853
851- def to_xml (self , enc = 'utf-8' ) -> bytes :
854+ def to_xml (self , enc = 'utf-8' , * , mode = 'strict' ) -> bytes :
852855 """Return XML corresponding to image content"""
856+ if mode == 'strict' :
857+ if any (arr .datatype not in GIFTI_DTYPES for arr in self .darrays ):
858+ raise ValueError (
859+ 'GiftiImage contains data arrays with invalid data types; '
860+ 'use mode="compat" to automatically cast to conforming types'
861+ )
862+ elif mode == 'compat' :
863+ darrays = []
864+ for arr in self .darrays :
865+ if arr .datatype not in GIFTI_DTYPES :
866+ arr = copy (arr )
867+ # TODO: Better typing for recoders
868+ dtype = cast (np .dtype , data_type_codes .dtype [arr .datatype ])
869+ if np .issubdtype (dtype , np .floating ):
870+ arr .datatype = data_type_codes ['float32' ]
871+ elif np .issubdtype (dtype , np .integer ):
872+ arr .datatype = data_type_codes ['int32' ]
873+ else :
874+ raise ValueError (f'Cannot convert { dtype } to float32/int32' )
875+ darrays .append (arr )
876+ gii = copy (self )
877+ gii .darrays = darrays
878+ return gii .to_xml (enc = enc , mode = 'strict' )
879+ elif mode != 'force' :
880+ raise TypeError (f'Unknown mode { mode } ' )
853881 header = b"""<?xml version="1.0" encoding="UTF-8"?>
854882<!DOCTYPE GIFTI SYSTEM "http://www.nitrc.org/frs/download.php/115/gifti.dtd">
855883"""
856884 return header + super ().to_xml (enc )
857885
858886 # Avoid the indirection of going through to_file_map
859- def to_bytes (self , enc = 'utf-8' ):
860- return self .to_xml (enc = enc )
887+ def to_bytes (self , enc = 'utf-8' , * , mode = 'strict' ):
888+ return self .to_xml (enc = enc , mode = mode )
861889
862890 to_bytes .__doc__ = SerializableImage .to_bytes .__doc__
863891
864- def to_file_map (self , file_map = None , enc = 'utf-8' ):
892+ def to_file_map (self , file_map = None , enc = 'utf-8' , * , mode = 'strict' ):
865893 """Save the current image to the specified file_map
866894
867895 Parameters
@@ -877,7 +905,7 @@ def to_file_map(self, file_map=None, enc='utf-8'):
877905 if file_map is None :
878906 file_map = self .file_map
879907 with file_map ['image' ].get_prepare_fileobj ('wb' ) as f :
880- f .write (self .to_xml (enc = enc ))
908+ f .write (self .to_xml (enc = enc , mode = mode ))
881909
882910 @classmethod
883911 def from_file_map (klass , file_map , buffer_size = 35000000 , mmap = True ):
0 commit comments