1414import org .elasticsearch .action .ActionRequest ;
1515import org .elasticsearch .action .ActionResponse ;
1616import org .elasticsearch .action .ActionType ;
17+ import org .elasticsearch .action .search .ClearScrollRequest ;
18+ import org .elasticsearch .action .search .ClearScrollResponse ;
1719import org .elasticsearch .action .search .SearchRequest ;
1820import org .elasticsearch .action .search .SearchResponse ;
1921import org .elasticsearch .action .search .SearchScrollRequest ;
22+ import org .elasticsearch .action .search .TransportClearScrollAction ;
2023import org .elasticsearch .action .search .TransportSearchAction ;
2124import org .elasticsearch .action .search .TransportSearchScrollAction ;
2225import org .elasticsearch .client .internal .ParentTaskAssigningClient ;
3740import org .elasticsearch .threadpool .TestThreadPool ;
3841import org .elasticsearch .threadpool .ThreadPool ;
3942import org .junit .After ;
43+ import org .junit .Assert ;
4044import org .junit .Before ;
4145
46+ import java .util .ArrayList ;
4247import java .util .List ;
4348import java .util .concurrent .ArrayBlockingQueue ;
4449import java .util .concurrent .BlockingQueue ;
4550import java .util .concurrent .TimeUnit ;
51+ import java .util .concurrent .atomic .AtomicBoolean ;
4652import java .util .concurrent .atomic .AtomicInteger ;
4753import java .util .concurrent .atomic .AtomicReference ;
4854import java .util .function .Consumer ;
5258import static org .apache .lucene .tests .util .TestUtil .randomSimpleString ;
5359import static org .elasticsearch .common .bytes .BytesReferenceTestUtils .equalBytes ;
5460import static org .elasticsearch .core .TimeValue .timeValueSeconds ;
61+ import static org .hamcrest .Matchers .contains ;
5562import static org .hamcrest .Matchers .instanceOf ;
5663
5764public class ClientScrollablePaginatedHitSourceTests extends ESTestCase {
@@ -157,10 +164,61 @@ public void testScrollKeepAlive() {
157164 new SearchRequest ().scroll (timeValueSeconds (10 ))
158165 );
159166
167+ paginatedHitSource .setScroll ("scroll_id" );
160168 paginatedHitSource .startNextScroll (timeValueSeconds (100 ));
161169 client .validateRequest (TransportSearchScrollAction .TYPE , (SearchScrollRequest r ) -> assertEquals (r .scroll ().seconds (), 110 ));
162170 }
163171
172+ /** When scroll ID is empty or null, close runs cleanup immediately without calling clearScroll. */
173+ public void testCloseWhenScrollIdEmpty () {
174+ MockClient client = new MockClient (threadPool );
175+ TaskId parentTask = new TaskId ("thenode" , randomInt ());
176+ ClientScrollablePaginatedHitSource paginatedHitSource = new ClientScrollablePaginatedHitSource (
177+ logger ,
178+ BackoffPolicy .constantBackoff (TimeValue .ZERO , 0 ),
179+ threadPool ,
180+ Assert ::fail ,
181+ r -> fail (),
182+ e -> fail (),
183+ new ParentTaskAssigningClient (client , parentTask ),
184+ new SearchRequest ().scroll (timeValueSeconds (10 ))
185+ );
186+ AtomicBoolean closeCallbackCalled = new AtomicBoolean ();
187+
188+ paginatedHitSource .close (() -> closeCallbackCalled .set (true ));
189+
190+ assertTrue (closeCallbackCalled .get ());
191+ assertFalse (client .hasExecuted (TransportClearScrollAction .TYPE ));
192+ }
193+
194+ /** When scroll ID is set, close calls clearScroll and runs cleanup after it completes. */
195+ public void testCloseWhenScrollIdSet () throws InterruptedException {
196+ MockClient client = new MockClient (threadPool );
197+ TaskId parentTask = new TaskId ("thenode" , randomInt ());
198+ ClientScrollablePaginatedHitSource paginatedHitSource = new ClientScrollablePaginatedHitSource (
199+ logger ,
200+ BackoffPolicy .constantBackoff (TimeValue .ZERO , 0 ),
201+ threadPool ,
202+ Assert ::fail ,
203+ r -> fail (),
204+ e -> fail (),
205+ new ParentTaskAssigningClient (client , parentTask ),
206+ new SearchRequest ().scroll (timeValueSeconds (10 ))
207+ );
208+ paginatedHitSource .setScroll ("scroll_123" );
209+ AtomicBoolean closeCallbackCalled = new AtomicBoolean ();
210+
211+ paginatedHitSource .close (() -> closeCallbackCalled .set (true ));
212+
213+ client .awaitOperation ();
214+ client .validateRequest (
215+ TransportClearScrollAction .TYPE ,
216+ (ClearScrollRequest r ) -> assertThat (r .getScrollIds (), contains ("scroll_123" ))
217+ );
218+ client .respond (TransportClearScrollAction .TYPE , new ClearScrollResponse (true , 1 ));
219+ assertTrue (closeCallbackCalled .get ());
220+ }
221+
164222 private SearchResponse createSearchResponse () {
165223 // create a simulated response.
166224 SearchHit hit = SearchHit .unpooled (0 , "id" ).sourceRef (new BytesArray ("{}" ));
@@ -214,6 +272,7 @@ public void validateRequest(ActionType<Response> actionType, Consumer<? super Re
214272
215273 private static class MockClient extends AbstractClient {
216274 private ExecuteRequest <?, ?> executeRequest ;
275+ private final List <ActionType <?>> executedActions = new ArrayList <>();
217276
218277 MockClient (ThreadPool threadPool ) {
219278 super (Settings .EMPTY , threadPool , TestProjectResolvers .alwaysThrow ());
@@ -225,11 +284,15 @@ protected synchronized <Request extends ActionRequest, Response extends ActionRe
225284 Request request ,
226285 ActionListener <Response > listener
227286 ) {
228-
287+ executedActions . add ( action );
229288 this .executeRequest = new ExecuteRequest <>(action , request , listener );
230289 this .notifyAll ();
231290 }
232291
292+ boolean hasExecuted (ActionType <?> action ) {
293+ return executedActions .contains (action );
294+ }
295+
233296 @ SuppressWarnings ("unchecked" )
234297 public <Request extends ActionRequest , Response extends ActionResponse > void respondx (
235298 ActionType <Response > action ,
0 commit comments