@@ -644,107 +644,27 @@ def test(self):
644644
645645generate_tensor_like_override_tests (TestTorchFunctionOverride )
646646
647- class Wrapper :
648- "Basic data container that knows how to unwrap itself"
649- def __init__ (self , data ):
650- self .__dict__ ["_data" ] = data
651- self .__dict__ ["used_attrs" ] = set ()
652- self .__dict__ ["used_calls" ] = set ()
653-
654- def __getattr__ (self , name ):
655- if name in self .__dict__ :
656- return self .__dict__ [name ]
657- self .used_attrs .add (name )
658-
659- val = getattr (self ._data , name )
660-
661- # If it's a method
662- if callable (val ):
663- c = getattr (type (self ._data ), name )
664- # Don't append self to args if classmethod/staticmethod
665- if c is val :
666- return lambda * a , ** kw : wrap (self .__torch_function__ (c , (Wrapper ,), args = a , kwargs = kw ))
667- # Otherwise append self to args
668- return lambda * a , ** kw : wrap (self .__torch_function__ (c , (Wrapper ,), args = (self ,) + a , kwargs = kw ))
669-
670- return wrap (val )
671-
672- def __setattr__ (self , name , value ):
673- if name in self .__dict__ :
674- self .__dict__ [name ] = value
675-
676- self .used_attrs .add (name )
677- setattr (self ._data , name , unwrap (value ))
678-
679- def __setitem__ (self , key , value ):
680- self ._data [unwrap (key )] = unwrap (value )
681-
682- def __getitem__ (self , key ):
683- return wrap (self ._data [unwrap (key )])
684-
685- def __torch_function__ (self , func , types , args = (), kwargs = None ):
686- if kwargs is None :
687- kwargs = {}
688- self .used_calls .add (func )
689- args = unwrap (tuple (args ))
690- kwargs = {k : unwrap (v ) for k , v in kwargs .items ()}
691-
692- return wrap (func (* args , ** kwargs ))
693-
694- def __add__ (self , other ):
695- return self .__torch_function__ (torch .add , (Wrapper ,), (self , other ))
696-
697- def __sub__ (self , other ):
698- return self .__torch_function__ (torch .sub , (Wrapper ,), (self , other ))
699-
700- def __truediv__ (self , other ):
701- return self .__torch_function__ (torch .true_divide , (Wrapper ,), (self , other ))
702-
703- def __floordiv__ (self , other ):
704- return self .__torch_function__ (torch .floor_divide , (Wrapper ,), (self , other ))
705-
706- def __ge__ (self , other ):
707- return self .__torch_function__ (torch .ge , (Wrapper ,), (self , other ))
708-
709- def __gt__ (self , other ):
710- return self .__torch_function__ (torch .gt , (Wrapper ,), (self , other ))
711-
712- def __lt__ (self , other ):
713- return self .__torch_function__ (torch .lt , (Wrapper ,), (self , other ))
714-
715- def __le__ (self , other ):
716- return self .__torch_function__ (torch .le , (Wrapper ,), (self , other ))
717-
718- def __eq__ (self , other ):
719- return self .__torch_function__ (torch .eq , (Wrapper ,), (self , other ))
720-
721- def __ne__ (self , other ):
722- return self .__torch_function__ (torch .ne , (Wrapper ,), (self , other ))
723-
724- def __bool__ (self ):
725- return self .__torch_function__ (torch .Tensor .__bool__ , (Wrapper ,), (self ,))
726-
727- def __int__ (self ):
728- return self .__torch_function__ (torch .Tensor .__int__ , (Wrapper ,), (self ,))
729-
647+ class TestEinsumOverride (TestCase ):
648+ "Regression test for gh-38479"
649+ def test_wrapper (self ):
650+ class Wrapper ():
651+ "Basic data container that knows how to unwrap itself"
652+ def __init__ (self , data ):
653+ self .data = data
730654
731- # unwrap inputs if necessary
732- def unwrap (v ):
733- if type (v ) in {tuple , list }:
734- return type (v )(unwrap (vi ) for vi in v )
655+ def __torch_function__ (self , func , types , args = (), kwargs = None ):
656+ if kwargs is None :
657+ kwargs = {}
735658
736- return v ._data if isinstance (v , Wrapper ) else v
659+ # unwrap inputs if necessary
660+ def unwrap (v ):
661+ return v .data if isinstance (v , Wrapper ) else v
737662
738- # wrap inputs if necessary
739- def wrap (v ):
740- if type (v ) in {tuple , list }:
741- return type (v )(wrap (vi ) for vi in v )
663+ args = map (unwrap , args )
664+ kwargs = {k : unwrap (v ) for k , v in kwargs .items ()}
742665
743- return Wrapper ( v ) if isinstance ( v , torch . Tensor ) else v
666+ return func ( * args , ** kwargs )
744667
745- class TestEinsumOverride (TestCase ):
746- "Regression test for gh-38479"
747- def test_wrapper (self ):
748668 x = Wrapper (torch .randn (5 ))
749669 y = Wrapper (torch .randn (4 ))
750670 self .assertTrue (torch .allclose (torch .einsum ('i,j->ij' , x , y ),
@@ -758,51 +678,5 @@ def test_wrapper(self):
758678 torch .nn .functional .bilinear (a , c , b )))
759679
760680
761- class TestGradCheckOverride (TestCase ):
762- "Test that wrappers work with gradcheck."
763- def test_gradcheck (self ):
764- from torch .autograd import gradcheck
765-
766- a = wrap (torch .tensor (5.0 , dtype = torch .double ))
767- b = wrap (torch .tensor (6.0 , dtype = torch .double ))
768-
769- a .requires_grad = True
770- b .requires_grad = True
771-
772- gradcheck (torch .add , (a , b ), raise_exception = False )
773-
774- total_used_attrs = a .used_attrs .union (b .used_attrs )
775- total_used_calls = a .used_calls .union (b .used_calls )
776-
777- # These attributes (and the functions below) may change
778- # if the gradcheck implementation changes. It's best to
779- # aim for attributes that may be commonly present on other
780- # Tensor-likes.
781- self .assertEqual (total_used_attrs , {
782- 'data' ,
783- 'dtype' ,
784- 'is_floating_point' ,
785- 'is_sparse' ,
786- 'layout' ,
787- 'nelement' ,
788- 'new_zeros' ,
789- 'requires_grad' ,
790- 'retain_grad' ,
791- 'size' ,
792- 'stride' ,
793- })
794-
795- self .assertEqual (total_used_calls , {
796- torch .Tensor .new_zeros ,
797- torch .Tensor .size ,
798- torch .Tensor .is_floating_point ,
799- torch .Tensor .nelement ,
800- torch .Tensor .retain_grad ,
801- torch .Tensor .stride ,
802- torch .autograd .grad ,
803- torch .add ,
804- })
805-
806-
807681if __name__ == '__main__' :
808682 unittest .main ()
0 commit comments