Skip to content

[Sparsity] When sparsifying using Wanda on only Linear layers, PerChannelNormObserver() being added to embedding layers, leading to RuntimeError: linalg.vector_norm: Expected a floating point or complex tensor as input. Got Long #1133

@agrawal-aka

Description

@agrawal-aka

Hello, I created a test script which I was testing on Aarch64 platform, for distilbert inference and using wanda sparsifier:

import torch
from transformers import BertForSequenceClassification, BertTokenizer, pipeline
from torch.ao.pruning import WeightNormSparsifier
from torch.profiler import profile, record_function, ProfilerActivity
import torch.profiler
from torchao.sparsity.wanda import WandaSparsifier
from torchao.quantization.quant_api import _is_linear
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification

torch.manual_seed(100)

sparsifier = WandaSparsifier(
    sparsity_level=0.6
)

tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")

sparse_config = []
for name, mod in model.named_modules():
    if _is_linear(mod, name):
        sparse_config.append({"tensor_fqn": f"{name}.weight"})

print("sparse config:",sparse_config)
sparsifier.prepare(model, sparse_config)
#print(model.distilbert.embeddings)

#Calibration samples - for wanda
calibration_texts = [
    "I love using CPUs for inference.",
    "This is a sample text for calibration.",
    "Calibration is important for pruning accuracy."]


# Tokenize and pass the calibration samples through the model
for text in calibration_texts:
    inputs = tokenizer(text, return_tensors="pt")
    #print(inputs)
    with torch.no_grad():
        model(**inputs)  # Forward pass to collect activation statistics

# Now that activation statistics have been collected, you can proceed with pruning
sparsifier.step()
sparsifier.squash_mask()

# Apply sparsity to linear layers and convert to CSR format
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Linear) and "layer" in name:
        # Convert dense weights to CSR format
        module.weight = torch.nn.Parameter(module.weight.to_sparse_csr())


# Set the model to evaluation mode
model.eval()

# Initialize Hugging Face sentiment analysis pipeline with the sparsified model
sentiment_analysis_pipeline = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)

# Run inference using the sparsified model
input_text = "I really love using PyTorch"
inputs = tokenizer(input_text, return_tensors="pt")
with torch.no_grad():
    with profile(with_stack=True,
    profile_memory=True, record_shapes=True) as prof:
        outputs = model(**inputs)
print(prof.key_averages(group_by_input_shape=False).table(sort_by="self_cpu_time_total", row_limit=-1))
print(outputs)
prediction = model.config.id2label[outputs.logits.argmax().item()]

print(f"Predicted sentiment: {prediction}")

Which is raising the below error:
image

The issue is coming only when I use wanda, (as using weightnorm sparsifier doesnt create this issue, and everything seems to run fine, with the same sparse config). I understood that the problem is coming due to PerChannelNormObserver() being attached to embedding layers after sparsifier.prepare(), which internally triggers linalg.vector_norm :
image

so, as a workaround after sparsifier.prepare() is called, I reinitialised the embedding layers from pretrained model again, before passing in the calibration texts and was able to run the script successfully.

sparsifier.prepare(model, sparse_config)
#added workaround to remove observers from embedding layer to avoid error
model.distilbert.embeddings = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english").distilbert.embeddings

# Print the model to confirm that the observer has been removed from the embeddings block
print(model.distilbert.embeddings)

# Tokenize and pass the calibration samples through the model
#.

Am I missing something out here in my script? This shouldn't be the expected behaviour i suppose, because the sparse config is the same when I use either the weight norm or wanda, but it seems to work in one case and not in the other.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions