From 9527bf5cedc488c066d1882c1e40c9d87c1c2cbb Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Thu, 18 Apr 2024 23:03:16 -0700 Subject: [PATCH 1/3] fix bug in model.cs --- LLAMA.cs | 3 +-- Model.cs | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/LLAMA.cs b/LLAMA.cs index f2eecd2..7acfd7a 100644 --- a/LLAMA.cs +++ b/LLAMA.cs @@ -47,10 +47,9 @@ public static LLaMA Build( stopWatch.Start(); paramJsonPath = Path.Combine(modelFolder, paramJsonPath); var modelArgs = JsonSerializer.Deserialize(File.ReadAllText(paramJsonPath)) ?? throw new Exception("Failed to deserialize model args"); - modelArgs.VocabSize = tokenizer.VocabSize; + modelArgs.VocabSize = 128256; modelArgs.MaxSeqLen = maxSeqLen; modelArgs.MaxBatchSize = maxBatchSize; - torch.set_default_dtype(torch.bfloat16); // print model args var modelArgsJson = JsonSerializer.Serialize(modelArgs, new JsonSerializerOptions { WriteIndented = true }); Console.WriteLine($"modelArgs: {modelArgsJson}"); diff --git a/Model.cs b/Model.cs index d9421d8..97a9c34 100644 --- a/Model.cs +++ b/Model.cs @@ -207,7 +207,7 @@ public FeedForward(ModelArgs args) { var hiddenDim = args.Dim * 4; hiddenDim = 2 * hiddenDim / 3; - hiddenDim = args.FFNDimMultiplier.HasValue ? (int)args.FFNDimMultiplier.Value * hiddenDim : hiddenDim; + hiddenDim = args.FFNDimMultiplier.HasValue ? (int)(args.FFNDimMultiplier.Value * hiddenDim) : hiddenDim; // Round the hidden_dim to the nearest multiple of the multiple_of parameter hiddenDim = args.MultipleOf * ((hiddenDim + args.MultipleOf - 1) / args.MultipleOf); From dc7d233d33fd0df10e78c362654140c0f4674033 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Fri, 19 Apr 2024 00:03:44 -0700 Subject: [PATCH 2/3] add llama3 tokenizer --- ITokenizer.cs | 241 +++++++++++++++++++++++++++++++++++++++- LLAMA.cs | 5 +- Model.cs | 1 - Program.cs | 41 +++++-- Torchsharp-llama.csproj | 5 +- 5 files changed, 268 insertions(+), 25 deletions(-) diff --git a/ITokenizer.cs b/ITokenizer.cs index 4a9a372..059e3e1 100644 --- a/ITokenizer.cs +++ b/ITokenizer.cs @@ -3,6 +3,8 @@ using System.Collections.Generic; using System.Linq; using System.Text; +using System.Text.Json; +using System.Text.RegularExpressions; using System.Threading.Tasks; namespace LLAMA; @@ -33,6 +35,17 @@ public override NormalizedString Normalize(string original) } } +public class TikTokenNormalizer : Normalizer +{ + public override NormalizedString Normalize(string original) + { + // replace space with Ġ + var normalized = original.Replace(" ", "Ġ"); + + return new NormalizedString(original, normalized, null, isOneToOneMapping: true); + } +} + public class PreTokenizer : Microsoft.ML.Tokenizers.PreTokenizer { public override IReadOnlyList PreTokenize(string sentence) @@ -43,16 +56,43 @@ public override IReadOnlyList PreTokenize(string sentence) } } +public class SplitPreTokenizer : Microsoft.ML.Tokenizers.PreTokenizer +{ + private readonly string _pattern; + + public SplitPreTokenizer(string pattern) + { + this._pattern = pattern; + } + + public override IReadOnlyList PreTokenize(string? sentence) + { + if (sentence == null) + { + return []; + } + + List list = new List(); + foreach (Match item in Regex.Matches(sentence, _pattern)) + { + list.Add(new Split(item.Value, (item.Index, item.Index + item.Length))); + } + + return list; + } +} + public class TokenizeDecoder : Microsoft.ML.Tokenizers.TokenizerDecoder { - private const char spaceReplacement = '▁'; + private char spaceReplacement = '▁'; private string bos = ""; private string eos = ""; - public TokenizeDecoder(string bos = "", string eos = "") + public TokenizeDecoder(string bos = "", string eos = "", char spaceReplacement = '▁') { this.bos = bos; this.eos = eos; + this.spaceReplacement = spaceReplacement; } public override string Decode(IEnumerable tokens) @@ -74,12 +114,12 @@ public override string Decode(IEnumerable tokens) } } -public class BPETokenizer : ITokenizer +public class LLama2Tokenizer : ITokenizer { private Tokenizer tokenizer; private bool addPrecedingSpace; - public BPETokenizer(string vocabPath, string mergesPath, bool addPrecedingSpace = true, int padToken = -1, int startToken = 1, int endToken = 2) + public LLama2Tokenizer(string vocabPath, string mergesPath, bool addPrecedingSpace = true, int padToken = -1, int startToken = 1, int endToken = 2) { this.BosId = startToken; this.EosId = endToken; @@ -91,13 +131,75 @@ public BPETokenizer(string vocabPath, string mergesPath, bool addPrecedingSpace this.tokenizer.Decoder = decoder; } - public static BPETokenizer FromPretrained( + public LLama2Tokenizer(Dictionary vocab, List merges, bool addPrecedingSpace = true, int padToken = -1, int startToken = 1, int endToken = 2) + { + this.BosId = startToken; + this.EosId = endToken; + this.addPrecedingSpace = addPrecedingSpace; + this.PadId = padToken; + // save vocab to vocab-temp.json + var vocabTempPath = "vocab-temp.json"; + var json = JsonSerializer.Serialize(vocab); + File.WriteAllText(vocabTempPath, json); + + // save merges to merges-temp.txt + var mergesTempPath = "merges-temp.txt"; + File.WriteAllLines(mergesTempPath, merges); + + var bpe = new Bpe(vocabTempPath, mergesTempPath); + + this.tokenizer = new Tokenizer(bpe, preTokenizer: new PreTokenizer(), normalizer: new Norm()); + var decoder = new TokenizeDecoder(this.tokenizer.Model.IdToToken(this.BosId)!, this.tokenizer.Model.IdToToken(this.EosId)!); + this.tokenizer.Decoder = decoder; + + // delete temp files + File.Delete(vocabTempPath); + File.Delete(mergesTempPath); + } + + public static LLama2Tokenizer FromPretrained( string folder, string tokenizerJsonPath = "tokenizer.json" ) { - throw new NotImplementedException(); + tokenizerJsonPath = Path.Combine(folder, tokenizerJsonPath); + var json = File.ReadAllText(tokenizerJsonPath); + var jsonDocument = JsonDocument.Parse(json); + // vocab: .model.vocab + var vocabNode = jsonDocument.RootElement.GetProperty("model").GetProperty("vocab"); + + // to Dictionary + var vocab = new Dictionary(); + foreach (var item in vocabNode.EnumerateObject()) + { + vocab[item.Name] = item.Value.GetInt32(); + } + + // added tokens: .added_tokens + var addedTokensNode = jsonDocument.RootElement.GetProperty("added_tokens"); + foreach (var item in addedTokensNode.EnumerateArray()) + { + // get id from item.id + var id = item.GetProperty("id").GetInt32(); + var content = item.GetProperty("content").GetString()!; + vocab[content] = id; + } + + // merges: .model.merges + var mergesNode = jsonDocument.RootElement.GetProperty("model").GetProperty("merges"); + // merges: List + var merges = new List(); + foreach (var item in mergesNode.EnumerateArray()) + { + merges.Add(item.GetString()!); + } + + var startToken = vocab["<|begin_of_text|>"]; + var endToken = vocab["<|end_of_text|>"]; + + return new LLama2Tokenizer(vocab, merges, startToken: startToken, endToken: endToken); } + public int VocabSize => this.tokenizer.Model.GetVocabSize(); public int PadId { get; } @@ -138,3 +240,130 @@ public int[] Encode(string input, bool bos, bool eos) return tokens; } } + +public class LLama3Tokenizer : ITokenizer +{ + private Tokenizer tokenizer; + private bool addPrecedingSpace; + + public LLama3Tokenizer(string vocabPath, string mergesPath, bool addPrecedingSpace = false, int padToken = -1, int startToken = 1, int endToken = 2) + { + this.BosId = startToken; + this.EosId = endToken; + this.addPrecedingSpace = addPrecedingSpace; + this.PadId = padToken; + var bpe = new Bpe(vocabPath, mergesPath); + this.tokenizer = new Tokenizer(bpe, preTokenizer: new PreTokenizer(), normalizer: new TikTokenNormalizer()); + var decoder = new TokenizeDecoder(this.tokenizer.Model.IdToToken(this.BosId)!, this.tokenizer.Model.IdToToken(this.EosId)!, 'Ġ'); + this.tokenizer.Decoder = decoder; + } + + public LLama3Tokenizer(Dictionary vocab, List merges, bool addPrecedingSpace = false, int padToken = -1, int startToken = 1, int endToken = 2) + { + this.BosId = startToken; + this.EosId = endToken; + this.addPrecedingSpace = addPrecedingSpace; + this.PadId = padToken; + // save vocab to vocab-temp.json + var vocabTempPath = "vocab-temp.json"; + var json = JsonSerializer.Serialize(vocab); + File.WriteAllText(vocabTempPath, json); + + // save merges to merges-temp.txt + var mergesTempPath = "merges-temp.txt"; + File.WriteAllLines(mergesTempPath, merges); + + var bpe = new Bpe(vocabTempPath, mergesTempPath); + this.tokenizer = new Tokenizer(bpe, preTokenizer: new PreTokenizer(), normalizer: new TikTokenNormalizer()); + var decoder = new TokenizeDecoder(this.tokenizer.Model.IdToToken(this.BosId)!, this.tokenizer.Model.IdToToken(this.EosId)!, 'Ġ'); + this.tokenizer.Decoder = decoder; + + // delete temp files + File.Delete(vocabTempPath); + File.Delete(mergesTempPath); + } + + public static LLama3Tokenizer FromPretrained( + string folder, + string tokenizerJsonPath = "tokenizer.json" + ) + { + tokenizerJsonPath = Path.Combine(folder, tokenizerJsonPath); + var json = File.ReadAllText(tokenizerJsonPath); + var jsonDocument = JsonDocument.Parse(json); + // vocab: .model.vocab + var vocabNode = jsonDocument.RootElement.GetProperty("model").GetProperty("vocab"); + + // to Dictionary + var vocab = new Dictionary(); + foreach (var item in vocabNode.EnumerateObject()) + { + vocab[item.Name] = item.Value.GetInt32(); + } + + // added tokens: .added_tokens + var addedTokensNode = jsonDocument.RootElement.GetProperty("added_tokens"); + foreach (var item in addedTokensNode.EnumerateArray()) + { + // get id from item.id + var id = item.GetProperty("id").GetInt32(); + var content = item.GetProperty("content").GetString()!; + vocab[content] = id; + } + + // merges: .model.merges + var mergesNode = jsonDocument.RootElement.GetProperty("model").GetProperty("merges"); + // merges: List + var merges = new List(); + foreach (var item in mergesNode.EnumerateArray()) + { + merges.Add(item.GetString()!); + } + + var startToken = vocab["<|begin_of_text|>"]; + var endToken = vocab["<|end_of_text|>"]; + + return new LLama3Tokenizer(vocab, merges, startToken: startToken, endToken: endToken); + } + + public int VocabSize => this.tokenizer.Model.GetVocabSize(); + + public int PadId { get; } + + public int BosId { get; } + + public int EosId { get; } + + public string Decode(int[] input) + { + var str = this.tokenizer.Decode(input) ?? throw new Exception("Failed to decode"); + if (this.addPrecedingSpace) + { + str = str.TrimStart(); + } + + return str; + } + + public int[] Encode(string input, bool bos, bool eos) + { + if (this.addPrecedingSpace) + { + input = " " + input; + } + var tokens = this.tokenizer.Encode(input).Ids.ToArray(); + if (bos) + { + tokens = new int[] { this.BosId }.Concat(tokens).ToArray(); + } + if (eos) + { + tokens = tokens.Concat(new int[] { this.EosId }).ToArray(); + } + + Console.WriteLine($"tokens: {string.Join(",", tokens)}"); + + return tokens; + } +} + diff --git a/LLAMA.cs b/LLAMA.cs index 7acfd7a..f65bec7 100644 --- a/LLAMA.cs +++ b/LLAMA.cs @@ -47,7 +47,7 @@ public static LLaMA Build( stopWatch.Start(); paramJsonPath = Path.Combine(modelFolder, paramJsonPath); var modelArgs = JsonSerializer.Deserialize(File.ReadAllText(paramJsonPath)) ?? throw new Exception("Failed to deserialize model args"); - modelArgs.VocabSize = 128256; + modelArgs.VocabSize = tokenizer.VocabSize; modelArgs.MaxSeqLen = maxSeqLen; modelArgs.MaxBatchSize = maxBatchSize; // print model args @@ -131,9 +131,6 @@ public static LLaMA Build( nextToken = nextToken.reshape(-1); // # only replace token if prompt has already been generated nextToken = torch.where(inputTextMask[.., curPos], tokens[.., curPos], nextToken); - - // print nextToken - Console.WriteLine($"nextToken: {string.Join(",", nextToken.data())}"); tokens[.., curPos] = nextToken; if (logProbs) { diff --git a/Model.cs b/Model.cs index 97a9c34..1f5282e 100644 --- a/Model.cs +++ b/Model.cs @@ -313,7 +313,6 @@ public override Tensor forward(Tensor tokens, int startPos) var h = this.tok_embeddings.forward(tokens); var freqsComplex = this.freqs_compex[startPos..(startPos + seqLen)].to(h.device); Tensor? mask = null; - Console.WriteLine($"tokens shape: {string.Join(",", tokens.shape)}"); if (seqLen > 1) { diff --git a/Program.cs b/Program.cs index 2f65ff7..4e456b2 100644 --- a/Program.cs +++ b/Program.cs @@ -3,20 +3,16 @@ using FluentAssertions; using LLAMA; using TorchSharp; -using System.Runtime.InteropServices; -var vocabPath = @"vocab.json"; -var mergesPath = @"merges.txt"; -var tokenizer = new BPETokenizer(vocabPath, mergesPath); + +var tokenizerFolder = @"C:\Users\xiaoyuz\source\repos\Meta-Llama-3-8B\"; +var tokenizer = LLama3Tokenizer.FromPretrained(tokenizerFolder); // update the following path to where you download the model -var checkpointDirectory = "/home/xiaoyuz/Llama-2-7b"; +var checkpointDirectory = @"C:\Users\xiaoyuz\source\repos\llama3\Meta-Llama-3-8B\"; var device = "cuda"; if (device == "cuda") { - // Comment out the following two line if you use a torchsharp runtime package. - var libTorch = "/anaconda/envs/py38_default/lib/python3.8/site-packages/torch/lib/libtorch.so"; - NativeLibrary.Load(libTorch); torch.InitializeDeviceType(DeviceType.CUDA); torch.cuda.is_available().Should().BeTrue(); } @@ -32,10 +28,33 @@ var prompts = new[] { "I believe the meaning of life is", + "Simply put, the theory of relativity states that", + """ + A brief message congratulating the team on the launch: + + Hi everyone, + + I just + """, + """ + Translate English to French: + + sea otter => loutre de mer + peppermint => menthe poivrée + plush girafe => girafe peluche + cheese => + """ }; -var result = model.TextCompletion(prompts, temperature: 0, echo: true, device: device); -foreach (var item in result) +foreach (var prompt in prompts) { - Console.WriteLine($"generation: {item.generation}"); + Console.WriteLine($"prompt: {prompt}"); + var result = model.TextCompletion([prompt], temperature: 0, echo: true, device: device); + + foreach (var item in result) + { + Console.WriteLine($"generation: {item.generation}"); + } } + + diff --git a/Torchsharp-llama.csproj b/Torchsharp-llama.csproj index 5599aa4..7d3a1b3 100644 --- a/Torchsharp-llama.csproj +++ b/Torchsharp-llama.csproj @@ -2,7 +2,7 @@ Exe - net6.0 + net8.0 LLAMA enable enable @@ -13,8 +13,7 @@ - - + From 1c4cebcae7ec61c8afcf47344e20b07769cc9c37 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Fri, 19 Apr 2024 00:26:35 -0700 Subject: [PATCH 3/3] update --- ITokenizer.cs | 11 ++++++++--- Program.cs | 23 +++++++++++------------ 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/ITokenizer.cs b/ITokenizer.cs index 059e3e1..0c97825 100644 --- a/ITokenizer.cs +++ b/ITokenizer.cs @@ -39,8 +39,10 @@ public class TikTokenNormalizer : Normalizer { public override NormalizedString Normalize(string original) { - // replace space with Ġ - var normalized = original.Replace(" ", "Ġ"); + // replace newline with Ċ + var normalized = original.Replace(Environment.NewLine, "Ċ"); + // replace whitespace with Ġ + normalized = normalized.Replace(' ', 'Ġ'); return new NormalizedString(original, normalized, null, isOneToOneMapping: true); } @@ -85,20 +87,23 @@ public override IReadOnlyList PreTokenize(string? sentence) public class TokenizeDecoder : Microsoft.ML.Tokenizers.TokenizerDecoder { private char spaceReplacement = '▁'; + private char newlineReplacement = 'Ċ'; private string bos = ""; private string eos = ""; - public TokenizeDecoder(string bos = "", string eos = "", char spaceReplacement = '▁') + public TokenizeDecoder(string bos = "", string eos = "", char spaceReplacement = '▁', char newlineReplacement = 'Ċ') { this.bos = bos; this.eos = eos; this.spaceReplacement = spaceReplacement; + this.newlineReplacement = newlineReplacement; } public override string Decode(IEnumerable tokens) { var str = string.Join("", tokens); str = str.Replace(spaceReplacement, ' '); + str = str.Replace(newlineReplacement.ToString(), Environment.NewLine); if (str.StartsWith(bos)) { diff --git a/Program.cs b/Program.cs index 4e456b2..0e7e58b 100644 --- a/Program.cs +++ b/Program.cs @@ -17,11 +17,11 @@ torch.cuda.is_available().Should().BeTrue(); } -torch.manual_seed(100); +torch.manual_seed(1); var model = LLaMA.Build( modelFolder: checkpointDirectory, tokenizer: tokenizer, - maxSeqLen: 128, + maxSeqLen: 1024, maxBatchSize: 1, device: device); @@ -29,27 +29,26 @@ { "I believe the meaning of life is", "Simply put, the theory of relativity states that", - """ - A brief message congratulating the team on the launch: + @"A brief message congratulating the team on the launch: - Hi everyone, + Hi everyone, - I just - """, - """ - Translate English to French: + I just ", + @"Translate English to French: sea otter => loutre de mer peppermint => menthe poivrée plush girafe => girafe peluche - cheese => - """ + cheese =>", }; +// tokenizer + foreach (var prompt in prompts) { Console.WriteLine($"prompt: {prompt}"); - var result = model.TextCompletion([prompt], temperature: 0, echo: true, device: device); + tokenizer.Encode(prompt, false, false).Should().NotBeNull(); + var result = model.TextCompletion([prompt], echo: true, device: device); foreach (var item in result) {