Skip to content

Commit c60f84e

Browse files
authored
feat: add dataset caching in worker node (#24)
1 parent a288086 commit c60f84e

6 files changed

Lines changed: 206 additions & 21 deletions

File tree

.github/workflows/test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ jobs:
1818
distribution: 'temurin'
1919
cache: maven
2020

21-
- name: Build and test with Maven
22-
run: ./mvnw clean install -Dair.check.skip-enforcer=true
21+
- name: Build and test
22+
run: make install
2323

2424
- name: Upload test results
2525
uses: actions/upload-artifact@v4

Makefile

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
.PHONY: build test clean install compile package help run lint format check
2+
3+
# Default target
4+
help:
5+
@echo "Available targets:"
6+
@echo " build - Build the project (compile + package)"
7+
@echo " test - Run tests"
8+
@echo " install - Build and install to local Maven repository"
9+
@echo " compile - Compile source code"
10+
@echo " package - Package the plugin"
11+
@echo " clean - Clean build artifacts"
12+
@echo " verify - Run full verification (compile, test, package)"
13+
@echo " run - Run the development query runner server"
14+
@echo " lint - Run all code style checks (checkstyle, modernizer, sortpom)"
15+
@echo " format - Format pom.xml files"
16+
@echo " check - Run lint checks without tests"
17+
18+
# Build the project
19+
build: compile package
20+
21+
# Run tests
22+
test:
23+
./mvnw test -Dair.check.skip-enforcer=true
24+
25+
# Install to local repository
26+
install:
27+
./mvnw clean install -Dair.check.skip-enforcer=true
28+
29+
# Compile source code
30+
compile:
31+
./mvnw compile -Dair.check.skip-enforcer=true
32+
33+
# Package the plugin
34+
package:
35+
./mvnw package -Dair.check.skip-enforcer=true -DskipTests
36+
37+
# Clean build artifacts
38+
clean:
39+
./mvnw clean
40+
41+
# Full verification
42+
verify:
43+
./mvnw verify -Dair.check.skip-enforcer=true
44+
45+
# Run the development query runner server
46+
run:
47+
./mvnw exec:java -pl plugin/trino-lance -Dexec.mainClass="io.trino.plugin.lance.LanceQueryRunner" -Dair.check.skip-enforcer=true
48+
49+
# Run all code style checks (Trino standard)
50+
lint:
51+
./mvnw checkstyle:check sortpom:verify modernizer:modernizer -Dair.check.skip-enforcer=true
52+
53+
# Format pom.xml files
54+
format:
55+
./mvnw sortpom:sort
56+
57+
# Run all checks without tests
58+
check:
59+
./mvnw compile checkstyle:check sortpom:verify modernizer:modernizer -Dair.check.skip-enforcer=true -DskipTests

README.md

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,23 @@ GROUP BY category;
7272
Build the connector:
7373

7474
```bash
75-
./mvnw clean install
75+
make install
76+
```
77+
78+
Other available commands:
79+
80+
```bash
81+
make help # Show all available commands
82+
make build # Compile and package
83+
make test # Run tests
84+
make compile # Compile only
85+
make package # Package without tests
86+
make clean # Clean build artifacts
87+
make verify # Full verification
88+
make run # Run development server
89+
make lint # Run code style checks (checkstyle, modernizer, sortpom)
90+
make format # Format pom.xml files
91+
make check # Run all checks without tests
7692
```
7793

7894
## Installation
@@ -93,16 +109,15 @@ We periodically upgrade the Trino version to stay up to date with the latest Tri
93109
### Running Tests
94110

95111
```bash
96-
./mvnw test -pl plugin/trino-lance
112+
make test
97113
```
98114

99115
### Running the Query Runner (Development Server)
100116

101117
You can run a local Trino server for development:
102118

103119
```bash
104-
cd plugin/trino-lance
105-
mvn exec:java -Dexec.mainClass="io.trino.plugin.lance.LanceQueryRunner"
120+
make run
106121
```
107122

108123
This starts a Trino server on port 8080 with the Lance connector configured.

plugin/trino-lance/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@
5353
<artifactId>units</artifactId>
5454
</dependency>
5555

56+
<dependency>
57+
<groupId>io.trino</groupId>
58+
<artifactId>trino-cache</artifactId>
59+
</dependency>
60+
5661
<dependency>
5762
<groupId>io.trino</groupId>
5863
<artifactId>trino-plugin-toolkit</artifactId>

plugin/trino-lance/src/main/java/io/trino/plugin/lance/LanceFragmentPageSource.java

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import io.trino.plugin.lance.internal.LanceReader;
1818
import io.trino.plugin.lance.internal.ScannerFactory;
1919
import org.apache.arrow.memory.BufferAllocator;
20-
import org.lance.Dataset;
2120
import org.lance.Fragment;
2221
import org.lance.ipc.LanceScanner;
2322
import org.lance.ipc.ScanOptions;
@@ -46,7 +45,6 @@ public static class FragmentScannerFactory
4645
implements ScannerFactory
4746
{
4847
private final int fragmentId;
49-
private Dataset lanceDataset;
5048
private Fragment lanceFragment;
5149
private LanceScanner lanceScanner;
5250

@@ -58,12 +56,11 @@ public FragmentScannerFactory(int fragmentId)
5856
@Override
5957
public LanceScanner open(String tablePath, BufferAllocator allocator, List<String> columns)
6058
{
61-
this.lanceDataset = Dataset.open(tablePath, allocator);
62-
// Find fragment by ID, not by list index (fragment IDs may not match list positions)
63-
this.lanceFragment = lanceDataset.getFragments().stream()
64-
.filter(f -> f.getId() == this.fragmentId)
65-
.findFirst()
66-
.orElseThrow(() -> new RuntimeException("Fragment not found: " + this.fragmentId));
59+
// Use cached fragment lookup instead of opening dataset and filtering
60+
this.lanceFragment = LanceReader.getFragment(tablePath, this.fragmentId);
61+
if (this.lanceFragment == null) {
62+
throw new RuntimeException("Fragment not found: " + this.fragmentId);
63+
}
6764
ScanOptions.Builder optionsBuilder = new ScanOptions.Builder();
6865
// Only set columns if non-empty; empty list means read all columns
6966
if (!columns.isEmpty()) {
@@ -76,6 +73,7 @@ public LanceScanner open(String tablePath, BufferAllocator allocator, List<Strin
7673
@Override
7774
public void close()
7875
{
76+
// Only close the scanner; the dataset is managed by LanceReader's cache
7977
try {
8078
if (lanceScanner != null) {
8179
lanceScanner.close();
@@ -84,9 +82,6 @@ public void close()
8482
catch (Exception e) {
8583
log.warn("error while closing lance scanner, Exception: %s", e.getMessage());
8684
}
87-
if (lanceDataset != null) {
88-
lanceDataset.close();
89-
}
9085
}
9186
}
9287
}

plugin/trino-lance/src/main/java/io/trino/plugin/lance/internal/LanceReader.java

Lines changed: 115 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
*/
1414
package io.trino.plugin.lance.internal;
1515

16+
import com.google.common.cache.CacheBuilder;
17+
import com.google.common.cache.CacheLoader;
1618
import com.google.inject.Inject;
19+
import io.airlift.log.Logger;
20+
import io.trino.cache.NonEvictableLoadingCache;
21+
import io.trino.cache.SafeCaches;
1722
import io.trino.plugin.lance.LanceColumnHandle;
1823
import io.trino.plugin.lance.LanceConfig;
1924
import io.trino.plugin.lance.LanceNamespaceProperties;
@@ -38,20 +43,62 @@
3843
import java.util.LinkedHashMap;
3944
import java.util.List;
4045
import java.util.Map;
46+
import java.util.Objects;
4147
import java.util.Set;
48+
import java.util.concurrent.ExecutionException;
49+
import java.util.concurrent.TimeUnit;
4250
import java.util.stream.Collectors;
4351

4452
import static com.google.common.collect.ImmutableList.toImmutableList;
4553

4654
public class LanceReader
4755
implements Closeable
4856
{
57+
private static final Logger log = Logger.get(LanceReader.class);
58+
4959
// TODO: support multiple schemas
5060
public static final String SCHEMA = "default";
5161
private static final String TABLE_PATH_SUFFIX = ".lance";
5262
private static final BufferAllocator allocator = new RootAllocator(
5363
RootAllocator.configBuilder().from(RootAllocator.defaultConfig()).maxAllocation(4 * 1024 * 1024).build());
5464

65+
// Cache for dataset metadata (fragments) - shared across all LanceReader instances per worker JVM
66+
// Maximum 100 entries, expires 1 hour after last access (same as lance-spark)
67+
private static final NonEvictableLoadingCache<CacheKey, Map<Integer, Fragment>> FRAGMENT_CACHE =
68+
SafeCaches.buildNonEvictableCache(
69+
CacheBuilder.newBuilder()
70+
.maximumSize(100)
71+
.expireAfterAccess(1, TimeUnit.HOURS),
72+
new CacheLoader<>()
73+
{
74+
@Override
75+
public Map<Integer, Fragment> load(CacheKey key)
76+
{
77+
log.debug("Loading fragments for table: %s", key.getTablePath());
78+
Dataset dataset = Dataset.open(key.getTablePath(), allocator);
79+
return dataset.getFragments().stream()
80+
.collect(Collectors.toMap(Fragment::getId, f -> f));
81+
}
82+
});
83+
84+
// Cache for schema metadata - shared across all LanceReader instances per worker JVM
85+
private static final NonEvictableLoadingCache<CacheKey, Schema> SCHEMA_CACHE =
86+
SafeCaches.buildNonEvictableCache(
87+
CacheBuilder.newBuilder()
88+
.maximumSize(100)
89+
.expireAfterAccess(1, TimeUnit.HOURS),
90+
new CacheLoader<>()
91+
{
92+
@Override
93+
public Schema load(CacheKey key)
94+
{
95+
log.debug("Loading schema for table: %s", key.getTablePath());
96+
try (Dataset dataset = Dataset.open(key.getTablePath(), allocator)) {
97+
return dataset.getSchema();
98+
}
99+
}
100+
});
101+
55102
private final String root;
56103
private final LanceNamespace namespace;
57104

@@ -138,17 +185,46 @@ public List<Fragment> getFragments(LanceTableHandle tableHandle)
138185
return getFragments(getTablePath(tableHandle.getTableName()));
139186
}
140187

188+
/**
189+
* Get a specific fragment by ID from the cache.
190+
* This is useful for workers that need to access a specific fragment for data reading.
191+
*
192+
* @param tablePath the path to the lance table
193+
* @param fragmentId the fragment ID to retrieve
194+
* @return the Fragment object, or null if not found
195+
*/
196+
public static Fragment getFragment(String tablePath, int fragmentId)
197+
{
198+
try {
199+
CacheKey key = new CacheKey(tablePath);
200+
Map<Integer, Fragment> fragments = FRAGMENT_CACHE.get(key);
201+
return fragments.get(fragmentId);
202+
}
203+
catch (ExecutionException e) {
204+
throw new RuntimeException("Failed to get fragment from cache for table: " + tablePath, e);
205+
}
206+
}
207+
141208
private static List<Fragment> getFragments(String tablePath)
142209
{
143-
try (Dataset dataset = Dataset.open(tablePath, allocator)) {
144-
return dataset.getFragments();
210+
try {
211+
CacheKey key = new CacheKey(tablePath);
212+
Map<Integer, Fragment> fragmentMap = FRAGMENT_CACHE.get(key);
213+
return List.copyOf(fragmentMap.values());
214+
}
215+
catch (ExecutionException e) {
216+
throw new RuntimeException("Failed to get fragments from cache for table: " + tablePath, e);
145217
}
146218
}
147219

148220
private static Schema getSchema(String tablePath)
149221
{
150-
try (Dataset dataset = Dataset.open(tablePath, allocator)) {
151-
return dataset.getSchema();
222+
try {
223+
CacheKey key = new CacheKey(tablePath);
224+
return SCHEMA_CACHE.get(key);
225+
}
226+
catch (ExecutionException e) {
227+
throw new RuntimeException("Failed to get schema from cache for table: " + tablePath, e);
152228
}
153229
}
154230

@@ -176,4 +252,39 @@ public void close()
176252
}
177253
}
178254
}
255+
256+
/**
257+
* Cache key for dataset metadata caching.
258+
* Uses table path as the unique identifier.
259+
*/
260+
private static class CacheKey
261+
{
262+
private final String tablePath;
263+
264+
CacheKey(String tablePath)
265+
{
266+
this.tablePath = tablePath;
267+
}
268+
269+
public String getTablePath()
270+
{
271+
return tablePath;
272+
}
273+
274+
@Override
275+
public boolean equals(Object o)
276+
{
277+
if (o == null || getClass() != o.getClass()) {
278+
return false;
279+
}
280+
CacheKey cacheKey = (CacheKey) o;
281+
return Objects.equals(tablePath, cacheKey.tablePath);
282+
}
283+
284+
@Override
285+
public int hashCode()
286+
{
287+
return Objects.hash(tablePath);
288+
}
289+
}
179290
}

0 commit comments

Comments
 (0)