From 4702dc0a269859fc4a72a03823d035854e606520 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Sun, 16 Nov 2025 12:38:16 +0200 Subject: [PATCH 1/3] Introduce FP16 (HalfFloat) tensor integration across TornadoVM layers and kernels. - Adapt tensors, task graphs, and layer planners to support `HalfFloatArray`. - Replace FP32 arrays with FP16-compatible implementations in key inference states (`wrapX`). - Add new FP16-specific kernels for data transfer and activation computations. - Optimize Q8_0 quantized operations with FP16 tensor support for improved efficiency. - Update `State` classes and TornadoVM integrations to utilize FP16 data structures for key activation paths. --- set_paths | 4 +- .../gpullama3/inference/InferenceCore.java | 3 +- .../gpullama3/inference/state/LlamaState.java | 4 +- .../gpullama3/inference/state/Phi3State.java | 3 +- .../gpullama3/inference/state/Qwen2State.java | 3 +- .../gpullama3/inference/state/Qwen3State.java | 3 +- .../gpullama3/inference/state/State.java | 8 +- .../model/loader/Phi3ModelLoader.java | 2 +- .../tornadovm/TornadoVMMasterPlan.java | 3 +- .../kernels/TransformerComputeKernels.java | 38 ++++- .../TransformerComputeKernelsLayered.java | 155 ++++++++++++++++-- .../tornadovm/layers/Activation.java | 2 +- .../layers/type/fp16/LogitsFP16Layer.java | 5 +- 13 files changed, 205 insertions(+), 28 deletions(-) diff --git a/set_paths b/set_paths index fd807c5e..fe79810e 100644 --- a/set_paths +++ b/set_paths @@ -6,10 +6,10 @@ # Resolve root of this project (LLaMA3) and TornadoVM export LLAMA_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -export TORNADO_ROOT="${LLAMA_ROOT}/external/tornadovm" +#export TORNADO_ROOT="${LLAMA_ROOT}/external/tornadovm" # Set the path to TornadoVM SDK binaries -export TORNADO_SDK="${TORNADO_ROOT}/bin/sdk" +#export TORNADO_SDK="${TORNADO_ROOT}/bin/sdk" # Add TornadoVM and LLaMA bin directories to PATH export PATH="${PATH}:${TORNADO_SDK}:${LLAMA_ROOT}" diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java index 8104e561..936c706a 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java @@ -583,7 +583,8 @@ public static FloatArray forwardTornadoVM(Model model, State state, int token, i final Configuration configuration = model.configuration(); final TornadoWeights weights = (TornadoWeights) model.weights(); - MemorySegment.copy(weights.getTokenEmbeddingTable().asFloatArray().getSegment(), (long) token * configuration.dim() * Float.BYTES, state.wrapX.getSegment(), 0, configuration.dim() * Float.BYTES); + //TODO: Xxxx + MemorySegment.copy(weights.getTokenEmbeddingTable().asFloatArray().getSegment(), (long) token * configuration.dim() * Short.BYTES, state.wrapX.getSegment(), 0, configuration.dim() * Short.BYTES); return tornadoVMMasterPlan.tornadoVMForwardExecuteLayered(position); } diff --git a/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java b/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java index 9f9fdcdb..db1704ba 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java @@ -4,6 +4,7 @@ import org.beehive.gpullama3.tensor.standard.FloatTensor; import org.beehive.gpullama3.model.Configuration; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; import java.util.stream.Stream; @@ -45,8 +46,9 @@ protected StateFields createStateFields(Configuration config) { fields.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), kvDim)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new); fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), kvDim)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new); + fields.hackX = new HalfFloatArray(config.dim()); // TornadoVM wrappers with Llama/Mistral dimensions - fields.wrapX = new FloatArray(config.dim()); + fields.wrapX = new HalfFloatArray(config.dim()); fields.wrapXb = new FloatArray(config.dim()); fields.wrapXb2 = new FloatArray(config.dim()); fields.wrapHb = new FloatArray(config.hiddenDim()); diff --git a/src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java b/src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java index d29ba130..2a944445 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java @@ -5,6 +5,7 @@ import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.phi3.Phi3Configuration; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; import java.util.stream.Stream; @@ -79,7 +80,7 @@ protected StateFields createStateFields(Configuration config) { fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(contextLength, kvDim)).limit(nLayers).toArray(FloatTensor[]::new); // TornadoVM wrapper arrays for GPU acceleration - fields.wrapX = new FloatArray(dim); + fields.wrapX = new HalfFloatArray(dim); fields.wrapXb = new FloatArray(dim); fields.wrapXb2 = new FloatArray(dim); fields.wrapHb = new FloatArray(2 * hiddenDim); diff --git a/src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java b/src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java index da6d7046..8b53fd28 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java @@ -5,6 +5,7 @@ import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; import java.util.stream.Stream; @@ -40,7 +41,7 @@ protected StateFields createStateFields(Configuration configuration) { fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new); // TornadoVM wrappers with Qwen2 dimensions - fields.wrapX = new FloatArray(config.dim()); + fields.wrapX = new HalfFloatArray(config.dim()); fields.wrapXb = new FloatArray(config.dim()); fields.wrapXb2 = new FloatArray(config.dim()); fields.wrapHb = new FloatArray(config.hiddenDim()); diff --git a/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java b/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java index d6a6d087..0bd5f86e 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java @@ -5,6 +5,7 @@ import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; import java.util.stream.Stream; @@ -65,7 +66,7 @@ protected StateFields createStateFields(Configuration configuration) { fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new); // TornadoVM wrappers with Qwen3-specific sizes - fields.wrapX = new FloatArray(config.dim()); + fields.wrapX = new HalfFloatArray(config.dim()); fields.wrapXb = new FloatArray(nEmbdHeadK * config.numberOfHeads()); fields.wrapXb2 = new FloatArray(config.dim()); fields.wrapHb = new FloatArray(config.hiddenDim()); diff --git a/src/main/java/org/beehive/gpullama3/inference/state/State.java b/src/main/java/org/beehive/gpullama3/inference/state/State.java index 01d94936..35b31f07 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/State.java @@ -3,6 +3,7 @@ import org.beehive.gpullama3.tensor.standard.FloatTensor; import org.beehive.gpullama3.model.Configuration; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; /** @@ -48,7 +49,7 @@ public abstract class State { public final FloatArray wrapXb2; // FloatArray wrapper for xb2, another residual buffer to aid in computations with TornadoVM. public final FloatArray wrapHb; // FloatArray wrapper for hb (hidden dimension buffer for FFN), optimized for TornadoVM. public final FloatArray wrapHb2; // FloatArray wrapper for hb2, additional hidden buffer for FFN, for compatibility with TornadoVM. - public final FloatArray wrapX; // FloatArray wrapper for the current activation tensor, optimized for TornadoVM. + public final HalfFloatArray wrapX; // FloatArray wrapper for the current activation tensor, optimized for TornadoVM. public final FloatArray wrapQ; // FloatArray wrapper for the query tensor, optimized for TornadoVM. public final FloatArray wrapK; // FloatArray wrapper for the key tensor, optimized for TornadoVM. public final FloatArray wrapV; // FloatArray wrapper for the value tensor, optimized for TornadoVM. @@ -64,6 +65,7 @@ public abstract class State { public FloatArray tempLogits; // Temporary buffer for logits calculations, size adjusted for local workgroup size. public int latestToken; // Keeps track of the most recent token processed by the model. Useful for stateful or autoregressive models. + public HalfFloatArray hackX; /** last index in previous block */ protected State(Configuration config, int batchsize) { @@ -108,6 +110,7 @@ protected State(Configuration config, int batchsize) { this.temp = fields.temp; this.tempFFN = fields.tempFFN; this.tempLogits = fields.tempLogits; + this.hackX = fields.wrapX; } // Abstract method - subclasses implement their specific allocation logic and sizes @@ -117,10 +120,11 @@ protected State(Configuration config, int batchsize) { protected static class StateFields { public FloatTensor x, xb, xb2, hb, hb2, q, k, v, att, logits; public FloatTensor[] keyCache, valueCache; - public FloatArray wrapX, wrapXb, wrapXb2, wrapHb, wrapHb2, wrapLogits; + public FloatArray wrapXb, wrapXb2, wrapHb, wrapHb2, wrapLogits; public FloatArray wrapQ, wrapK, wrapV, wrapAtt, wrapKeyCache, wrapValueCache; public IntArray positionHolder; public FloatArray temp, tempFFN, tempLogits; + public HalfFloatArray wrapX, hackX; } @Override diff --git a/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java index 745367c7..61606e7b 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java @@ -130,7 +130,7 @@ protected Weights createTornadoVMWeights(Map tensorEntr // Load all tensors uniformly as TornadoTensor hierarchy return new Phi3TornadoWeights( - loadTornadoTensorAsFP32(tokenEmbeddings), + loadTornadoTensor(tokenEmbeddings), loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_qkv.weight")), loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")), diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java index 293d2c0c..8b17b05e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java @@ -7,6 +7,7 @@ import org.beehive.gpullama3.tornadovm.layerplanner.base.QuantizationPlannerFactory; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TornadoExecutionPlan; +import uk.ac.manchester.tornado.api.types.HalfFloat; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; public class TornadoVMMasterPlan { @@ -179,7 +180,7 @@ private int getFinalLogitsGraphIndex() { /// Execute the forward pass of the LLaMA transformer model using TornadoVM acceleration just once to copy the data into the read-only data layer. public void forceCopyInReadOnlyDataLayered() { // Execute all TornadoVM graphs - state.wrapX.init(0.0f); + state.wrapX.init(new HalfFloat(0)); state.positionHolder.init(0); // Execute activation update graph diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java index 7f69e496..c3a06c0b 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java @@ -1,8 +1,11 @@ package org.beehive.gpullama3.tornadovm.kernels; import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.annotations.Parallel; import uk.ac.manchester.tornado.api.math.TornadoMath; +import uk.ac.manchester.tornado.api.types.HalfFloat; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; public class TransformerComputeKernels { @@ -12,13 +15,20 @@ public class TransformerComputeKernels { public TransformerComputeKernels() { } - public static void emptyTaskToForceCopyIn(FloatArray buffer) { + public static void copyInEmbeddingActivation(FloatArray buffer) { float dummy = buffer.get(0); if (dummy > Float.MAX_VALUE) { buffer.set(0, dummy); } } + public static void copyInEmbeddingActivationFP16(HalfFloatArray buffer) { + float dummy = buffer.get(0).getFloat32(); + if (dummy > Float.MAX_VALUE) { + buffer.set(0, new HalfFloat(dummy)); + } + } + /** * Performs RMS (Root Mean Square) normalization using parallel reduction. * This is a two-phase reduction: first within work groups, then across work groups. @@ -33,7 +43,7 @@ public static void emptyTaskToForceCopyIn(FloatArray buffer) { * @param ermsNorm Epsilon value for numerical stability (epsilon * epsilon) * @param localMemSize Size of local memory allocation (work group size) */ - public static void reductionOneBlockWithLayer(KernelContext context, FloatArray output, FloatArray x, int size, float ermsNorm, int localMemSize) { + public static void reductionOneBlockWithLayer(KernelContext context, FloatArray output, HalfFloatArray x, int size, float ermsNorm, int localMemSize) { int gid = context.globalIdx; int lid = context.localIdx; int groupId = context.groupIdx; @@ -44,7 +54,7 @@ public static void reductionOneBlockWithLayer(KernelContext context, FloatArray // Load input value and compute square if (gid < size) { - localX[lid] = x.get(gid); + localX[lid] = x.get(gid).getFloat32(); localX[lid] = localX[lid] * localX[lid]; } else { localX[lid] = 0.0f; @@ -87,11 +97,29 @@ public static void reductionOneBlockWithLayer(KernelContext context, FloatArray * @param output Array for normalized output * @param weights Weight values to normalize * @param temp Temporary array containing a normalization factor at index 0 + * + * */ - public static void reductionOneBlock2WithLogits(KernelContext context, FloatArray output, FloatArray weights, FloatArray temp) { + + public static void copyHack(HalfFloatArray x, HalfFloatArray hackX) { + for (@Parallel int i = 0; i < x.getSize(); i++) { + hackX.set(i, x.get(i)); + } + } + + public static void reductionOneBlock2WithLogits(KernelContext context, HalfFloatArray output, FloatArray weights, FloatArray temp) { + int gid = context.globalIdx; + float ss = temp.get(0); + output.set(gid, new HalfFloat((weights.get(gid) * (ss * output.get(gid).getFloat32())))); + } + + + public static void reductionOneBlock2WithLogits2(KernelContext context, HalfFloatArray input, HalfFloatArray output, FloatArray weights, FloatArray temp) { int gid = context.globalIdx; float ss = temp.get(0); - output.set(gid, weights.get(gid) * (ss * output.get(gid))); + float inter = ss * input.get(gid).getHalfFloatValue(); + HalfFloat x = new HalfFloat((weights.get(gid) * inter)); + output.set(gid, x); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java index dfe4ef27..5e2e0bb3 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java @@ -3,6 +3,7 @@ import uk.ac.manchester.tornado.api.KernelContext; import uk.ac.manchester.tornado.api.annotations.Parallel; import uk.ac.manchester.tornado.api.math.TornadoMath; +import uk.ac.manchester.tornado.api.types.HalfFloat; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.Int8Array; @@ -35,7 +36,7 @@ public TransformerComputeKernelsLayered() { * @param localMemSize * Size of local memory allocation (must match work group size) */ - public static void reductionOneBlockWithLayer(KernelContext context, FloatArray output, FloatArray x, int size, float ermsNorm, int localMemSize) { + public static void reductionOneBlockWithLayer(KernelContext context, FloatArray output, HalfFloatArray x, int size, float ermsNorm, int localMemSize) { int gid = context.globalIdx; int lid = context.localIdx; int groupId = context.groupIdx; @@ -46,7 +47,7 @@ public static void reductionOneBlockWithLayer(KernelContext context, FloatArray // Load input value and compute square if (gid < size) { - localX[lid] = x.get(gid); + localX[lid] = x.get(gid).getFloat32(); localX[lid] = localX[lid] * localX[lid]; } else { localX[lid] = 0.0f; @@ -97,11 +98,11 @@ public static void reductionOneBlockWithLayer(KernelContext context, FloatArray * @param temp * Temporary array containing normalization factor at index 0 */ - public static void reductionOneBlock2WithLayer(KernelContext context, FloatArray output, FloatArray x, FloatArray weights, FloatArray temp) { + public static void reductionOneBlock2WithLayer(KernelContext context, FloatArray output, HalfFloatArray x, FloatArray weights, FloatArray temp) { int gid = context.globalIdx; float ss = temp.get(0); - output.set(gid, weights.get(gid) * (ss * x.get(gid))); + output.set(gid, weights.get(gid) * (ss * x.get(gid).getFloat32())); } /** @@ -690,6 +691,32 @@ public static void matrixVectorGeneric( hb.set(rowId, sum); } } + + public static void matrixVectorGeneric( + KernelContext context, + HalfFloatArray x, + FloatArray hb, // output + HalfFloatArray w, + int dim1, // inner loop + int dim0, // outer loop + int localWorkGroupSize) { + // One row per workgroup (not per thread) + int rowId = context.groupIdx; + int localId = context.localIdx; + int localSize = localWorkGroupSize; + + // Early exit if this workgroup is beyond our output dimension + if (rowId >= dim0) { + return; + } + float sum = matrixVectorRowMajorOptimized(context, localSize, x, w, dim1); + + // Thread 0 in each workgroup writes the final result + if (localId == 0) { + hb.set(rowId, sum); + } + } + // @formatter:on /** @@ -712,7 +739,7 @@ public static void matrixVectorGeneric( * @param localWorkGroupSize * Work group size */ - public static void matrixVectorGenericWithResidual(KernelContext context, FloatArray x, FloatArray hb, HalfFloatArray w, int n, int d, int localWorkGroupSize) { + public static void matrixVectorGenericWithResidual(KernelContext context, FloatArray x, HalfFloatArray hb, HalfFloatArray w, int n, int d, int localWorkGroupSize) { // One row per workgroup (not per thread) int rowId = context.groupIdx; int localId = context.localIdx; @@ -727,8 +754,8 @@ public static void matrixVectorGenericWithResidual(KernelContext context, FloatA // Thread 0 in each workgroup writes the final result if (localId == 0) { - float result = hb.get(rowId) + sum; - hb.set(rowId, result); + float result = hb.get(rowId).getFloat32() + sum; + hb.set(rowId, new HalfFloat(result)); } } @@ -847,6 +874,38 @@ public static float matrixVectorRowMajorOptimized(KernelContext context, int loc return localSum[0]; } + + public static float matrixVectorRowMajorOptimized(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) { + int rowId = context.groupIdx; + int localId = context.localIdx; + + // Allocate local memory for reduction + float[] localSum = context.allocateFloatLocalArray(localSize); + + int rowOffset = rowId * n; + + // Each thread calculates partial dot product + float partialSum = 0.0f; + for (int j = localId; j < n; j += localSize) { + int matrixIdx = rowOffset + j; + partialSum += w.get(matrixIdx).getFloat32() * x.get(j).getFloat32(); + } + + // Store partial sum in local memory + localSum[localId] = partialSum; + context.localBarrier(); + + // Parallel reduction within workgroup + for (int stride = localSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + return localSum[0]; + } + public static float matrixVectorRowMajorOptimized(KernelContext context, int localSize, FloatArray x, HalfFloatArray w, int n) { int rowId = context.groupIdx; int localId = context.localIdx; @@ -959,6 +1018,25 @@ public static void matrixVectorGeneric(KernelContext context, FloatArray x, Floa } } + public static void matrixVectorGeneric(KernelContext context, HalfFloatArray x, FloatArray output, Int8Array weightsQ, HalfFloatArray weightScales, int dim1, int dim0, int localWorkGroupSize) { + + // One row per workgroup + int rowId = context.groupIdx; + int localId = context.localIdx; + + // Early exit if this workgroup is beyond output dimension + if (rowId >= dim0) { + return; + } + + float sum = matrixVectorRowMajorOptimizedQ8_0(context, localWorkGroupSize, x, weightsQ, weightScales, dim1); + + // Thread 0 writes the result + if (localId == 0) { + output.set(rowId, sum); + } + } + /** * Helper method to compute dot product for a single row with Q8_0 quantized weights. Uses 4-way unrolling for better performance. */ @@ -1015,7 +1093,64 @@ public static float matrixVectorRowMajorOptimizedQ8_0(KernelContext context, int return localSums[0]; } - public static void matrixVectorGenericWithResidual(KernelContext context, FloatArray x, FloatArray hb, Int8Array w_quants, HalfFloatArray w_scales, int n, int d, int localWorkGroupSize) { + + /** + * Helper method to compute dot product for a single row with Q8_0 quantized weights. Uses 4-way unrolling for better performance. + */ + public static float matrixVectorRowMajorOptimizedQ8_0(KernelContext context, int localSize, HalfFloatArray x, Int8Array weightsQ, HalfFloatArray weightScales, int n) { + int rowId = context.groupIdx; + int localId = context.localIdx; + int blockSize = 32; + + // Allocate local memory for reduction + float[] localSums = context.allocateFloatLocalArray(localSize); + + int rowOffset = rowId * n; + int scalesRowOffset = rowId * (n / blockSize); + + // 4-way unrolling + float partialSum1 = 0.0f; + float partialSum2 = 0.0f; + float partialSum3 = 0.0f; + float partialSum4 = 0.0f; + + // Main loop - process 4 elements at a time + for (int j = localId * 4; j < n - 3; j += localSize * 4) { + int blockIdx = j / blockSize; + float scale = weightScales.get(scalesRowOffset + blockIdx).getFloat32(); + + // Dequantize and multiply + partialSum1 += ((float) weightsQ.get(rowOffset + j) * scale) * x.get(j).getHalfFloatValue(); + partialSum2 += ((float) weightsQ.get(rowOffset + j + 1) * scale) * x.get(j + 1).getHalfFloatValue(); + partialSum3 += ((float) weightsQ.get(rowOffset + j + 2) * scale) * x.get(j + 2).getHalfFloatValue(); + partialSum4 += ((float) weightsQ.get(rowOffset + j + 3) * scale) * x.get(j + 3).getHalfFloatValue(); + } + + float partialSum = partialSum1 + partialSum2 + partialSum3 + partialSum4; + + // Handle remaining elements + for (int j = ((n / 4) * 4) + localId; j < n; j += localSize) { + int blockIdx = j / blockSize; + float scale = weightScales.get(scalesRowOffset + blockIdx).getFloat32(); + partialSum += ((float) weightsQ.get(rowOffset + j) * scale) * x.get(j).getHalfFloatValue(); + } + + // Store partial sum + localSums[localId] = partialSum; + context.localBarrier(); + + // Parallel reduction + for (int stride = localSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSums[localId] += localSums[localId + stride]; + } + context.localBarrier(); + } + + return localSums[0]; + } + + public static void matrixVectorGenericWithResidual(KernelContext context, FloatArray x, HalfFloatArray hb, Int8Array w_quants, HalfFloatArray w_scales, int n, int d, int localWorkGroupSize) { // One row per workgroup (not per thread) int rowId = context.groupIdx; int localId = context.localIdx; @@ -1030,8 +1165,8 @@ public static void matrixVectorGenericWithResidual(KernelContext context, FloatA // Thread 0 in each workgroup writes the final result if (localId == 0) { - float result = hb.get(rowId) + sum; - hb.set(rowId, result); + float result = hb.get(rowId).getFloat32() + sum; + hb.set(rowId, new HalfFloat(result)); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java index 16783829..67ca726a 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java @@ -19,7 +19,7 @@ public Activation(String taskGraphHandle, State state, Weights weights, Configur // formatter:off this.activationUpdate = new TaskGraph(taskGraphHandle).transferToDevice(DataTransferMode.EVERY_EXECUTION, state.wrapX) - .task("updateX", TransformerComputeKernels::emptyTaskToForceCopyIn, state.wrapX).persistOnDevice(state.wrapX); + .task("updateX", TransformerComputeKernels::copyInEmbeddingActivationFP16, state.wrapX).persistOnDevice(state.wrapX); // formatter:on } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java index a674c1c5..c4743c26 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java @@ -16,6 +16,7 @@ import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.WorkerGrid1D; import uk.ac.manchester.tornado.api.enums.DataTransferMode; +import uk.ac.manchester.tornado.api.types.HalfFloat; public class LogitsFP16Layer extends AbstractLayer { @@ -29,6 +30,7 @@ public LogitsFP16Layer(String name, State state, Weights weights, Configuration super(name, state, weights, config); this.lastTaskGraphID = lastTaskGraphID; state.tempLogits.init(0.0f); + state.hackX.init(new HalfFloat(0)); var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsFP16Layer", "TornadoTensor"); this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, config); this.schedulerType = schedulerType; @@ -45,7 +47,8 @@ private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration con if (schedulerType == SchedulerType.NON_NVIDIA) { logits.task("reductionFinalNormalizationLogits", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempLogits, config.dim(), config.rmsNormEps()); } - logits.task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX, weights.rms_final_weight_as_floatArray.asFloatArray(), state.tempLogits) + logits.task("hackCopy", TransformerComputeKernels::copyHack, state.wrapX, state.hackX); + logits.task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits2, context, state.hackX, state.wrapX, weights.rms_final_weight_as_floatArray.asFloatArray(), state.tempLogits) .task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapX, state.wrapLogits, weights.wclsByteArray.asHalfFloatArray(), config.dim(), config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); From 23feddece8a839a103b4469a7a1152f7ba674161 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Sun, 16 Nov 2025 20:52:52 +0200 Subject: [PATCH 2/3] Refactor `reductionOneBlock2WithLogits` for improved readability and maintainability by adding step-by-step comments and simplifying scaled output computation. --- .../kernels/TransformerComputeKernels.java | 27 ++++++++++++++++++- .../layers/type/fp16/LogitsFP16Layer.java | 4 +-- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java index c3a06c0b..cd5f10cb 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java @@ -108,9 +108,34 @@ public static void copyHack(HalfFloatArray x, HalfFloatArray hackX) { } public static void reductionOneBlock2WithLogits(KernelContext context, HalfFloatArray output, FloatArray weights, FloatArray temp) { +// int gid = context.globalIdx; +// float ss = temp.get(0); +// output.set(gid, new HalfFloat((weights.get(gid) * (ss * output.get(gid).getFloat32())))); + + int gid = context.globalIdx; + + // Step 1: read normalization scalar float ss = temp.get(0); - output.set(gid, new HalfFloat((weights.get(gid) * (ss * output.get(gid).getFloat32())))); + + // Step 2: read current output value as float + HalfFloat hf = output.get(gid); + float out_f = hf.getFloat32(); + + // Step 3: read weight +// float w = weights.get(gid); + + // Step 4: compute scaled output + float scaled = ss * out_f; + + // Step 5: multiply by weight + float prod = weights.get(gid) * scaled; + + // Step 6: create HalfFloat result + HalfFloat result = new HalfFloat(prod); + + // Step 7: write back + output.set(gid, result); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java index c4743c26..36eacfd5 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java @@ -47,8 +47,8 @@ private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration con if (schedulerType == SchedulerType.NON_NVIDIA) { logits.task("reductionFinalNormalizationLogits", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempLogits, config.dim(), config.rmsNormEps()); } - logits.task("hackCopy", TransformerComputeKernels::copyHack, state.wrapX, state.hackX); - logits.task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits2, context, state.hackX, state.wrapX, weights.rms_final_weight_as_floatArray.asFloatArray(), state.tempLogits) +// logits.task("hackCopy", TransformerComputeKernels::copyHack, state.wrapX, state.hackX); + logits.task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX, weights.rms_final_weight_as_floatArray.asFloatArray(), state.tempLogits) .task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapX, state.wrapLogits, weights.wclsByteArray.asHalfFloatArray(), config.dim(), config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); From dbb2c1990ffc9298453bfed99e6da74870abe153 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Sun, 16 Nov 2025 21:22:32 +0200 Subject: [PATCH 3/3] Refactor tensor loading and inference flow to improve FP16 integration and remove obsolete hacky methods. - Replace `loadTornadoTensorAsFP32` with `loadTornadoTensor` for cleaner tensor loading. - Add logging for tensor loading details in `loadTornadoTensor`. - Remove `copyHack` method and associated comments from compute kernels and logits layer. - Update `wrapX` state in inference to utilize `asHalfFloatArray` for FP16 support. - Cleanup redundant initialization and tasks in FP16 logits layer. --- .../org/beehive/gpullama3/inference/InferenceCore.java | 8 ++++++-- .../beehive/gpullama3/model/loader/LlamaModelLoader.java | 2 +- .../org/beehive/gpullama3/model/loader/ModelLoader.java | 2 ++ .../tornadovm/kernels/TransformerComputeKernels.java | 5 ----- .../tornadovm/layers/type/fp16/LogitsFP16Layer.java | 2 -- 5 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java index 936c706a..480fdcfb 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java @@ -15,7 +15,9 @@ import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; +import uk.ac.manchester.tornado.api.types.HalfFloat; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import java.lang.foreign.MemorySegment; @@ -583,8 +585,10 @@ public static FloatArray forwardTornadoVM(Model model, State state, int token, i final Configuration configuration = model.configuration(); final TornadoWeights weights = (TornadoWeights) model.weights(); - //TODO: Xxxx - MemorySegment.copy(weights.getTokenEmbeddingTable().asFloatArray().getSegment(), (long) token * configuration.dim() * Short.BYTES, state.wrapX.getSegment(), 0, configuration.dim() * Short.BYTES); +// MemorySegment.copy(weights.getTokenEmbeddingTable().asFloatArray().getSegment(), (long) token * configuration.dim() * Float.BYTES, state.wrapX.getSegment(), 0, configuration.dim() * Float.BYTES); + +// System.out.println("token emdfing table type " + weights.getTokenEmbeddingTable().getClass().getName()); + MemorySegment.copy(weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment(), (long) token * configuration.dim() * 2, state.wrapX.getSegment(), 0, configuration.dim() * 2); return tornadoVMMasterPlan.tornadoVMForwardExecuteLayered(position); } diff --git a/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java index aa3a3894..4605e56c 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java @@ -116,7 +116,7 @@ protected Weights createTornadoVMWeights(Map tensorEntr // Load all tensors uniformly as TornadoTensor hierarchy return new LlamaTornadoWeights( - loadTornadoTensorAsFP32(tokenEmbeddings), + loadTornadoTensor(tokenEmbeddings), loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), diff --git a/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java index ce8e6ca9..8c079769 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java @@ -126,6 +126,8 @@ public static FloatTensor[] loadArrayOfTensors(int size, IntFunction " + entry.shape() + " and memory segment " + entry.memorySegment()); +// return switch (ggmlType) { case F32 -> new FP32TornadoTensor(size, entry.memorySegment()); case F16 -> new FP16TornadoTensor(size, entry.memorySegment()); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java index cd5f10cb..d39108b8 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java @@ -101,11 +101,6 @@ public static void reductionOneBlockWithLayer(KernelContext context, FloatArray * */ - public static void copyHack(HalfFloatArray x, HalfFloatArray hackX) { - for (@Parallel int i = 0; i < x.getSize(); i++) { - hackX.set(i, x.get(i)); - } - } public static void reductionOneBlock2WithLogits(KernelContext context, HalfFloatArray output, FloatArray weights, FloatArray temp) { // int gid = context.globalIdx; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java index 36eacfd5..ff0c26c7 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java @@ -30,7 +30,6 @@ public LogitsFP16Layer(String name, State state, Weights weights, Configuration super(name, state, weights, config); this.lastTaskGraphID = lastTaskGraphID; state.tempLogits.init(0.0f); - state.hackX.init(new HalfFloat(0)); var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsFP16Layer", "TornadoTensor"); this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, config); this.schedulerType = schedulerType; @@ -47,7 +46,6 @@ private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration con if (schedulerType == SchedulerType.NON_NVIDIA) { logits.task("reductionFinalNormalizationLogits", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempLogits, config.dim(), config.rmsNormEps()); } -// logits.task("hackCopy", TransformerComputeKernels::copyHack, state.wrapX, state.hackX); logits.task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX, weights.rms_final_weight_as_floatArray.asFloatArray(), state.tempLogits) .task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapX, state.wrapLogits, weights.wclsByteArray.asHalfFloatArray(), config.dim(), config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS);