llama : add option for greedy sampling with probs#3813
Conversation
e274fe3 to
4aa1fb0
Compare
|
|
||
| if (temp <= 0) { | ||
| // greedy sampling | ||
| if (temp < 0.0) { |
There was a problem hiding this comment.
The result is the same in either case right? I'm not entirely sure it's worth special casing this instead of just changing greedy sampling to do:
llama_sample_softmax(ctx_main, &cur_p);
id = cur_p.data[0].id;But if you did go that way, you'd probably also want to change the common args parsing stuff to clamp the user-specified temperature to 0.0 so if they pass a negative value it's the same.
It's only internal stuff that would care about probs generated vs no probs unless I'm misunderstanding.
There was a problem hiding this comment.
It's the same result yes. The probs are not used only internally - we are using them in speculative. Before this PR, we had to do the hack with temp = 0.01f; to get probs. Now we get them with temp = 0.0f;
The user specified input should not be affected by this change. Technically, the user would normally want to pass temp = -1.0f to save the extra softmax compute, but it's probably not something that would affect performance in measurable way.
There was a problem hiding this comment.
Sorry, "internal" was a poor choice of words. I meant it's not something someone calling the application and passing --temp on the commandline would care about. So if they do --temp -1 for an example that doesn't care about probs then it's kind of weird/unnecessary to turn on generating probs in that case.
So what I'm proposing is that the argument handling stuff would do something like:
params.sparams.temp = std::max(0.0f, atof(blah));when parsing the commandline arguments, so even if the user does --temp -1 it's still just 0.0. Then something like speculative which cares about probs in the greedy sampling case can do:
if (params.sparams.temp == 0.0f) {
params.sparams.temp = -1.0f;
}edit: Actually, you'd need to reverse the logic for the softmax case a bit also: so 0.0 = greedy sampling, no softmax. < 0.0 = greedy sampling with softmax.
There was a problem hiding this comment.
Got it. Should be good now
* llama : add option for greedy sampling with probs * llama : add comment about llama_sample_token_greedy() missing probs * sampling : temp == 0.0 -> no probs, temp < 0.0 -> probs
* llama : add option for greedy sampling with probs * llama : add comment about llama_sample_token_greedy() missing probs * sampling : temp == 0.0 -> no probs, temp < 0.0 -> probs
On
masterwhen using temp <= 0.0, we get greedy sampling but we don't have the probs of the tokens.This PR adds an option when using temp == 0.0, to do greedy sampling but also apply softmax so we get the probs.