Skip to content

Cohere: Use diff tool instead of Copied from mechanism#31211

Closed
younesbelkada wants to merge 5 commits intomainfrom
cohere-diff-2
Closed

Cohere: Use diff tool instead of Copied from mechanism#31211
younesbelkada wants to merge 5 commits intomainfrom
cohere-diff-2

Conversation

@younesbelkada
Copy link
Copy Markdown
Contributor

What does this PR do?

As per title

cc @ArthurZucker

ALL_LAYERNORM_LAYERS.append(CohereRMSNorm)


class CohereLayerNorm(CohereRMSNorm):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case users rely on CohereLayerNorm class

Comment on lines +745 to +756
def logit_scale(self):
logger.warning(
"`logit_scale` attribute is going to be deprecated in future versions, please use `model.config.logit_scale` instead."
)
return self.config.logit_scale

@property
def tie_word_embeddings(self):
logger.warning(
"`tie_word_embeddings` attribute is going to be deprecated in future versions, please use `model.config.tie_word_embeddings` instead."
)
return self.config.tie_word_embeddings
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these attributes are public, but I suggest to use the config variable directly to make it cleaner with a deprecation cycle

Comment on lines +116 to +142
class CohereLinearScalingRotaryEmbedding(CohereRotaryEmbedding):
"""CohereRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""

def forward(self, x, position_ids):
# difference to the original RoPE: a scaling factor is aplied to the position ids
position_ids = position_ids.float() / self.scaling_factor
cos, sin = super().forward(x, position_ids)
return cos, sin


class CohereDynamicNTKScalingRotaryEmbedding(CohereRotaryEmbedding):
"""CohereRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""

def forward(self, x, position_ids):
# difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (
base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation

cos, sin = super().forward(x, position_ids)
return cos, sin
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There classes are never used but I couldn't find a way to remove them

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no way to do so 😓 Maybe a skip layer?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm yeah, or maybe it is ok to manually remove them from now

Comment on lines +742 to +754
@property
def logit_scale(self):
logger.warning(
"`logit_scale` attribute is going to be deprecated in future versions, please use `model.config.logit_scale` instead."
)
return self.config.logit_scale

@property
def tie_word_embeddings(self):
logger.warning(
"`tie_word_embeddings` attribute is going to be deprecated in future versions, please use `model.config.tie_word_embeddings` instead."
)
return self.config.tie_word_embeddings
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any idea why these are not propagated in the generated modeling code?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll have to dive a bit into this!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok that's on me to do now!

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

model_type = "cohere"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the init we can use the super.init() and should use super_kwargs to only change the ones that are actually different from the default we have in gemme 😉

Comment on lines +742 to +754
@property
def logit_scale(self):
logger.warning(
"`logit_scale` attribute is going to be deprecated in future versions, please use `model.config.logit_scale` instead."
)
return self.config.logit_scale

@property
def tie_word_embeddings(self):
logger.warning(
"`tie_word_embeddings` attribute is going to be deprecated in future versions, please use `model.config.tie_word_embeddings` instead."
)
return self.config.tie_word_embeddings
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok that's on me to do now!

Comment on lines +116 to +142
class CohereLinearScalingRotaryEmbedding(CohereRotaryEmbedding):
"""CohereRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""

def forward(self, x, position_ids):
# difference to the original RoPE: a scaling factor is aplied to the position ids
position_ids = position_ids.float() / self.scaling_factor
cos, sin = super().forward(x, position_ids)
return cos, sin


class CohereDynamicNTKScalingRotaryEmbedding(CohereRotaryEmbedding):
"""CohereRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""

def forward(self, x, position_ids):
# difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (
base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation

cos, sin = super().forward(x, position_ids)
return cos, sin
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no way to do so 😓 Maybe a skip layer?

_CONFIG_FOR_DOC = "CohereConfig"


# Copied from transformers.models.llama.modeling_llama._get_unpad_data
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a problem no? the unpad_data should still be present!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return attn_output, None, past_key_value


def _get_unpad_data(attention_mask):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the _get_unpad_data is pasted here @ArthurZucker

@github-actions
Copy link
Copy Markdown
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Jul 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants