-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[Kernels] Migrate sampling to WebGPU #737
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
ee5a212
cf2a22c
2a2b5ad
96c8e96
fb32943
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -53,6 +53,8 @@ export class LLMChatPipeline { | |
| private fapplyPenalty: tvmjs.PackedFunc; | ||
| private fapplyLogitBias: tvmjs.PackedFunc; | ||
| private fsoftmaxWithTemperature: tvmjs.PackedFunc; | ||
| private fsampleWithTopP: tvmjs.PackedFunc; | ||
| private fargsortProbs: tvmjs.PackedFunc; | ||
|
|
||
| // Functions related to PagedKVCache | ||
| private fclearKVCaches: tvmjs.PackedFunc; | ||
|
|
@@ -142,6 +144,10 @@ export class LLMChatPipeline { | |
| private curRoundGrammarInitTotalTime = 0; | ||
| // Total time of getting next bitmask and accepting token in seconds | ||
| private curRoundGrammarPerTokenTotalTime = 0; | ||
| // Instance variables for supporting sampling on WebGPU | ||
| private sampleIndices: Int32Array; | ||
| private sampleIndicesDevice: tvmjs.Tensor; | ||
| private topPDevice: tvmjs.Tensor; | ||
|
|
||
| constructor( | ||
| tvm: tvmjs.Instance, | ||
|
|
@@ -213,6 +219,12 @@ export class LLMChatPipeline { | |
| this.fsoftmaxWithTemperature = this.tvm.detachFromCurrentScope( | ||
| this.vm.getFunction("softmax_with_temperature"), | ||
| ); | ||
| this.fsampleWithTopP = this.tvm.detachFromCurrentScope( | ||
| this.vm.getFunction("sample_with_top_p"), | ||
| ); | ||
| this.fargsortProbs = this.tvm.detachFromCurrentScope( | ||
| this.vm.getFunction("argsort_probs"), | ||
| ); | ||
| try { | ||
| this.image_embed = this.tvm.detachFromCurrentScope( | ||
| this.vm.getFunction("image_embed"), | ||
|
|
@@ -310,6 +322,25 @@ export class LLMChatPipeline { | |
|
|
||
| this.filledKVCacheLength = 0; | ||
| this.resetChat(); // especially needed for PagedKVCache as we need to call fKVCacheAddSequence | ||
|
|
||
| // Initialize WebGPU sampling related device tensors | ||
| const numSamples = 1; | ||
| const numProbs = 1; | ||
|
|
||
| this.sampleIndices = new Int32Array(numSamples); | ||
| for (let i = 0; i < numSamples; i++) { | ||
| this.sampleIndices[i] = i; | ||
| } | ||
| this.sampleIndicesDevice = this.tvm.detachFromCurrentScope( | ||
| this.tvm | ||
| .empty([numSamples], "int32", this.device) | ||
| .copyFrom(this.sampleIndices), | ||
| ); | ||
|
|
||
| this.topPDevice = this.tvm.detachFromCurrentScope( | ||
| this.tvm.empty([numProbs], "float32", this.device), | ||
| ); | ||
|
|
||
| tvm.endScope(); | ||
| } | ||
|
|
||
|
|
@@ -1271,11 +1302,13 @@ export class LLMChatPipeline { | |
| // If logprobs, need the actual distribution via softmax, otherwise directly sample from logits | ||
| const sampleBegin = performance.now(); | ||
| let sampledToken: number; | ||
| if (logprobs) { | ||
| let sampledTokensDevice: tvmjs.Tensor; | ||
| if (logprobs && _hasValue(top_p)) { | ||
| // Inplace transform logitsOnCPU to a distribution | ||
| temperature = Math.max(1e-6, temperature); // to prevent division by zero | ||
|
|
||
| const numSeqs = 1; | ||
| const numProbs = 1; | ||
|
|
||
| const temperatures = new Float32Array([temperature]); | ||
|
|
||
|
|
@@ -1284,18 +1317,52 @@ export class LLMChatPipeline { | |
| .empty([numSeqs], "float32", this.device) | ||
| .copyFrom(temperatures); | ||
|
|
||
| const probs = this.fsoftmaxWithTemperature( | ||
| logitsOnGPU.view([numSeqs, 1, this.fullVocabSize]), | ||
| let probs = this.fsoftmaxWithTemperature( | ||
| logitsOnGPU.view([numSeqs, numProbs, this.fullVocabSize]), | ||
| temperaturesDevice, | ||
| ); | ||
| this.updateLogitsOnCPU(probs); | ||
| probs = probs.view([numProbs, this.fullVocabSize]); | ||
|
|
||
| const argsortResults = this.fargsortProbs(probs); | ||
| const sortedProbsDevice = argsortResults.get(0); | ||
| const sortedIndicesDevice = argsortResults.get(1); | ||
|
|
||
| const uniformSamplesDevice = this.tvm.uniform([1], 0.0, 1.0, this.device); | ||
|
|
||
| const topPHost = new Float32Array(numProbs).fill(-1); | ||
| const topPValue = Math.max(top_p, 1e-5); | ||
| this.sampleIndices.forEach((row) => { | ||
| topPHost[row] = topPValue; | ||
| }); | ||
| this.topPDevice.copyFrom(topPHost); | ||
|
|
||
| sampledTokensDevice = this.tvm.detachFromCurrentScope( | ||
| this.fsampleWithTopP( | ||
| sortedProbsDevice, | ||
| sortedIndicesDevice, | ||
| uniformSamplesDevice, | ||
| this.sampleIndicesDevice, | ||
| this.topPDevice, | ||
| ), | ||
| ); | ||
| const sampledTokensHost = this.tvm.detachFromCurrentScope( | ||
| this.tvm | ||
| .empty([numSeqs], "int32", this.tvm.cpu()) | ||
| .copyFrom(sampledTokensDevice), | ||
| ); | ||
| if (top_logprobs! > 0) { | ||
| this.updateLogitsOnCPU(probs); | ||
| } | ||
| this.tvm.endScope(); | ||
| await this.device.sync(); | ||
|
|
||
| sampledToken = this.tvm.sampleTopPFromProb(this.logitsOnCPU!, top_p); | ||
| this.tokenLogprobArray.push( | ||
| this.getTokenLogprob(sampledToken, top_logprobs!), | ||
| ); | ||
| sampledToken = sampledTokensHost.toArray()[0]; | ||
|
|
||
| if (top_logprobs! > 0) { | ||
| this.tokenLogprobArray.push( | ||
| this.getTokenLogprob(sampledToken, top_logprobs!), | ||
| ); | ||
| } | ||
| } else { | ||
|
||
| // temperature being 0 is allowed here, equivalent to argmax | ||
| this.tvm.beginScope(); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you remind me why we add a
_hasValue(top_p)here? If a user wantslogprobsbut does not provide atop_p, it would go to the else branch, and thus not populating thetokenLogprobArray.Let's set
top_pto1.0-- the default value at the start when we are pre-processing the sampling parameters. Then we can remove this condition changeThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. Based on our discussion offline (summarized here for reference), I've removed the else branch (corresponding to sampling on CPU).
Based on the code in https://github.com/mlc-ai/mlc-llm/blob/main/cpp/serve/sampler/gpu_sampler.cc, the
logprobsflag does not affect the kernels being called. Specifically, branching only occurs based on whether or not we wish to usetop_p. In the case wheretop_pis not set, we would need to usemultinomial_sampling_func/parallel_sampling_from_prob, which is currently not possible as described above. Most models have a defaulttop_pvalue, and I settop_p = 1.0when this is not the case, which allows removal of the else branch.Note: Currently, the path corresponding to
logprobstakes place on CPU, though this can likely be replaced by a call to thesampler_take_probskernel. I will take a deeper look at this in the near future.