-
Notifications
You must be signed in to change notification settings - Fork 23
ENG-2161: Update Model Params for Embeddings #534
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENG-2161: Update Model Params for Embeddings #534
Conversation
| from aixplain.factories import ModelFactory | ||
|
|
||
| model = ModelFactory.get(embedding_model) | ||
| self.embedding_size = model.additional_info["embedding_size"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.embedding_size = model.additional_info.get("embedding_size")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@thiago-aixplain In that case, let me update this as try catch doesnt need to check for embedding_size key retrieval. It can only check if ModelFactory.get doesnt fail
| data = super().to_dict() | ||
| data["embedding_model"] = self.embedding_model | ||
| data["embedding_size"] = self.embedding_size | ||
| data["collection_type"] = self.version.split("-", 1)[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there any case that version is None? In case not, this is fine. Otherwise, it will crash.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
version shouldnt be none ever as it points to collection name and environment name in the format:
{index_type}-{env}-{team}-{collection_name}
|
|
||
| params = supplier_params(name=str(uuid4()), description=str(uuid4()), embedding_model=embedding_model) | ||
| index_model = IndexFactory.create(params=params) | ||
| assert index_model.embedding_model == embedding_model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what will happen if I add index_model = ModelFactory.get(index_model.id) before the assertions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@thiago-aixplain It should work the same way. IndexFactory.create also actually also returns a ModelFactory.get in the end
Add attributes to model card additional_info key:
Example Usage: