@@ -1599,7 +1599,43 @@ def size(x):
15991599
16001600
16011601def sort (x , axis = - 1 ):
1602- raise NotImplementedError ("`sort` is not supported with openvino backend" )
1602+ x = get_ov_output (x )
1603+ x_shape = x .get_partial_shape ()
1604+ rank = x_shape .rank .get_length ()
1605+
1606+ if rank == 0 :
1607+ return OpenVINOKerasTensor (x )
1608+
1609+ # Handle axis=None by flattening the input
1610+ if axis is None :
1611+ x = ov_opset .reshape (
1612+ x , ov_opset .constant ([- 1 ], Type .i32 ), False
1613+ ).output (0 )
1614+ axis = 0
1615+ # Handle negative axis
1616+ elif axis < 0 :
1617+ axis = rank + axis
1618+
1619+ # Get the size of the dimension to sort
1620+ shape_tensor = ov_opset .shape_of (x , output_type = Type .i32 ).output (0 )
1621+ k = ov_opset .gather (
1622+ shape_tensor ,
1623+ ov_opset .constant ([axis ], Type .i32 ).output (0 ),
1624+ ov_opset .constant (0 , Type .i32 ).output (0 ),
1625+ ).output (0 )
1626+
1627+ # Convert k to a scalar value
1628+ k_scalar = ov_opset .squeeze (k , ov_opset .constant ([0 ], Type .i32 )).output (0 )
1629+
1630+ # Use topk with k=size_of_axis to get all elements sorted
1631+ topk_outputs = ov_opset .topk (
1632+ x , k = k_scalar , axis = axis , mode = "min" , sort = "value" , stable = True
1633+ )
1634+
1635+ # Get the sorted values
1636+ sorted_values = topk_outputs .output (0 )
1637+
1638+ return OpenVINOKerasTensor (sorted_values )
16031639
16041640
16051641def split (x , indices_or_sections , axis = 0 ):
0 commit comments