Skip to content

[tx] Implement sampling from base model#527

Merged
pcmoritz merged 12 commits into
NovaSky-AI:mainfrom
pcmoritz:tx-sample-base-model
Oct 20, 2025
Merged

[tx] Implement sampling from base model#527
pcmoritz merged 12 commits into
NovaSky-AI:mainfrom
pcmoritz:tx-sample-base-model

Conversation

@pcmoritz

@pcmoritz pcmoritz commented Oct 20, 2025

Copy link
Copy Markdown
Collaborator

This is a first simple implementation of sample that can only sample from the base_model. It currently doesn't batch the requests, we will implement batching for inference in a followup PR.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the capability to sample from a base model, which is a great first step. The implementation is mostly solid, but there are a few critical issues related to model parameter validation that could lead to runtime errors, particularly when extending this to support LoRA models. I've also identified some areas for improvement in the tests and logging to enhance maintainability and clarity. My review includes suggestions to fix these issues and align the code with best practices.

Comment thread skyrl-tx/tx/tinker/types.py Outdated
Comment thread skyrl-tx/tx/tinker/api.py Outdated
Comment thread skyrl-tx/tx/tinker/engine.py
Comment thread skyrl-tx/tx/tinker/engine.py
Comment thread skyrl-tx/tests/tinker/test_api.py Outdated
Comment thread skyrl-tx/tests/tinker/test_api.py Outdated
Comment thread skyrl-tx/tx/tinker/engine.py
Comment thread skyrl-tx/tx/tinker/engine.py
@pcmoritz pcmoritz changed the title [tx][WIP] Implement sampling from base_model [tx][WIP] Implement sampling from base model Oct 20, 2025
pcmoritz and others added 7 commits October 20, 2025 12:53
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@pcmoritz

Copy link
Copy Markdown
Collaborator Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the functionality to sample from a base model, which is a great first step. The changes in the API layer to support base_model in SampleRequest are well-implemented with proper validation. The core logic in the engine for sampling is also in place. I have a couple of suggestions for improvement in engine.py: one is to replace an assert with a more robust error handling mechanism, and the other is to improve the seeding logic for better sample diversity in the future. Overall, this is a solid WIP implementation.

Comment thread skyrl-tx/tx/tinker/engine.py Outdated
Comment thread skyrl-tx/tx/tinker/engine.py Outdated
@pcmoritz

Copy link
Copy Markdown
Collaborator Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the functionality to sample from a base model. The changes include updating the API to accept sampling requests for a base model, implementing the sampling logic in the engine, and adding corresponding tests. The implementation is a good first step. My main feedback is to add validation for sampling parameters that are exposed in the API but not yet implemented, to avoid user confusion.

Comment thread skyrl-tx/tx/tinker/api.py
pcmoritz and others added 2 commits October 20, 2025 14:10
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@pcmoritz pcmoritz changed the title [tx][WIP] Implement sampling from base model [tx] Implement sampling from base model Oct 20, 2025
@pcmoritz pcmoritz merged commit a8a86dd into NovaSky-AI:main Oct 20, 2025
4 checks passed
pcmoritz added a commit that referenced this pull request Oct 21, 2025
Followup to #527 to implement
the same functionality for LoRA models. This allows us to run the full
piglatin example now:

Start the server with
```bash
uv run --extra tinker -m tx.tinker.api --host 0.0.0.0 --port 8000 --base-model "Qwen/Qwen3-0.6B"
```

and run with `uv run --with tinker python`:

```python
import tinker
service_client = tinker.ServiceClient(base_url="http://localhost:8000", api_key="dummy")
print("Available models:")
for item in service_client.get_server_capabilities().supported_models:
    print("- " + item.model_name)

base_model = "Qwen/Qwen3-0.6B"
training_client = service_client.create_lora_training_client(
    base_model=base_model
)

# Create some training examples
examples = [
    {
        "input": "banana split",
        "output": "anana-bay plit-say"
    },
    {
        "input": "quantum physics",
        "output": "uantum-qay ysics-phay"
    },
    {
        "input": "donut shop",
        "output": "onut-day op-shay"
    },
    {
        "input": "pickle jar",
        "output": "ickle-pay ar-jay"
    },
    {
        "input": "space exploration",
        "output": "ace-spay exploration-way"
    },
    {
        "input": "rubber duck",
        "output": "ubber-ray uck-day"
    },
    {
        "input": "coding wizard",
        "output": "oding-cay izard-way"
    },
]
 
# Convert examples into the format expected by the training client
from tinker import types
 
# Get the tokenizer from the training client
tokenizer = training_client.get_tokenizer()
 
def process_example(example: dict, tokenizer) -> types.Datum:
    # Format the input with Input/Output template
    # For most real use cases, you'll want to use a renderer / chat template,
    # (see later docs) but here, we'll keep it simple.
    prompt = f"English: {example['input']}\nPig Latin:"
 
    prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
    prompt_weights = [0] * len(prompt_tokens)
    # Add a space before the output string, and finish with double newline
    completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False)
    completion_weights = [1] * len(completion_tokens)
 
    tokens = prompt_tokens + completion_tokens
    weights = prompt_weights + completion_weights
 
    input_tokens = tokens[:-1]
    target_tokens = tokens[1:] # We're predicting the next token, so targets need to be shifted.
    weights = weights[1:]
 
    # A datum is a single training example for the loss function.
    # It has model_input, which is the input sequence that'll be passed into the LLM,
    # loss_fn_inputs, which is a dictionary of extra inputs used by the loss function.
    return types.Datum(
        model_input=types.ModelInput.from_ints(tokens=input_tokens),
        loss_fn_inputs=dict(weights=weights, target_tokens=target_tokens)
    )
 
processed_examples = [process_example(ex, tokenizer) for ex in examples]
 
# Visualize the first example for debugging purposes
datum0 = processed_examples[0]
print(f"{'Input':<20} {'Target':<20} {'Weight':<10}")
print("-" * 50)
for i, (inp, tgt, wgt) in enumerate(zip(datum0.model_input.to_ints(), datum0.loss_fn_inputs['target_tokens'].tolist(), datum0.loss_fn_inputs['weights'].tolist())):
    print(f"{repr(tokenizer.decode([inp])):<20} {repr(tokenizer.decode([tgt])):<20} {wgt:<10}")

import numpy as np
for _ in range(6):
    fwdbwd_future = training_client.forward_backward(processed_examples, "cross_entropy")
    optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))
 
    # Wait for the results
    fwdbwd_result = fwdbwd_future.result()
    optim_result = optim_future.result()
 
    # fwdbwd_result contains the logprobs of all the tokens we put in. Now we can compute the weighted
    # average log loss per token.
    logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs])
    weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in processed_examples])
    print(f"Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}")

## Sampling

# First, create a sampling client. We need to transfer weights
sampling_client = training_client.save_weights_and_get_sampling_client(name='pig-latin-model')
 
# Now, we can sample from the model.
prompt=types.ModelInput.from_ints(tokenizer.encode("English: coffee break\nPig Latin:"))
params = types.SamplingParams(max_tokens=20, temperature=0.0) # Greedy sampling
future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=8)
result = future.result()
print("Responses:")
for i, seq in enumerate(result.sequences):
    print(f"{i}: {repr(tokenizer.decode(seq.tokens))}")
```
atemaguer pushed a commit to atemaguer/SkyRL that referenced this pull request Oct 24, 2025
This is a first simple implementation of sample that can only sample
from the base_model. It currently doesn't batch the requests, we will
implement batching for inference in a followup PR.

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
atemaguer pushed a commit to atemaguer/SkyRL that referenced this pull request Oct 24, 2025
Followup to NovaSky-AI#527 to implement
the same functionality for LoRA models. This allows us to run the full
piglatin example now:

Start the server with
```bash
uv run --extra tinker -m tx.tinker.api --host 0.0.0.0 --port 8000 --base-model "Qwen/Qwen3-0.6B"
```

and run with `uv run --with tinker python`:

```python
import tinker
service_client = tinker.ServiceClient(base_url="http://localhost:8000", api_key="dummy")
print("Available models:")
for item in service_client.get_server_capabilities().supported_models:
    print("- " + item.model_name)

base_model = "Qwen/Qwen3-0.6B"
training_client = service_client.create_lora_training_client(
    base_model=base_model
)

# Create some training examples
examples = [
    {
        "input": "banana split",
        "output": "anana-bay plit-say"
    },
    {
        "input": "quantum physics",
        "output": "uantum-qay ysics-phay"
    },
    {
        "input": "donut shop",
        "output": "onut-day op-shay"
    },
    {
        "input": "pickle jar",
        "output": "ickle-pay ar-jay"
    },
    {
        "input": "space exploration",
        "output": "ace-spay exploration-way"
    },
    {
        "input": "rubber duck",
        "output": "ubber-ray uck-day"
    },
    {
        "input": "coding wizard",
        "output": "oding-cay izard-way"
    },
]
 
# Convert examples into the format expected by the training client
from tinker import types
 
# Get the tokenizer from the training client
tokenizer = training_client.get_tokenizer()
 
def process_example(example: dict, tokenizer) -> types.Datum:
    # Format the input with Input/Output template
    # For most real use cases, you'll want to use a renderer / chat template,
    # (see later docs) but here, we'll keep it simple.
    prompt = f"English: {example['input']}\nPig Latin:"
 
    prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
    prompt_weights = [0] * len(prompt_tokens)
    # Add a space before the output string, and finish with double newline
    completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False)
    completion_weights = [1] * len(completion_tokens)
 
    tokens = prompt_tokens + completion_tokens
    weights = prompt_weights + completion_weights
 
    input_tokens = tokens[:-1]
    target_tokens = tokens[1:] # We're predicting the next token, so targets need to be shifted.
    weights = weights[1:]
 
    # A datum is a single training example for the loss function.
    # It has model_input, which is the input sequence that'll be passed into the LLM,
    # loss_fn_inputs, which is a dictionary of extra inputs used by the loss function.
    return types.Datum(
        model_input=types.ModelInput.from_ints(tokens=input_tokens),
        loss_fn_inputs=dict(weights=weights, target_tokens=target_tokens)
    )
 
processed_examples = [process_example(ex, tokenizer) for ex in examples]
 
# Visualize the first example for debugging purposes
datum0 = processed_examples[0]
print(f"{'Input':<20} {'Target':<20} {'Weight':<10}")
print("-" * 50)
for i, (inp, tgt, wgt) in enumerate(zip(datum0.model_input.to_ints(), datum0.loss_fn_inputs['target_tokens'].tolist(), datum0.loss_fn_inputs['weights'].tolist())):
    print(f"{repr(tokenizer.decode([inp])):<20} {repr(tokenizer.decode([tgt])):<20} {wgt:<10}")

import numpy as np
for _ in range(6):
    fwdbwd_future = training_client.forward_backward(processed_examples, "cross_entropy")
    optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))
 
    # Wait for the results
    fwdbwd_result = fwdbwd_future.result()
    optim_result = optim_future.result()
 
    # fwdbwd_result contains the logprobs of all the tokens we put in. Now we can compute the weighted
    # average log loss per token.
    logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs])
    weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in processed_examples])
    print(f"Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}")

## Sampling

# First, create a sampling client. We need to transfer weights
sampling_client = training_client.save_weights_and_get_sampling_client(name='pig-latin-model')
 
# Now, we can sample from the model.
prompt=types.ModelInput.from_ints(tokenizer.encode("English: coffee break\nPig Latin:"))
params = types.SamplingParams(max_tokens=20, temperature=0.0) # Greedy sampling
future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=8)
result = future.result()
print("Responses:")
for i, seq in enumerate(result.sequences):
    print(f"{i}: {repr(tokenizer.decode(seq.tokens))}")
```
li-boxuan pushed a commit to li-boxuan/SkyRL that referenced this pull request Nov 23, 2025
This is a first simple implementation of sample that can only sample
from the base_model. It currently doesn't batch the requests, we will
implement batching for inference in a followup PR.

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
li-boxuan pushed a commit to li-boxuan/SkyRL that referenced this pull request Nov 23, 2025
Followup to NovaSky-AI#527 to implement
the same functionality for LoRA models. This allows us to run the full
piglatin example now:

Start the server with
```bash
uv run --extra tinker -m tx.tinker.api --host 0.0.0.0 --port 8000 --base-model "Qwen/Qwen3-0.6B"
```

and run with `uv run --with tinker python`:

```python
import tinker
service_client = tinker.ServiceClient(base_url="http://localhost:8000", api_key="dummy")
print("Available models:")
for item in service_client.get_server_capabilities().supported_models:
    print("- " + item.model_name)

base_model = "Qwen/Qwen3-0.6B"
training_client = service_client.create_lora_training_client(
    base_model=base_model
)

# Create some training examples
examples = [
    {
        "input": "banana split",
        "output": "anana-bay plit-say"
    },
    {
        "input": "quantum physics",
        "output": "uantum-qay ysics-phay"
    },
    {
        "input": "donut shop",
        "output": "onut-day op-shay"
    },
    {
        "input": "pickle jar",
        "output": "ickle-pay ar-jay"
    },
    {
        "input": "space exploration",
        "output": "ace-spay exploration-way"
    },
    {
        "input": "rubber duck",
        "output": "ubber-ray uck-day"
    },
    {
        "input": "coding wizard",
        "output": "oding-cay izard-way"
    },
]
 
# Convert examples into the format expected by the training client
from tinker import types
 
# Get the tokenizer from the training client
tokenizer = training_client.get_tokenizer()
 
def process_example(example: dict, tokenizer) -> types.Datum:
    # Format the input with Input/Output template
    # For most real use cases, you'll want to use a renderer / chat template,
    # (see later docs) but here, we'll keep it simple.
    prompt = f"English: {example['input']}\nPig Latin:"
 
    prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
    prompt_weights = [0] * len(prompt_tokens)
    # Add a space before the output string, and finish with double newline
    completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False)
    completion_weights = [1] * len(completion_tokens)
 
    tokens = prompt_tokens + completion_tokens
    weights = prompt_weights + completion_weights
 
    input_tokens = tokens[:-1]
    target_tokens = tokens[1:] # We're predicting the next token, so targets need to be shifted.
    weights = weights[1:]
 
    # A datum is a single training example for the loss function.
    # It has model_input, which is the input sequence that'll be passed into the LLM,
    # loss_fn_inputs, which is a dictionary of extra inputs used by the loss function.
    return types.Datum(
        model_input=types.ModelInput.from_ints(tokens=input_tokens),
        loss_fn_inputs=dict(weights=weights, target_tokens=target_tokens)
    )
 
processed_examples = [process_example(ex, tokenizer) for ex in examples]
 
# Visualize the first example for debugging purposes
datum0 = processed_examples[0]
print(f"{'Input':<20} {'Target':<20} {'Weight':<10}")
print("-" * 50)
for i, (inp, tgt, wgt) in enumerate(zip(datum0.model_input.to_ints(), datum0.loss_fn_inputs['target_tokens'].tolist(), datum0.loss_fn_inputs['weights'].tolist())):
    print(f"{repr(tokenizer.decode([inp])):<20} {repr(tokenizer.decode([tgt])):<20} {wgt:<10}")

import numpy as np
for _ in range(6):
    fwdbwd_future = training_client.forward_backward(processed_examples, "cross_entropy")
    optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))
 
    # Wait for the results
    fwdbwd_result = fwdbwd_future.result()
    optim_result = optim_future.result()
 
    # fwdbwd_result contains the logprobs of all the tokens we put in. Now we can compute the weighted
    # average log loss per token.
    logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs])
    weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in processed_examples])
    print(f"Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}")

## Sampling

# First, create a sampling client. We need to transfer weights
sampling_client = training_client.save_weights_and_get_sampling_client(name='pig-latin-model')
 
# Now, we can sample from the model.
prompt=types.ModelInput.from_ints(tokenizer.encode("English: coffee break\nPig Latin:"))
params = types.SamplingParams(max_tokens=20, temperature=0.0) # Greedy sampling
future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=8)
result = future.result()
print("Responses:")
for i, seq in enumerate(result.sequences):
    print(f"{i}: {repr(tokenizer.decode(seq.tokens))}")
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants