Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Deque;
import java.util.HashSet;
Expand All @@ -61,6 +62,7 @@
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import static java.util.Arrays.asList;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
Expand Down Expand Up @@ -107,7 +109,7 @@ public void testDrainBatches() throws Exception {
PartitionInfo part4 = new PartitionInfo(topic, partition4, node2, null, null);

long batchSize = value.length + DefaultRecordBatch.RECORD_BATCH_OVERHEAD;
RecordAccumulator accum = createTestRecordAccumulator((int) batchSize, 1024, CompressionType.NONE, 10);
RecordAccumulator accum = createTestRecordAccumulator((int) batchSize, Integer.MAX_VALUE, CompressionType.NONE, 10);
Cluster cluster = new Cluster(null, Arrays.asList(node1, node2), Arrays.asList(part1, part2, part3, part4),
Collections.emptySet(), Collections.emptySet());

Expand Down Expand Up @@ -142,15 +144,25 @@ public void testDrainBatches() throws Exception {
// drain batches from 2 nodes: node1 => tp2, node2 => tp3 (because tp4 is muted)
Map<Integer, List<ProducerBatch>> batches4 = accum.drain(cluster, new HashSet<Node>(Arrays.asList(node1, node2)), (int) batchSize, 0);
verifyTopicPartitionInBatches(batches4, tp2, tp3);

// add record for tp1, tp2, tp3, and unmute tp4
accum.append(tp1, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds());
accum.append(tp2, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds());
accum.append(tp3, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds());
accum.unmutePartition(tp4);
// set maxSize as a max value, so that the all partitions in 2 nodes should be drained: node1 => [tp1, tp2], node2 => [tp3, tp4]
Map<Integer, List<ProducerBatch>> batches5 = accum.drain(cluster, new HashSet<Node>(Arrays.asList(node1, node2)), Integer.MAX_VALUE, 0);
verifyTopicPartitionInBatches(batches5, tp1, tp2, tp3, tp4);
}

private void verifyTopicPartitionInBatches(Map<Integer, List<ProducerBatch>> batches, TopicPartition... tp) {
assertEquals(tp.length, batches.size());
private void verifyTopicPartitionInBatches(Map<Integer, List<ProducerBatch>> nodeBatches, TopicPartition... tp) {
int allTpBatchCount = nodeBatches.values().stream().flatMap(Collection::stream).collect(Collectors.toList()).size();

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

allTpBatchCount represent the count of the producerbatch from the nodeBatches map

assertEquals(tp.length, allTpBatchCount);
List<TopicPartition> topicPartitionsInBatch = new ArrayList<TopicPartition>();
for (Map.Entry<Integer, List<ProducerBatch>> entry : batches.entrySet()) {
List<ProducerBatch> batchList = entry.getValue();
assertEquals(1, batchList.size());

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because the batcheList.size() can be 1 or 2, remove this assert statement

topicPartitionsInBatch.add(batchList.get(0).topicPartition);
for (Map.Entry<Integer, List<ProducerBatch>> entry : nodeBatches.entrySet()) {
List<ProducerBatch> tpBatchList = entry.getValue();
List<TopicPartition> tpList = tpBatchList.stream().map(producerBatch -> producerBatch.topicPartition).collect(Collectors.toList());
topicPartitionsInBatch.addAll(tpList);
}

for (int i = 0; i < tp.length; i++) {
Expand Down