Skip to content

Commit 23895f1

Browse files
committed
refactor: header validation and add custom error
1 parent 5b01ca4 commit 23895f1

3 files changed

Lines changed: 40 additions & 14 deletions

File tree

mobius/internal/providers/azure/base.go

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package azure
33
import (
44
"net/http"
55

6-
"github.com/go-playground/validator/v10"
76
"github.com/missingstudio/studio/backend/internal/providers/base"
87
"github.com/missingstudio/studio/backend/pkg/utils"
98
"github.com/missingstudio/studio/common/errors"
@@ -12,10 +11,10 @@ import (
1211
type AzureProviderFactory struct{}
1312

1413
type AzureHeaders struct {
15-
APIKey string `validate:"required" json:"Authorization"`
16-
ResourceName string `validate:"required" json:"X-Ms-Azure-Resource-Name"`
17-
DeploymentID string `validate:"required" json:"X-Ms-Deployment-ID"`
18-
APIVersion string `validate:"required" json:"X-Ms-API-Version"`
14+
APIKey string `validate:"required" json:"Authorization" error:"API key is required"`
15+
ResourceName string `validate:"required" json:"X-Ms-Azure-Resource-Name" error:"Resource Name is required"`
16+
DeploymentID string `validate:"required" json:"X-Ms-Deployment-ID" error:"Deployment ID is required"`
17+
APIVersion string `validate:"required" json:"X-Ms-API-Version" error:"API Version is required"`
1918
}
2019

2120
func (azf AzureProviderFactory) Validate(headers http.Header) (*AzureHeaders, error) {
@@ -24,9 +23,8 @@ func (azf AzureProviderFactory) Validate(headers http.Header) (*AzureHeaders, er
2423
return nil, errors.New(err)
2524
}
2625

27-
validate := validator.New()
28-
if err := validate.Struct(azHeaders); err != nil {
29-
return nil, errors.NewBadRequest("provider's required headers are missing")
26+
if err := utils.ValidateHeaders(azHeaders); err != nil {
27+
return nil, err
3028
}
3129

3230
return &azHeaders, nil

mobius/internal/providers/openai/base.go

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"net/http"
55
"strings"
66

7-
"github.com/go-playground/validator/v10"
87
"github.com/missingstudio/studio/backend/internal/providers/base"
98
"github.com/missingstudio/studio/backend/pkg/utils"
109
"github.com/missingstudio/studio/common/errors"
@@ -13,18 +12,17 @@ import (
1312
type OpenAIProviderFactory struct{}
1413

1514
type OpenAIHeaders struct {
16-
APIKey string `validate:"required" json:"Authorization"`
15+
APIKey string `validate:"required" json:"Authorization" error:"API key is required"`
1716
}
1817

19-
func (oaif OpenAIProviderFactory) Validate(headers http.Header) (*OpenAIHeaders, error) {
18+
func (azf OpenAIProviderFactory) Validate(headers http.Header) (*OpenAIHeaders, error) {
2019
var oaiHeaders OpenAIHeaders
2120
if err := utils.UnmarshalHeader(headers, &oaiHeaders); err != nil {
2221
return nil, errors.New(err)
2322
}
2423

25-
validate := validator.New()
26-
if err := validate.Struct(oaiHeaders); err != nil {
27-
return nil, errors.NewBadRequest("provider's required headers are missing")
24+
if err := utils.ValidateHeaders(oaiHeaders); err != nil {
25+
return nil, err
2826
}
2927

3028
return &oaiHeaders, nil

mobius/pkg/utils/headers.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
package utils
22

33
import (
4+
"fmt"
45
"net/http"
56
"reflect"
7+
"strings"
8+
9+
"github.com/go-playground/validator/v10"
10+
"github.com/missingstudio/studio/common/errors"
611
)
712

813
// UnmarshalHeader unmarshals an http.Header into a struct
@@ -22,3 +27,28 @@ func UnmarshalHeader(header http.Header, v interface{}) error {
2227

2328
return nil
2429
}
30+
31+
// ValidateHeaders is a generic function to validate any structure with the `validate` struct tag.
32+
func ValidateHeaders(data interface{}) error {
33+
validate := validator.New()
34+
if err := validate.Struct(data); err != nil {
35+
errorMessages := []string{}
36+
37+
// Collect all validation errors
38+
for _, e := range err.(validator.ValidationErrors) {
39+
fieldName := e.Field()
40+
field, _ := reflect.TypeOf(data).FieldByName(fieldName)
41+
customMessage := field.Tag.Get("error")
42+
43+
if customMessage == "" {
44+
errorMessages = append(errorMessages, fmt.Sprintf("Validation error on field %s", fieldName))
45+
} else {
46+
errorMessages = append(errorMessages, customMessage)
47+
}
48+
}
49+
50+
return errors.New(fmt.Errorf("Validation failed: %v", strings.Join(errorMessages, ", ")))
51+
}
52+
53+
return nil
54+
}

0 commit comments

Comments
 (0)