Skip to content

Commit 54eae2e

Browse files
committed
refactor: clean up duplicates
1 parent 487fc73 commit 54eae2e

2 files changed

Lines changed: 80 additions & 41 deletions

File tree

internal/server/handlers.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ func (h *Handler) SetBatchStore(store batchstore.Store) {
6666
}
6767

6868
func (h *Handler) translatedInference() *translatedInferenceService {
69-
return &translatedInferenceService{
69+
s := &translatedInferenceService{
7070
provider: h.provider,
7171
modelResolver: h.modelResolver,
7272
translatedRequestPatcher: h.translatedRequestPatcher,
@@ -76,6 +76,8 @@ func (h *Handler) translatedInference() *translatedInferenceService {
7676
responseCache: h.responseCache,
7777
guardrailsHash: h.guardrailsHash,
7878
}
79+
s.initHandlers()
80+
return s
7981
}
8082

8183
func (h *Handler) nativeBatch() *nativeBatchService {

internal/server/translated_inference_service.go

Lines changed: 77 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package server
22

33
import (
4+
"context"
45
"encoding/json"
56
"io"
67
"log/slog"
@@ -27,43 +28,50 @@ type translatedInferenceService struct {
2728
pricingResolver usage.PricingResolver
2829
responseCache *responsecache.ResponseCacheMiddleware
2930
guardrailsHash string
30-
}
3131

32-
func (s *translatedInferenceService) ChatCompletion(c *echo.Context) error {
33-
req, err := canonicalJSONRequestFromSemantics[*core.ChatRequest](c, core.DecodeChatRequest)
34-
if err != nil {
35-
return handleError(c, core.NewInvalidRequestError("invalid request body: "+err.Error(), err))
36-
}
37-
plan, err := ensureTranslatedRequestPlan(c, s.provider, s.modelResolver, &req.Model, &req.Provider)
38-
if err != nil {
39-
return handleError(c, err)
40-
}
41-
42-
ctx := c.Request().Context()
43-
if s.translatedRequestPatcher != nil {
44-
req, err = s.translatedRequestPatcher.PatchChatRequest(ctx, req)
45-
if err != nil {
46-
return handleError(c, err)
47-
}
48-
}
32+
// Pre-built handlers initialized via initHandlers.
33+
chatCompletionHandler echo.HandlerFunc
34+
responsesHandler echo.HandlerFunc
35+
}
4936

50-
if s.guardrailsHash != "" {
51-
ctx = core.WithGuardrailsHash(ctx, s.guardrailsHash)
52-
c.SetRequest(c.Request().WithContext(ctx))
53-
}
37+
func (s *translatedInferenceService) initHandlers() {
38+
s.chatCompletionHandler = newTranslatedHandler(s,
39+
core.DecodeChatRequest,
40+
func(r *core.ChatRequest) (*string, *string) { return &r.Model, &r.Provider },
41+
func(ctx context.Context, r *core.ChatRequest) (*core.ChatRequest, error) {
42+
return s.translatedRequestPatcher.PatchChatRequest(ctx, r)
43+
},
44+
func(r *core.ChatRequest) bool { return r.Stream },
45+
s.dispatchChatCompletion,
46+
)
47+
s.responsesHandler = newTranslatedHandler(s,
48+
core.DecodeResponsesRequest,
49+
func(r *core.ResponsesRequest) (*string, *string) { return &r.Model, &r.Provider },
50+
func(ctx context.Context, r *core.ResponsesRequest) (*core.ResponsesRequest, error) {
51+
return s.translatedRequestPatcher.PatchResponsesRequest(ctx, r)
52+
},
53+
func(r *core.ResponsesRequest) bool { return r.Stream },
54+
s.dispatchResponses,
55+
)
56+
}
5457

55-
if s.responseCache != nil && !req.Stream {
56-
body, marshalErr := marshalRequestBody(req)
57-
if marshalErr != nil {
58-
slog.Debug("marshalRequestBody failed", "err", marshalErr)
59-
} else {
60-
return s.responseCache.HandleRequest(c, body, func() error {
61-
return s.dispatchChatCompletion(c, req, plan)
62-
})
63-
}
58+
// newTranslatedHandler returns an echo.HandlerFunc that executes the
59+
// decode→plan→patch→dispatch pipeline for a translated inference endpoint.
60+
func newTranslatedHandler[R any](
61+
s *translatedInferenceService,
62+
decode func([]byte, *core.WhiteBoxPrompt) (R, error),
63+
modelProvider func(R) (*string, *string),
64+
patch func(context.Context, R) (R, error),
65+
isStream func(R) bool,
66+
dispatch func(*echo.Context, R, *core.ExecutionPlan) error,
67+
) echo.HandlerFunc {
68+
return func(c *echo.Context) error {
69+
return handleTranslatedInference(s, c, decode, modelProvider, patch, isStream, dispatch)
6470
}
71+
}
6572

66-
return s.dispatchChatCompletion(c, req, plan)
73+
func (s *translatedInferenceService) ChatCompletion(c *echo.Context) error {
74+
return s.chatCompletionHandler(c)
6775
}
6876

6977
func (s *translatedInferenceService) dispatchChatCompletion(c *echo.Context, req *core.ChatRequest, plan *core.ExecutionPlan) error {
@@ -90,40 +98,69 @@ func (s *translatedInferenceService) dispatchChatCompletion(c *echo.Context, req
9098
}
9199

92100
func (s *translatedInferenceService) Responses(c *echo.Context) error {
93-
req, err := canonicalJSONRequestFromSemantics[*core.ResponsesRequest](c, core.DecodeResponsesRequest)
101+
return s.responsesHandler(c)
102+
}
103+
104+
// handleTranslatedInference is the shared decode→plan→patch→dispatch pipeline
105+
// for ChatCompletion and Responses, parameterised over the request type.
106+
func handleTranslatedInference[R any](
107+
s *translatedInferenceService,
108+
c *echo.Context,
109+
decode func([]byte, *core.WhiteBoxPrompt) (R, error),
110+
modelProvider func(R) (*string, *string),
111+
patch func(context.Context, R) (R, error),
112+
isStream func(R) bool,
113+
dispatch func(*echo.Context, R, *core.ExecutionPlan) error,
114+
) error {
115+
req, err := canonicalJSONRequestFromSemantics(c, decode)
94116
if err != nil {
95117
return handleError(c, core.NewInvalidRequestError("invalid request body: "+err.Error(), err))
96118
}
97-
plan, err := ensureTranslatedRequestPlan(c, s.provider, s.modelResolver, &req.Model, &req.Provider)
119+
modelPtr, providerPtr := modelProvider(req)
120+
plan, err := ensureTranslatedRequestPlan(c, s.provider, s.modelResolver, modelPtr, providerPtr)
98121
if err != nil {
99122
return handleError(c, err)
100123
}
101124

102-
ctx := c.Request().Context()
103125
if s.translatedRequestPatcher != nil {
104-
req, err = s.translatedRequestPatcher.PatchResponsesRequest(ctx, req)
126+
ctx := c.Request().Context()
127+
req, err = patch(ctx, req)
105128
if err != nil {
106129
return handleError(c, err)
107130
}
108131
}
109132

133+
return handleWithCache(s, c, req, isStream(req), plan, dispatch)
134+
}
135+
136+
// handleWithCache injects the guardrails hash into context, then either routes the
137+
// request through the dual-layer response cache (non-streaming) or calls dispatch
138+
// directly (streaming). R is the post-patch request type.
139+
func handleWithCache[R any](
140+
s *translatedInferenceService,
141+
c *echo.Context,
142+
req R,
143+
stream bool,
144+
plan *core.ExecutionPlan,
145+
dispatch func(*echo.Context, R, *core.ExecutionPlan) error,
146+
) error {
110147
if s.guardrailsHash != "" {
111-
ctx = core.WithGuardrailsHash(ctx, s.guardrailsHash)
148+
ctx := core.WithGuardrailsHash(c.Request().Context(), s.guardrailsHash)
112149
c.SetRequest(c.Request().WithContext(ctx))
113150
}
114151

115-
if s.responseCache != nil && !req.Stream {
152+
if s.responseCache != nil && !stream {
116153
body, marshalErr := marshalRequestBody(req)
117154
if marshalErr != nil {
118155
slog.Debug("marshalRequestBody failed", "err", marshalErr)
119156
} else {
120157
return s.responseCache.HandleRequest(c, body, func() error {
121-
return s.dispatchResponses(c, req, plan)
158+
return dispatch(c, req, plan)
122159
})
123160
}
124161
}
125162

126-
return s.dispatchResponses(c, req, plan)
163+
return dispatch(c, req, plan)
127164
}
128165

129166
func (s *translatedInferenceService) dispatchResponses(c *echo.Context, req *core.ResponsesRequest, plan *core.ExecutionPlan) error {

0 commit comments

Comments
 (0)