1717import org .elasticsearch .common .io .Streams ;
1818import org .elasticsearch .common .unit .ByteSizeUnit ;
1919import org .elasticsearch .common .unit .ByteSizeValue ;
20+ import org .elasticsearch .core .Nullable ;
2021import org .elasticsearch .core .SuppressForbidden ;
2122import org .elasticsearch .rest .RestStatus ;
2223import org .elasticsearch .xcontent .XContentParser ;
3435import java .security .AccessController ;
3536import java .security .MessageDigest ;
3637import java .security .PrivilegedAction ;
38+ import java .util .ArrayList ;
3739import java .util .HashMap ;
3840import java .util .List ;
3941import java .util .Locale ;
4042import java .util .Map ;
43+ import java .util .concurrent .atomic .AtomicInteger ;
44+ import java .util .concurrent .atomic .AtomicLong ;
4145import java .util .stream .Collectors ;
4246
4347import static java .net .HttpURLConnection .HTTP_MOVED_PERM ;
4448import static java .net .HttpURLConnection .HTTP_MOVED_TEMP ;
4549import static java .net .HttpURLConnection .HTTP_NOT_FOUND ;
4650import static java .net .HttpURLConnection .HTTP_OK ;
51+ import static java .net .HttpURLConnection .HTTP_PARTIAL ;
4752import static java .net .HttpURLConnection .HTTP_SEE_OTHER ;
4853
4954/**
@@ -61,6 +66,73 @@ final class ModelLoaderUtils {
6166
6267 record VocabularyParts (List <String > vocab , List <String > merges , List <Double > scores ) {}
6368
69+ // Range in bytes
70+ record RequestRange (long rangeStart , long rangeEnd , int startPart , int numParts ) {
71+ public String bytesRange () {
72+ return "bytes=" + rangeStart + "-" + rangeEnd ;
73+ }
74+ }
75+
76+ static class HttpStreamChunker {
77+
78+ record BytesAndPartIndex (BytesArray bytes , int partIndex ) {}
79+
80+ private final InputStream inputStream ;
81+ private final int chunkSize ;
82+ private final AtomicLong totalBytesRead = new AtomicLong ();
83+ private final AtomicInteger currentPart ;
84+ private final int lastPartNumber ;
85+
86+ HttpStreamChunker (URI uri , RequestRange range , int chunkSize ) {
87+ var inputStream = getHttpOrHttpsInputStream (uri , range );
88+ this .inputStream = inputStream ;
89+ this .chunkSize = chunkSize ;
90+ this .lastPartNumber = range .startPart () + range .numParts ();
91+ this .currentPart = new AtomicInteger (range .startPart ());
92+ }
93+
94+ // This ctor exists for testing purposes only.
95+ HttpStreamChunker (InputStream inputStream , RequestRange range , int chunkSize ) {
96+ this .inputStream = inputStream ;
97+ this .chunkSize = chunkSize ;
98+ this .lastPartNumber = range .startPart () + range .numParts ();
99+ this .currentPart = new AtomicInteger (range .startPart ());
100+ }
101+
102+ public boolean hasNext () {
103+ return currentPart .get () < lastPartNumber ;
104+ }
105+
106+ public BytesAndPartIndex next () throws IOException {
107+ int bytesRead = 0 ;
108+ byte [] buf = new byte [chunkSize ];
109+
110+ while (bytesRead < chunkSize ) {
111+ int read = inputStream .read (buf , bytesRead , chunkSize - bytesRead );
112+ // EOF??
113+ if (read == -1 ) {
114+ break ;
115+ }
116+ bytesRead += read ;
117+ }
118+
119+ if (bytesRead > 0 ) {
120+ totalBytesRead .addAndGet (bytesRead );
121+ return new BytesAndPartIndex (new BytesArray (buf , 0 , bytesRead ), currentPart .getAndIncrement ());
122+ } else {
123+ return new BytesAndPartIndex (BytesArray .EMPTY , currentPart .get ());
124+ }
125+ }
126+
127+ public long getTotalBytesRead () {
128+ return totalBytesRead .get ();
129+ }
130+
131+ public int getCurrentPart () {
132+ return currentPart .get ();
133+ }
134+ }
135+
64136 static class InputStreamChunker {
65137
66138 private final InputStream inputStream ;
@@ -101,21 +173,26 @@ public int getTotalBytesRead() {
101173 }
102174 }
103175
104- static InputStream getInputStreamFromModelRepository (URI uri ) throws IOException {
176+ static InputStream getInputStreamFromModelRepository (URI uri ) {
105177 String scheme = uri .getScheme ().toLowerCase (Locale .ROOT );
106178
107179 // if you add a scheme here, also add it to the bootstrap check in {@link MachineLearningPackageLoader#validateModelRepository}
108180 switch (scheme ) {
109181 case "http" :
110182 case "https" :
111- return getHttpOrHttpsInputStream (uri );
183+ return getHttpOrHttpsInputStream (uri , null );
112184 case "file" :
113185 return getFileInputStream (uri );
114186 default :
115187 throw new IllegalArgumentException ("unsupported scheme" );
116188 }
117189 }
118190
191+ static boolean uriIsFile (URI uri ) {
192+ String scheme = uri .getScheme ().toLowerCase (Locale .ROOT );
193+ return "file" .equals (scheme );
194+ }
195+
119196 static VocabularyParts loadVocabulary (URI uri ) {
120197 if (uri .getPath ().endsWith (".json" )) {
121198 try (InputStream vocabInputStream = getInputStreamFromModelRepository (uri )) {
@@ -174,7 +251,7 @@ private ModelLoaderUtils() {}
174251
175252 @ SuppressWarnings ("'java.lang.SecurityManager' is deprecated and marked for removal " )
176253 @ SuppressForbidden (reason = "we need socket connection to download" )
177- private static InputStream getHttpOrHttpsInputStream (URI uri ) throws IOException {
254+ private static InputStream getHttpOrHttpsInputStream (URI uri , @ Nullable RequestRange range ) {
178255
179256 assert uri .getUserInfo () == null : "URI's with credentials are not supported" ;
180257
@@ -186,18 +263,30 @@ private static InputStream getHttpOrHttpsInputStream(URI uri) throws IOException
186263 PrivilegedAction <InputStream > privilegedHttpReader = () -> {
187264 try {
188265 HttpURLConnection conn = (HttpURLConnection ) uri .toURL ().openConnection ();
266+ if (range != null ) {
267+ conn .setRequestProperty ("Range" , range .bytesRange ());
268+ }
189269 switch (conn .getResponseCode ()) {
190270 case HTTP_OK :
271+ case HTTP_PARTIAL :
191272 return conn .getInputStream ();
273+
192274 case HTTP_MOVED_PERM :
193275 case HTTP_MOVED_TEMP :
194276 case HTTP_SEE_OTHER :
195277 throw new IllegalStateException ("redirects aren't supported yet" );
196278 case HTTP_NOT_FOUND :
197279 throw new ResourceNotFoundException ("{} not found" , uri );
280+ case 416 : // Range not satisfiable, for some reason not in the list of constants
281+ throw new IllegalStateException ("Invalid request range [" + range .bytesRange () + "]" );
198282 default :
199283 int responseCode = conn .getResponseCode ();
200- throw new ElasticsearchStatusException ("error during downloading {}" , RestStatus .fromCode (responseCode ), uri );
284+ throw new ElasticsearchStatusException (
285+ "error during downloading {}. Got response code {}" ,
286+ RestStatus .fromCode (responseCode ),
287+ uri ,
288+ responseCode
289+ );
201290 }
202291 } catch (IOException e ) {
203292 throw new UncheckedIOException (e );
@@ -209,7 +298,7 @@ private static InputStream getHttpOrHttpsInputStream(URI uri) throws IOException
209298
210299 @ SuppressWarnings ("'java.lang.SecurityManager' is deprecated and marked for removal " )
211300 @ SuppressForbidden (reason = "we need load model data from a file" )
212- private static InputStream getFileInputStream (URI uri ) {
301+ static InputStream getFileInputStream (URI uri ) {
213302
214303 SecurityManager sm = System .getSecurityManager ();
215304 if (sm != null ) {
@@ -232,4 +321,53 @@ private static InputStream getFileInputStream(URI uri) {
232321 return AccessController .doPrivileged (privilegedFileReader );
233322 }
234323
324+ /**
325+ * Split a stream of size {@code sizeInBytes} into {@code numberOfStreams} +1
326+ * ranges aligned on {@code chunkSizeBytes} boundaries. Each range contains a
327+ * whole number of chunks.
328+ * The first {@code numberOfStreams} ranges will be split evenly (in terms of
329+ * number of chunks not the byte size), the final range split
330+ * is for the single final chunk and will be no more than {@code chunkSizeBytes}
331+ * in size. The separate range for the final chunk is because when streaming and
332+ * uploading a large model definition, writing the last part has to handled
333+ * as a special case.
334+ * @param sizeInBytes The total size of the stream
335+ * @param numberOfStreams Divide the bulk of the size into this many streams.
336+ * @param chunkSizeBytes The size of each chunk
337+ * @return List of {@code numberOfStreams} + 1 ranges.
338+ */
339+ static List <RequestRange > split (long sizeInBytes , int numberOfStreams , long chunkSizeBytes ) {
340+ int numberOfChunks = (int ) ((sizeInBytes + chunkSizeBytes - 1 ) / chunkSizeBytes );
341+
342+ var ranges = new ArrayList <RequestRange >();
343+
344+ int baseChunksPerStream = numberOfChunks / numberOfStreams ;
345+ int remainder = numberOfChunks % numberOfStreams ;
346+ long startOffset = 0 ;
347+ int startChunkIndex = 0 ;
348+
349+ for (int i = 0 ; i < numberOfStreams - 1 ; i ++) {
350+ int numChunksInStream = (i < remainder ) ? baseChunksPerStream + 1 : baseChunksPerStream ;
351+ long rangeEnd = startOffset + (numChunksInStream * chunkSizeBytes ) - 1 ; // range index is 0 based
352+ ranges .add (new RequestRange (startOffset , rangeEnd , startChunkIndex , numChunksInStream ));
353+ startOffset = rangeEnd + 1 ; // range is inclusive start and end
354+ startChunkIndex += numChunksInStream ;
355+ }
356+
357+ // Want the final range request to be a single chunk
358+ if (baseChunksPerStream > 1 ) {
359+ int numChunksExcludingFinal = baseChunksPerStream - 1 ;
360+ long rangeEnd = startOffset + (numChunksExcludingFinal * chunkSizeBytes ) - 1 ;
361+ ranges .add (new RequestRange (startOffset , rangeEnd , startChunkIndex , numChunksExcludingFinal ));
362+
363+ startOffset = rangeEnd + 1 ;
364+ startChunkIndex += numChunksExcludingFinal ;
365+ }
366+
367+ // The final range is a single chunk the end of which should not exceed sizeInBytes
368+ long rangeEnd = Math .min (sizeInBytes , startOffset + (baseChunksPerStream * chunkSizeBytes )) - 1 ;
369+ ranges .add (new RequestRange (startOffset , rangeEnd , startChunkIndex , 1 ));
370+
371+ return ranges ;
372+ }
235373}
0 commit comments