[ML] Refactor OpenAI request managers#124144
Conversation
| public static final String COMPLETION_ERROR_PREFIX = "OpenAI chat completions"; | ||
| public static final String USER_ROLE = "user"; | ||
|
|
||
| static final ResponseHandler COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler( |
There was a problem hiding this comment.
This changes are basically to move the logic from the request manager files into here.
| // It's possible that two inference endpoints have the same information defining the group but have different | ||
| // rate limits then they should be in different groups otherwise whoever initially created the group will set | ||
| // the rate and the other inference endpoint's rate will be ignored | ||
| return new EndpointGrouping(rateLimitGroup, rateLimitSettings); |
There was a problem hiding this comment.
We don't need to be recreating the object on each call.
| this.rateLimitSettings = rateLimitSettings; | ||
| } | ||
|
|
||
| BaseRequestManager(ThreadPool threadPool, RateLimitGroupingModel rateLimitGroupingModel) { |
There was a problem hiding this comment.
This is a stopgap. Once all the request managers are refactored the old constructor can be removed.
| * This is a temporary class to use while we refactor all the request managers. After all the request managers extend | ||
| * this class we'll move this functionality directly into the {@link BaseRequestManager}. | ||
| */ | ||
| public class GenericRequestManager<T extends InferenceInputs> extends BaseRequestManager { |
There was a problem hiding this comment.
Once all the request managers are refactored, I envision that we'll be able to move this logic up into the base class.
|
|
||
| import static org.elasticsearch.xpack.inference.common.Truncator.truncate; | ||
|
|
||
| public class TruncatingRequestManager extends BaseRequestManager { |
There was a problem hiding this comment.
Currently this would only be used for text embedding requests.
|
|
||
| public HttpRequest createHttpRequest() { | ||
| HttpPost httpPost = new HttpPost(account.uri()); | ||
| HttpPost httpPost = new HttpPost(model.uri()); |
There was a problem hiding this comment.
I pushed all the account related stuff into the model since we need it there to calculate the hash anyway.
| ); | ||
| } | ||
|
|
||
| public static URI buildDefaultUri() throws URISyntaxException { |
There was a problem hiding this comment.
public because it's used in a few tests.
| public abstract ExecutableAction accept(OpenAiActionVisitor creator, Map<String, Object> taskSettings); | ||
|
|
||
| public int rateLimitGroupingHash() { | ||
| return Objects.hash(rateLimitServiceSettings.modelId(), apiKey, uri); |
There was a problem hiding this comment.
We could probably calculate this only once, I suppose to avoid weird bugs maybe it's better to do it on every call in the event that one of the fields gets reset. They shouldn't though since they're final.
|
Pinging @elastic/ml-core (Team:ML) |
|
|
||
| public abstract int rateLimitGroupingHash(); | ||
|
|
||
| public abstract RateLimitSettings rateLimitSettings(); |
There was a problem hiding this comment.
Should we maybe (eventually) move this into Model, since I think everyone has RateLimitSettings anyway?
There was a problem hiding this comment.
Good idea. There might be a few places like the Bedrock implementation that doesn't. I'll see if we can handle that elegantly.
* Code compiling * Removing OpenAiAccount
💚 Backport successful
|
* Code compiling * Removing OpenAiAccount
* Code compiling * Removing OpenAiAccount
This PR demonstrates how we can remove many of the RequestManager files within the inference API. I only did this for OpenAI as a demonstration. If we're ok with the approach I can do it for the rest of the services.