Update Mamba types and pass through use_cache attr to MambaModel#29605
Update Mamba types and pass through use_cache attr to MambaModel#29605ArthurZucker merged 14 commits intohuggingface:mainfrom
Conversation
|
Thanks for adding this @koayon Pinging @gante for first review of the cache logic, as @ArthurZucker is off this week |
gante
left a comment
There was a problem hiding this comment.
LGTM, thank you for the PR 🤗
I'd like a final check from @ArthurZucker, though -- there are some terminology updates in the docstrings, and I'm not very familiar with Mamba :)
|
Hey @ArthurZucker! Hope you had a great holiday 🙌 |
| is_fast_path_available = all( | ||
| (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) | ||
| ( | ||
| selective_state_update, | ||
| selective_scan_fn, | ||
| causal_conv1d_fn, | ||
| causal_conv1d_update, | ||
| mamba_inner_fn, | ||
| ) | ||
| ) |
There was a problem hiding this comment.
this is unrelated and is styling, should be reverted!
There was a problem hiding this comment.
Thanks, I've updated the styling 👌
| class MambaCache: | ||
| def __init__(self, config, batch_size, dtype=torch.float16, device=None): | ||
| self.seqlen_offset = 0 | ||
| self.dtype = dtype | ||
| intermediate_size = config.intermediate_size | ||
| ssm_state_size = config.state_size | ||
| conv_kernel_size = config.conv_kernel | ||
|
|
||
| self.conv_states = { | ||
| i: torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype) | ||
| for i in range(config.num_hidden_layers) | ||
| } | ||
| self.ssm_states = { | ||
| i: torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype) | ||
| for i in range(config.num_hidden_layers) | ||
| } | ||
|
|
There was a problem hiding this comment.
if moved, let's just keep the styling of this one please
| ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 | ||
| ssm_parameters, | ||
| [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], | ||
| dim=-1, |
There was a problem hiding this comment.
same here, unrelated change
| ) | ||
| cache_params.conv_states[self.layer_idx].copy_(conv_states) | ||
| hidden_states = causal_conv1d_fn( | ||
| hidden_states, conv_weights, self.conv1d.bias, activation=self.activation |
| else: | ||
| if cache_params is not None: | ||
| conv_states = nn.functional.pad( | ||
| hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) |
There was a problem hiding this comment.
same here, unrelated change
| self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False) | ||
| self.x_proj = nn.Linear( | ||
| self.intermediate_size, | ||
| self.time_step_rank + self.ssm_state_size * 2, | ||
| bias=False, | ||
| ) |
|
|
||
| if cache_params is None and use_cache: | ||
| cache_params = MambaCache( | ||
| self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype |
There was a problem hiding this comment.
unrelated change let's revert
| return model_kwargs | ||
|
|
||
| def prepare_inputs_for_generation( | ||
| self, input_ids, cache_params=None, inputs_embeds=None, attention_mask=None, **kwargs |
| inputs_embeds=inputs_embeds, | ||
| output_hidden_states=output_hidden_states, | ||
| return_dict=return_dict, | ||
| **kwargs, |
There was a problem hiding this comment.
why is this required? it should not. The cache params are passed right above
There was a problem hiding this comment.
I believe it's the use_cache argument that needs to be passed in for this to work as expected - we could restrict to just passing that through?
There was a problem hiding this comment.
Have amended this to only pass through the use_cache argument
|
Hey @ArthurZucker, thanks for your review! 🙌 In terms of the image that you were sending, it's unfortunately not showing up for me. But without the change to pass in use_cache I don't see the cache_params being returned. If there's a difference for you, I've just thought that it might be how it's running on CUDA vs MPS/CPU? I append the following to the file: import torch as t
from transformers import AutoTokenizer
if __name__ == "__main__":
model = MambaForCausalLM(MambaConfig())
tokeniser = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
input_ids: t.Tensor = tokeniser("Hey how are you doing?", return_tensors="pt")["input_ids"] # type: ignore
out: MambaCausalLMOutput = model(input_ids=input_ids, use_cache=True)
assert out.cache_params is not None
print(out.cache_params.ssm_states)If the use_cache argument isn't passed through to the backbone (either with kwargs or separately as in the newer version), there is no cache_params returned and I get the error: python src/transformers/models/mamba/modeling_mamba.py
...
Traceback (most recent call last):
File "/[PATH_TO_TRANSFORMERS]/transformers/src/transformers/models/mamba/modeling_mamba.py", line 688, in <module>
assert out.cache_params is not None
AssertionErrorwhereas with the use_cache argument being passed through I get a tensor returned: {0: tensor([[[-5.5237e-04, 9.6599e-04, 6.6771e-04, ..., -5.3982e-04,
-4.6061e-04, -7.0508e-04],
[-3.7170e-05, -2.2089e-04, -1.0218e-04, ..., 7.1232e-05,
...It does seem like this would be required for the expected behaviour. Please let me know if you have any questions! 😄 |
|
Alright, when using from_pretrained, the cache is used and passed subsequently, but not when using the initialization |
ArthurZucker
left a comment
There was a problem hiding this comment.
Almost good to go!
| cache_params: Optional[MambaCache] = None, | ||
| labels: Optional[torch.LongTensor] = None, | ||
| output_hidden_states: Optional[bool] = None, | ||
| return_dict: Optional[bool] = None, |
There was a problem hiding this comment.
let's add use_cache as an arg here
ArthurZucker
left a comment
There was a problem hiding this comment.
Sorry forgot about these!
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
|
@ArthurZucker great suggestion, didn't realise that was an attribute of the Config 👌 |
|
The failing test seems new, but it's because when training the use_cache should be disabled by the model |
|
I'll have a look |
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
ArthurZucker
left a comment
There was a problem hiding this comment.
Thanks for iterating!
) * Update docstring for RMSNorm * Update cache_params object to correct MambaCache type * Update docstrings and type info * Pass through use_cache * ruff * Reformat with 119 char limit per line (thanks Arthur) * Pass through use_cache specifically to the backbone rather than all keyword arguments * Update src/transformers/models/mamba/modeling_mamba.py * Update src/transformers/models/mamba/modeling_mamba.py * Update src/transformers/models/mamba/modeling_mamba.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/mamba/modeling_mamba.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update tab * Update src/transformers/models/mamba/modeling_mamba.py * Update src/transformers/models/mamba/modeling_mamba.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
What does this PR do?
MambaCache,torch.Tensororlist[torch.Tensor]. This PR updates this toMambaCacheeverywhere which is inline with the attributes that are being accessed in the logic.MambaModel. This PR fixes this as below:Allowed the use_cache information to be passed through so that you can do:
And get back the ssm_states, which was not previously possible
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker
@gante