@@ -553,6 +553,8 @@ def enable_compression(
553553 assert not self .se_ttebd .resnet_dt , (
554554 "Model compression error: descriptor resnet_dt must be false!"
555555 )
556+ if self .tebd_input_mode != "strip" :
557+ raise RuntimeError ("Cannot compress model when tebd_input_mode != 'strip'" )
556558 for tt in self .se_ttebd .exclude_types :
557559 if (tt [0 ] not in range (self .se_ttebd .ntypes )) or (
558560 tt [1 ] not in range (self .se_ttebd .ntypes )
@@ -573,9 +575,6 @@ def enable_compression(
573575 "Empty embedding-nets are not supported in model compression!"
574576 )
575577
576- if self .tebd_input_mode != "strip" :
577- raise RuntimeError ("Cannot compress model when tebd_input_mode == 'concat'" )
578-
579578 data = self .serialize ()
580579 self .table = DPTabulate (
581580 self ,
@@ -597,9 +596,12 @@ def enable_compression(
597596 )
598597
599598 self .se_ttebd .enable_compression (
600- self .table .data , self .table_config , self .lower , self .upper
599+ self .type_embedding ,
600+ self .table .data ,
601+ self .table_config ,
602+ self .lower ,
603+ self .upper ,
601604 )
602- self .se_ttebd .type_embedding_compression (self .type_embedding )
603605 self .compress = True
604606
605607
@@ -695,13 +697,17 @@ def __init__(
695697 self .stats = None
696698 # compression related variables
697699 self .compress = False
700+ # For geometric compression
698701 self .compress_info = nn .ParameterList (
699702 [nn .Parameter (torch .zeros (0 , dtype = self .prec , device = "cpu" ))]
700703 )
701704 self .compress_data = nn .ParameterList (
702705 [nn .Parameter (torch .zeros (0 , dtype = self .prec , device = env .DEVICE ))]
703706 )
704- self .type_embd_data : Optional [torch .Tensor ] = None
707+ # For type embedding compression
708+ self .register_buffer (
709+ "type_embd_data" , torch .zeros (0 , dtype = self .prec , device = env .DEVICE )
710+ )
705711
706712 def get_rcut (self ) -> float :
707713 """Returns the cut-off radius."""
@@ -840,30 +846,6 @@ def reinit_exclude(
840846 self .exclude_types = exclude_types
841847 self .emask = PairExcludeMask (self .ntypes , exclude_types = exclude_types )
842848
843- def type_embedding_compression (self , type_embedding_net : TypeEmbedNet ) -> None :
844- """Precompute strip-mode type embeddings for all type pairs."""
845- if self .tebd_input_mode != "strip" :
846- raise RuntimeError ("Type embedding compression only works in strip mode" )
847- if self .filter_layers_strip is None :
848- raise RuntimeError (
849- "filter_layers_strip must exist for type embedding compression"
850- )
851-
852- with torch .no_grad ():
853- full_embd = type_embedding_net .get_full_embedding (env .DEVICE )
854- nt , t_dim = full_embd .shape
855- type_embedding_i = full_embd .view (nt , 1 , t_dim ).expand (nt , nt , t_dim )
856- type_embedding_j = full_embd .view (1 , nt , t_dim ).expand (nt , nt , t_dim )
857- two_side_type_embedding = torch .cat (
858- [type_embedding_i , type_embedding_j ], dim = - 1
859- ).reshape (- 1 , t_dim * 2 )
860- embd_tensor = self .filter_layers_strip .networks [0 ](
861- two_side_type_embedding
862- ).detach ()
863- if hasattr (self , "type_embd_data" ):
864- del self .type_embd_data
865- self .register_buffer ("type_embd_data" , embd_tensor )
866-
867849 def forward (
868850 self ,
869851 nlist : torch .Tensor ,
@@ -1013,7 +995,7 @@ def forward(
1013995 idx_i = nei_type_i * ntypes_with_padding
1014996 idx_j = nei_type_j
1015997 idx = (idx_i + idx_j ).reshape (- 1 ).to (torch .long )
1016- if self .type_embd_data is not None :
998+ if self .compress :
1017999 tt_full = self .type_embd_data
10181000 else :
10191001 type_embedding_i = torch .tile (
@@ -1061,6 +1043,7 @@ def forward(
10611043
10621044 def enable_compression (
10631045 self ,
1046+ type_embedding_net : TypeEmbedNet ,
10641047 table_data : dict ,
10651048 table_config : dict ,
10661049 lower : dict ,
@@ -1070,6 +1053,8 @@ def enable_compression(
10701053
10711054 Parameters
10721055 ----------
1056+ type_embedding_net : TypeEmbedNet
1057+ The type embedding network
10731058 table_data : dict
10741059 The tabulated data from DPTabulate
10751060 table_config : dict
@@ -1079,6 +1064,13 @@ def enable_compression(
10791064 upper : dict
10801065 Upper bounds for compression
10811066 """
1067+ if self .tebd_input_mode != "strip" :
1068+ raise RuntimeError ("Type embedding compression only works in strip mode" )
1069+ if self .filter_layers_strip is None :
1070+ raise RuntimeError (
1071+ "filter_layers_strip must exist for type embedding compression"
1072+ )
1073+
10821074 # Compress the main geometric embedding network (self.filter_layers)
10831075 net_key = "filter_net"
10841076 self .compress_info [0 ] = torch .as_tensor (
@@ -1097,6 +1089,22 @@ def enable_compression(
10971089 device = env .DEVICE , dtype = self .prec
10981090 )
10991091
1092+ # Compress the type embedding network (self.filter_layers_strip)
1093+ with torch .no_grad ():
1094+ full_embd = type_embedding_net .get_full_embedding (env .DEVICE )
1095+ nt , t_dim = full_embd .shape
1096+ type_embedding_i = full_embd .view (nt , 1 , t_dim ).expand (nt , nt , t_dim )
1097+ type_embedding_j = full_embd .view (1 , nt , t_dim ).expand (nt , nt , t_dim )
1098+ two_side_type_embedding = torch .cat (
1099+ [type_embedding_i , type_embedding_j ], dim = - 1
1100+ ).reshape (- 1 , t_dim * 2 )
1101+ embd_tensor = self .filter_layers_strip .networks [0 ](
1102+ two_side_type_embedding
1103+ ).detach ()
1104+ if hasattr (self , "type_embd_data" ):
1105+ del self .type_embd_data
1106+ self .register_buffer ("type_embd_data" , embd_tensor )
1107+
11001108 self .compress = True
11011109
11021110 def has_message_passing (self ) -> bool :
0 commit comments