Fix RecurrentGemma device_map#30273
Conversation
| indices = (slicing + to_shift[-1].int() - 1) % self.config.attention_window_size | ||
|
|
||
| k_out, v_out = self.key_states, self.value_states | ||
| k_out, v_out = self.key_states.to(key_states.device), self.value_states.to(value_states.device) |
There was a problem hiding this comment.
Due to _setup_cache, self.key_states and self.value_states are initialized on the device of the hidden state that we pass to the model in generate (e.g. cuda:0). However, this layer might not be on the same device as the hidden state if we use multi-gpu. Hence, we need to make sure that self.key_states is on the same device as key_states. Same for value_states.
| contextualized_states = recurrent_gate.type(acc_dtype) * recurrent_states[:, None].to( | ||
| recurrent_gate.device | ||
| ) |
There was a problem hiding this comment.
Same issue with recurrent_gate which is initialized in _setup_cache.
| contextualized_states = torch.zeros_like(hidden_states) | ||
| for t in range(hidden_states.shape[1]): | ||
| recurrent_states = recurrent_gate[:, t].type(acc_dtype) * recurrent_states | ||
| recurrent_states = recurrent_gate[:, t].type(acc_dtype) * recurrent_states.to(recurrent_gate.device) |
| self.register_buffer( | ||
| "normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.bfloat16), persistent=False | ||
| ) |
There was a problem hiding this comment.
We don't need this to be persistant. This fixes an issue that we get with accelerate too.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ArthurZucker
left a comment
There was a problem hiding this comment.
Thanks, the device thing could be fixed by placing them on the same device as self.key_states? rather than the device passed?
Also tad bit scared of the slow down of doing it there? But LGTM otherwise
| self.register_buffer( | ||
| "normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.bfloat16), persistent=False | ||
| ) |
I think it will slow down if why place them on the same device as self.key_states for example. Let's say |
* Switch to non persistant buffer * fix device mismatch issue due to cache * style
What does this PR do ?
This PR makes gemma compatible with multi-gpu device_map. To try out:
I get the same output in the single gpu or multi gpu setup.