33import cupy as cp
44import numpy as np
55
6+ # TODO: Add requires_grad condition to args
67
78class Tensor :
89 def __init__ (self , data : Any , args = None , op = None , requires_grad : bool = True , dtype = None , device : str = "cpu" ):
@@ -14,9 +15,9 @@ def __init__(self, data: Any, args=None, op=None, requires_grad: bool=True, dtyp
1415 self .xp = cp
1516
1617 if isinstance (data , Tensor ):
17- self .data : Union [np .ndarray , cp .ndarray ] = self .xp .array (data .data , dtype = dtype )
18+ self .data : Union [np .ndarray , cp .ndarray ] = self .xp .array (data .data , dtype = dtype if dtype else np . float32 )
1819 else :
19- self .data = self .xp .array (data , dtype = dtype )
20+ self .data = self .xp .array (data , dtype = dtype if dtype else np . float32 )
2021
2122 self .grad = None
2223 self .op = op
@@ -355,6 +356,145 @@ def flip(self, axis: Any) -> 'Tensor':
355356 device = self .device ,
356357 )
357358
359+ def where (self , condition : Union [Any , 'Tensor' ], t : Union [Any , 'Tensor' ]) -> 'Tensor' :
360+ condition = self .tensor (condition )
361+ t = self .tensor (t )
362+
363+ requires_grad = self .requires_grad or t .requires_grad
364+ args = [self , condition , t ] if requires_grad else None
365+
366+ return Tensor (
367+ np .where (condition .data , self .data , t .data ),
368+ args ,
369+ "where" ,
370+ requires_grad = requires_grad ,
371+ device = self .device ,
372+ )
373+
374+ def equal (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' :
375+ t = self .tensor (t )
376+
377+ return Tensor (
378+ self .xp .equal (self .data , t .data ),
379+ None ,
380+ "equal" ,
381+ requires_grad = False ,
382+ device = self .device ,
383+ )
384+
385+ def not_equal (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' :
386+ t = self .tensor (t )
387+
388+ return Tensor (
389+ self .xp .not_equal (self .data , t .data ),
390+ None ,
391+ "not_equal" ,
392+ requires_grad = False ,
393+ device = self .device ,
394+ )
395+
396+ def greater (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' :
397+ t = self .tensor (t )
398+
399+ return Tensor (
400+ self .xp .greater (self .data , t .data ),
401+ None ,
402+ "greater" ,
403+ requires_grad = False ,
404+ device = self .device ,
405+ )
406+
407+ def greater_equal (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' :
408+ t = self .tensor (t )
409+
410+ return Tensor (
411+ self .xp .greater_equal (self .data , t .data ),
412+ None ,
413+ "greater_equal" ,
414+ requires_grad = False ,
415+ device = self .device ,
416+ )
417+
418+ def less (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' :
419+ t = self .tensor (t )
420+
421+ return Tensor (
422+ self .xp .less (self .data , t .data ),
423+ None ,
424+ "less" ,
425+ requires_grad = False ,
426+ device = self .device ,
427+ )
428+
429+ def less_equal (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' :
430+ t = self .tensor (t )
431+
432+ return Tensor (
433+ self .xp .less_equal (self .data , t .data ),
434+ None ,
435+ "less_equal" ,
436+ requires_grad = False ,
437+ device = self .device ,
438+ )
439+
440+ def logical_and (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' :
441+ t = self .tensor (t )
442+
443+ return Tensor (
444+ self .xp .logical_and (self .data , t .data ),
445+ None ,
446+ "logical_and" ,
447+ requires_grad = False ,
448+ device = self .device ,
449+ )
450+
451+ def logical_or (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' :
452+ t = self .tensor (t )
453+
454+ return Tensor (
455+ self .xp .logical_or (self .data , t .data ),
456+ None ,
457+ "logical_or" ,
458+ requires_grad = False ,
459+ device = self .device ,
460+ )
461+
462+ def logical_not (self ) -> 'Tensor' :
463+ return Tensor (
464+ self .xp .logical_not (self .data ),
465+ None ,
466+ "logical_not" ,
467+ requires_grad = False ,
468+ device = self .device ,
469+ )
470+
471+ def __eq__ (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' : # type: ignore[override]
472+ return self .equal (t )
473+
474+ def __ne__ (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' : # type: ignore[override]
475+ return self .not_equal (t )
476+
477+ def __gt__ (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' :
478+ return self .greater (t )
479+
480+ def __ge__ (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' :
481+ return self .greater_equal (t )
482+
483+ def __lt__ (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' :
484+ return self .less (t )
485+
486+ def __le__ (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' :
487+ return self .less_equal (t )
488+
489+ def __and__ (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' :
490+ return self .logical_and (t )
491+
492+ def __or__ (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' :
493+ return self .logical_or (t )
494+
495+ def __invert__ (self ) -> 'Tensor' :
496+ return self .logical_not ()
497+
358498 def __neg__ (self ) -> 'Tensor' :
359499 return Tensor (
360500 - self .data ,
@@ -667,6 +807,10 @@ def backward(
667807
668808 elif self .op == "flip" :
669809 self .args [0 ].backward (self .xp .flip (grad , axis = self .args [1 ]))
810+
811+ elif self .op == "where" :
812+ self .args [0 ].backward (grad * self .xp .where (self .args [1 ].data , grad , self .xp .zeros_like (grad )))
813+ self .args [2 ].backward (grad * self .xp .where (self .args [1 ].data , self .xp .zeros_like (grad ), grad ))
670814
671815 elif self .op == "neg" :
672816 self .args [0 ].backward (- grad )
@@ -680,27 +824,4 @@ def backward(
680824
681825# BUGS:
682826# grad X - mean not correct with pytorch; maybe NOT BUG becase small numbers manipulation (Numerical stability issues)
683- # softmax not equals grads with pytorch; place: div; maybe NOT BUG becase small numbers manipulation (Numerical stability issues)????
684-
685-
686- # def repeat_to_match_shape(self, g, shape, dtype, axis, keepdims): same
687- # https://github.com/HIPS/autograd/blob/master/autograd/numpy/numpy_vjps.py
688- # """Returns the array g repeated along axis to fit vector space vs.
689- # Also returns the number of repetitions of the array."""
690- # if shape == ():
691- # return g, 1
692- # axis = list(axis) if isinstance(axis, tuple) else axis
693- # new_shape = self.xp.array(shape)
694- # new_shape[axis] = 1
695- # num_reps = self.xp.prod(self.xp.array(shape)[axis])
696- # # Can't use broadcast_to because of numpy bug: https://github.com/numpy/numpy/issues/9165
697- # # return aself.xp.broadcast_to(aself.xp.reshape(g, new_shape), shape), num_reps
698- # return self.xp.reshape(g, new_shape) + self.xp.zeros(shape, dtype=dtype), num_reps
699-
700- # elif self.op == "mean":
701- # shape = self.args[0].data.shape
702- # axis = self.args[1]
703- # dtype = self.xp.result_type(self.args[0].data)
704- # g_repeated, num_reps = self.repeat_to_match_shape(grad, shape, dtype, axis, None)
705- # print(f"g_repeated {g_repeated}")
706- # self.args[0].backward(g_repeated / num_reps)
827+ # softmax not equals grads with pytorch; place: div; maybe NOT BUG becase small numbers manipulation (Numerical stability issues)????
0 commit comments