@@ -644,27 +644,107 @@ def test(self):
644644
645645generate_tensor_like_override_tests (TestTorchFunctionOverride )
646646
647- class TestEinsumOverride (TestCase ):
648- "Regression test for gh-38479"
649- def test_wrapper (self ):
650- class Wrapper ():
651- "Basic data container that knows how to unwrap itself"
652- def __init__ (self , data ):
653- self .data = data
647+ class Wrapper :
648+ "Basic data container that knows how to unwrap itself"
649+ def __init__ (self , data ):
650+ self .__dict__ ["_data" ] = data
651+ self .__dict__ ["used_attrs" ] = set ()
652+ self .__dict__ ["used_calls" ] = set ()
653+
654+ def __getattr__ (self , name ):
655+ if name in self .__dict__ :
656+ return self .__dict__ [name ]
657+ self .used_attrs .add (name )
658+
659+ val = getattr (self ._data , name )
660+
661+ # If it's a method
662+ if callable (val ):
663+ c = getattr (type (self ._data ), name )
664+ # Don't append self to args if classmethod/staticmethod
665+ if c is val :
666+ return lambda * a , ** kw : wrap (self .__torch_function__ (c , (Wrapper ,), args = a , kwargs = kw ))
667+ # Otherwise append self to args
668+ return lambda * a , ** kw : wrap (self .__torch_function__ (c , (Wrapper ,), args = (self ,) + a , kwargs = kw ))
669+
670+ return wrap (val )
671+
672+ def __setattr__ (self , name , value ):
673+ if name in self .__dict__ :
674+ self .__dict__ [name ] = value
675+
676+ self .used_attrs .add (name )
677+ setattr (self ._data , name , unwrap (value ))
678+
679+ def __setitem__ (self , key , value ):
680+ self ._data [unwrap (key )] = unwrap (value )
681+
682+ def __getitem__ (self , key ):
683+ return wrap (self ._data [unwrap (key )])
684+
685+ def __torch_function__ (self , func , types , args = (), kwargs = None ):
686+ if kwargs is None :
687+ kwargs = {}
688+ self .used_calls .add (func )
689+ args = unwrap (tuple (args ))
690+ kwargs = {k : unwrap (v ) for k , v in kwargs .items ()}
691+
692+ return wrap (func (* args , ** kwargs ))
693+
694+ def __add__ (self , other ):
695+ return self .__torch_function__ (torch .add , (Wrapper ,), (self , other ))
696+
697+ def __sub__ (self , other ):
698+ return self .__torch_function__ (torch .sub , (Wrapper ,), (self , other ))
654699
655- def __torch_function__ (self , func , types , args = (), kwargs = None ):
656- if kwargs is None :
657- kwargs = {}
700+ def __truediv__ (self , other ):
701+ return self .__torch_function__ (torch .true_divide , (Wrapper ,), (self , other ))
658702
659- # unwrap inputs if necessary
660- def unwrap (v ):
661- return v .data if isinstance (v , Wrapper ) else v
703+ def __floordiv__ (self , other ):
704+ return self .__torch_function__ (torch .floor_divide , (Wrapper ,), (self , other ))
662705
663- args = map ( unwrap , args )
664- kwargs = { k : unwrap ( v ) for k , v in kwargs . items ()}
706+ def __ge__ ( self , other ):
707+ return self . __torch_function__ ( torch . ge , ( Wrapper ,), ( self , other ))
665708
666- return func (* args , ** kwargs )
709+ def __gt__ (self , other ):
710+ return self .__torch_function__ (torch .gt , (Wrapper ,), (self , other ))
667711
712+ def __lt__ (self , other ):
713+ return self .__torch_function__ (torch .lt , (Wrapper ,), (self , other ))
714+
715+ def __le__ (self , other ):
716+ return self .__torch_function__ (torch .le , (Wrapper ,), (self , other ))
717+
718+ def __eq__ (self , other ):
719+ return self .__torch_function__ (torch .eq , (Wrapper ,), (self , other ))
720+
721+ def __ne__ (self , other ):
722+ return self .__torch_function__ (torch .ne , (Wrapper ,), (self , other ))
723+
724+ def __bool__ (self ):
725+ return self .__torch_function__ (torch .Tensor .__bool__ , (Wrapper ,), (self ,))
726+
727+ def __int__ (self ):
728+ return self .__torch_function__ (torch .Tensor .__int__ , (Wrapper ,), (self ,))
729+
730+
731+ # unwrap inputs if necessary
732+ def unwrap (v ):
733+ if type (v ) in {tuple , list }:
734+ return type (v )(unwrap (vi ) for vi in v )
735+
736+ return v ._data if isinstance (v , Wrapper ) else v
737+
738+ # wrap inputs if necessary
739+ def wrap (v ):
740+ if type (v ) in {tuple , list }:
741+ return type (v )(wrap (vi ) for vi in v )
742+
743+ return Wrapper (v ) if isinstance (v , torch .Tensor ) else v
744+
745+ class TestEinsumOverride (TestCase ):
746+ "Regression test for gh-38479"
747+ def test_wrapper (self ):
668748 x = Wrapper (torch .randn (5 ))
669749 y = Wrapper (torch .randn (4 ))
670750 self .assertTrue (torch .allclose (torch .einsum ('i,j->ij' , x , y ),
@@ -678,5 +758,51 @@ def unwrap(v):
678758 torch .nn .functional .bilinear (a , c , b )))
679759
680760
761+ class TestGradCheckOverride (TestCase ):
762+ "Test that wrappers work with gradcheck."
763+ def test_gradcheck (self ):
764+ from torch .autograd import gradcheck
765+
766+ a = wrap (torch .tensor (5.0 , dtype = torch .double ))
767+ b = wrap (torch .tensor (6.0 , dtype = torch .double ))
768+
769+ a .requires_grad = True
770+ b .requires_grad = True
771+
772+ gradcheck (torch .add , (a , b ), raise_exception = False )
773+
774+ total_used_attrs = a .used_attrs .union (b .used_attrs )
775+ total_used_calls = a .used_calls .union (b .used_calls )
776+
777+ # These attributes (and the functions below) may change
778+ # if the gradcheck implementation changes. It's best to
779+ # aim for attributes that may be commonly present on other
780+ # Tensor-likes.
781+ self .assertEqual (total_used_attrs , {
782+ 'data' ,
783+ 'dtype' ,
784+ 'is_floating_point' ,
785+ 'is_sparse' ,
786+ 'layout' ,
787+ 'nelement' ,
788+ 'new_zeros' ,
789+ 'requires_grad' ,
790+ 'retain_grad' ,
791+ 'size' ,
792+ 'stride' ,
793+ })
794+
795+ self .assertEqual (total_used_calls , {
796+ torch .Tensor .new_zeros ,
797+ torch .Tensor .size ,
798+ torch .Tensor .is_floating_point ,
799+ torch .Tensor .nelement ,
800+ torch .Tensor .retain_grad ,
801+ torch .Tensor .stride ,
802+ torch .autograd .grad ,
803+ torch .add ,
804+ })
805+
806+
681807if __name__ == '__main__' :
682808 unittest .main ()
0 commit comments