@@ -220,7 +220,6 @@ def processInst(writer: io.TextIOWrapper,
220220 conds = []
221221 op_name = instruction ["opname" ]
222222 fn_name = op_name [2 ].lower () + op_name [3 :]
223- result_types = []
224223 exts = instruction ["extensions" ] if "extensions" in instruction else []
225224
226225 if "capabilities" in instruction and len (instruction ["capabilities" ]) > 0 :
@@ -244,107 +243,97 @@ def processInst(writer: io.TextIOWrapper,
244243 conds .append ("(is_signed_v<T> || is_unsigned_v<T>)" )
245244 break
246245 case "U" :
247- fn_name = fn_name [0 :m [1 ][0 ]] + fn_name [m [1 ][1 ]:]
248- result_types = ["uint16_t" , "uint32_t" , "uint64_t" ]
246+ conds .append ("is_unsigned_v<T>" )
249247 break
250248 case "S" :
251- fn_name = fn_name [0 :m [1 ][0 ]] + fn_name [m [1 ][1 ]:]
252- result_types = ["int16_t" , "int32_t" , "int64_t" ]
249+ conds .append ("is_signed_v<T>" )
253250 break
254251 case "F" :
255- fn_name = fn_name [0 :m [1 ][0 ]] + fn_name [m [1 ][1 ]:]
256- result_types = ["float16_t" , "float32_t" , "float64_t" ]
252+ conds .append ("is_floating_point<T>" )
257253 break
258-
259- match instruction ["class" ]:
260- case "Bit" :
261- if len (result_types ) == 0 : conds .append ("(is_signed_v<T> || is_unsigned_v<T>)" )
254+ else :
255+ if instruction ["class" ] == "Bit" :
256+ conds .append ("(is_signed_v<T> || is_unsigned_v<T>)" )
262257
263258 if "operands" in instruction and instruction ["operands" ][0 ]["kind" ] == "IdResultType" :
264- if len (result_types ) == 0 :
265- if result_ty == None :
266- result_types = ["T" ]
267- else :
268- result_types = [result_ty ]
259+ if result_ty == None :
260+ result_ty = "T"
269261 else :
270- assert len (result_types ) == 0
271- result_types = ["void" ]
272-
273- for rt in result_types :
274- overload_caps = caps .copy ()
275- match rt :
276- case "uint16_t" | "int16_t" : overload_caps .append ("Int16" )
277- case "uint64_t" | "int64_t" : overload_caps .append ("Int64" )
278- case "float16_t" : overload_caps .append ("Float16" )
279- case "float64_t" : overload_caps .append ("Float64" )
280-
281- for cap in overload_caps or [None ]:
282- final_fn_name = fn_name + "_" + cap if (len (overload_caps ) > 1 ) else fn_name
283- final_templates = templates .copy ()
262+ result_ty = "void"
263+
264+ match result_ty :
265+ case "uint16_t" | "int16_t" : caps .append ("Int16" )
266+ case "uint64_t" | "int64_t" : caps .append ("Int64" )
267+ case "float16_t" : caps .append ("Float16" )
268+ case "float64_t" : caps .append ("Float64" )
269+
270+ for cap in caps or [None ]:
271+ final_fn_name = fn_name + "_" + cap if (len (caps ) > 1 ) else fn_name
272+ final_templates = templates .copy ()
273+
274+ if (not "typename T" in final_templates ) and (result_ty == "T" ):
275+ final_templates = ["typename T" ] + final_templates
276+
277+ if len (caps ) > 0 :
278+ if (("Float16" in cap and result_ty != "float16_t" ) or
279+ ("Float32" in cap and result_ty != "float32_t" ) or
280+ ("Float64" in cap and result_ty != "float64_t" ) or
281+ ("Int16" in cap and result_ty != "int16_t" and result_ty != "uint16_t" ) or
282+ ("Int64" in cap and result_ty != "int64_t" and result_ty != "uint64_t" )): continue
284283
285- if (not "typename T" in final_templates ) and (rt == "T" ):
286- final_templates = ["typename T" ] + final_templates
287-
288- if len (overload_caps ) > 0 :
289- if (("Float16" in cap and rt != "float16_t" ) or
290- ("Float32" in cap and rt != "float32_t" ) or
291- ("Float64" in cap and rt != "float64_t" ) or
292- ("Int16" in cap and rt != "int16_t" and rt != "uint16_t" ) or
293- ("Int64" in cap and rt != "int64_t" and rt != "uint64_t" )): continue
294-
295- if "Vector" in cap :
296- rt = "vector<" + rt + ", N> "
297- final_templates .append ("uint32_t N" )
298-
299- op_ty = "T"
300- if prefered_op_ty != None :
301- op_ty = prefered_op_ty
302- elif rt != "void" :
303- op_ty = rt
304-
305- args = []
306- if "operands" in instruction :
307- for operand in instruction ["operands" ]:
308- operand_name = operand ["name" ].strip ("'" ) if "name" in operand else None
309- operand_name = operand_name [0 ].lower () + operand_name [1 :] if (operand_name != None ) else ""
310- match operand ["kind" ]:
311- case "IdResult" | "IdResultType" : continue
312- case "IdRef" :
313- match operand ["name" ]:
314- case "'Pointer'" :
315- if shape == Shape .PTR_TEMPLATE :
316- args .append ("P " + operand_name )
317- elif shape == Shape .BDA :
318- if (not "typename T" in final_templates ) and (rt == "T" or op_ty == "T" ):
319- final_templates = ["typename T" ] + final_templates
320- args .append ("pointer_t<spv::StorageClassPhysicalStorageBuffer, " + op_ty + "> " + operand_name )
321- else :
322- if (not "typename T" in final_templates ) and (rt == "T" or op_ty == "T" ):
323- final_templates = ["typename T" ] + final_templates
324- args .append ("[[vk::ext_reference]] " + op_ty + " " + operand_name )
325- case "'Value'" | "'Object'" | "'Comparator'" | "'Base'" | "'Insert'" :
326- if (not "typename T" in final_templates ) and (rt == "T" or op_ty == "T" ):
284+ if "Vector" in cap :
285+ result_ty = "vector<" + result_ty + ", N> "
286+ final_templates .append ("uint32_t N" )
287+
288+ op_ty = "T"
289+ if prefered_op_ty != None :
290+ op_ty = prefered_op_ty
291+ elif result_ty != "void" :
292+ op_ty = result_ty
293+
294+ args = []
295+ if "operands" in instruction :
296+ for operand in instruction ["operands" ]:
297+ operand_name = operand ["name" ].strip ("'" ) if "name" in operand else None
298+ operand_name = operand_name [0 ].lower () + operand_name [1 :] if (operand_name != None ) else ""
299+ match operand ["kind" ]:
300+ case "IdResult" | "IdResultType" : continue
301+ case "IdRef" :
302+ match operand ["name" ]:
303+ case "'Pointer'" :
304+ if shape == Shape .PTR_TEMPLATE :
305+ args .append ("P " + operand_name )
306+ elif shape == Shape .BDA :
307+ if (not "typename T" in final_templates ) and (result_ty == "T" or op_ty == "T" ):
308+ final_templates = ["typename T" ] + final_templates
309+ args .append ("pointer_t<spv::StorageClassPhysicalStorageBuffer, " + op_ty + "> " + operand_name )
310+ else :
311+ if (not "typename T" in final_templates ) and (result_ty == "T" or op_ty == "T" ):
327312 final_templates = ["typename T" ] + final_templates
328- args .append (op_ty + " " + operand_name )
329- case "'Offset'" | "'Count'" | "'Id'" | "'Index'" | "'Mask'" | "'Delta'" :
330- args .append ("uint32_t " + operand_name )
331- case "'Predicate'" : args .append ("bool " + operand_name )
332- case "'ClusterSize'" :
333- if "quantifier" in operand and operand ["quantifier" ] == "?" : continue # TODO: overload
334- else : return ignore (op_name ) # TODO
335- case _: return ignore (op_name ) # TODO
336- case "IdScope" : args .append ("uint32_t " + operand_name .lower () + "Scope" )
337- case "IdMemorySemantics" : args .append (" uint32_t " + operand_name )
338- case "GroupOperation" : args .append ("[[vk::ext_literal]] uint32_t " + operand_name )
339- case "MemoryAccess" :
340- assert len (overload_caps ) <= 1
341- if shape != Shape .BDA :
342- writeInst (writer , final_templates , cap , exts , op_name , final_fn_name , conds , rt , args + ["[[vk::ext_literal]] uint32_t memoryAccess" ])
343- writeInst (writer , final_templates , cap , exts , op_name , final_fn_name , conds , rt , args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam" ])
344- writeInst (writer , final_templates + ["uint32_t alignment" ], cap , exts , op_name , final_fn_name , conds , rt , args + ["[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002" , "[[vk::ext_literal]] uint32_t __alignment = alignment" ])
345- case _: return ignore (op_name ) # TODO
346-
347- writeInst (writer , final_templates , cap , exts , op_name , final_fn_name , conds , rt , args )
313+ args .append ("[[vk::ext_reference]] " + op_ty + " " + operand_name )
314+ case "'Value'" | "'Object'" | "'Comparator'" | "'Base'" | "'Insert'" :
315+ if (not "typename T" in final_templates ) and (result_ty == "T" or op_ty == "T" ):
316+ final_templates = ["typename T" ] + final_templates
317+ args .append (op_ty + " " + operand_name )
318+ case "'Offset'" | "'Count'" | "'Id'" | "'Index'" | "'Mask'" | "'Delta'" :
319+ args .append ("uint32_t " + operand_name )
320+ case "'Predicate'" : args .append ("bool " + operand_name )
321+ case "'ClusterSize'" :
322+ if "quantifier" in operand and operand ["quantifier" ] == "?" : continue # TODO: overload
323+ else : return ignore (op_name ) # TODO
324+ case _: return ignore (op_name ) # TODO
325+ case "IdScope" : args .append ("uint32_t " + operand_name .lower () + "Scope" )
326+ case "IdMemorySemantics" : args .append (" uint32_t " + operand_name )
327+ case "GroupOperation" : args .append ("[[vk::ext_literal]] uint32_t " + operand_name )
328+ case "MemoryAccess" :
329+ assert len (caps ) <= 1
330+ if shape != Shape .BDA :
331+ writeInst (writer , final_templates , cap , exts , op_name , final_fn_name , conds , result_ty , args + ["[[vk::ext_literal]] uint32_t memoryAccess" ])
332+ writeInst (writer , final_templates , cap , exts , op_name , final_fn_name , conds , result_ty , args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam" ])
333+ writeInst (writer , final_templates + ["uint32_t alignment" ], cap , exts , op_name , final_fn_name , conds , result_ty , args + ["[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002" , "[[vk::ext_literal]] uint32_t __alignment = alignment" ])
334+ case _: return ignore (op_name ) # TODO
335+
336+ writeInst (writer , final_templates , cap , exts , op_name , final_fn_name , conds , result_ty , args )
348337
349338
350339def writeInst (writer : io .TextIOWrapper , templates , cap , exts , op_name , fn_name , conds , result_type , args ):
0 commit comments