@@ -1184,7 +1184,138 @@ def maximum(x1, x2):
11841184
11851185
11861186def median (x , axis = None , keepdims = False ):
1187- raise NotImplementedError ("`median` is not supported with openvino backend" )
1187+ x = get_ov_output (x )
1188+ x_shape = x .get_partial_shape ()
1189+ rank = x_shape .rank .get_length ()
1190+
1191+ if rank == 0 :
1192+ return OpenVINOKerasTensor (x )
1193+
1194+ # Handle axis=None by flattening the input
1195+ flattened_all = False
1196+ if axis is None :
1197+ x = ov_opset .reshape (x , [- 1 ], False ).output (0 )
1198+ axis = 0
1199+ original_rank = rank
1200+ rank = 1
1201+ flattened_all = True
1202+ else :
1203+ # Handle tuple axis - for median, we only support single axis
1204+ if isinstance (axis , (tuple , list )):
1205+ if len (axis ) != 1 :
1206+ raise ValueError ("median only supports single axis reduction" )
1207+ axis = axis [0 ]
1208+
1209+ # Handle negative axis
1210+ if axis < 0 :
1211+ axis = rank + axis
1212+ original_rank = rank
1213+
1214+ # Get the size of the dimension to sort
1215+ shape_tensor = ov_opset .shape_of (x , output_type = Type .i32 ).output (0 )
1216+ k = ov_opset .gather (
1217+ shape_tensor ,
1218+ ov_opset .constant ([axis ], Type .i32 ).output (0 ),
1219+ ov_opset .constant (0 , Type .i32 ).output (0 ),
1220+ ).output (0 )
1221+
1222+ # Convert k to a scalar value
1223+ k_scalar = ov_opset .squeeze (k , [0 ]).output (0 )
1224+
1225+ # Use topk with k=size_of_axis to get all elements sorted
1226+ topk_outputs = ov_opset .topk (
1227+ x , k = k_scalar , axis = axis , mode = "min" , sort = "value" , stable = True
1228+ )
1229+
1230+ # Get the sorted values
1231+ sorted_values = topk_outputs .output (0 )
1232+
1233+ # Convert to float for median calculation
1234+ x1_type = ov_to_keras_type (sorted_values .get_element_type ())
1235+ result_type = dtypes .result_type (x1_type , float )
1236+ result_type = OPENVINO_DTYPES [result_type ]
1237+ sorted_values = ov_opset .convert (sorted_values , result_type ).output (0 )
1238+
1239+ # Calculate median indices
1240+ # For odd length: median_idx = (k-1) // 2
1241+ # For even length: we need indices (k//2 - 1) and k//2, then average
1242+
1243+ k_minus_1 = ov_opset .subtract (
1244+ k_scalar , ov_opset .constant (1 , Type .i32 ).output (0 )
1245+ ).output (0 )
1246+ k_div_2 = ov_opset .divide (
1247+ k_scalar , ov_opset .constant (2 , Type .i32 ).output (0 )
1248+ ).output (0 )
1249+ k_minus_1_div_2 = ov_opset .divide (
1250+ k_minus_1 , ov_opset .constant (2 , Type .i32 ).output (0 )
1251+ ).output (0 )
1252+
1253+ # Check if k is odd
1254+ k_mod_2 = ov_opset .mod (
1255+ k_scalar , ov_opset .constant (2 , Type .i32 ).output (0 )
1256+ ).output (0 )
1257+ is_odd = ov_opset .equal (
1258+ k_mod_2 , ov_opset .constant (1 , Type .i32 ).output (0 )
1259+ ).output (0 )
1260+
1261+ # For odd case: take the middle element
1262+ odd_idx = k_minus_1_div_2
1263+
1264+ # For even case: take average of two middle elements
1265+ even_idx1 = ov_opset .subtract (
1266+ k_div_2 , ov_opset .constant (1 , Type .i32 ).output (0 )
1267+ ).output (0 )
1268+ even_idx2 = k_div_2
1269+
1270+ # Gather elements for both cases
1271+ # Create gather indices tensor for the axis
1272+ gather_indices_odd = ov_opset .unsqueeze (odd_idx , [0 ]).output (0 )
1273+ gather_indices_even1 = ov_opset .unsqueeze (even_idx1 , [0 ]).output (0 )
1274+ gather_indices_even2 = ov_opset .unsqueeze (even_idx2 , [0 ]).output (0 )
1275+
1276+ # Gather the median elements
1277+ odd_result = ov_opset .gather (
1278+ sorted_values ,
1279+ gather_indices_odd ,
1280+ ov_opset .constant (axis , Type .i32 ).output (0 ),
1281+ ).output (0 )
1282+ even_result1 = ov_opset .gather (
1283+ sorted_values ,
1284+ gather_indices_even1 ,
1285+ ov_opset .constant (axis , Type .i32 ).output (0 ),
1286+ ).output (0 )
1287+ even_result2 = ov_opset .gather (
1288+ sorted_values ,
1289+ gather_indices_even2 ,
1290+ ov_opset .constant (axis , Type .i32 ).output (0 ),
1291+ ).output (0 )
1292+
1293+ # Average the two middle elements for even case
1294+ even_sum = ov_opset .add (even_result1 , even_result2 ).output (0 )
1295+ even_result = ov_opset .divide (
1296+ even_sum , ov_opset .constant (2.0 , result_type ).output (0 )
1297+ ).output (0 )
1298+
1299+ # Select between odd and even results
1300+ median_result = ov_opset .select (is_odd , odd_result , even_result ).output (0 )
1301+
1302+ # Remove the gathered dimension (squeeze)
1303+ median_result = ov_opset .squeeze (median_result , [axis ]).output (0 )
1304+
1305+ # Handle keepdims
1306+ if keepdims :
1307+ if flattened_all :
1308+ # When axis=None, keepdims should restore all dimensions as 1
1309+ ones_shape = ov_opset .constant (
1310+ [1 ] * original_rank , Type .i32
1311+ ).output (0 )
1312+ median_result = ov_opset .reshape (
1313+ median_result , ones_shape , False
1314+ ).output (0 )
1315+ else :
1316+ median_result = ov_opset .unsqueeze (median_result , [axis ]).output (0 )
1317+
1318+ return OpenVINOKerasTensor (median_result )
11881319
11891320
11901321def meshgrid (* x , indexing = "xy" ):
0 commit comments