Skip to content

Commit 3cc87c3

Browse files
authored
Fix poor buffering case for MultipartReader (#8665)
* Demonstrate poor buffering case * Fix for repeated reads of small byteCount from large part
1 parent 9ee3463 commit 3cc87c3

3 files changed

Lines changed: 139 additions & 4 deletions

File tree

okhttp/src/commonJvmAndroid/kotlin/okhttp3/MultipartReader.kt

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import java.io.Closeable
1919
import java.io.IOException
2020
import java.net.ProtocolException
2121
import okhttp3.internal.http1.HeadersReader
22+
import okhttp3.internal.limit
2223
import okio.Buffer
2324
import okio.BufferedSource
2425
import okio.ByteString.Companion.encodeUtf8
@@ -183,10 +184,14 @@ class MultipartReader
183184
* one byte left to read.
184185
*/
185186
private fun currentPartBytesRemaining(maxResult: Long): Long {
186-
source.require(crlfDashDashBoundary.size.toLong())
187-
188-
return when (val delimiterIndex = source.buffer.indexOf(crlfDashDashBoundary)) {
189-
-1L -> minOf(maxResult, source.buffer.size - crlfDashDashBoundary.size + 1)
187+
// Avoid indexOf scanning repeatedly over the entire source by using limit
188+
// Since maxResult could be midway through the boundary, read further to be safe.
189+
val limitSource = source.peek().limit(maxResult + crlfDashDashBoundary.size).buffer()
190+
limitSource.require(crlfDashDashBoundary.size.toLong())
191+
192+
val delimiterIndex = limitSource.buffer.indexOf(crlfDashDashBoundary)
193+
return when (delimiterIndex) {
194+
-1L -> minOf(maxResult, limitSource.buffer.size - crlfDashDashBoundary.size + 1)
190195
else -> minOf(maxResult, delimiterIndex)
191196
}
192197
}
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright (C) 2024 Square, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package okhttp3.internal
18+
19+
import kotlin.jvm.JvmOverloads
20+
import okio.Buffer
21+
import okio.ForwardingSource
22+
import okio.Source
23+
24+
/**
25+
* Return a new [Source] whose [read function][Source.read] returns -1 after [byteCount]
26+
* bytes have been read.
27+
*
28+
* @param onReadExhausted Callback invoked once when the end of bytes has been reached. It receives
29+
* `true` if the end of bytes was because the underlying stream did not contain enough bytes and
30+
* `false` if [byteCount] bytes were successfully read.
31+
*/
32+
@JvmOverloads
33+
internal fun Source.limit(
34+
byteCount: Long,
35+
onReadExhausted: (eof: Boolean) -> Unit = {},
36+
): Source {
37+
require(byteCount >= 0) { "byteCount < 0: $byteCount" }
38+
return FixedLengthSource(this, byteCount, onReadExhausted, truncate = true)
39+
}
40+
41+
internal class FixedLengthSource(
42+
delegate: Source,
43+
private var bytesRemaining: Long,
44+
onReadExhausted: (eof: Boolean) -> Unit,
45+
private val truncate: Boolean,
46+
) : ForwardingSource(delegate) {
47+
/** `null` once invoked. */
48+
private var onReadExhausted: ((eof: Boolean) -> Unit)? = onReadExhausted
49+
50+
override fun read(
51+
sink: Buffer,
52+
byteCount: Long,
53+
): Long {
54+
val requestBytes =
55+
if (truncate) {
56+
if (bytesRemaining == 0L) {
57+
// If the limit was 0 we want to wait until the first call to this function before
58+
// triggering the callback.
59+
onReadExhausted?.invoke(false)
60+
onReadExhausted = null
61+
return -1L
62+
}
63+
minOf(bytesRemaining, byteCount)
64+
} else {
65+
byteCount
66+
}
67+
val readBytes = super.read(sink, requestBytes)
68+
if (readBytes == -1L) {
69+
onReadExhausted!!(true)
70+
onReadExhausted = null
71+
return -1L
72+
}
73+
bytesRemaining -= readBytes
74+
if (bytesRemaining == 0L) {
75+
onReadExhausted!!(false)
76+
onReadExhausted = null
77+
}
78+
return readBytes
79+
}
80+
}

okhttp/src/jvmTest/kotlin/okhttp3/MultipartReaderTest.kt

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import java.net.ProtocolException
2525
import kotlin.test.assertFailsWith
2626
import okhttp3.Headers.Companion.headersOf
2727
import okhttp3.MediaType.Companion.toMediaType
28+
import okhttp3.MediaType.Companion.toMediaTypeOrNull
2829
import okhttp3.RequestBody.Companion.toRequestBody
2930
import okhttp3.ResponseBody.Companion.toResponseBody
3031
import okio.Buffer
@@ -587,4 +588,53 @@ class MultipartReaderTest {
587588

588589
assertThat(reader.nextPart()).isNull()
589590
}
591+
592+
@Test
593+
fun `reading a large part with small byteCount`() {
594+
val multipartBody: RequestBody =
595+
MultipartBody.Builder("foo").addPart(
596+
headersOf("header-name", "header-value"),
597+
object : RequestBody() {
598+
override fun contentType(): MediaType? {
599+
return "application/octet-stream".toMediaTypeOrNull()
600+
}
601+
602+
override fun contentLength(): Long {
603+
return (1024 * 1024 * 100).toLong()
604+
}
605+
606+
override fun writeTo(sink: okio.BufferedSink) {
607+
repeat(100) {
608+
sink.writeUtf8(
609+
"a".repeat(1024 * 1024),
610+
)
611+
}
612+
}
613+
},
614+
).build()
615+
val buffer =
616+
Buffer().apply {
617+
multipartBody.writeTo(this)
618+
}
619+
620+
val multipartReader = MultipartReader(buffer, "foo")
621+
while (true) {
622+
val part = multipartReader.nextPart()
623+
624+
if (part == null) break
625+
626+
assertThat(part.headers["header-name"]).isEqualTo("header-value")
627+
while (true) {
628+
val readBuff = Buffer()
629+
val read = part.body.read(readBuff, (1024).toLong())
630+
if (read == -1L) {
631+
break
632+
} else {
633+
assertThat(readBuff.readUtf8()).isEqualTo(
634+
"a".repeat(read.toInt()),
635+
)
636+
}
637+
}
638+
}
639+
}
590640
}

0 commit comments

Comments
 (0)