diff --git a/set_paths b/set_paths index fd807c5e..fe79810e 100644 --- a/set_paths +++ b/set_paths @@ -6,10 +6,10 @@ # Resolve root of this project (LLaMA3) and TornadoVM export LLAMA_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -export TORNADO_ROOT="${LLAMA_ROOT}/external/tornadovm" +#export TORNADO_ROOT="${LLAMA_ROOT}/external/tornadovm" # Set the path to TornadoVM SDK binaries -export TORNADO_SDK="${TORNADO_ROOT}/bin/sdk" +#export TORNADO_SDK="${TORNADO_ROOT}/bin/sdk" # Add TornadoVM and LLaMA bin directories to PATH export PATH="${PATH}:${TORNADO_SDK}:${LLAMA_ROOT}" diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java index 8104e561..480fdcfb 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java @@ -15,7 +15,9 @@ import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; +import uk.ac.manchester.tornado.api.types.HalfFloat; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import java.lang.foreign.MemorySegment; @@ -583,7 +585,10 @@ public static FloatArray forwardTornadoVM(Model model, State state, int token, i final Configuration configuration = model.configuration(); final TornadoWeights weights = (TornadoWeights) model.weights(); - MemorySegment.copy(weights.getTokenEmbeddingTable().asFloatArray().getSegment(), (long) token * configuration.dim() * Float.BYTES, state.wrapX.getSegment(), 0, configuration.dim() * Float.BYTES); +// MemorySegment.copy(weights.getTokenEmbeddingTable().asFloatArray().getSegment(), (long) token * configuration.dim() * Float.BYTES, state.wrapX.getSegment(), 0, configuration.dim() * Float.BYTES); + +// System.out.println("token emdfing table type " + weights.getTokenEmbeddingTable().getClass().getName()); + MemorySegment.copy(weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment(), (long) token * configuration.dim() * 2, state.wrapX.getSegment(), 0, configuration.dim() * 2); return tornadoVMMasterPlan.tornadoVMForwardExecuteLayered(position); } diff --git a/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java b/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java index 9f9fdcdb..db1704ba 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java @@ -4,6 +4,7 @@ import org.beehive.gpullama3.tensor.standard.FloatTensor; import org.beehive.gpullama3.model.Configuration; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; import java.util.stream.Stream; @@ -45,8 +46,9 @@ protected StateFields createStateFields(Configuration config) { fields.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), kvDim)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new); fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), kvDim)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new); + fields.hackX = new HalfFloatArray(config.dim()); // TornadoVM wrappers with Llama/Mistral dimensions - fields.wrapX = new FloatArray(config.dim()); + fields.wrapX = new HalfFloatArray(config.dim()); fields.wrapXb = new FloatArray(config.dim()); fields.wrapXb2 = new FloatArray(config.dim()); fields.wrapHb = new FloatArray(config.hiddenDim()); diff --git a/src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java b/src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java index d29ba130..2a944445 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java @@ -5,6 +5,7 @@ import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.phi3.Phi3Configuration; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; import java.util.stream.Stream; @@ -79,7 +80,7 @@ protected StateFields createStateFields(Configuration config) { fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(contextLength, kvDim)).limit(nLayers).toArray(FloatTensor[]::new); // TornadoVM wrapper arrays for GPU acceleration - fields.wrapX = new FloatArray(dim); + fields.wrapX = new HalfFloatArray(dim); fields.wrapXb = new FloatArray(dim); fields.wrapXb2 = new FloatArray(dim); fields.wrapHb = new FloatArray(2 * hiddenDim); diff --git a/src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java b/src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java index da6d7046..8b53fd28 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java @@ -5,6 +5,7 @@ import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; import java.util.stream.Stream; @@ -40,7 +41,7 @@ protected StateFields createStateFields(Configuration configuration) { fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new); // TornadoVM wrappers with Qwen2 dimensions - fields.wrapX = new FloatArray(config.dim()); + fields.wrapX = new HalfFloatArray(config.dim()); fields.wrapXb = new FloatArray(config.dim()); fields.wrapXb2 = new FloatArray(config.dim()); fields.wrapHb = new FloatArray(config.hiddenDim()); diff --git a/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java b/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java index d6a6d087..0bd5f86e 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java @@ -5,6 +5,7 @@ import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; import java.util.stream.Stream; @@ -65,7 +66,7 @@ protected StateFields createStateFields(Configuration configuration) { fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new); // TornadoVM wrappers with Qwen3-specific sizes - fields.wrapX = new FloatArray(config.dim()); + fields.wrapX = new HalfFloatArray(config.dim()); fields.wrapXb = new FloatArray(nEmbdHeadK * config.numberOfHeads()); fields.wrapXb2 = new FloatArray(config.dim()); fields.wrapHb = new FloatArray(config.hiddenDim()); diff --git a/src/main/java/org/beehive/gpullama3/inference/state/State.java b/src/main/java/org/beehive/gpullama3/inference/state/State.java index 01d94936..35b31f07 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/State.java @@ -3,6 +3,7 @@ import org.beehive.gpullama3.tensor.standard.FloatTensor; import org.beehive.gpullama3.model.Configuration; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; /** @@ -48,7 +49,7 @@ public abstract class State { public final FloatArray wrapXb2; // FloatArray wrapper for xb2, another residual buffer to aid in computations with TornadoVM. public final FloatArray wrapHb; // FloatArray wrapper for hb (hidden dimension buffer for FFN), optimized for TornadoVM. public final FloatArray wrapHb2; // FloatArray wrapper for hb2, additional hidden buffer for FFN, for compatibility with TornadoVM. - public final FloatArray wrapX; // FloatArray wrapper for the current activation tensor, optimized for TornadoVM. + public final HalfFloatArray wrapX; // FloatArray wrapper for the current activation tensor, optimized for TornadoVM. public final FloatArray wrapQ; // FloatArray wrapper for the query tensor, optimized for TornadoVM. public final FloatArray wrapK; // FloatArray wrapper for the key tensor, optimized for TornadoVM. public final FloatArray wrapV; // FloatArray wrapper for the value tensor, optimized for TornadoVM. @@ -64,6 +65,7 @@ public abstract class State { public FloatArray tempLogits; // Temporary buffer for logits calculations, size adjusted for local workgroup size. public int latestToken; // Keeps track of the most recent token processed by the model. Useful for stateful or autoregressive models. + public HalfFloatArray hackX; /** last index in previous block */ protected State(Configuration config, int batchsize) { @@ -108,6 +110,7 @@ protected State(Configuration config, int batchsize) { this.temp = fields.temp; this.tempFFN = fields.tempFFN; this.tempLogits = fields.tempLogits; + this.hackX = fields.wrapX; } // Abstract method - subclasses implement their specific allocation logic and sizes @@ -117,10 +120,11 @@ protected State(Configuration config, int batchsize) { protected static class StateFields { public FloatTensor x, xb, xb2, hb, hb2, q, k, v, att, logits; public FloatTensor[] keyCache, valueCache; - public FloatArray wrapX, wrapXb, wrapXb2, wrapHb, wrapHb2, wrapLogits; + public FloatArray wrapXb, wrapXb2, wrapHb, wrapHb2, wrapLogits; public FloatArray wrapQ, wrapK, wrapV, wrapAtt, wrapKeyCache, wrapValueCache; public IntArray positionHolder; public FloatArray temp, tempFFN, tempLogits; + public HalfFloatArray wrapX, hackX; } @Override diff --git a/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java index aa3a3894..4605e56c 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java @@ -116,7 +116,7 @@ protected Weights createTornadoVMWeights(Map tensorEntr // Load all tensors uniformly as TornadoTensor hierarchy return new LlamaTornadoWeights( - loadTornadoTensorAsFP32(tokenEmbeddings), + loadTornadoTensor(tokenEmbeddings), loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), diff --git a/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java index ce8e6ca9..8c079769 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java @@ -126,6 +126,8 @@ public static FloatTensor[] loadArrayOfTensors(int size, IntFunction " + entry.shape() + " and memory segment " + entry.memorySegment()); +// return switch (ggmlType) { case F32 -> new FP32TornadoTensor(size, entry.memorySegment()); case F16 -> new FP16TornadoTensor(size, entry.memorySegment()); diff --git a/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java index 745367c7..61606e7b 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java @@ -130,7 +130,7 @@ protected Weights createTornadoVMWeights(Map tensorEntr // Load all tensors uniformly as TornadoTensor hierarchy return new Phi3TornadoWeights( - loadTornadoTensorAsFP32(tokenEmbeddings), + loadTornadoTensor(tokenEmbeddings), loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_qkv.weight")), loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")), diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java index 293d2c0c..8b17b05e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java @@ -7,6 +7,7 @@ import org.beehive.gpullama3.tornadovm.layerplanner.base.QuantizationPlannerFactory; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TornadoExecutionPlan; +import uk.ac.manchester.tornado.api.types.HalfFloat; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; public class TornadoVMMasterPlan { @@ -179,7 +180,7 @@ private int getFinalLogitsGraphIndex() { /// Execute the forward pass of the LLaMA transformer model using TornadoVM acceleration just once to copy the data into the read-only data layer. public void forceCopyInReadOnlyDataLayered() { // Execute all TornadoVM graphs - state.wrapX.init(0.0f); + state.wrapX.init(new HalfFloat(0)); state.positionHolder.init(0); // Execute activation update graph diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java index 7f69e496..d39108b8 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java @@ -1,8 +1,11 @@ package org.beehive.gpullama3.tornadovm.kernels; import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.annotations.Parallel; import uk.ac.manchester.tornado.api.math.TornadoMath; +import uk.ac.manchester.tornado.api.types.HalfFloat; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; public class TransformerComputeKernels { @@ -12,13 +15,20 @@ public class TransformerComputeKernels { public TransformerComputeKernels() { } - public static void emptyTaskToForceCopyIn(FloatArray buffer) { + public static void copyInEmbeddingActivation(FloatArray buffer) { float dummy = buffer.get(0); if (dummy > Float.MAX_VALUE) { buffer.set(0, dummy); } } + public static void copyInEmbeddingActivationFP16(HalfFloatArray buffer) { + float dummy = buffer.get(0).getFloat32(); + if (dummy > Float.MAX_VALUE) { + buffer.set(0, new HalfFloat(dummy)); + } + } + /** * Performs RMS (Root Mean Square) normalization using parallel reduction. * This is a two-phase reduction: first within work groups, then across work groups. @@ -33,7 +43,7 @@ public static void emptyTaskToForceCopyIn(FloatArray buffer) { * @param ermsNorm Epsilon value for numerical stability (epsilon * epsilon) * @param localMemSize Size of local memory allocation (work group size) */ - public static void reductionOneBlockWithLayer(KernelContext context, FloatArray output, FloatArray x, int size, float ermsNorm, int localMemSize) { + public static void reductionOneBlockWithLayer(KernelContext context, FloatArray output, HalfFloatArray x, int size, float ermsNorm, int localMemSize) { int gid = context.globalIdx; int lid = context.localIdx; int groupId = context.groupIdx; @@ -44,7 +54,7 @@ public static void reductionOneBlockWithLayer(KernelContext context, FloatArray // Load input value and compute square if (gid < size) { - localX[lid] = x.get(gid); + localX[lid] = x.get(gid).getFloat32(); localX[lid] = localX[lid] * localX[lid]; } else { localX[lid] = 0.0f; @@ -87,11 +97,49 @@ public static void reductionOneBlockWithLayer(KernelContext context, FloatArray * @param output Array for normalized output * @param weights Weight values to normalize * @param temp Temporary array containing a normalization factor at index 0 + * + * */ - public static void reductionOneBlock2WithLogits(KernelContext context, FloatArray output, FloatArray weights, FloatArray temp) { + + + public static void reductionOneBlock2WithLogits(KernelContext context, HalfFloatArray output, FloatArray weights, FloatArray temp) { +// int gid = context.globalIdx; +// float ss = temp.get(0); +// output.set(gid, new HalfFloat((weights.get(gid) * (ss * output.get(gid).getFloat32())))); + + + int gid = context.globalIdx; + + // Step 1: read normalization scalar + float ss = temp.get(0); + + // Step 2: read current output value as float + HalfFloat hf = output.get(gid); + float out_f = hf.getFloat32(); + + // Step 3: read weight +// float w = weights.get(gid); + + // Step 4: compute scaled output + float scaled = ss * out_f; + + // Step 5: multiply by weight + float prod = weights.get(gid) * scaled; + + // Step 6: create HalfFloat result + HalfFloat result = new HalfFloat(prod); + + // Step 7: write back + output.set(gid, result); + } + + + public static void reductionOneBlock2WithLogits2(KernelContext context, HalfFloatArray input, HalfFloatArray output, FloatArray weights, FloatArray temp) { int gid = context.globalIdx; float ss = temp.get(0); - output.set(gid, weights.get(gid) * (ss * output.get(gid))); + float inter = ss * input.get(gid).getHalfFloatValue(); + HalfFloat x = new HalfFloat((weights.get(gid) * inter)); + output.set(gid, x); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java index dfe4ef27..5e2e0bb3 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java @@ -3,6 +3,7 @@ import uk.ac.manchester.tornado.api.KernelContext; import uk.ac.manchester.tornado.api.annotations.Parallel; import uk.ac.manchester.tornado.api.math.TornadoMath; +import uk.ac.manchester.tornado.api.types.HalfFloat; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.Int8Array; @@ -35,7 +36,7 @@ public TransformerComputeKernelsLayered() { * @param localMemSize * Size of local memory allocation (must match work group size) */ - public static void reductionOneBlockWithLayer(KernelContext context, FloatArray output, FloatArray x, int size, float ermsNorm, int localMemSize) { + public static void reductionOneBlockWithLayer(KernelContext context, FloatArray output, HalfFloatArray x, int size, float ermsNorm, int localMemSize) { int gid = context.globalIdx; int lid = context.localIdx; int groupId = context.groupIdx; @@ -46,7 +47,7 @@ public static void reductionOneBlockWithLayer(KernelContext context, FloatArray // Load input value and compute square if (gid < size) { - localX[lid] = x.get(gid); + localX[lid] = x.get(gid).getFloat32(); localX[lid] = localX[lid] * localX[lid]; } else { localX[lid] = 0.0f; @@ -97,11 +98,11 @@ public static void reductionOneBlockWithLayer(KernelContext context, FloatArray * @param temp * Temporary array containing normalization factor at index 0 */ - public static void reductionOneBlock2WithLayer(KernelContext context, FloatArray output, FloatArray x, FloatArray weights, FloatArray temp) { + public static void reductionOneBlock2WithLayer(KernelContext context, FloatArray output, HalfFloatArray x, FloatArray weights, FloatArray temp) { int gid = context.globalIdx; float ss = temp.get(0); - output.set(gid, weights.get(gid) * (ss * x.get(gid))); + output.set(gid, weights.get(gid) * (ss * x.get(gid).getFloat32())); } /** @@ -690,6 +691,32 @@ public static void matrixVectorGeneric( hb.set(rowId, sum); } } + + public static void matrixVectorGeneric( + KernelContext context, + HalfFloatArray x, + FloatArray hb, // output + HalfFloatArray w, + int dim1, // inner loop + int dim0, // outer loop + int localWorkGroupSize) { + // One row per workgroup (not per thread) + int rowId = context.groupIdx; + int localId = context.localIdx; + int localSize = localWorkGroupSize; + + // Early exit if this workgroup is beyond our output dimension + if (rowId >= dim0) { + return; + } + float sum = matrixVectorRowMajorOptimized(context, localSize, x, w, dim1); + + // Thread 0 in each workgroup writes the final result + if (localId == 0) { + hb.set(rowId, sum); + } + } + // @formatter:on /** @@ -712,7 +739,7 @@ public static void matrixVectorGeneric( * @param localWorkGroupSize * Work group size */ - public static void matrixVectorGenericWithResidual(KernelContext context, FloatArray x, FloatArray hb, HalfFloatArray w, int n, int d, int localWorkGroupSize) { + public static void matrixVectorGenericWithResidual(KernelContext context, FloatArray x, HalfFloatArray hb, HalfFloatArray w, int n, int d, int localWorkGroupSize) { // One row per workgroup (not per thread) int rowId = context.groupIdx; int localId = context.localIdx; @@ -727,8 +754,8 @@ public static void matrixVectorGenericWithResidual(KernelContext context, FloatA // Thread 0 in each workgroup writes the final result if (localId == 0) { - float result = hb.get(rowId) + sum; - hb.set(rowId, result); + float result = hb.get(rowId).getFloat32() + sum; + hb.set(rowId, new HalfFloat(result)); } } @@ -847,6 +874,38 @@ public static float matrixVectorRowMajorOptimized(KernelContext context, int loc return localSum[0]; } + + public static float matrixVectorRowMajorOptimized(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) { + int rowId = context.groupIdx; + int localId = context.localIdx; + + // Allocate local memory for reduction + float[] localSum = context.allocateFloatLocalArray(localSize); + + int rowOffset = rowId * n; + + // Each thread calculates partial dot product + float partialSum = 0.0f; + for (int j = localId; j < n; j += localSize) { + int matrixIdx = rowOffset + j; + partialSum += w.get(matrixIdx).getFloat32() * x.get(j).getFloat32(); + } + + // Store partial sum in local memory + localSum[localId] = partialSum; + context.localBarrier(); + + // Parallel reduction within workgroup + for (int stride = localSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + return localSum[0]; + } + public static float matrixVectorRowMajorOptimized(KernelContext context, int localSize, FloatArray x, HalfFloatArray w, int n) { int rowId = context.groupIdx; int localId = context.localIdx; @@ -959,6 +1018,25 @@ public static void matrixVectorGeneric(KernelContext context, FloatArray x, Floa } } + public static void matrixVectorGeneric(KernelContext context, HalfFloatArray x, FloatArray output, Int8Array weightsQ, HalfFloatArray weightScales, int dim1, int dim0, int localWorkGroupSize) { + + // One row per workgroup + int rowId = context.groupIdx; + int localId = context.localIdx; + + // Early exit if this workgroup is beyond output dimension + if (rowId >= dim0) { + return; + } + + float sum = matrixVectorRowMajorOptimizedQ8_0(context, localWorkGroupSize, x, weightsQ, weightScales, dim1); + + // Thread 0 writes the result + if (localId == 0) { + output.set(rowId, sum); + } + } + /** * Helper method to compute dot product for a single row with Q8_0 quantized weights. Uses 4-way unrolling for better performance. */ @@ -1015,7 +1093,64 @@ public static float matrixVectorRowMajorOptimizedQ8_0(KernelContext context, int return localSums[0]; } - public static void matrixVectorGenericWithResidual(KernelContext context, FloatArray x, FloatArray hb, Int8Array w_quants, HalfFloatArray w_scales, int n, int d, int localWorkGroupSize) { + + /** + * Helper method to compute dot product for a single row with Q8_0 quantized weights. Uses 4-way unrolling for better performance. + */ + public static float matrixVectorRowMajorOptimizedQ8_0(KernelContext context, int localSize, HalfFloatArray x, Int8Array weightsQ, HalfFloatArray weightScales, int n) { + int rowId = context.groupIdx; + int localId = context.localIdx; + int blockSize = 32; + + // Allocate local memory for reduction + float[] localSums = context.allocateFloatLocalArray(localSize); + + int rowOffset = rowId * n; + int scalesRowOffset = rowId * (n / blockSize); + + // 4-way unrolling + float partialSum1 = 0.0f; + float partialSum2 = 0.0f; + float partialSum3 = 0.0f; + float partialSum4 = 0.0f; + + // Main loop - process 4 elements at a time + for (int j = localId * 4; j < n - 3; j += localSize * 4) { + int blockIdx = j / blockSize; + float scale = weightScales.get(scalesRowOffset + blockIdx).getFloat32(); + + // Dequantize and multiply + partialSum1 += ((float) weightsQ.get(rowOffset + j) * scale) * x.get(j).getHalfFloatValue(); + partialSum2 += ((float) weightsQ.get(rowOffset + j + 1) * scale) * x.get(j + 1).getHalfFloatValue(); + partialSum3 += ((float) weightsQ.get(rowOffset + j + 2) * scale) * x.get(j + 2).getHalfFloatValue(); + partialSum4 += ((float) weightsQ.get(rowOffset + j + 3) * scale) * x.get(j + 3).getHalfFloatValue(); + } + + float partialSum = partialSum1 + partialSum2 + partialSum3 + partialSum4; + + // Handle remaining elements + for (int j = ((n / 4) * 4) + localId; j < n; j += localSize) { + int blockIdx = j / blockSize; + float scale = weightScales.get(scalesRowOffset + blockIdx).getFloat32(); + partialSum += ((float) weightsQ.get(rowOffset + j) * scale) * x.get(j).getHalfFloatValue(); + } + + // Store partial sum + localSums[localId] = partialSum; + context.localBarrier(); + + // Parallel reduction + for (int stride = localSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSums[localId] += localSums[localId + stride]; + } + context.localBarrier(); + } + + return localSums[0]; + } + + public static void matrixVectorGenericWithResidual(KernelContext context, FloatArray x, HalfFloatArray hb, Int8Array w_quants, HalfFloatArray w_scales, int n, int d, int localWorkGroupSize) { // One row per workgroup (not per thread) int rowId = context.groupIdx; int localId = context.localIdx; @@ -1030,8 +1165,8 @@ public static void matrixVectorGenericWithResidual(KernelContext context, FloatA // Thread 0 in each workgroup writes the final result if (localId == 0) { - float result = hb.get(rowId) + sum; - hb.set(rowId, result); + float result = hb.get(rowId).getFloat32() + sum; + hb.set(rowId, new HalfFloat(result)); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java index 16783829..67ca726a 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java @@ -19,7 +19,7 @@ public Activation(String taskGraphHandle, State state, Weights weights, Configur // formatter:off this.activationUpdate = new TaskGraph(taskGraphHandle).transferToDevice(DataTransferMode.EVERY_EXECUTION, state.wrapX) - .task("updateX", TransformerComputeKernels::emptyTaskToForceCopyIn, state.wrapX).persistOnDevice(state.wrapX); + .task("updateX", TransformerComputeKernels::copyInEmbeddingActivationFP16, state.wrapX).persistOnDevice(state.wrapX); // formatter:on } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java index a674c1c5..ff0c26c7 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java @@ -16,6 +16,7 @@ import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.WorkerGrid1D; import uk.ac.manchester.tornado.api.enums.DataTransferMode; +import uk.ac.manchester.tornado.api.types.HalfFloat; public class LogitsFP16Layer extends AbstractLayer { @@ -45,7 +46,7 @@ private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration con if (schedulerType == SchedulerType.NON_NVIDIA) { logits.task("reductionFinalNormalizationLogits", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempLogits, config.dim(), config.rmsNormEps()); } - logits.task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX, weights.rms_final_weight_as_floatArray.asFloatArray(), state.tempLogits) + logits.task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX, weights.rms_final_weight_as_floatArray.asFloatArray(), state.tempLogits) .task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapX, state.wrapLogits, weights.wclsByteArray.asHalfFloatArray(), config.dim(), config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits);