Skip to content

Commit 16b383f

Browse files
committed
feat(pt): enforce strip mode for model compression and enhance type embedding handling
1 parent 878e3f8 commit 16b383f

File tree

2 files changed

+39
-811
lines changed

2 files changed

+39
-811
lines changed

deepmd/pt/model/descriptor/se_t_tebd.py

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,8 @@ def enable_compression(
553553
assert not self.se_ttebd.resnet_dt, (
554554
"Model compression error: descriptor resnet_dt must be false!"
555555
)
556+
if self.tebd_input_mode != "strip":
557+
raise RuntimeError("Cannot compress model when tebd_input_mode != 'strip'")
556558
for tt in self.se_ttebd.exclude_types:
557559
if (tt[0] not in range(self.se_ttebd.ntypes)) or (
558560
tt[1] not in range(self.se_ttebd.ntypes)
@@ -573,9 +575,6 @@ def enable_compression(
573575
"Empty embedding-nets are not supported in model compression!"
574576
)
575577

576-
if self.tebd_input_mode != "strip":
577-
raise RuntimeError("Cannot compress model when tebd_input_mode == 'concat'")
578-
579578
data = self.serialize()
580579
self.table = DPTabulate(
581580
self,
@@ -597,9 +596,12 @@ def enable_compression(
597596
)
598597

599598
self.se_ttebd.enable_compression(
600-
self.table.data, self.table_config, self.lower, self.upper
599+
self.type_embedding,
600+
self.table.data,
601+
self.table_config,
602+
self.lower,
603+
self.upper,
601604
)
602-
self.se_ttebd.type_embedding_compression(self.type_embedding)
603605
self.compress = True
604606

605607

@@ -695,13 +697,17 @@ def __init__(
695697
self.stats = None
696698
# compression related variables
697699
self.compress = False
700+
# For geometric compression
698701
self.compress_info = nn.ParameterList(
699702
[nn.Parameter(torch.zeros(0, dtype=self.prec, device="cpu"))]
700703
)
701704
self.compress_data = nn.ParameterList(
702705
[nn.Parameter(torch.zeros(0, dtype=self.prec, device=env.DEVICE))]
703706
)
704-
self.type_embd_data: Optional[torch.Tensor] = None
707+
# For type embedding compression
708+
self.register_buffer(
709+
"type_embd_data", torch.zeros(0, dtype=self.prec, device=env.DEVICE)
710+
)
705711

706712
def get_rcut(self) -> float:
707713
"""Returns the cut-off radius."""
@@ -840,30 +846,6 @@ def reinit_exclude(
840846
self.exclude_types = exclude_types
841847
self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types)
842848

843-
def type_embedding_compression(self, type_embedding_net: TypeEmbedNet) -> None:
844-
"""Precompute strip-mode type embeddings for all type pairs."""
845-
if self.tebd_input_mode != "strip":
846-
raise RuntimeError("Type embedding compression only works in strip mode")
847-
if self.filter_layers_strip is None:
848-
raise RuntimeError(
849-
"filter_layers_strip must exist for type embedding compression"
850-
)
851-
852-
with torch.no_grad():
853-
full_embd = type_embedding_net.get_full_embedding(env.DEVICE)
854-
nt, t_dim = full_embd.shape
855-
type_embedding_i = full_embd.view(nt, 1, t_dim).expand(nt, nt, t_dim)
856-
type_embedding_j = full_embd.view(1, nt, t_dim).expand(nt, nt, t_dim)
857-
two_side_type_embedding = torch.cat(
858-
[type_embedding_i, type_embedding_j], dim=-1
859-
).reshape(-1, t_dim * 2)
860-
embd_tensor = self.filter_layers_strip.networks[0](
861-
two_side_type_embedding
862-
).detach()
863-
if hasattr(self, "type_embd_data"):
864-
del self.type_embd_data
865-
self.register_buffer("type_embd_data", embd_tensor)
866-
867849
def forward(
868850
self,
869851
nlist: torch.Tensor,
@@ -1013,7 +995,7 @@ def forward(
1013995
idx_i = nei_type_i * ntypes_with_padding
1014996
idx_j = nei_type_j
1015997
idx = (idx_i + idx_j).reshape(-1).to(torch.long)
1016-
if self.type_embd_data is not None:
998+
if self.compress:
1017999
tt_full = self.type_embd_data
10181000
else:
10191001
type_embedding_i = torch.tile(
@@ -1061,6 +1043,7 @@ def forward(
10611043

10621044
def enable_compression(
10631045
self,
1046+
type_embedding_net: TypeEmbedNet,
10641047
table_data: dict,
10651048
table_config: dict,
10661049
lower: dict,
@@ -1070,6 +1053,8 @@ def enable_compression(
10701053
10711054
Parameters
10721055
----------
1056+
type_embedding_net : TypeEmbedNet
1057+
The type embedding network
10731058
table_data : dict
10741059
The tabulated data from DPTabulate
10751060
table_config : dict
@@ -1079,6 +1064,13 @@ def enable_compression(
10791064
upper : dict
10801065
Upper bounds for compression
10811066
"""
1067+
if self.tebd_input_mode != "strip":
1068+
raise RuntimeError("Type embedding compression only works in strip mode")
1069+
if self.filter_layers_strip is None:
1070+
raise RuntimeError(
1071+
"filter_layers_strip must exist for type embedding compression"
1072+
)
1073+
10821074
# Compress the main geometric embedding network (self.filter_layers)
10831075
net_key = "filter_net"
10841076
self.compress_info[0] = torch.as_tensor(
@@ -1097,6 +1089,22 @@ def enable_compression(
10971089
device=env.DEVICE, dtype=self.prec
10981090
)
10991091

1092+
# Compress the type embedding network (self.filter_layers_strip)
1093+
with torch.no_grad():
1094+
full_embd = type_embedding_net.get_full_embedding(env.DEVICE)
1095+
nt, t_dim = full_embd.shape
1096+
type_embedding_i = full_embd.view(nt, 1, t_dim).expand(nt, nt, t_dim)
1097+
type_embedding_j = full_embd.view(1, nt, t_dim).expand(nt, nt, t_dim)
1098+
two_side_type_embedding = torch.cat(
1099+
[type_embedding_i, type_embedding_j], dim=-1
1100+
).reshape(-1, t_dim * 2)
1101+
embd_tensor = self.filter_layers_strip.networks[0](
1102+
two_side_type_embedding
1103+
).detach()
1104+
if hasattr(self, "type_embd_data"):
1105+
del self.type_embd_data
1106+
self.register_buffer("type_embd_data", embd_tensor)
1107+
11001108
self.compress = True
11011109

11021110
def has_message_passing(self) -> bool:

0 commit comments

Comments
 (0)