|
25 | 25 |
|
26 | 26 | namespace tc { |
27 | 27 | namespace aten { |
| 28 | + |
| 29 | +// Stolen from ATen, get rid of our copy when ATen exposes the functionality |
| 30 | +// Unfortunately we need to wait for updated conda packages so we just copy |
| 31 | +// for now. |
| 32 | +inline DLDataType getDLDataType(const at::Type& type) { |
| 33 | + using at::ScalarType; |
| 34 | + |
| 35 | + DLDataType dtype; |
| 36 | + dtype.lanes = 1; |
| 37 | + dtype.bits = type.elementSizeInBytes() * 8; |
| 38 | + switch (type.scalarType()) { |
| 39 | + case ScalarType::Byte: |
| 40 | + dtype.code = DLDataTypeCode::kDLUInt; |
| 41 | + break; |
| 42 | + case ScalarType::Char: |
| 43 | + dtype.code = DLDataTypeCode::kDLInt; |
| 44 | + break; |
| 45 | + case ScalarType::Double: |
| 46 | + dtype.code = DLDataTypeCode::kDLFloat; |
| 47 | + break; |
| 48 | + case ScalarType::Float: |
| 49 | + dtype.code = DLDataTypeCode::kDLFloat; |
| 50 | + break; |
| 51 | + case ScalarType::Int: |
| 52 | + dtype.code = DLDataTypeCode::kDLInt; |
| 53 | + break; |
| 54 | + case ScalarType::Long: |
| 55 | + dtype.code = DLDataTypeCode::kDLInt; |
| 56 | + break; |
| 57 | + case ScalarType::Short: |
| 58 | + dtype.code = DLDataTypeCode::kDLInt; |
| 59 | + break; |
| 60 | + case ScalarType::Half: |
| 61 | + dtype.code = DLDataTypeCode::kDLFloat; |
| 62 | + break; |
| 63 | + case ScalarType::Undefined: |
| 64 | + throw std::logic_error("Undefined is not a valid ScalarType"); |
| 65 | + case ScalarType::NumOptions: |
| 66 | + throw std::logic_error("NumOptions is not a valid ScalarType"); |
| 67 | + } |
| 68 | + return dtype; |
| 69 | +} |
| 70 | + |
| 71 | +inline TensorInfo toTensorInfo(const at::Tensor& t) { |
| 72 | + return TensorInfo( |
| 73 | + getDLDataType(t.type()), |
| 74 | + reinterpret_cast<std::uintptr_t>(t.data_ptr()) % TensorInfo::kAlignment, |
| 75 | + t.sizes(), |
| 76 | + t.strides()); |
| 77 | +} |
| 78 | + |
28 | 79 | inline std::vector<DLTensorUPtr> makeDLTensors( |
29 | 80 | const std::vector<at::Tensor>& tensors) { |
30 | 81 | std::vector<DLTensorUPtr> dlTensors; |
|
0 commit comments