我好像发现了一个bug

这里好像不应该用一个transpose,因为SVD分解出来的vh=B x emb_dim x emb_dim,(_, s, vh = torch.svd(error_every_class, some=False), 每一个batch,其实它的列向量才是空间的基, 你这里好像把行向量作为空间的基了,这样求得的M,最后拿来投影error时(测试M能否把error投影到新空间中的0)
assert (torch.matmul(error_every_class, M) > 1E-6).sum().item() == 0时会报错,
如果把转置去掉就不会了
我好像发现了一个bug

这里好像不应该用一个transpose,因为SVD分解出来的vh=B x emb_dim x emb_dim,(
_, s, vh = torch.svd(error_every_class, some=False), 每一个batch,其实它的列向量才是空间的基, 你这里好像把行向量作为空间的基了,这样求得的M,最后拿来投影error时(测试M能否把error投影到新空间中的0)assert (torch.matmul(error_every_class, M) > 1E-6).sum().item() == 0时会报错,如果把转置去掉就不会了