Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions set_paths
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

# Resolve root of this project (LLaMA3) and TornadoVM
export LLAMA_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
export TORNADO_ROOT="${LLAMA_ROOT}/external/tornadovm"
#export TORNADO_ROOT="${LLAMA_ROOT}/external/tornadovm"

# Set the path to TornadoVM SDK binaries
export TORNADO_SDK="${TORNADO_ROOT}/bin/sdk"
#export TORNADO_SDK="${TORNADO_ROOT}/bin/sdk"

# Add TornadoVM and LLaMA bin directories to PATH
export PATH="${PATH}:${TORNADO_SDK}:${LLAMA_ROOT}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
import org.beehive.gpullama3.model.qwen3.Qwen3Configuration;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
import uk.ac.manchester.tornado.api.types.HalfFloat;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;

import java.lang.foreign.MemorySegment;

Expand Down Expand Up @@ -583,7 +585,10 @@ public static FloatArray forwardTornadoVM(Model model, State state, int token, i
final Configuration configuration = model.configuration();
final TornadoWeights weights = (TornadoWeights) model.weights();

MemorySegment.copy(weights.getTokenEmbeddingTable().asFloatArray().getSegment(), (long) token * configuration.dim() * Float.BYTES, state.wrapX.getSegment(), 0, configuration.dim() * Float.BYTES);
// MemorySegment.copy(weights.getTokenEmbeddingTable().asFloatArray().getSegment(), (long) token * configuration.dim() * Float.BYTES, state.wrapX.getSegment(), 0, configuration.dim() * Float.BYTES);

// System.out.println("token emdfing table type " + weights.getTokenEmbeddingTable().getClass().getName());
MemorySegment.copy(weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment(), (long) token * configuration.dim() * 2, state.wrapX.getSegment(), 0, configuration.dim() * 2);

return tornadoVMMasterPlan.tornadoVMForwardExecuteLayered(position);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import org.beehive.gpullama3.tensor.standard.FloatTensor;
import org.beehive.gpullama3.model.Configuration;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;

import java.util.stream.Stream;
Expand Down Expand Up @@ -45,8 +46,9 @@ protected StateFields createStateFields(Configuration config) {
fields.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), kvDim)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), kvDim)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);

fields.hackX = new HalfFloatArray(config.dim());
// TornadoVM wrappers with Llama/Mistral dimensions
fields.wrapX = new FloatArray(config.dim());
fields.wrapX = new HalfFloatArray(config.dim());
fields.wrapXb = new FloatArray(config.dim());
fields.wrapXb2 = new FloatArray(config.dim());
fields.wrapHb = new FloatArray(config.hiddenDim());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.phi3.Phi3Configuration;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;

import java.util.stream.Stream;
Expand Down Expand Up @@ -79,7 +80,7 @@ protected StateFields createStateFields(Configuration config) {
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(contextLength, kvDim)).limit(nLayers).toArray(FloatTensor[]::new);

// TornadoVM wrapper arrays for GPU acceleration
fields.wrapX = new FloatArray(dim);
fields.wrapX = new HalfFloatArray(dim);
fields.wrapXb = new FloatArray(dim);
fields.wrapXb2 = new FloatArray(dim);
fields.wrapHb = new FloatArray(2 * hiddenDim);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;

import java.util.stream.Stream;
Expand Down Expand Up @@ -40,7 +41,7 @@ protected StateFields createStateFields(Configuration configuration) {
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);

// TornadoVM wrappers with Qwen2 dimensions
fields.wrapX = new FloatArray(config.dim());
fields.wrapX = new HalfFloatArray(config.dim());
fields.wrapXb = new FloatArray(config.dim());
fields.wrapXb2 = new FloatArray(config.dim());
fields.wrapHb = new FloatArray(config.hiddenDim());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.qwen3.Qwen3Configuration;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;

import java.util.stream.Stream;
Expand Down Expand Up @@ -65,7 +66,7 @@ protected StateFields createStateFields(Configuration configuration) {
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);

// TornadoVM wrappers with Qwen3-specific sizes
fields.wrapX = new FloatArray(config.dim());
fields.wrapX = new HalfFloatArray(config.dim());
fields.wrapXb = new FloatArray(nEmbdHeadK * config.numberOfHeads());
fields.wrapXb2 = new FloatArray(config.dim());
fields.wrapHb = new FloatArray(config.hiddenDim());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import org.beehive.gpullama3.tensor.standard.FloatTensor;
import org.beehive.gpullama3.model.Configuration;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;

/**
Expand Down Expand Up @@ -48,7 +49,7 @@ public abstract class State {
public final FloatArray wrapXb2; // FloatArray wrapper for xb2, another residual buffer to aid in computations with TornadoVM.
public final FloatArray wrapHb; // FloatArray wrapper for hb (hidden dimension buffer for FFN), optimized for TornadoVM.
public final FloatArray wrapHb2; // FloatArray wrapper for hb2, additional hidden buffer for FFN, for compatibility with TornadoVM.
public final FloatArray wrapX; // FloatArray wrapper for the current activation tensor, optimized for TornadoVM.
public final HalfFloatArray wrapX; // FloatArray wrapper for the current activation tensor, optimized for TornadoVM.
public final FloatArray wrapQ; // FloatArray wrapper for the query tensor, optimized for TornadoVM.
public final FloatArray wrapK; // FloatArray wrapper for the key tensor, optimized for TornadoVM.
public final FloatArray wrapV; // FloatArray wrapper for the value tensor, optimized for TornadoVM.
Expand All @@ -64,6 +65,7 @@ public abstract class State {
public FloatArray tempLogits; // Temporary buffer for logits calculations, size adjusted for local workgroup size.
public int latestToken; // Keeps track of the most recent token processed by the model. Useful for stateful or autoregressive models.

public HalfFloatArray hackX;
/** last index in previous block */

protected State(Configuration config, int batchsize) {
Expand Down Expand Up @@ -108,6 +110,7 @@ protected State(Configuration config, int batchsize) {
this.temp = fields.temp;
this.tempFFN = fields.tempFFN;
this.tempLogits = fields.tempLogits;
this.hackX = fields.wrapX;
}

// Abstract method - subclasses implement their specific allocation logic and sizes
Expand All @@ -117,10 +120,11 @@ protected State(Configuration config, int batchsize) {
protected static class StateFields {
public FloatTensor x, xb, xb2, hb, hb2, q, k, v, att, logits;
public FloatTensor[] keyCache, valueCache;
public FloatArray wrapX, wrapXb, wrapXb2, wrapHb, wrapHb2, wrapLogits;
public FloatArray wrapXb, wrapXb2, wrapHb, wrapHb2, wrapLogits;
public FloatArray wrapQ, wrapK, wrapV, wrapAtt, wrapKeyCache, wrapValueCache;
public IntArray positionHolder;
public FloatArray temp, tempFFN, tempLogits;
public HalfFloatArray wrapX, hackX;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntr

// Load all tensors uniformly as TornadoTensor hierarchy
return new LlamaTornadoWeights(
loadTornadoTensorAsFP32(tokenEmbeddings),
loadTornadoTensor(tokenEmbeddings),
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ public static FloatTensor[] loadArrayOfTensors(int size, IntFunction<GGMLTensorE
public static TornadoTensor loadTornadoTensor(GGMLTensorEntry entry) {
GGMLType ggmlType = entry.ggmlType();
int size = FloatTensor.numberOfElements(entry.shape());
System.out.println("Loading tensor of type " + ggmlType + " with shape " + entry.name() + " -> " + entry.shape() + " and memory segment " + entry.memorySegment());
//
return switch (ggmlType) {
case F32 -> new FP32TornadoTensor(size, entry.memorySegment());
case F16 -> new FP16TornadoTensor(size, entry.memorySegment());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntr

// Load all tensors uniformly as TornadoTensor hierarchy
return new Phi3TornadoWeights(
loadTornadoTensorAsFP32(tokenEmbeddings),
loadTornadoTensor(tokenEmbeddings),
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_qkv.weight")),
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.beehive.gpullama3.tornadovm.layerplanner.base.QuantizationPlannerFactory;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.types.HalfFloat;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;

public class TornadoVMMasterPlan {
Expand Down Expand Up @@ -179,7 +180,7 @@ private int getFinalLogitsGraphIndex() {
/// Execute the forward pass of the LLaMA transformer model using TornadoVM acceleration just once to copy the data into the read-only data layer.
public void forceCopyInReadOnlyDataLayered() {
// Execute all TornadoVM graphs
state.wrapX.init(0.0f);
state.wrapX.init(new HalfFloat(0));
state.positionHolder.init(0);

// Execute activation update graph
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package org.beehive.gpullama3.tornadovm.kernels;

import uk.ac.manchester.tornado.api.KernelContext;
import uk.ac.manchester.tornado.api.annotations.Parallel;
import uk.ac.manchester.tornado.api.math.TornadoMath;
import uk.ac.manchester.tornado.api.types.HalfFloat;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;

public class TransformerComputeKernels {

Expand All @@ -12,13 +15,20 @@ public class TransformerComputeKernels {
public TransformerComputeKernels() {
}

public static void emptyTaskToForceCopyIn(FloatArray buffer) {
public static void copyInEmbeddingActivation(FloatArray buffer) {
float dummy = buffer.get(0);
if (dummy > Float.MAX_VALUE) {
buffer.set(0, dummy);
}
}

public static void copyInEmbeddingActivationFP16(HalfFloatArray buffer) {
float dummy = buffer.get(0).getFloat32();
if (dummy > Float.MAX_VALUE) {
buffer.set(0, new HalfFloat(dummy));
}
}

/**
* Performs RMS (Root Mean Square) normalization using parallel reduction.
* This is a two-phase reduction: first within work groups, then across work groups.
Expand All @@ -33,7 +43,7 @@ public static void emptyTaskToForceCopyIn(FloatArray buffer) {
* @param ermsNorm Epsilon value for numerical stability (epsilon * epsilon)
* @param localMemSize Size of local memory allocation (work group size)
*/
public static void reductionOneBlockWithLayer(KernelContext context, FloatArray output, FloatArray x, int size, float ermsNorm, int localMemSize) {
public static void reductionOneBlockWithLayer(KernelContext context, FloatArray output, HalfFloatArray x, int size, float ermsNorm, int localMemSize) {
int gid = context.globalIdx;
int lid = context.localIdx;
int groupId = context.groupIdx;
Expand All @@ -44,7 +54,7 @@ public static void reductionOneBlockWithLayer(KernelContext context, FloatArray

// Load input value and compute square
if (gid < size) {
localX[lid] = x.get(gid);
localX[lid] = x.get(gid).getFloat32();
localX[lid] = localX[lid] * localX[lid];
} else {
localX[lid] = 0.0f;
Expand Down Expand Up @@ -87,11 +97,49 @@ public static void reductionOneBlockWithLayer(KernelContext context, FloatArray
* @param output Array for normalized output
* @param weights Weight values to normalize
* @param temp Temporary array containing a normalization factor at index 0
*
*
*/
public static void reductionOneBlock2WithLogits(KernelContext context, FloatArray output, FloatArray weights, FloatArray temp) {


public static void reductionOneBlock2WithLogits(KernelContext context, HalfFloatArray output, FloatArray weights, FloatArray temp) {
// int gid = context.globalIdx;
// float ss = temp.get(0);
// output.set(gid, new HalfFloat((weights.get(gid) * (ss * output.get(gid).getFloat32()))));


int gid = context.globalIdx;

// Step 1: read normalization scalar
float ss = temp.get(0);

// Step 2: read current output value as float
HalfFloat hf = output.get(gid);
float out_f = hf.getFloat32();

// Step 3: read weight
// float w = weights.get(gid);

// Step 4: compute scaled output
float scaled = ss * out_f;

// Step 5: multiply by weight
float prod = weights.get(gid) * scaled;

// Step 6: create HalfFloat result
HalfFloat result = new HalfFloat(prod);

// Step 7: write back
output.set(gid, result);
}


public static void reductionOneBlock2WithLogits2(KernelContext context, HalfFloatArray input, HalfFloatArray output, FloatArray weights, FloatArray temp) {
int gid = context.globalIdx;
float ss = temp.get(0);
output.set(gid, weights.get(gid) * (ss * output.get(gid)));
float inter = ss * input.get(gid).getHalfFloatValue();
HalfFloat x = new HalfFloat((weights.get(gid) * inter));
output.set(gid, x);
}

}
Loading
Loading