@@ -1454,6 +1454,8 @@ func TestCalculateBackoff(t *testing.T) {
14541454}
14551455
14561456func TestIsRetryableStatusCode (t * testing.T ) {
1457+ defaultConfig := getDefaultRetryConfig ()
1458+
14571459 tests := []struct {
14581460 statusCode int
14591461 expected bool
@@ -1473,7 +1475,7 @@ func TestIsRetryableStatusCode(t *testing.T) {
14731475
14741476 for _ , tt := range tests {
14751477 t .Run (fmt .Sprintf ("status_%d" , tt .statusCode ), func (t * testing.T ) {
1476- result := isRetryableStatusCode (tt .statusCode )
1478+ result := isRetryableStatusCode (tt .statusCode , defaultConfig )
14771479 assert .Equal (t , tt .expected , result )
14781480 })
14791481 }
@@ -1510,3 +1512,262 @@ func TestRetryWithContext(t *testing.T) {
15101512 assert .GreaterOrEqual (t , callCount , 1 )
15111513 assert .LessOrEqual (t , callCount , 5 )
15121514}
1515+
1516+ func TestCustomRetryableStatusCodes (t * testing.T ) {
1517+ tests := []struct {
1518+ name string
1519+ customStatusCodes []int
1520+ statusCode int
1521+ expected bool
1522+ }{
1523+ {
1524+ name : "custom codes include 418" ,
1525+ customStatusCodes : []int {418 , 422 },
1526+ statusCode : 418 ,
1527+ expected : true ,
1528+ },
1529+ {
1530+ name : "custom codes exclude 500" ,
1531+ customStatusCodes : []int {418 , 422 },
1532+ statusCode : 500 ,
1533+ expected : false ,
1534+ },
1535+ {
1536+ name : "custom codes include 422" ,
1537+ customStatusCodes : []int {418 , 422 },
1538+ statusCode : 422 ,
1539+ expected : true ,
1540+ },
1541+ {
1542+ name : "custom codes exclude 200" ,
1543+ customStatusCodes : []int {418 , 422 },
1544+ statusCode : 200 ,
1545+ expected : false ,
1546+ },
1547+ {
1548+ name : "empty custom codes use defaults" ,
1549+ customStatusCodes : []int {},
1550+ statusCode : 500 ,
1551+ expected : true ,
1552+ },
1553+ {
1554+ name : "nil custom codes use defaults" ,
1555+ customStatusCodes : nil ,
1556+ statusCode : 500 ,
1557+ expected : true ,
1558+ },
1559+ }
1560+
1561+ for _ , tt := range tests {
1562+ t .Run (tt .name , func (t * testing.T ) {
1563+ config := & RetryConfig {
1564+ RetryableStatusCodes : tt .customStatusCodes ,
1565+ }
1566+ result := isRetryableStatusCode (tt .statusCode , config )
1567+ assert .Equal (t , tt .expected , result )
1568+ })
1569+ }
1570+ }
1571+
1572+ func TestRetryCallback (t * testing.T ) {
1573+ var callbackCalls []struct {
1574+ attempt int
1575+ err error
1576+ delay time.Duration
1577+ }
1578+
1579+ retryConfig := & RetryConfig {
1580+ Enabled : true ,
1581+ MaxAttempts : 3 ,
1582+ InitialBackoffSec : 1 ,
1583+ MaxBackoffSec : 10 ,
1584+ BackoffMultiplier : 2 ,
1585+ OnRetry : func (attempt int , err error , delay time.Duration ) {
1586+ callbackCalls = append (callbackCalls , struct {
1587+ attempt int
1588+ err error
1589+ delay time.Duration
1590+ }{attempt , err , delay })
1591+ },
1592+ }
1593+
1594+ callCount := 0
1595+ server := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
1596+ if callCount < 2 {
1597+ w .WriteHeader (http .StatusInternalServerError )
1598+ err := json .NewEncoder (w ).Encode (Error {Error : stringPtr ("Server error" )})
1599+ assert .NoError (t , err )
1600+ } else {
1601+ w .WriteHeader (http .StatusOK )
1602+ response := ListModelsResponse {Object : "list" , Data : []Model {}}
1603+ err := json .NewEncoder (w ).Encode (response )
1604+ assert .NoError (t , err )
1605+ }
1606+ callCount ++
1607+ }))
1608+ defer server .Close ()
1609+
1610+ baseURL := server .URL + "/v1"
1611+ client := NewClient (& ClientOptions {
1612+ BaseURL : baseURL ,
1613+ RetryConfig : retryConfig ,
1614+ })
1615+
1616+ ctx := context .Background ()
1617+ _ , err := client .ListModels (ctx )
1618+
1619+ assert .NoError (t , err )
1620+ assert .Equal (t , 3 , callCount )
1621+ assert .Len (t , callbackCalls , 2 ) // Two retries after initial failure
1622+
1623+ // Verify callback calls
1624+ assert .Equal (t , 1 , callbackCalls [0 ].attempt )
1625+ assert .Contains (t , callbackCalls [0 ].err .Error (), "HTTP 500" )
1626+ assert .Equal (t , 1 * time .Second , callbackCalls [0 ].delay )
1627+
1628+ assert .Equal (t , 2 , callbackCalls [1 ].attempt )
1629+ assert .Contains (t , callbackCalls [1 ].err .Error (), "HTTP 500" )
1630+ assert .Equal (t , 2 * time .Second , callbackCalls [1 ].delay )
1631+ }
1632+
1633+ func TestRetryWithCustomStatusCodesAndCallback (t * testing.T ) {
1634+ var callbackCalls []struct {
1635+ attempt int
1636+ err error
1637+ delay time.Duration
1638+ }
1639+
1640+ retryConfig := & RetryConfig {
1641+ Enabled : true ,
1642+ MaxAttempts : 3 ,
1643+ InitialBackoffSec : 1 ,
1644+ MaxBackoffSec : 10 ,
1645+ BackoffMultiplier : 2 ,
1646+ RetryableStatusCodes : []int {418 , 503 }, // Custom codes: I'm a teapot and Service Unavailable
1647+ OnRetry : func (attempt int , err error , delay time.Duration ) {
1648+ callbackCalls = append (callbackCalls , struct {
1649+ attempt int
1650+ err error
1651+ delay time.Duration
1652+ }{attempt , err , delay })
1653+ },
1654+ }
1655+
1656+ tests := []struct {
1657+ name string
1658+ statusCodes []int
1659+ expectRetries bool
1660+ expectedCalls int
1661+ callbackCounts int
1662+ }{
1663+ {
1664+ name : "retry on custom 418 status code" ,
1665+ statusCodes : []int {418 , 418 , 200 },
1666+ expectRetries : true ,
1667+ expectedCalls : 3 ,
1668+ callbackCounts : 2 ,
1669+ },
1670+ {
1671+ name : "retry on custom 503 status code" ,
1672+ statusCodes : []int {503 , 200 },
1673+ expectRetries : true ,
1674+ expectedCalls : 2 ,
1675+ callbackCounts : 1 ,
1676+ },
1677+ {
1678+ name : "no retry on non-custom 500 status code" ,
1679+ statusCodes : []int {500 },
1680+ expectRetries : false ,
1681+ expectedCalls : 1 ,
1682+ callbackCounts : 0 ,
1683+ },
1684+ {
1685+ name : "no retry on 400 status code" ,
1686+ statusCodes : []int {400 },
1687+ expectRetries : false ,
1688+ expectedCalls : 1 ,
1689+ callbackCounts : 0 ,
1690+ },
1691+ }
1692+
1693+ for _ , tt := range tests {
1694+ t .Run (tt .name , func (t * testing.T ) {
1695+ callbackCalls = nil // Reset callback calls
1696+ callCount := 0
1697+
1698+ server := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
1699+ if callCount < len (tt .statusCodes ) {
1700+ w .WriteHeader (tt .statusCodes [callCount ])
1701+ if tt .statusCodes [callCount ] == 200 {
1702+ response := ListModelsResponse {Object : "list" , Data : []Model {}}
1703+ err := json .NewEncoder (w ).Encode (response )
1704+ assert .NoError (t , err )
1705+ } else {
1706+ err := json .NewEncoder (w ).Encode (Error {Error : stringPtr ("Server error" )})
1707+ assert .NoError (t , err )
1708+ }
1709+ }
1710+ callCount ++
1711+ }))
1712+ defer server .Close ()
1713+
1714+ baseURL := server .URL + "/v1"
1715+ client := NewClient (& ClientOptions {
1716+ BaseURL : baseURL ,
1717+ RetryConfig : retryConfig ,
1718+ })
1719+
1720+ ctx := context .Background ()
1721+ _ , err := client .ListModels (ctx )
1722+
1723+ assert .Equal (t , tt .expectedCalls , callCount )
1724+ assert .Len (t , callbackCalls , tt .callbackCounts )
1725+
1726+ if tt .expectRetries && tt .statusCodes [len (tt .statusCodes )- 1 ] == 200 {
1727+ assert .NoError (t , err )
1728+ } else if ! tt .expectRetries || tt .statusCodes [len (tt .statusCodes )- 1 ] != 200 {
1729+ assert .Error (t , err )
1730+ }
1731+ })
1732+ }
1733+ }
1734+
1735+ func TestRetryConfigWithNilCallback (t * testing.T ) {
1736+ retryConfig := & RetryConfig {
1737+ Enabled : true ,
1738+ MaxAttempts : 2 ,
1739+ InitialBackoffSec : 1 ,
1740+ MaxBackoffSec : 10 ,
1741+ BackoffMultiplier : 2 ,
1742+ OnRetry : nil , // Explicitly nil callback
1743+ }
1744+
1745+ callCount := 0
1746+ server := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
1747+ if callCount == 0 {
1748+ w .WriteHeader (http .StatusInternalServerError )
1749+ err := json .NewEncoder (w ).Encode (Error {Error : stringPtr ("Server error" )})
1750+ assert .NoError (t , err )
1751+ } else {
1752+ w .WriteHeader (http .StatusOK )
1753+ response := ListModelsResponse {Object : "list" , Data : []Model {}}
1754+ err := json .NewEncoder (w ).Encode (response )
1755+ assert .NoError (t , err )
1756+ }
1757+ callCount ++
1758+ }))
1759+ defer server .Close ()
1760+
1761+ baseURL := server .URL + "/v1"
1762+ client := NewClient (& ClientOptions {
1763+ BaseURL : baseURL ,
1764+ RetryConfig : retryConfig ,
1765+ })
1766+
1767+ ctx := context .Background ()
1768+ _ , err := client .ListModels (ctx )
1769+
1770+ // Should work without callback
1771+ assert .NoError (t , err )
1772+ assert .Equal (t , 2 , callCount )
1773+ }
0 commit comments