Skip to content

MPS cumsum issue - RuntimeError: MPS does not support cumsum op with int64 input. Support has been added in macOS 13.3 #96610

@jacobmlloyd

Description

@jacobmlloyd

🐛 Describe the bug

I'm on a Macbook Pro M1 Pro and I've upgraded to 13.3 Beta 3 - I am running into the cumsum issue. I've created 2 new conda environment and installed the nightly version on 3/11/2023 at 12PM PST using
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
and the other using
conda install pytorch torchvision torchaudio -c pytorch-nightly
but I'm still getting cumsum errors.
It looks like #96512 may have broken the previous patch, unless I'm missing something? I'm new to this so any help is greatly appreciated!

Code:

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model = "Salesforce/codegen-350M-mono"
device = torch.device("mps")

# Load the CodeGen tokenizer and model
print("[+] Initializing Tokenizer")
tokenizer = AutoTokenizer.from_pretrained(model)
print("[+] Finished Initializing Tokenizer")
print("[+] Initializing Model")
model = AutoModelForCausalLM.from_pretrained(model).to(device)
print("[+] Finished Initializing Model")

prompt = "write a hello world function"

# Tokenize the description and generate the code
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(
    inputs["input_ids"].to('mps'),
    attention_mask = inputs["attention_mask"].to(device),
    min_new_tokens = 10,
    max_new_tokens = 1000,
    pad_token_id = 0,
    eos_token_id = tokenizer.eos_token_id,
    do_sample = True)

generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(generated_code)

Traceback:

/Users/lloyd/miniforge3/envs/coding2/lib/python3.9/site-packages/transformers/generation/utils.py:662: UserWarning: MPS: no support for int64 repeats mask, casting it to int32. Support has been added in macOS 13.3 (Triggered internally at /Users/runner/work/_temp/anaconda/conda-bld/pytorch_1678522277666/work/aten/src/ATen/native/mps/operations/Repeat.mm:229.)
  input_ids = input_ids.repeat_interleave(expand_size, dim=0)
Traceback (most recent call last):
  File "/Users/lloyd/Projects/machine-learning/models/salesforce-live.py", line 155, in <module>
    outputs = model.generate(
  File "/Users/lloyd/miniforge3/envs/coding2/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/Users/lloyd/miniforge3/envs/coding2/lib/python3.9/site-packages/transformers/generation/utils.py", line 1437, in generate
    return self.sample(
  File "/Users/lloyd/miniforge3/envs/coding2/lib/python3.9/site-packages/transformers/generation/utils.py", line 2440, in sample
    model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
  File "/Users/lloyd/miniforge3/envs/coding2/lib/python3.9/site-packages/transformers/models/codegen/modeling_codegen.py", line 646, in prepare_inputs_for_generation
    position_ids = attention_mask.long().cumsum(-1) - 1
**RuntimeError: MPS does not support cumsum op with int64 input. Support has been added in macOS 13.3**

Versions

Collecting environment information...
PyTorch version: 2.1.0.dev20230311
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 13.3 (arm64)
GCC version: Could not collect
Clang version: 14.0.0 (clang-1400.0.29.202)
CMake version: version 3.25.2
Libc version: N/A

Python version: 3.9.16 | packaged by conda-forge | (main, Feb 1 2023, 21:38:11) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-13.3-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Pro

Versions of relevant libraries:
[pip3] numpy==1.24.2
[pip3] torch==2.1.0.dev20230311
[pip3] torchaudio==2.0.0.dev20230311
[pip3] torchvision==0.15.0.dev20230311
[conda] numpy 1.24.2 py39hff61c6a_0 conda-forge
[conda] pytorch 2.1.0.dev20230311 py3.9_0 pytorch-nightly
[conda] torchaudio 2.0.0.dev20230311 py39_cpu pytorch-nightly
[conda] torchvision 0.15.0.dev20230311 py39_cpu pytorch-nightly

cc @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr @abhudev

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: mpsRelated to Apple Metal Performance Shaders frameworktriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions