Skip to content

Commit 1d3cefa

Browse files
feat(retry): Add configurable status codes and callback mechanism
- Add RetryableStatusCodes field to RetryConfig for custom status codes - Add OnRetry callback for retry event logging and monitoring - Maintain backward compatibility with default behavior - Update isRetryableStatusCode to use configurable codes - Add comprehensive test coverage for new functionality Co-authored-by: Eden Reich <edenreich@users.noreply.github.com>
1 parent b5a32db commit 1d3cefa

3 files changed

Lines changed: 287 additions & 4 deletions

File tree

sdk.go

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,18 @@ func isRetryableError(err error) bool {
7171
}
7272

7373
// isRetryableStatusCode determines if an HTTP status code should trigger a retry
74-
func isRetryableStatusCode(statusCode int) bool {
74+
func isRetryableStatusCode(statusCode int, config *RetryConfig) bool {
75+
// Use custom status codes if provided
76+
if len(config.RetryableStatusCodes) > 0 {
77+
for _, code := range config.RetryableStatusCodes {
78+
if statusCode == code {
79+
return true
80+
}
81+
}
82+
return false
83+
}
84+
85+
// Use default status codes
7586
switch statusCode {
7687
case
7788
http.StatusRequestTimeout, // 408
@@ -179,6 +190,11 @@ func (c *clientImpl) executeWithRetry(ctx context.Context, request func() (*rest
179190
if attempt > 0 {
180191
delay := calculateBackoff(attempt, c.retryConfig)
181192

193+
// Call OnRetry callback if provided
194+
if c.retryConfig.OnRetry != nil {
195+
c.retryConfig.OnRetry(attempt, lastErr, delay)
196+
}
197+
182198
select {
183199
case <-ctx.Done():
184200
return nil, ctx.Err()
@@ -189,13 +205,13 @@ func (c *clientImpl) executeWithRetry(ctx context.Context, request func() (*rest
189205
resp, lastErr = request()
190206

191207
if lastErr == nil {
192-
if !resp.IsError() || !isRetryableStatusCode(resp.StatusCode()) {
208+
if !resp.IsError() || !isRetryableStatusCode(resp.StatusCode(), c.retryConfig) {
193209
return resp, nil
194210
}
195211
lastErr = fmt.Errorf("HTTP %d", resp.StatusCode())
196212
}
197213

198-
if !isRetryableError(lastErr) && (resp == nil || !isRetryableStatusCode(resp.StatusCode())) {
214+
if !isRetryableError(lastErr) && (resp == nil || !isRetryableStatusCode(resp.StatusCode(), c.retryConfig)) {
199215
break
200216
}
201217

sdk_test.go

Lines changed: 262 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1454,6 +1454,8 @@ func TestCalculateBackoff(t *testing.T) {
14541454
}
14551455

14561456
func TestIsRetryableStatusCode(t *testing.T) {
1457+
defaultConfig := getDefaultRetryConfig()
1458+
14571459
tests := []struct {
14581460
statusCode int
14591461
expected bool
@@ -1473,7 +1475,7 @@ func TestIsRetryableStatusCode(t *testing.T) {
14731475

14741476
for _, tt := range tests {
14751477
t.Run(fmt.Sprintf("status_%d", tt.statusCode), func(t *testing.T) {
1476-
result := isRetryableStatusCode(tt.statusCode)
1478+
result := isRetryableStatusCode(tt.statusCode, defaultConfig)
14771479
assert.Equal(t, tt.expected, result)
14781480
})
14791481
}
@@ -1510,3 +1512,262 @@ func TestRetryWithContext(t *testing.T) {
15101512
assert.GreaterOrEqual(t, callCount, 1)
15111513
assert.LessOrEqual(t, callCount, 5)
15121514
}
1515+
1516+
func TestCustomRetryableStatusCodes(t *testing.T) {
1517+
tests := []struct {
1518+
name string
1519+
customStatusCodes []int
1520+
statusCode int
1521+
expected bool
1522+
}{
1523+
{
1524+
name: "custom codes include 418",
1525+
customStatusCodes: []int{418, 422},
1526+
statusCode: 418,
1527+
expected: true,
1528+
},
1529+
{
1530+
name: "custom codes exclude 500",
1531+
customStatusCodes: []int{418, 422},
1532+
statusCode: 500,
1533+
expected: false,
1534+
},
1535+
{
1536+
name: "custom codes include 422",
1537+
customStatusCodes: []int{418, 422},
1538+
statusCode: 422,
1539+
expected: true,
1540+
},
1541+
{
1542+
name: "custom codes exclude 200",
1543+
customStatusCodes: []int{418, 422},
1544+
statusCode: 200,
1545+
expected: false,
1546+
},
1547+
{
1548+
name: "empty custom codes use defaults",
1549+
customStatusCodes: []int{},
1550+
statusCode: 500,
1551+
expected: true,
1552+
},
1553+
{
1554+
name: "nil custom codes use defaults",
1555+
customStatusCodes: nil,
1556+
statusCode: 500,
1557+
expected: true,
1558+
},
1559+
}
1560+
1561+
for _, tt := range tests {
1562+
t.Run(tt.name, func(t *testing.T) {
1563+
config := &RetryConfig{
1564+
RetryableStatusCodes: tt.customStatusCodes,
1565+
}
1566+
result := isRetryableStatusCode(tt.statusCode, config)
1567+
assert.Equal(t, tt.expected, result)
1568+
})
1569+
}
1570+
}
1571+
1572+
func TestRetryCallback(t *testing.T) {
1573+
var callbackCalls []struct {
1574+
attempt int
1575+
err error
1576+
delay time.Duration
1577+
}
1578+
1579+
retryConfig := &RetryConfig{
1580+
Enabled: true,
1581+
MaxAttempts: 3,
1582+
InitialBackoffSec: 1,
1583+
MaxBackoffSec: 10,
1584+
BackoffMultiplier: 2,
1585+
OnRetry: func(attempt int, err error, delay time.Duration) {
1586+
callbackCalls = append(callbackCalls, struct {
1587+
attempt int
1588+
err error
1589+
delay time.Duration
1590+
}{attempt, err, delay})
1591+
},
1592+
}
1593+
1594+
callCount := 0
1595+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1596+
if callCount < 2 {
1597+
w.WriteHeader(http.StatusInternalServerError)
1598+
err := json.NewEncoder(w).Encode(Error{Error: stringPtr("Server error")})
1599+
assert.NoError(t, err)
1600+
} else {
1601+
w.WriteHeader(http.StatusOK)
1602+
response := ListModelsResponse{Object: "list", Data: []Model{}}
1603+
err := json.NewEncoder(w).Encode(response)
1604+
assert.NoError(t, err)
1605+
}
1606+
callCount++
1607+
}))
1608+
defer server.Close()
1609+
1610+
baseURL := server.URL + "/v1"
1611+
client := NewClient(&ClientOptions{
1612+
BaseURL: baseURL,
1613+
RetryConfig: retryConfig,
1614+
})
1615+
1616+
ctx := context.Background()
1617+
_, err := client.ListModels(ctx)
1618+
1619+
assert.NoError(t, err)
1620+
assert.Equal(t, 3, callCount)
1621+
assert.Len(t, callbackCalls, 2) // Two retries after initial failure
1622+
1623+
// Verify callback calls
1624+
assert.Equal(t, 1, callbackCalls[0].attempt)
1625+
assert.Contains(t, callbackCalls[0].err.Error(), "HTTP 500")
1626+
assert.Equal(t, 1*time.Second, callbackCalls[0].delay)
1627+
1628+
assert.Equal(t, 2, callbackCalls[1].attempt)
1629+
assert.Contains(t, callbackCalls[1].err.Error(), "HTTP 500")
1630+
assert.Equal(t, 2*time.Second, callbackCalls[1].delay)
1631+
}
1632+
1633+
func TestRetryWithCustomStatusCodesAndCallback(t *testing.T) {
1634+
var callbackCalls []struct {
1635+
attempt int
1636+
err error
1637+
delay time.Duration
1638+
}
1639+
1640+
retryConfig := &RetryConfig{
1641+
Enabled: true,
1642+
MaxAttempts: 3,
1643+
InitialBackoffSec: 1,
1644+
MaxBackoffSec: 10,
1645+
BackoffMultiplier: 2,
1646+
RetryableStatusCodes: []int{418, 503}, // Custom codes: I'm a teapot and Service Unavailable
1647+
OnRetry: func(attempt int, err error, delay time.Duration) {
1648+
callbackCalls = append(callbackCalls, struct {
1649+
attempt int
1650+
err error
1651+
delay time.Duration
1652+
}{attempt, err, delay})
1653+
},
1654+
}
1655+
1656+
tests := []struct {
1657+
name string
1658+
statusCodes []int
1659+
expectRetries bool
1660+
expectedCalls int
1661+
callbackCounts int
1662+
}{
1663+
{
1664+
name: "retry on custom 418 status code",
1665+
statusCodes: []int{418, 418, 200},
1666+
expectRetries: true,
1667+
expectedCalls: 3,
1668+
callbackCounts: 2,
1669+
},
1670+
{
1671+
name: "retry on custom 503 status code",
1672+
statusCodes: []int{503, 200},
1673+
expectRetries: true,
1674+
expectedCalls: 2,
1675+
callbackCounts: 1,
1676+
},
1677+
{
1678+
name: "no retry on non-custom 500 status code",
1679+
statusCodes: []int{500},
1680+
expectRetries: false,
1681+
expectedCalls: 1,
1682+
callbackCounts: 0,
1683+
},
1684+
{
1685+
name: "no retry on 400 status code",
1686+
statusCodes: []int{400},
1687+
expectRetries: false,
1688+
expectedCalls: 1,
1689+
callbackCounts: 0,
1690+
},
1691+
}
1692+
1693+
for _, tt := range tests {
1694+
t.Run(tt.name, func(t *testing.T) {
1695+
callbackCalls = nil // Reset callback calls
1696+
callCount := 0
1697+
1698+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1699+
if callCount < len(tt.statusCodes) {
1700+
w.WriteHeader(tt.statusCodes[callCount])
1701+
if tt.statusCodes[callCount] == 200 {
1702+
response := ListModelsResponse{Object: "list", Data: []Model{}}
1703+
err := json.NewEncoder(w).Encode(response)
1704+
assert.NoError(t, err)
1705+
} else {
1706+
err := json.NewEncoder(w).Encode(Error{Error: stringPtr("Server error")})
1707+
assert.NoError(t, err)
1708+
}
1709+
}
1710+
callCount++
1711+
}))
1712+
defer server.Close()
1713+
1714+
baseURL := server.URL + "/v1"
1715+
client := NewClient(&ClientOptions{
1716+
BaseURL: baseURL,
1717+
RetryConfig: retryConfig,
1718+
})
1719+
1720+
ctx := context.Background()
1721+
_, err := client.ListModels(ctx)
1722+
1723+
assert.Equal(t, tt.expectedCalls, callCount)
1724+
assert.Len(t, callbackCalls, tt.callbackCounts)
1725+
1726+
if tt.expectRetries && tt.statusCodes[len(tt.statusCodes)-1] == 200 {
1727+
assert.NoError(t, err)
1728+
} else if !tt.expectRetries || tt.statusCodes[len(tt.statusCodes)-1] != 200 {
1729+
assert.Error(t, err)
1730+
}
1731+
})
1732+
}
1733+
}
1734+
1735+
func TestRetryConfigWithNilCallback(t *testing.T) {
1736+
retryConfig := &RetryConfig{
1737+
Enabled: true,
1738+
MaxAttempts: 2,
1739+
InitialBackoffSec: 1,
1740+
MaxBackoffSec: 10,
1741+
BackoffMultiplier: 2,
1742+
OnRetry: nil, // Explicitly nil callback
1743+
}
1744+
1745+
callCount := 0
1746+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1747+
if callCount == 0 {
1748+
w.WriteHeader(http.StatusInternalServerError)
1749+
err := json.NewEncoder(w).Encode(Error{Error: stringPtr("Server error")})
1750+
assert.NoError(t, err)
1751+
} else {
1752+
w.WriteHeader(http.StatusOK)
1753+
response := ListModelsResponse{Object: "list", Data: []Model{}}
1754+
err := json.NewEncoder(w).Encode(response)
1755+
assert.NoError(t, err)
1756+
}
1757+
callCount++
1758+
}))
1759+
defer server.Close()
1760+
1761+
baseURL := server.URL + "/v1"
1762+
client := NewClient(&ClientOptions{
1763+
BaseURL: baseURL,
1764+
RetryConfig: retryConfig,
1765+
})
1766+
1767+
ctx := context.Background()
1768+
_, err := client.ListModels(ctx)
1769+
1770+
// Should work without callback
1771+
assert.NoError(t, err)
1772+
assert.Equal(t, 2, callCount)
1773+
}

types.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,12 @@ type RetryConfig struct {
7676
MaxBackoffSec int
7777
// BackoffMultiplier is the multiplier for exponential backoff
7878
BackoffMultiplier int
79+
// RetryableStatusCodes is a custom list of HTTP status codes that should trigger a retry.
80+
// If nil or empty, uses default status codes (408, 429, 500, 502, 503, 504)
81+
RetryableStatusCodes []int
82+
// OnRetry is called before each retry attempt with attempt number, error, and delay.
83+
// The attempt number starts from 1 for the first retry (after initial request fails)
84+
OnRetry func(attempt int, err error, delay time.Duration)
7985
}
8086

8187
// MiddlewareOptions represents options for controlling middleware behavior

0 commit comments

Comments
 (0)