@@ -138,6 +138,10 @@ def quantized_add_meta(
138138 output_multiplier : int ,
139139 output_shift : int ,
140140) -> torch .Tensor :
141+ assert self .shape == other .shape , (
142+ "Cortex-M quantized_mul: broadcasting is not yet supported — "
143+ f"got self.shape={ self .shape } , other.shape={ other .shape } "
144+ )
141145 broadcasted_shape = torch .broadcast_shapes (self .shape , other .shape )
142146 return torch .empty (broadcasted_shape , dtype = torch .int8 , device = self .device )
143147
@@ -156,6 +160,10 @@ def quantized_add_impl(
156160 output_multiplier : int ,
157161 output_shift : int ,
158162) -> torch .Tensor :
163+ assert self .shape == other .shape , (
164+ "Cortex-M quantized_mul: broadcasting is not yet supported — "
165+ f"got self.shape={ self .shape } , other.shape={ other .shape } "
166+ )
159167 self_shifted = (self .to (torch .int32 ) - self_zero_point ) << SHIFT_INT8
160168 self_fp = requantize_cmsis (self_shifted , self_multiplier , self_shift )
161169
@@ -168,6 +176,68 @@ def quantized_add_impl(
168176 return result
169177
170178
179+ # ===================================================================
180+ # QUANTIZED MUL OPERATION DEFINITION
181+ # ===================================================================
182+ lib .define (
183+ "quantized_mul("
184+ "Tensor self, Scalar self_zero_point, "
185+ "Tensor other, Scalar other_zero_point, "
186+ "Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift) -> Tensor"
187+ )
188+ lib .define (
189+ "quantized_mul.out("
190+ "Tensor self, Scalar self_zero_point, "
191+ "Tensor other, Scalar other_zero_point, "
192+ "Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift, "
193+ "*, Tensor(a!) out) -> Tensor(a!)"
194+ )
195+
196+
197+ @register_fake ("cortex_m::quantized_mul" )
198+ def quantized_mul_meta (
199+ self : torch .Tensor ,
200+ self_zero_point : int ,
201+ other : torch .Tensor ,
202+ other_zero_point : int ,
203+ output_zero_point : int ,
204+ output_multiplier : int ,
205+ output_shift : int ,
206+ ) -> torch .Tensor :
207+ # Broadcast to output shape
208+ assert self .shape == other .shape , (
209+ "Cortex-M quantized_mul: broadcasting is not yet supported — "
210+ f"got self.shape={ self .shape } , other.shape={ other .shape } "
211+ )
212+ broadcasted_shape = torch .broadcast_shapes (self .shape , other .shape )
213+ return torch .empty (broadcasted_shape , dtype = torch .int8 , device = self .device )
214+
215+
216+ @impl (lib , "quantized_mul" , "CompositeExplicitAutograd" )
217+ def quantized_mul_impl (
218+ self : torch .Tensor ,
219+ self_zero_point : int ,
220+ other : torch .Tensor ,
221+ other_zero_point : int ,
222+ output_zero_point : int ,
223+ output_multiplier : int ,
224+ output_shift : int ,
225+ ) -> torch .Tensor :
226+ # CMSIS-NN kernel multiplies raw int8 tensors (after zero-point offset) and
227+ # only uses the output multiplier/shift for rescaling. Mirror that here to
228+ # keep the composite implementation numerically aligned with the backend.
229+ assert self .shape == other .shape , (
230+ "Cortex-M quantized_mul: broadcasting is not yet supported — "
231+ f"got self.shape={ self .shape } , other.shape={ other .shape } "
232+ )
233+ self_int = self .to (torch .int32 ) - self_zero_point
234+ other_int = other .to (torch .int32 ) - other_zero_point
235+ result_fp = self_int * other_int
236+ result_quantized = requantize_cmsis (result_fp , output_multiplier , output_shift )
237+ result = torch .clamp (result_quantized + output_zero_point , - 128 , 127 ).to (torch .int8 )
238+ return result
239+
240+
171241# ===================================================================
172242# QUANTIZED LINEAR OPERATION DEFINITION
173243# ===================================================================
0 commit comments