@@ -1985,6 +1985,7 @@ at::Tensor AtenIpexCPUDev::dil_slice(const at::Tensor & self, int64_t dim, int64
19851985 DEBUG (" AtenIpexCPUDev::dil_slice\n " );
19861986 CHECK_DNNL_OP_PRE_COND (self);
19871987
1988+ // TODO use weight TAG to decide whether to reorder or not
19881989 dbl::comm::reorder_to_bf16_for_mix_prec (self, true );
19891990
19901991 // Port from aten/src/ATen/native/TensorShape.cpp
@@ -2023,6 +2024,22 @@ at::Tensor AtenIpexCPUDev::dil_slice(const at::Tensor & self, int64_t dim, int64
20232024 return result;
20242025}
20252026
2027+ std::vector<at::Tensor> AtenIpexCPUDev::dil_unbind (const at::Tensor &self, int64_t dim) {
2028+ DEBUG (" AtenIpexCPUDev::dil_unbind\n " );
2029+
2030+ dim = at::maybe_wrap_dim (dim, self.dim ());
2031+ int64_t size = dil_size (self, dim);
2032+ std::vector<at::Tensor> tensors (size);
2033+ for (int i = 0 ; i < size; i++) {
2034+ tensors[i] = dil_select (self, dim, i);
2035+ }
2036+ return tensors;
2037+ }
2038+
2039+ std::vector<at::Tensor>AtenIpexCPUDev::dil_unbind (const at::Tensor& self, at::Dimname dim) {
2040+ return dil_unbind (self, at::dimname_to_position (self, dim));
2041+ }
2042+
20262043at::Tensor AtenIpexCPUDev::dil_select (const at::Tensor & self, int64_t dim, int64_t index) {
20272044 DEBUG (" AtenIpexCPUDev::dil_select\n " );
20282045 CHECK_DNNL_OP_PRE_COND (self);
@@ -2119,19 +2136,43 @@ at::Tensor AtenIpexCPUDev::dil_select(const at::Tensor & self, at::Dimname dim,
21192136
21202137std::vector<at::Tensor> AtenIpexCPUDev::dil_split (const at::Tensor& self, int64_t split_size, int64_t dim) {
21212138 DEBUG (" AtenIpexCPUDev::dil_split\n " );
2139+ TORCH_CHECK (self.dim () != 0 , " split expects at least a 1-dimensional tensor" );
2140+ TORCH_CHECK (split_size >= 0 , " split expects split_size be non-negative, but got split_size=" , split_size);
2141+
21222142 CHECK_DNNL_OP_PRE_COND (self);
21232143 dim = at::maybe_wrap_dim (dim, self.dim ());
21242144 int64_t dim_size = dil_size (self, dim);
2145+ TORCH_CHECK (split_size > 0 || self.size (dim) == 0 ,
2146+ " split_size can only be 0 if dimension size is 0, "
2147+ " but got dimension size of " , dim_size);
2148+ // if split_size is 0 and dimension size is 0, there is 1 split.
21252149 int64_t num_splits = 1 ;
21262150 if (split_size != 0 ) {
21272151 // ensuring num_splits is at least 1 makes consistent the case where split_size > dim_size
21282152 // (returns a single split). We might want to error here, but keep it for BC.
21292153 num_splits = std::max<int64_t >((dim_size + split_size - 1 ) / split_size, 1 );
21302154 }
2131- std::vector<int64_t > split_sizes (num_splits, split_size );
2155+ std::vector<at::Tensor> splits (num_splits);
21322156 int64_t last_split_size = split_size - (split_size * num_splits - dim_size);
2133- split_sizes[num_splits-1 ] = last_split_size;
2134- return dil_split_with_sizes (self, split_sizes, dim);
2157+
2158+ for (int64_t i = 0 ; i < num_splits; ++i) {
2159+ auto length = i < num_splits - 1 ? split_size : last_split_size;
2160+ splits[i] = _dil_narrow (self, dim, i * split_size, length);
2161+ }
2162+ return splits;
2163+ }
2164+
2165+ // TODO only used for dil_split
2166+ at::Tensor AtenIpexCPUDev::_dil_narrow (const at::Tensor& self, int64_t dim, int64_t start, int64_t length) {
2167+ // Port from aten/src/ATen/native/TensorShape.cpp
2168+ TORCH_CHECK (self.dim () > 0 , " narrow() cannot be applied to a 0-dim tensor." );
2169+ auto cur_size = self.size (dim);
2170+ if (start != cur_size) { // start being the end is valid, but not a valid dim specification.
2171+ start = at::maybe_wrap_dim (start, cur_size);
2172+ }
2173+ TORCH_CHECK (length >= 0 && start <= cur_size - length,
2174+ " start (" , start, " ) + length (" , length, " ) exceeds dimension size (" , cur_size, " )." );
2175+ return dil_slice (self, dim, start, start + length, 1 );
21352176}
21362177
21372178at::Tensor AtenIpexCPUDev::dil_gelu (const at::Tensor& input) {
0 commit comments