Skip to content

Revert "Support lse in trtllm paged attn kernels"#3079

Merged
aleozlx merged 1 commit intomainfrom
revert-3058-support-trtllm-attn-lse
Apr 16, 2026
Merged

Revert "Support lse in trtllm paged attn kernels"#3079
aleozlx merged 1 commit intomainfrom
revert-3058-support-trtllm-attn-lse

Conversation

@aleozlx
Copy link
Copy Markdown
Collaborator

@aleozlx aleozlx commented Apr 15, 2026

Reverts #3058

Summary by CodeRabbit

  • Refactor

    • Removed Log-Sum-Exp (LSE) tensor support from attention decode and context operations, simplifying API signatures.
    • Streamlined internal workspace allocation and management for generation-mode operations.
  • Tests

    • Updated test coverage to reflect simplified attention operation interfaces and removed LSE validation checks.

@claude
Copy link
Copy Markdown

claude Bot commented Apr 15, 2026

Code Review: Revert "Support lse in trtllm paged attn kernels"

This PR reverts #3058, removing lse/return_lse support from the TRT-LLM paged attention kernels. The overall structure of the revert is sound, but there are several issues worth addressing before merging.


🐛 Critical Bug: Wrong dtype in LSE validation (decode.py)

In flashinfer/decode.py (the BatchMLAPagedAttentionWrapper.run() path), the revert accidentally introduced a dtype regression:

# After revert — WRONG
check_shape_dtype_device(
    lse,
    (q_nope.size(0), q_nope.size(1)),
    q_nope.dtype,   # <-- should be torch.float32, not q_nope.dtype
    q_nope.device,
    "lse",
)

LSE is always float32 regardless of q_nope.dtype. The original pre-#3058 code correctly used torch.float32 here. This was not in #3058 itself, so it looks like a fresh mistake in this revert.


⚠️ Incomplete Struct Cleanup in fmhaRunnerParams.h

The stride fields lseStrideTokens and lseStrideHeads were removed from TllmGenFmhaRunnerParams, but lsePtr is left in the struct. If LSE support is being fully removed, lsePtr should be removed too. If LSE will be re-added soon (as the TODO comment suggests), then removing the stride fields but keeping the pointer is inconsistent.

Either:

  • Remove lsePtr entirely (clean revert), or
  • Leave a commented-out block with the TODO explaining the future plan

⚠️ softmaxStatsPtr Unset in Paged Attention Path

Before #3058, softmaxStatsPtr was always allocated. After #3058, it was allocated whenever LSE was requested. After this revert, it is no longer allocated at all in the paged attention path — only a TODO comment is left:

// todo(Yingyi): add softmax buffer later for lse return

Risk: If the TRT-LLM FMHA kernel internally reads softmaxStatsPtr in the generation path (even without an explicit LSE output), this will be a use-of-uninitialized-pointer bug. Please confirm the kernel is safe when softmaxStatsPtr == nullptr for ForGen mode.

This concern does not apply to trtllm_ragged_attention_launcher, which still correctly allocates softmaxStatsPtr.


⚠️ assert Used for Public API Validation

In decode.py and prefill.py, the revert adds:

assert maybe_lse is None

Python assertions can be disabled with -O/-OO. For user-facing API guards, prefer:

if maybe_lse is not None:
    raise ValueError("lse output is not yet supported for the trtllm-gen backend")

📝 Test Coverage Regression

The revert removes all LSE correctness checks (shape, dtype, finiteness, numerical comparison against reference). While this is expected in a revert, note that:

  1. The workspace check is now less precise — the test now checks workspace_buffer[:8192*256*4] == 0 (the counter region) rather than the softmax buffer boundary. The comment acknowledges this size "might change in the future," which will make the test brittle. It would be better to compute the boundary dynamically or document it as a known limitation.

  2. In test_trtllm_gen_mla.py, the variable lse is now referenced at line ~486 in the original test code path after should_check_lse is removed — double-check that no dangling references to lse remain in the test file after the revert.


📝 Missing Revert Motivation

The PR body doesn't explain why #3058 is being reverted. Was there a correctness issue, a kernel crash, a broken CI test, or an architectural decision? Adding a brief note (or linking to the discussion/issue) would help reviewers understand the urgency and whether a partial revert (e.g., keeping the Python API but fixing the C++ layer) was considered.


✅ What's Correct

  • The C++ launcher signature changes (removing float* lse, lse_stride_tokens, lse_stride_heads) are clean and correctly threaded through both decode and context paths.
  • trtllm_ragged_attention_launcher is correctly left with its softmaxStatsPtr allocation intact (the raggged path legitimately needs it).
  • Return type simplifications (Union[Tensor, Tuple[...]]Union[Tensor, FP4Tensor]) are correct.
  • The return (out ...) style cleanup (removing intermediate out = + conditional return) is a nice simplification.

Summary: The revert is mostly correct, but please address the dtype bug in decode.py before merging, and confirm that softmaxStatsPtr = nullptr is safe for the ForGen paged-attention path.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request removes Log-Sum-Exp (LSE) support and its associated parameters from the TRT-LLM attention kernels and their Python wrappers, including paged attention and MLA implementations. The changes involve simplifying function signatures, removing LSE-related workspace allocations, and updating tests to reflect the removal of LSE return values. Feedback includes a correction for a type check in flashinfer/decode.py where the LSE tensor should be validated against torch.float32 rather than the query's data type, and a suggestion to replace magic numbers in the test suite with named constants for better maintainability.

Comment thread flashinfer/decode.py
Comment on lines +1974 to +1980
check_shape_dtype_device(
lse,
(q_nope.size(0), q_nope.size(1)),
q_nope.dtype,
q_nope.device,
"lse",
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The check_shape_dtype_device for lse is using q_nope.dtype as the expected data type. The log-sum-exp tensor (lse) should have a high-precision float type, typically torch.float32, regardless of the query's data type. Using q_nope.dtype could lead to incorrect type checks when q_nope is a lower precision type like float16 or bfloat16. This appears to be a reintroduction of a bug that might have been fixed in the reverted changes.

Suggested change
check_shape_dtype_device(
lse,
(q_nope.size(0), q_nope.size(1)),
q_nope.dtype,
q_nope.device,
"lse",
)
check_shape_dtype_device(
lse,
(q_nope.size(0), q_nope.size(1)),
torch.float32,
q_nope.device,
"lse",
)

).all()
# check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero
# note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future
assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The size of the workspace buffer being checked, 8192 * 256 * 4, is a magic number. This value corresponds to the size of the counter workspace. To improve readability and maintainability, it would be better to define this as a constant, for example TRTLLM_GEN_COUNTER_WORKSPACE_BYTES, and use that constant here and on line 744. The comment on line 665 already indicates that this size might change in the future, which further strengthens the case for using a named constant.

@aleozlx aleozlx enabled auto-merge (squash) April 15, 2026 17:25
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 15, 2026

📝 Walkthrough

Walkthrough

This PR removes LSE (log-sum-exp) buffer support from the TRTLlm-gen paged attention implementation. Changes include eliminating LSE parameters from kernel launcher signatures, removing LSE workspace allocations, and updating Python decode/prefill/MLA APIs to no longer accept or return LSE tensors.

Changes

Cohort / File(s) Summary
C++ Kernel Launcher & Header
csrc/trtllm_fmha_kernel_launcher.cu, include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
Removed lse pointer and stride parameters (lse_stride_tokens, lse_stride_heads) from trtllm_paged_attention_launcher, trtllm_paged_attention_decode, and trtllm_paged_attention_context signatures; eliminated LSE workspace allocation and parameter wiring; removed stride fields from TllmGenFmhaRunnerParams struct.
Python Decode APIs
flashinfer/decode.py, flashinfer/mla/_core.py
Removed lse and return_lse parameters from trtllm_batch_decode_with_kv_cache() and trtllm_batch_decode_with_kv_cache_mla() and eliminated LSE validation/allocation; updated return types to always be single tensor rather than optional tuple.
Python Prefill API
flashinfer/prefill.py
Removed lse and return_lse parameters from trtllm_batch_context_with_kv_cache() and internal _paged_run; added assertion that maybe_lse is None for trtllm-gen backend; simplified return type to single tensor.
Tests
tests/attention/test_trtllm_gen_attention.py, tests/attention/test_trtllm_gen_mla.py
Removed LSE validation logic, helper functions (get_lse_test_tolerances, trtllm_gen_workspace_softmax_end_bytes_context), and conditional LSE assertions; updated test control flow to treat kernel outputs as single tensors.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • sricketts
  • yongwww
  • yzh119
  • cyx-6
  • samuellees
  • saltyminty
  • bkryu
  • yyihuang
  • kahyunnam
  • jimmyzho
  • qsang-nv
  • nv-yunzheq

Poem

🐰 The LSE hops away, no longer needed here,
Parameters trimmed, the kernel's now more clear,
Workspace shrinks with joy, return types shine so bright,
Simpler signatures dance—the refactor feels just right! ✨

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 39.13% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ❓ Inconclusive The PR description is minimal but appropriate for a revert: it references the original PR being reverted (#3058). However, the pull request template requires additional sections like checklist confirmations, but these are not filled out. Consider adding confirmation of pre-commit checks completion and test status to align with the repository's pull request template requirements.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely describes the main change: reverting support for LSE in TRTLlm paged attention kernels, which aligns with the extensive changes removing LSE parameters and logic across multiple files.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch revert-3058-support-trtllm-attn-lse

Warning

Review ran into problems

🔥 Problems

Git: Failed to clone repository. Please run the @coderabbitai full review command to re-trigger a full review. If the issue persists, set path_filters to include or exclude specific files.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tests/attention/test_trtllm_gen_mla.py (1)

665-675: ⚠️ Potential issue | 🟠 Major

Re-zero the shared TRT-LLM workspace before each sparse run.

Unlike trtllm_batch_decode_mla(), this path reuses global_trtllm_gen_fmha_workspace_buffer without resetting it first. Because the buffer is global, earlier parametrized cases can dirty the counter region and make the zero-region assertion at Line 697 flaky.

♻️ Suggested fix
     if global_trtllm_gen_fmha_workspace_buffer is None:
         global_trtllm_gen_fmha_workspace_buffer = torch.zeros(
             workspace_size, dtype=torch.int8, device=device
         )
     workspace_buffer = global_trtllm_gen_fmha_workspace_buffer
+    workspace_buffer.zero_()
     # workspace_buffer_ref = global_workspace_buffer
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_trtllm_gen_mla.py` around lines 665 - 675, Reset the
shared TRT-LLM workspace buffer before each sparse run by explicitly zeroing
global_trtllm_gen_fmha_workspace_buffer prior to setting workspace_buffer;
locate the block that sets global_trtllm_gen_fmha_workspace_buffer and
workspace_buffer in test_trtllm_gen_mla.py and call an in-place zeroing
operation (e.g., fill_(0) or torch.zeros_like assignment) on
global_trtllm_gen_fmha_workspace_buffer so the counter region is cleared before
reuse (similar to how trtllm_batch_decode_mla() initializes its buffer).
tests/attention/test_trtllm_gen_attention.py (1)

639-666: ⚠️ Potential issue | 🟠 Major

These fixed zero-slice assertions need a per-call workspace reset.

create_workspace_buffers() returns a global TRT-LLM workspace, and the decode helper also reuses that buffer for the XQA path, which writes semaphore state into this exact prefix. After switching to a fixed zero-region assert, these tests become order-dependent unless the TRT-LLM workspace is zero_()ed before each kernel invocation.

♻️ Suggested fix
     workspace_buffer, workspace_buffer_ref = create_workspace_buffers(GPU_DEVICE)
+    workspace_buffer.zero_()

Apply the same reset in both the prefill and decode helpers before the TRT-LLM call.

Also applies to: 1095-1125

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_trtllm_gen_attention.py` around lines 639 - 666, The
tests assume the TRT-LLM global workspace is zeroed but it’s reused across
calls; before each kernel invocation (e.g., before calling
trtllm_batch_context_with_kv_cache in the prefill helper and similarly in the
decode helper), explicitly reset the shared workspace returned by
create_workspace_buffers() by calling its zeroing method (e.g.,
workspace_buffer.zero_() or .zero()) so the fixed zero-region assertion on
workspace_buffer[:8192 * 256 * 4] is valid and tests are order-independent;
apply the same reset in both the prefill and decode helper locations referenced
in the diff.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/decode.py`:
- Around line 1966-1980: The code allocates lse as torch.float32 when return_lse
is true but validates a caller-supplied lse against q_nope.dtype (which may be
bf16/fp16), causing valid float32 buffers to fail; update the validation to
expect torch.float32 instead of q_nope.dtype by calling
check_shape_dtype_device(lse, (q_nope.size(0), q_nope.size(1)), torch.float32,
q_nope.device, "lse") (ensure you import/qualify torch.float32 if needed) while
keeping the shape and device checks the same.
- Around line 1367-1375: The TRT-LLM decode wrapper must refuse use of
return_lse/lse instead of letting the custom op's paged_run assert; add an
explicit check in the wrapper (the block handling return_lse and lse) that if
the backend is TRT-LLM (or when calling paged_run) and (return_lse is True or
lse is not None) raise a clear ValueError with a user-facing message; reference
the existing symbols return_lse, lse, paged_run and maybe_lse so you locate the
code path and replace the silent assertion with this explicit check.

---

Outside diff comments:
In `@tests/attention/test_trtllm_gen_attention.py`:
- Around line 639-666: The tests assume the TRT-LLM global workspace is zeroed
but it’s reused across calls; before each kernel invocation (e.g., before
calling trtllm_batch_context_with_kv_cache in the prefill helper and similarly
in the decode helper), explicitly reset the shared workspace returned by
create_workspace_buffers() by calling its zeroing method (e.g.,
workspace_buffer.zero_() or .zero()) so the fixed zero-region assertion on
workspace_buffer[:8192 * 256 * 4] is valid and tests are order-independent;
apply the same reset in both the prefill and decode helper locations referenced
in the diff.

In `@tests/attention/test_trtllm_gen_mla.py`:
- Around line 665-675: Reset the shared TRT-LLM workspace buffer before each
sparse run by explicitly zeroing global_trtllm_gen_fmha_workspace_buffer prior
to setting workspace_buffer; locate the block that sets
global_trtllm_gen_fmha_workspace_buffer and workspace_buffer in
test_trtllm_gen_mla.py and call an in-place zeroing operation (e.g., fill_(0) or
torch.zeros_like assignment) on global_trtllm_gen_fmha_workspace_buffer so the
counter region is cleared before reuse (similar to how trtllm_batch_decode_mla()
initializes its buffer).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c6857e0b-0eb4-486c-b26d-292c8254736e

📥 Commits

Reviewing files that changed from the base of the PR and between 25b324d and 70f0763.

📒 Files selected for processing (7)
  • csrc/trtllm_fmha_kernel_launcher.cu
  • flashinfer/decode.py
  • flashinfer/mla/_core.py
  • flashinfer/prefill.py
  • include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
  • tests/attention/test_trtllm_gen_attention.py
  • tests/attention/test_trtllm_gen_mla.py
💤 Files with no reviewable changes (1)
  • include/flashinfer/trtllm/fmha/fmhaRunnerParams.h

Comment thread flashinfer/decode.py
Comment on lines +1367 to +1375
if return_lse:
if lse is None:
lse = torch.empty(
(q.size(0), q.size(1)), dtype=torch.float32, device=q.device
)
else:
check_shape_dtype_device(
lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -e
sed -n '1218,1525p' flashinfer/decode.py
echo '---'
sed -n '2096,2160p' flashinfer/decode.py

Repository: flashinfer-ai/flashinfer

Length of output: 16188


🏁 Script executed:

# Check for existing guards on lse/return_lse with trtllm-gen
rg -n "trtllm-gen" flashinfer/decode.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 1599


🏁 Script executed:

# Check the complete argument assembly and how lse is passed to the backend
sed -n '1350,1450p' flashinfer/decode.py | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 4692


🏁 Script executed:

# Verify the custom op assert and surrounding context
sed -n '2145,2165p' flashinfer/decode.py

Repository: flashinfer-ai/flashinfer

Length of output: 975


Block lse/return_lse on the TRT-LLM decode wrapper instead of asserting internally.

The public wrapper accepts return_lse=True and explicit lse tensors without checking the backend, but passes them to the custom op's paged_run() which asserts maybe_lse is None. This causes an AssertionError instead of a stable user-facing error, and fails silently under python -O.

♻️ Suggested fix
+        if self._backend == "trtllm-gen" and (return_lse or lse is not None):
+            raise ValueError(
+                "trtllm-gen backend does not support lse/return_lse"
+            )
         if return_lse:
             if lse is None:
                 lse = torch.empty(
                     (q.size(0), q.size(1)), dtype=torch.float32, device=q.device
                 )
             else:
                 check_shape_dtype_device(
                     lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse"
                 )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/decode.py` around lines 1367 - 1375, The TRT-LLM decode wrapper
must refuse use of return_lse/lse instead of letting the custom op's paged_run
assert; add an explicit check in the wrapper (the block handling return_lse and
lse) that if the backend is TRT-LLM (or when calling paged_run) and (return_lse
is True or lse is not None) raise a clear ValueError with a user-facing message;
reference the existing symbols return_lse, lse, paged_run and maybe_lse so you
locate the code path and replace the silent assertion with this explicit check.

Comment thread flashinfer/decode.py
Comment on lines +1966 to +1980
if return_lse:
if lse is None:
lse = torch.empty(
(q_nope.size(0), q_nope.size(1)),
dtype=torch.float32,
device=device,
)
else:
check_shape_dtype_device(
lse,
(q_nope.size(0), q_nope.size(1)),
q_nope.dtype,
q_nope.device,
"lse",
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate caller-supplied MLA lse buffers as float32.

This branch allocates lse as torch.float32, but the explicit-buffer path validates against q_nope.dtype. A correctly preallocated float32 lse tensor will fail here for bf16/fp16 inputs.

♻️ Suggested fix
                 check_shape_dtype_device(
                     lse,
                     (q_nope.size(0), q_nope.size(1)),
-                    q_nope.dtype,
+                    torch.float32,
                     q_nope.device,
                     "lse",
                 )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/decode.py` around lines 1966 - 1980, The code allocates lse as
torch.float32 when return_lse is true but validates a caller-supplied lse
against q_nope.dtype (which may be bf16/fp16), causing valid float32 buffers to
fail; update the validation to expect torch.float32 instead of q_nope.dtype by
calling check_shape_dtype_device(lse, (q_nope.size(0), q_nope.size(1)),
torch.float32, q_nope.device, "lse") (ensure you import/qualify torch.float32 if
needed) while keeping the shape and device checks the same.

Comment thread flashinfer/prefill.py
Comment on lines +2285 to +2293
if return_lse:
if lse is None:
lse = torch.empty(
(q.size(0), q.size(1)), dtype=torch.float32, device=q.device
)
else:
check_shape_dtype_device(
lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Preserve explicit lse validation and reject TRT-LLM LSE requests earlier.

lse is still forwarded to the backend even when return_lse is False, but this block now validates it only inside the return_lse path. That drops the shape/dtype/device check for callers reusing an explicit lse buffer, and on backend="trtllm-gen" it allocates lse here only to trip the internal assert at Line 682. Please validate any provided lse unconditionally, then raise a real NotImplementedError/ValueError before allocation when the selected backend does not support LSE.

Suggested fix
-        if return_lse:
-            if lse is None:
-                lse = torch.empty(
-                    (q.size(0), q.size(1)), dtype=torch.float32, device=q.device
-                )
-            else:
-                check_shape_dtype_device(
-                    lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse"
-                )
+        if lse is not None:
+            check_shape_dtype_device(
+                lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse"
+            )
+        if return_lse:
+            if self._backend == "trtllm-gen":
+                raise NotImplementedError(
+                    "return_lse is not supported for backend='trtllm-gen'."
+                )
+            if lse is None:
+                lse = torch.empty(
+                    (q.size(0), q.size(1)), dtype=torch.float32, device=q.device
+                )

@aleozlx aleozlx added the v0.6.9 release blocker label for 0.6.9 label Apr 15, 2026
@aleozlx aleozlx merged commit a99ee72 into main Apr 16, 2026
68 of 98 checks passed
@aleozlx aleozlx deleted the revert-3058-support-trtllm-attn-lse branch April 16, 2026 02:45
ziang-and pushed a commit to zianglih/flashinfer that referenced this pull request Apr 17, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

op: attention v0.6.9 release blocker label for 0.6.9

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants