1- import { createTestServer } from '@ai-sdk/test-server/with-vitest' ;
21import { DownloadError } from '@ai-sdk/provider-utils' ;
32import { download } from './download' ;
4- import { describe , it , expect , vi } from 'vitest' ;
5-
6- const server = createTestServer ( {
7- 'http://example.com/file' : { } ,
8- 'http://example.com/large' : { } ,
9- } ) ;
3+ import { describe , it , expect , vi , afterEach , beforeEach } from 'vitest' ;
104
115describe ( 'download SSRF protection' , ( ) => {
126 it ( 'should reject private IPv4 addresses' , async ( ) => {
@@ -28,17 +22,104 @@ describe('download SSRF protection', () => {
2822 } ) ;
2923} ) ;
3024
25+ describe ( 'download SSRF redirect protection' , ( ) => {
26+ const originalFetch = globalThis . fetch ;
27+
28+ afterEach ( ( ) => {
29+ globalThis . fetch = originalFetch ;
30+ } ) ;
31+
32+ it ( 'should reject redirects to private IP addresses' , async ( ) => {
33+ globalThis . fetch = vi . fn ( ) . mockResolvedValue ( {
34+ ok : true ,
35+ status : 200 ,
36+ redirected : true ,
37+ url : 'http://169.254.169.254/latest/meta-data/' ,
38+ headers : new Headers ( { 'content-type' : 'text/plain' } ) ,
39+ body : new ReadableStream ( {
40+ start ( controller ) {
41+ controller . enqueue ( new TextEncoder ( ) . encode ( 'secret' ) ) ;
42+ controller . close ( ) ;
43+ } ,
44+ } ) ,
45+ } as unknown as Response ) ;
46+
47+ await expect (
48+ download ( { url : new URL ( 'https://evil.com/redirect' ) } ) ,
49+ ) . rejects . toThrow ( DownloadError ) ;
50+ } ) ;
51+
52+ it ( 'should reject redirects to localhost' , async ( ) => {
53+ globalThis . fetch = vi . fn ( ) . mockResolvedValue ( {
54+ ok : true ,
55+ status : 200 ,
56+ redirected : true ,
57+ url : 'http://localhost:8080/admin' ,
58+ headers : new Headers ( { 'content-type' : 'text/plain' } ) ,
59+ body : new ReadableStream ( {
60+ start ( controller ) {
61+ controller . enqueue ( new TextEncoder ( ) . encode ( 'secret' ) ) ;
62+ controller . close ( ) ;
63+ } ,
64+ } ) ,
65+ } as unknown as Response ) ;
66+
67+ await expect (
68+ download ( { url : new URL ( 'https://evil.com/redirect' ) } ) ,
69+ ) . rejects . toThrow ( DownloadError ) ;
70+ } ) ;
71+
72+ it ( 'should allow redirects to safe URLs' , async ( ) => {
73+ const content = new Uint8Array ( [ 1 , 2 , 3 ] ) ;
74+ globalThis . fetch = vi . fn ( ) . mockResolvedValue ( {
75+ ok : true ,
76+ status : 200 ,
77+ redirected : true ,
78+ url : 'https://cdn.example.com/image.png' ,
79+ headers : new Headers ( { 'content-type' : 'image/png' } ) ,
80+ body : new ReadableStream ( {
81+ start ( controller ) {
82+ controller . enqueue ( content ) ;
83+ controller . close ( ) ;
84+ } ,
85+ } ) ,
86+ } as unknown as Response ) ;
87+
88+ const result = await download ( {
89+ url : new URL ( 'https://example.com/image.png' ) ,
90+ } ) ;
91+ expect ( result . data ) . toEqual ( content ) ;
92+ expect ( result . mediaType ) . toBe ( 'image/png' ) ;
93+ } ) ;
94+ } ) ;
95+
3196describe ( 'download' , ( ) => {
97+ const originalFetch = globalThis . fetch ;
98+
99+ beforeEach ( ( ) => {
100+ vi . resetAllMocks ( ) ;
101+ } ) ;
102+
103+ afterEach ( ( ) => {
104+ globalThis . fetch = originalFetch ;
105+ } ) ;
106+
32107 it ( 'should download data successfully and match expected bytes' , async ( ) => {
33108 const expectedBytes = new Uint8Array ( [ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 ] ) ;
34109
35- server . urls [ 'http://example.com/file' ] . response = {
36- type : 'binary' ,
37- headers : {
110+ globalThis . fetch = vi . fn ( ) . mockResolvedValue ( {
111+ ok : true ,
112+ status : 200 ,
113+ headers : new Headers ( {
38114 'content-type' : 'application/octet-stream' ,
39- } ,
40- body : Buffer . from ( expectedBytes ) ,
41- } ;
115+ } ) ,
116+ body : new ReadableStream ( {
117+ start ( controller ) {
118+ controller . enqueue ( expectedBytes ) ;
119+ controller . close ( ) ;
120+ } ,
121+ } ) ,
122+ } as unknown as Response ) ;
42123
43124 const result = await download ( {
44125 url : new URL ( 'http://example.com/file' ) ,
@@ -48,16 +129,21 @@ describe('download', () => {
48129 expect ( result ! . data ) . toEqual ( expectedBytes ) ;
49130 expect ( result ! . mediaType ) . toBe ( 'application/octet-stream' ) ;
50131
51- // UA header assertion
52- expect ( server . calls [ 0 ] . requestUserAgent ) . toContain ( 'ai-sdk/' ) ;
132+ expect ( fetch ) . toHaveBeenCalledWith (
133+ 'http://example.com/file' ,
134+ expect . objectContaining ( {
135+ headers : expect . any ( Object ) ,
136+ } ) ,
137+ ) ;
53138 } ) ;
54139
55140 it ( 'should throw DownloadError when response is not ok' , async ( ) => {
56- server . urls [ 'http://example.com/file' ] . response = {
57- type : 'error' ,
141+ globalThis . fetch = vi . fn ( ) . mockResolvedValue ( {
142+ ok : false ,
58143 status : 404 ,
59- body : 'Not Found' ,
60- } ;
144+ statusText : 'Not Found' ,
145+ headers : new Headers ( ) ,
146+ } as unknown as Response ) ;
61147
62148 try {
63149 await download ( {
@@ -72,11 +158,7 @@ describe('download', () => {
72158 } ) ;
73159
74160 it ( 'should throw DownloadError when fetch throws an error' , async ( ) => {
75- server . urls [ 'http://example.com/file' ] . response = {
76- type : 'error' ,
77- status : 500 ,
78- body : 'Network error' ,
79- } ;
161+ globalThis . fetch = vi . fn ( ) . mockRejectedValue ( new Error ( 'Network error' ) ) ;
80162
81163 try {
82164 await download ( {
@@ -89,15 +171,20 @@ describe('download', () => {
89171 } ) ;
90172
91173 it ( 'should abort when response exceeds default size limit' , async ( ) => {
92- // Create a response that claims to be larger than 2 GiB
93- server . urls [ 'http://example.com/large' ] . response = {
94- type : 'binary' ,
95- headers : {
174+ globalThis . fetch = vi . fn ( ) . mockResolvedValue ( {
175+ ok : true ,
176+ status : 200 ,
177+ headers : new Headers ( {
96178 'content-type' : 'application/octet-stream' ,
97179 'content-length' : `${ 3 * 1024 * 1024 * 1024 } ` ,
98- } ,
99- body : Buffer . from ( new Uint8Array ( 10 ) ) ,
100- } ;
180+ } ) ,
181+ body : new ReadableStream ( {
182+ start ( controller ) {
183+ controller . enqueue ( new Uint8Array ( 10 ) ) ;
184+ controller . close ( ) ;
185+ } ,
186+ } ) ,
187+ } as unknown as Response ) ;
101188
102189 try {
103190 await download ( {
@@ -116,13 +203,11 @@ describe('download', () => {
116203 const controller = new AbortController ( ) ;
117204 controller . abort ( ) ;
118205
119- server . urls [ 'http://example.com/file' ] . response = {
120- type : 'binary' ,
121- headers : {
122- 'content-type' : 'application/octet-stream' ,
123- } ,
124- body : Buffer . from ( new Uint8Array ( [ 1 , 2 , 3 ] ) ) ,
125- } ;
206+ globalThis . fetch = vi
207+ . fn ( )
208+ . mockRejectedValue (
209+ new DOMException ( 'The operation was aborted.' , 'AbortError' ) ,
210+ ) ;
126211
127212 try {
128213 await download ( {
@@ -131,8 +216,14 @@ describe('download', () => {
131216 } ) ;
132217 expect . fail ( 'Expected download to throw' ) ;
133218 } catch ( error : unknown ) {
134- // The fetch should be aborted, resulting in a DownloadError wrapping an AbortError
135219 expect ( DownloadError . isInstance ( error ) ) . toBe ( true ) ;
136220 }
221+
222+ expect ( fetch ) . toHaveBeenCalledWith (
223+ 'http://example.com/file' ,
224+ expect . objectContaining ( {
225+ signal : controller . signal ,
226+ } ) ,
227+ ) ;
137228 } ) ;
138229} ) ;
0 commit comments