[tx] Implement sampling from LoRA models#531
Conversation
There was a problem hiding this comment.
Code Review
This pull request successfully implements sampling from LoRA models, extending the existing functionality to support adapter-based generation. The changes are well-organized across the API, engine, and utility layers, and the inclusion of parameterized tests for both base model and LoRA sampling is a great addition. My review includes a couple of minor suggestions to improve code quality by removing an unused import and adhering to standard Python import practices.
|
I'm branching off of this to work on generation and batching |
| case []: | ||
| stop_tokens = None | ||
| case list(elements) if isinstance(elements[0], str): | ||
| stop_tokens = {token for s in elements for token in self.tokenizer.encode(s, add_special_tokens=False)} |
There was a problem hiding this comment.
If I'm reading this correctly, then a multi-token stop string will be broken up into it's tokens and flattened into a list. So if the stop string is encoded to {token_1, token_2} then if either of those tokens are generated, then generation stops. Am I reading this correctly?
There was a problem hiding this comment.
Yeah, you are right, I think I better remove the stop tokens right now and then we implement it properly (I'm actually not super sure how it needs to be implemented yet)
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request successfully implements sampling from LoRA models, which is a great addition. The changes are well-structured, particularly the introduction of the GenerateResult dataclass for cleaner output from the generate method, and the new load_sampler_weights function in the engine provides a clear separation of concerns. The test coverage has also been appropriately extended to cover both base model and LoRA sampling scenarios.
I've found one significant issue in the generate method where it doesn't handle end-of-sequence (EOS) tokens, causing it to always generate up to max_new_tokens and incorrectly report the stop reason. I've left a detailed comment with a suggested fix for this.
Followup to NovaSky-AI#527 to implement the same functionality for LoRA models. This allows us to run the full piglatin example now: Start the server with ```bash uv run --extra tinker -m tx.tinker.api --host 0.0.0.0 --port 8000 --base-model "Qwen/Qwen3-0.6B" ``` and run with `uv run --with tinker python`: ```python import tinker service_client = tinker.ServiceClient(base_url="http://localhost:8000", api_key="dummy") print("Available models:") for item in service_client.get_server_capabilities().supported_models: print("- " + item.model_name) base_model = "Qwen/Qwen3-0.6B" training_client = service_client.create_lora_training_client( base_model=base_model ) # Create some training examples examples = [ { "input": "banana split", "output": "anana-bay plit-say" }, { "input": "quantum physics", "output": "uantum-qay ysics-phay" }, { "input": "donut shop", "output": "onut-day op-shay" }, { "input": "pickle jar", "output": "ickle-pay ar-jay" }, { "input": "space exploration", "output": "ace-spay exploration-way" }, { "input": "rubber duck", "output": "ubber-ray uck-day" }, { "input": "coding wizard", "output": "oding-cay izard-way" }, ] # Convert examples into the format expected by the training client from tinker import types # Get the tokenizer from the training client tokenizer = training_client.get_tokenizer() def process_example(example: dict, tokenizer) -> types.Datum: # Format the input with Input/Output template # For most real use cases, you'll want to use a renderer / chat template, # (see later docs) but here, we'll keep it simple. prompt = f"English: {example['input']}\nPig Latin:" prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True) prompt_weights = [0] * len(prompt_tokens) # Add a space before the output string, and finish with double newline completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False) completion_weights = [1] * len(completion_tokens) tokens = prompt_tokens + completion_tokens weights = prompt_weights + completion_weights input_tokens = tokens[:-1] target_tokens = tokens[1:] # We're predicting the next token, so targets need to be shifted. weights = weights[1:] # A datum is a single training example for the loss function. # It has model_input, which is the input sequence that'll be passed into the LLM, # loss_fn_inputs, which is a dictionary of extra inputs used by the loss function. return types.Datum( model_input=types.ModelInput.from_ints(tokens=input_tokens), loss_fn_inputs=dict(weights=weights, target_tokens=target_tokens) ) processed_examples = [process_example(ex, tokenizer) for ex in examples] # Visualize the first example for debugging purposes datum0 = processed_examples[0] print(f"{'Input':<20} {'Target':<20} {'Weight':<10}") print("-" * 50) for i, (inp, tgt, wgt) in enumerate(zip(datum0.model_input.to_ints(), datum0.loss_fn_inputs['target_tokens'].tolist(), datum0.loss_fn_inputs['weights'].tolist())): print(f"{repr(tokenizer.decode([inp])):<20} {repr(tokenizer.decode([tgt])):<20} {wgt:<10}") import numpy as np for _ in range(6): fwdbwd_future = training_client.forward_backward(processed_examples, "cross_entropy") optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4)) # Wait for the results fwdbwd_result = fwdbwd_future.result() optim_result = optim_future.result() # fwdbwd_result contains the logprobs of all the tokens we put in. Now we can compute the weighted # average log loss per token. logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs]) weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in processed_examples]) print(f"Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}") ## Sampling # First, create a sampling client. We need to transfer weights sampling_client = training_client.save_weights_and_get_sampling_client(name='pig-latin-model') # Now, we can sample from the model. prompt=types.ModelInput.from_ints(tokenizer.encode("English: coffee break\nPig Latin:")) params = types.SamplingParams(max_tokens=20, temperature=0.0) # Greedy sampling future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=8) result = future.result() print("Responses:") for i, seq in enumerate(result.sequences): print(f"{i}: {repr(tokenizer.decode(seq.tokens))}") ```
Followup to NovaSky-AI#527 to implement the same functionality for LoRA models. This allows us to run the full piglatin example now: Start the server with ```bash uv run --extra tinker -m tx.tinker.api --host 0.0.0.0 --port 8000 --base-model "Qwen/Qwen3-0.6B" ``` and run with `uv run --with tinker python`: ```python import tinker service_client = tinker.ServiceClient(base_url="http://localhost:8000", api_key="dummy") print("Available models:") for item in service_client.get_server_capabilities().supported_models: print("- " + item.model_name) base_model = "Qwen/Qwen3-0.6B" training_client = service_client.create_lora_training_client( base_model=base_model ) # Create some training examples examples = [ { "input": "banana split", "output": "anana-bay plit-say" }, { "input": "quantum physics", "output": "uantum-qay ysics-phay" }, { "input": "donut shop", "output": "onut-day op-shay" }, { "input": "pickle jar", "output": "ickle-pay ar-jay" }, { "input": "space exploration", "output": "ace-spay exploration-way" }, { "input": "rubber duck", "output": "ubber-ray uck-day" }, { "input": "coding wizard", "output": "oding-cay izard-way" }, ] # Convert examples into the format expected by the training client from tinker import types # Get the tokenizer from the training client tokenizer = training_client.get_tokenizer() def process_example(example: dict, tokenizer) -> types.Datum: # Format the input with Input/Output template # For most real use cases, you'll want to use a renderer / chat template, # (see later docs) but here, we'll keep it simple. prompt = f"English: {example['input']}\nPig Latin:" prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True) prompt_weights = [0] * len(prompt_tokens) # Add a space before the output string, and finish with double newline completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False) completion_weights = [1] * len(completion_tokens) tokens = prompt_tokens + completion_tokens weights = prompt_weights + completion_weights input_tokens = tokens[:-1] target_tokens = tokens[1:] # We're predicting the next token, so targets need to be shifted. weights = weights[1:] # A datum is a single training example for the loss function. # It has model_input, which is the input sequence that'll be passed into the LLM, # loss_fn_inputs, which is a dictionary of extra inputs used by the loss function. return types.Datum( model_input=types.ModelInput.from_ints(tokens=input_tokens), loss_fn_inputs=dict(weights=weights, target_tokens=target_tokens) ) processed_examples = [process_example(ex, tokenizer) for ex in examples] # Visualize the first example for debugging purposes datum0 = processed_examples[0] print(f"{'Input':<20} {'Target':<20} {'Weight':<10}") print("-" * 50) for i, (inp, tgt, wgt) in enumerate(zip(datum0.model_input.to_ints(), datum0.loss_fn_inputs['target_tokens'].tolist(), datum0.loss_fn_inputs['weights'].tolist())): print(f"{repr(tokenizer.decode([inp])):<20} {repr(tokenizer.decode([tgt])):<20} {wgt:<10}") import numpy as np for _ in range(6): fwdbwd_future = training_client.forward_backward(processed_examples, "cross_entropy") optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4)) # Wait for the results fwdbwd_result = fwdbwd_future.result() optim_result = optim_future.result() # fwdbwd_result contains the logprobs of all the tokens we put in. Now we can compute the weighted # average log loss per token. logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs]) weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in processed_examples]) print(f"Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}") ## Sampling # First, create a sampling client. We need to transfer weights sampling_client = training_client.save_weights_and_get_sampling_client(name='pig-latin-model') # Now, we can sample from the model. prompt=types.ModelInput.from_ints(tokenizer.encode("English: coffee break\nPig Latin:")) params = types.SamplingParams(max_tokens=20, temperature=0.0) # Greedy sampling future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=8) result = future.result() print("Responses:") for i, seq in enumerate(result.sequences): print(f"{i}: {repr(tokenizer.decode(seq.tokens))}") ```
Followup to #527 to implement the same functionality for LoRA models. This allows us to run the full piglatin example now:
Start the server with
uv run --extra tinker -m tx.tinker.api --host 0.0.0.0 --port 8000 --base-model "Qwen/Qwen3-0.6B"and run with
uv run --with tinker python: