From baad80a4e70fa870bfd712514f6752e96fd1ef4f Mon Sep 17 00:00:00 2001 From: Branden Butler Date: Sun, 30 Apr 2023 15:18:12 -0500 Subject: [PATCH 1/9] Update Solr and Lucene to 9.0.0, fixing a bunch of incompatibilities / conflicts too --- pom.xml | 63 ++++++++--- .../java/org/myrobotlab/service/Solr.java | 101 ++++++++++++------ .../service/meta/ProgramABMeta.java | 8 +- .../org/myrobotlab/service/meta/SolrMeta.java | 14 ++- .../service/meta/TensorflowMeta.java | 6 +- .../resource/Solr/core1/conf/solrconfig.xml | 2 +- 6 files changed, 144 insertions(+), 50 deletions(-) diff --git a/pom.xml b/pom.xml index 7fad72c287..7caceb9769 100644 --- a/pom.xml +++ b/pom.xml @@ -1185,6 +1185,10 @@ 0.0.8.9 provided + + org.apache.lucene + * + ch.qos.logback logback-classic @@ -1205,14 +1209,14 @@ org.apache.lucene - lucene-analyzers-common - 8.11.2 + lucene-analysis-common + 9.0.0 provided org.apache.lucene - lucene-analyzers-kuromoji - 8.11.2 + lucene-analysis-kuromoji + 9.0.0 provided @@ -1372,13 +1376,19 @@ org.apache.lucene lucene-core - 8.11.2 + 9.0.0 + provided + + + org.apache.lucene + lucene-codecs + 9.0.0 provided org.apache.solr solr-core - 8.11.2 + 9.0.0 provided @@ -1399,10 +1409,22 @@ + + org.apache.solr + solr-scripting + 9.0.0 + provided + + + com.google.guava + * + + + org.apache.solr solr-test-framework - 8.11.2 + 9.0.0 provided @@ -1426,7 +1448,7 @@ org.apache.solr solr-solrj - 8.11.2 + 9.0.0 provided @@ -1452,6 +1474,24 @@ + + com.robrua.nlp + easy-bert + 1.0.3 + provided + + + com.robrua.nlp.models + easy-bert-uncased-L-12-H-768-A-12 + 1.0.0 + provided + + + org.tensorflow + tensorflow + 1.15.0 + provided + @@ -1470,12 +1510,7 @@ - - org.tensorflow - tensorflow - 1.8.0 - provided - + diff --git a/src/main/java/org/myrobotlab/service/Solr.java b/src/main/java/org/myrobotlab/service/Solr.java index 025be8a13b..ee9d51569b 100644 --- a/src/main/java/org/myrobotlab/service/Solr.java +++ b/src/main/java/org/myrobotlab/service/Solr.java @@ -9,27 +9,35 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Date; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import javax.imageio.ImageIO; +import com.robrua.nlp.bert.Bert; import org.apache.commons.codec.binary.Base64; import org.apache.commons.lang3.StringUtils; +import org.apache.lucene.analysis.core.KeywordTokenizerFactory; import org.apache.solr.client.solrj.SolrClient; import org.apache.solr.client.solrj.SolrQuery; import org.apache.solr.client.solrj.SolrQuery.ORDER; import org.apache.solr.client.solrj.SolrQuery.SortClause; +import org.apache.solr.client.solrj.SolrRequest; import org.apache.solr.client.solrj.SolrServerException; import org.apache.solr.client.solrj.embedded.EmbeddedSolrServer; import org.apache.solr.client.solrj.impl.HttpSolrClient; import org.apache.solr.client.solrj.response.QueryResponse; import org.apache.solr.common.SolrDocument; +import org.apache.solr.common.SolrDocumentList; import org.apache.solr.common.SolrInputDocument; +import org.apache.solr.common.SolrInputField; import org.apache.solr.core.CoreContainer; import org.bytedeco.javacv.Java2DFrameConverter; import org.bytedeco.javacv.OpenCVFrameConverter; @@ -140,6 +148,9 @@ public void startEmbedded(String path) throws SolrServerException, IOException { // File.separator + "solr.xml"); // load up the solr core container and start solr +// System.setProperty("solr.modules", "scripting"); +// System.setProperty("solr.install.dir", "."); + // FIXME - a bit unsatisfactory File f = new File(getDataInstanceDir()); f.mkdirs(); @@ -994,38 +1005,68 @@ public static void main(String[] args) { * index solr.commit(); */ - doc = new SolrInputDocument(); - doc.setField("id", "Doc3"); - doc.setField("title", "My title 3"); - doc.setField("content", "This is the text field, for a sample document in myrobotlab. 2 "); - doc.setField("annoyance", 1); - // add the document to the index - solr.addDocument(doc); - // commit the index - solr.commit(); - - // search for the word myrobotlab - String queryString = "myrobotlab"; - QueryResponse resp = solr.search(queryString); - for (int i = 0; i < resp.getResults().size(); i++) { - System.out.println("---------------------------------"); - System.out.println("-- Printing Result number :" + i); - // grab a document out of the result set. - SolrDocument d = resp.getResults().get(i); - // iterate over the fields on the returned document - for (String fieldName : d.getFieldNames()) { - - System.out.print(fieldName + "\t"); - // fields can be multi-valued - for (Object value : d.getFieldValues(fieldName)) { - System.out.print(value); - System.out.print("\t"); - } - System.out.println(""); + // Loading a BERT model that is stored in one of our Maven dependencies + try (Bert bert = Bert.load("com/robrua/nlp/easy-bert/bert-uncased-L-12-H-768-A-12")) { + String sentence = "Hello, my name is AP."; + doc.addField("id", "doc1"); + doc.addField("text_field", sentence); + // I don't know what I should be doing here, I need a dense vector field but can't figure out the type + doc.addField("vector_field", bert.embedSequence(sentence)); + solr.addDocument(doc); + solr.commit(); + SolrQuery query = new SolrQuery(); + float[] embeddings = bert.embedSequence("What is my name?"); + + query.setQuery("*:*"); + query.setParam("rq", "{!knn f=vector topK=3}" + Arrays.toString(embeddings)); + query.setParam("q1", "vector_field:[0 TO *]"); + query.setParam("fl", "*,score"); + + String vector = IntStream.range(0, embeddings.length) + .mapToObj(i -> String.valueOf(embeddings[i])) + .collect(Collectors.joining(",")); + query.setParam("vector", String.join(",", vector)); + QueryResponse response = solr.search(query); + SolrDocumentList results = response.getResults(); + for (SolrDocument docResult : results) { + String textData = (String) docResult.getFieldValue("text_field"); + System.out.println(textData); } + } - System.out.println("---------------------------------"); - System.out.println("Done."); + +// doc = new SolrInputDocument(); +// doc.setField("id", "Doc3"); +// doc.setField("title", "My title 3"); +// doc.setField("content", "This is the text field, for a sample document in myrobotlab. 2 "); +// doc.setField("annoyance", 1); +// // add the document to the index +// solr.addDocument(doc); +// // commit the index +// solr.commit(); +// +// // search for the word myrobotlab +// String queryString = "myrobotlab"; +// QueryResponse resp = solr.search(queryString); +// for (int i = 0; i < resp.getResults().size(); i++) { +// System.out.println("---------------------------------"); +// System.out.println("-- Printing Result number :" + i); +// // grab a document out of the result set. +// SolrDocument d = resp.getResults().get(i); +// // iterate over the fields on the returned document +// for (String fieldName : d.getFieldNames()) { +// +// System.out.print(fieldName + "\t"); +// // fields can be multi-valued +// for (Object value : d.getFieldValues(fieldName)) { +// System.out.print(value); +// System.out.print("\t"); +// } +// System.out.println(""); +// } +// } +// System.out.println("---------------------------------"); +// System.out.println("Done."); } catch (Exception e) { Logging.logError(e); diff --git a/src/main/java/org/myrobotlab/service/meta/ProgramABMeta.java b/src/main/java/org/myrobotlab/service/meta/ProgramABMeta.java index 91aafa2465..a3c3e5fca2 100644 --- a/src/main/java/org/myrobotlab/service/meta/ProgramABMeta.java +++ b/src/main/java/org/myrobotlab/service/meta/ProgramABMeta.java @@ -21,6 +21,9 @@ public ProgramABMeta() { addDependency("program-ab", "program-ab-data", null, "zip"); addDependency("program-ab", "program-ab-kw", "0.0.8.9"); + // For now we ignore Lucene deps from program AB because they conflict with Solr + // We should update ProgramAB pom + exclude("org.apache.lucene", "*"); exclude("ch.qos.logback", "logback-classic"); exclude("ch.qos.logback", "logback-core"); @@ -31,8 +34,9 @@ public ProgramABMeta() { addDependency("commons-io", "commons-io", "2.7"); // TODO: This is for CJK support in ProgramAB move this into the published // POM for ProgramAB so they are pulled in transiently. - addDependency("org.apache.lucene", "lucene-analyzers-common", "8.11.2"); - addDependency("org.apache.lucene", "lucene-analyzers-kuromoji", "8.11.2"); + // For version 9 the coordinates were changed from *-analyzers-* to *-analysis-* + addDependency("org.apache.lucene", "lucene-analysis-common", "9.0.0"); + addDependency("org.apache.lucene", "lucene-analysis-kuromoji", "9.0.0"); addCategory("ai", "control"); } diff --git a/src/main/java/org/myrobotlab/service/meta/SolrMeta.java b/src/main/java/org/myrobotlab/service/meta/SolrMeta.java index 0f2f25fffe..2255da1588 100644 --- a/src/main/java/org/myrobotlab/service/meta/SolrMeta.java +++ b/src/main/java/org/myrobotlab/service/meta/SolrMeta.java @@ -16,15 +16,22 @@ public SolrMeta() { addDescription("Solr Service - Open source search engine"); addCategory("search"); - String solrVersion = "8.11.2"; + // Solr version 9.1.* requires us to set solr.install.dir sys property + // for some reason, this one does not. Probably should investigate further + String solrVersion = "9.0.0"; String luceneVersion = solrVersion; addDependency("org.apache.lucene", "lucene-core", luceneVersion); + addDependency("org.apache.lucene", "lucene-codecs", luceneVersion); addDependency("org.apache.solr", "solr-core", solrVersion); exclude("log4j", "*"); exclude("org.apache.logging.log4j", "*"); exclude("com.fasterxml.jackson.core", "*"); exclude("io.netty", "*"); // prevent it from bringing in an old version of netty + // Some parts of Solr 8 were factored out into modules it seems + addDependency("org.apache.solr", "solr-scripting", solrVersion); + exclude("com.google.guava", "*"); + addDependency("org.apache.solr", "solr-test-framework", solrVersion); exclude("log4j", "*"); exclude("org.apache.logging.log4j", "*"); @@ -45,6 +52,11 @@ public SolrMeta() { // force correct version of netty addDependency("io.netty", "netty-all", "4.1.82.Final"); + // BERT embeddings. Could be moved to diff service + addDependency("com.robrua.nlp", "easy-bert", "1.0.3"); + addDependency("com.robrua.nlp.models", "easy-bert-uncased-L-12-H-768-A-12", "1.0.0"); + addDependency("org.tensorflow", "tensorflow", "1.15.0"); + // Dependencies issue setAvailable(true); diff --git a/src/main/java/org/myrobotlab/service/meta/TensorflowMeta.java b/src/main/java/org/myrobotlab/service/meta/TensorflowMeta.java index 9f8e43c10a..75dc0e9ebf 100644 --- a/src/main/java/org/myrobotlab/service/meta/TensorflowMeta.java +++ b/src/main/java/org/myrobotlab/service/meta/TensorflowMeta.java @@ -22,10 +22,12 @@ public TensorflowMeta() { addCategory("ai"); // TODO: what happens when you try to install this on an ARM processor like // RasPI or the Jetson TX2 ? - addDependency("org.tensorflow", "tensorflow", "1.8.0"); + // Needed to update because conflicts with BERT. + // FIXME our POM generation is still putting two artifacts with same ID but diff version in pom + addDependency("org.tensorflow", "tensorflow", "1.15.0"); // enable GPU support ? - boolean gpu = Boolean.valueOf(System.getProperty("gpu.enabled", "false")); + boolean gpu = Boolean.parseBoolean(System.getProperty("gpu.enabled", "false")); if (gpu) { // Currently only supported on Linux. 64 bit. addDependency("org.tensorflow", "libtensorflow", "1.8.0"); diff --git a/src/main/resources/resource/Solr/core1/conf/solrconfig.xml b/src/main/resources/resource/Solr/core1/conf/solrconfig.xml index 90f2e8234e..27178c3512 100755 --- a/src/main/resources/resource/Solr/core1/conf/solrconfig.xml +++ b/src/main/resources/resource/Solr/core1/conf/solrconfig.xml @@ -1309,7 +1309,7 @@ in Solr's conf/xslt directory. Changes to xslt files are checked for every xsltCacheLifetimeSeconds. --> - + 5 From da6ec8af6b9e67960229875c49fd46d5a778928f Mon Sep 17 00:00:00 2001 From: Branden Butler Date: Thu, 4 May 2023 22:32:20 -0500 Subject: [PATCH 2/9] Add knn_vector fieldType and _vector dynamic field to managed_schema. --- .../Solr/core1/conf/managed-schema.xml | 499 ++++++++++++++++++ 1 file changed, 499 insertions(+) create mode 100644 src/main/resources/resource/Solr/core1/conf/managed-schema.xml diff --git a/src/main/resources/resource/Solr/core1/conf/managed-schema.xml b/src/main/resources/resource/Solr/core1/conf/managed-schema.xml new file mode 100644 index 0000000000..553e762613 --- /dev/null +++ b/src/main/resources/resource/Solr/core1/conf/managed-schema.xml @@ -0,0 +1,499 @@ + + + + id + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + From 9df8c5f542dd18107b21aad9202b36d1999b0d31 Mon Sep 17 00:00:00 2001 From: Branden Butler Date: Thu, 4 May 2023 22:33:45 -0500 Subject: [PATCH 3/9] Fix deprecated LRUCache implementations and add schema management --- .../resource/Solr/core1/conf/solrconfig.xml | 23 ++++++++++++++----- 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/src/main/resources/resource/Solr/core1/conf/solrconfig.xml b/src/main/resources/resource/Solr/core1/conf/solrconfig.xml index 27178c3512..0b58e2e6e2 100755 --- a/src/main/resources/resource/Solr/core1/conf/solrconfig.xml +++ b/src/main/resources/resource/Solr/core1/conf/solrconfig.xml @@ -21,6 +21,11 @@ this file, see http://wiki.apache.org/solr/SolrConfigXml. --> + + true + managed-schema.xml + + - @@ -423,7 +428,7 @@ maxRamMB - the maximum amount of RAM (in MB) that this cache is allowed to occupy --> - @@ -434,14 +439,14 @@ document). Since Lucene internal document ids are transient, this cache will not be autowarmed. --> - From 0ca26c1c3278559ebaff321ff24de31f396d4d6b Mon Sep 17 00:00:00 2001 From: Branden Butler Date: Thu, 4 May 2023 22:34:43 -0500 Subject: [PATCH 4/9] Make Solr vector search example work --- .../java/org/myrobotlab/service/Solr.java | 31 ++++++++++++++----- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/src/main/java/org/myrobotlab/service/Solr.java b/src/main/java/org/myrobotlab/service/Solr.java index ee9d51569b..93005b4875 100644 --- a/src/main/java/org/myrobotlab/service/Solr.java +++ b/src/main/java/org/myrobotlab/service/Solr.java @@ -21,15 +21,14 @@ import javax.imageio.ImageIO; +import com.google.common.primitives.Floats; import com.robrua.nlp.bert.Bert; import org.apache.commons.codec.binary.Base64; import org.apache.commons.lang3.StringUtils; -import org.apache.lucene.analysis.core.KeywordTokenizerFactory; import org.apache.solr.client.solrj.SolrClient; import org.apache.solr.client.solrj.SolrQuery; import org.apache.solr.client.solrj.SolrQuery.ORDER; import org.apache.solr.client.solrj.SolrQuery.SortClause; -import org.apache.solr.client.solrj.SolrRequest; import org.apache.solr.client.solrj.SolrServerException; import org.apache.solr.client.solrj.embedded.EmbeddedSolrServer; import org.apache.solr.client.solrj.impl.HttpSolrClient; @@ -37,7 +36,6 @@ import org.apache.solr.common.SolrDocument; import org.apache.solr.common.SolrDocumentList; import org.apache.solr.common.SolrInputDocument; -import org.apache.solr.common.SolrInputField; import org.apache.solr.core.CoreContainer; import org.bytedeco.javacv.Java2DFrameConverter; import org.bytedeco.javacv.OpenCVFrameConverter; @@ -1008,28 +1006,45 @@ public static void main(String[] args) { // Loading a BERT model that is stored in one of our Maven dependencies try (Bert bert = Bert.load("com/robrua/nlp/easy-bert/bert-uncased-L-12-H-768-A-12")) { String sentence = "Hello, my name is AP."; + solr.deleteEmbeddedIndex(); + solr.deleteDocument("doc1"); + solr.deleteDocument("doc2"); doc.addField("id", "doc1"); doc.addField("text_field", sentence); // I don't know what I should be doing here, I need a dense vector field but can't figure out the type - doc.addField("vector_field", bert.embedSequence(sentence)); + doc.addField("test_vector", Floats.asList(bert.embedSequence(sentence))); + + solr.addDocument(doc); + + for (int i = 0; i < 1000; i++) { + SolrInputDocument doc2 = new SolrInputDocument(); + doc2.addField("id", "doc2" + i); + doc2.addField("text_field", "Make a Python program: " + i * i); + doc2.addField("test_vector", Floats.asList(bert.embedSequence("Make a Python program: " + i * i))); + solr.deleteDocument("doc2" + i); +// solr.addDocument(doc2); + + } + solr.commit(); SolrQuery query = new SolrQuery(); float[] embeddings = bert.embedSequence("What is my name?"); query.setQuery("*:*"); - query.setParam("rq", "{!knn f=vector topK=3}" + Arrays.toString(embeddings)); - query.setParam("q1", "vector_field:[0 TO *]"); + query.setParam("q", "{!knn f=test_vector topK=3}" + Arrays.toString(embeddings)); query.setParam("fl", "*,score"); + String vector = IntStream.range(0, embeddings.length) .mapToObj(i -> String.valueOf(embeddings[i])) .collect(Collectors.joining(",")); - query.setParam("vector", String.join(",", vector)); +// query.setParam("vector", String.join(",", vector)); QueryResponse response = solr.search(query); SolrDocumentList results = response.getResults(); for (SolrDocument docResult : results) { - String textData = (String) docResult.getFieldValue("text_field"); + String textData = ((ArrayList) docResult.getFieldValue("text_field")).toString(); + System.out.println(docResult.getFieldValue("score")); System.out.println(textData); } From c1ad4350cdcd301f8f59f4cfea6e8327f1fe1edc Mon Sep 17 00:00:00 2001 From: Branden Butler Date: Fri, 5 May 2023 01:07:41 -0500 Subject: [PATCH 5/9] Add ConversationTurn dataclass and ChatbotMemory interface --- .../service/data/ConversationTurn.java | 75 +++++++++++++++++++ .../service/interfaces/ChatbotMemory.java | 60 +++++++++++++++ 2 files changed, 135 insertions(+) create mode 100644 src/main/java/org/myrobotlab/service/data/ConversationTurn.java create mode 100644 src/main/java/org/myrobotlab/service/interfaces/ChatbotMemory.java diff --git a/src/main/java/org/myrobotlab/service/data/ConversationTurn.java b/src/main/java/org/myrobotlab/service/data/ConversationTurn.java new file mode 100644 index 0000000000..fdb142e2ac --- /dev/null +++ b/src/main/java/org/myrobotlab/service/data/ConversationTurn.java @@ -0,0 +1,75 @@ +package org.myrobotlab.service.data; + +import java.util.Objects; + +/** + * Represents one "turn" of a conversation. + * In a conversation the participants take turns + * speaking, while one is speaking the others + * should be listening. In most conversations, + * a participant can speak for as long as they like, + * but in all cases respondents must know who was speaking. + * Thus, this class contains the name of the speaker + * and what they said during their turn. + * Since one cannot unsay what has been said, + * this class is immutable. It is meant as a data object + * to pass records of the conversation around. + * + * @author AutonomicPerfectionist + */ +public class ConversationTurn { + + /** + * When an AI / Chatbot is speaking during + * the conversation, its speakerName is the value + * of this constant. This allows a chatbot to be + * renamed without forgetting everything. + */ + public static final String AI = "AI"; + + /** + * The person who was speaking during this turn. + * If a chatbot was speaking, then this field should have + * the value of {@link #AI}. + */ + public final String speaker; + + /** + * What the {@link #speaker} said during + * their turn. + */ + public final String turnContents; + + /** + * The ID of the conversation this turn was a part of. + * This ID can be generated through a number of ways, a simple + * way would be to add the hashcodes of the participants' names. + */ + public final long conversationId; + + public ConversationTurn(String speaker, String turnContents, long conversationId) { + Objects.requireNonNull(speaker, "Speaker may not be null"); + Objects.requireNonNull(turnContents, "Turn contents may not be null"); + this.speaker = speaker; + this.turnContents = turnContents; + this.conversationId = conversationId; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + ConversationTurn that = (ConversationTurn) o; + + if (!speaker.equals(that.speaker)) return false; + return turnContents.equals(that.turnContents); + } + + @Override + public int hashCode() { + int result = speaker.hashCode(); + result = 31 * result + turnContents.hashCode(); + return result; + } +} diff --git a/src/main/java/org/myrobotlab/service/interfaces/ChatbotMemory.java b/src/main/java/org/myrobotlab/service/interfaces/ChatbotMemory.java new file mode 100644 index 0000000000..c8dd9d3b9b --- /dev/null +++ b/src/main/java/org/myrobotlab/service/interfaces/ChatbotMemory.java @@ -0,0 +1,60 @@ +package org.myrobotlab.service.interfaces; + + +import org.myrobotlab.service.data.ConversationTurn; + +import java.util.List; + +/** + * Provides a form of memory for chatbots. + * The idea is to store information in an + * index of some kind during the conversation, + * and then recall specific information using + * the current input request. Usually + * this would be implemented through a vector store. + * + * @author AutonomicPerfectionist + */ +public interface ChatbotMemory { + + /** + * Commit a piece of the conversation to memory. + * Once memorized, the memory can be recalled if a request + * has high enough similarity to the memory. + * + * @param memory The turn to be remembered. + */ + void memorize(ConversationTurn memory); + + /** + * Recall a number of memorized conversation turns + * that have similarity to the request. The maximum number + * of memories recalled is set via {@link #setMaxNumMemoriesRecalled(int)}. + * This usually corresponds to the {@code top_k} parameter in vector stores. + * @param request The most recent conversation turn that is used to recall memories. + * @return Recalled memories + */ + List recallMemories(ConversationTurn request); + + /** + * Upon recalling memories, they are published through this method. + * Services that are interested in recalled memories should subscribe to this method. + * @param memories The memories that have been recalled. + * @return The recalled memories. + */ + List publishMemories(List memories); + + /** + * Sets the maximum number of memories to be recalled + * via {@link #recallMemories(ConversationTurn)}. + * @param number The maximum number of memories that can be recalled at once + */ + void setMaxNumMemoriesRecalled(int number); + + /** + * Gets the maximum number of memories to be recalled + * via {@link #recallMemories(ConversationTurn)}. + * @return The maximum number of memories that can be recalled at once. + */ + int getMaxNumMemoriesRecalled(); +} From b5d59239a3621c16e58cf6c570a4b7b6be8b142b Mon Sep 17 00:00:00 2001 From: Branden Butler Date: Sat, 13 May 2023 12:53:42 -0500 Subject: [PATCH 6/9] Refactor memory interface into ChatMessageVectorStore interface --- .../java/org/myrobotlab/service/Solr.java | 240 ++++++++++++++---- ...ConversationTurn.java => ChatMessage.java} | 39 ++- ...emory.java => ChatMessageVectorStore.java} | 27 +- 3 files changed, 237 insertions(+), 69 deletions(-) rename src/main/java/org/myrobotlab/service/data/{ConversationTurn.java => ChatMessage.java} (67%) rename src/main/java/org/myrobotlab/service/interfaces/{ChatbotMemory.java => ChatMessageVectorStore.java} (71%) diff --git a/src/main/java/org/myrobotlab/service/Solr.java b/src/main/java/org/myrobotlab/service/Solr.java index 93005b4875..a4635816f0 100644 --- a/src/main/java/org/myrobotlab/service/Solr.java +++ b/src/main/java/org/myrobotlab/service/Solr.java @@ -9,15 +9,12 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collection; import java.util.Date; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.UUID; -import java.util.stream.Collectors; -import java.util.stream.IntStream; import javax.imageio.ImageIO; @@ -34,7 +31,6 @@ import org.apache.solr.client.solrj.impl.HttpSolrClient; import org.apache.solr.client.solrj.response.QueryResponse; import org.apache.solr.common.SolrDocument; -import org.apache.solr.common.SolrDocumentList; import org.apache.solr.common.SolrInputDocument; import org.apache.solr.core.CoreContainer; import org.bytedeco.javacv.Java2DFrameConverter; @@ -58,6 +54,8 @@ import org.myrobotlab.opencv.OpenCVData; import org.myrobotlab.opencv.YoloDetectedObject; import org.myrobotlab.programab.Response; +import org.myrobotlab.service.data.ChatMessage; +import org.myrobotlab.service.interfaces.ChatMessageVectorStore; import org.myrobotlab.service.interfaces.DocumentListener; import org.myrobotlab.service.interfaces.SpeechRecognizer; import org.myrobotlab.service.interfaces.TextListener; @@ -77,7 +75,7 @@ * @author kwatters * */ -public class Solr extends Service implements DocumentListener, TextListener, MessageListener { +public class Solr extends Service implements DocumentListener, TextListener, MessageListener, ChatMessageVectorStore { private static final String CORE_NAME = "core1"; public final static Logger log = LoggerFactory.getLogger(Solr.class); @@ -98,6 +96,15 @@ public class Solr extends Service implements DocumentListener, TextListener, Mes public int yoloPersonTrainingCount = 0; public String yoloPersonLabel = null; + /** + * The maximum number of memories that can be recalled + * at once via {@link ChatMessageVectorStore#recallMemories(List)}. + */ + private int maxNumRecalledMemories = 3; + + private final String CONVERSATION_DOC_ID_TEMPLATE = "conversation_%d"; + private final Bert bert = Bert.load("com/robrua/nlp/easy-bert/bert-uncased-L-12-H-768-A-12"); + public Solr(String n, String id) { super(n, id); } @@ -254,6 +261,7 @@ public void deleteDocument(String docId) { /** * Returns a document given the doc id from the index if it exists otherwise + * null. * * @param docId * - the doc id @@ -992,9 +1000,17 @@ public static void main(String[] args) { try { Solr solr = (Solr) Runtime.start("solr", "Solr"); solr.startEmbedded(); + solr.deleteEmbeddedIndex(); + // WebGui webgui = (WebGui)Runtime.start("webgui", "WebGui"); // Create a test document SolrInputDocument doc = new SolrInputDocument(); + solr.memorize(new ChatMessage("AP", "I have a cat named Sunny", 1234)); + solr.memorize(new ChatMessage(ChatMessage.AI, "Hello AP, I am Hugo. How may I help you?", 1234)); + solr.memorize(new ChatMessage("AP", "Write a Python program for me.", 1234)); + solr.memorize(new ChatMessage(ChatMessage.AI, "Certainly, here is a hello world program:\n```python\nprint(\"hello world\")\n```", 1234)); + + System.out.println(solr.recallMemories(new ChatMessage("AP", "What is my cat's name?", 1234))); /* * doc.setField("id", "Doc1"); doc.setField("title", "My title"); * doc.setField("content", @@ -1003,52 +1019,51 @@ public static void main(String[] args) { * index solr.commit(); */ - // Loading a BERT model that is stored in one of our Maven dependencies - try (Bert bert = Bert.load("com/robrua/nlp/easy-bert/bert-uncased-L-12-H-768-A-12")) { - String sentence = "Hello, my name is AP."; - solr.deleteEmbeddedIndex(); - solr.deleteDocument("doc1"); - solr.deleteDocument("doc2"); - doc.addField("id", "doc1"); - doc.addField("text_field", sentence); - // I don't know what I should be doing here, I need a dense vector field but can't figure out the type - doc.addField("test_vector", Floats.asList(bert.embedSequence(sentence))); - - - solr.addDocument(doc); - - for (int i = 0; i < 1000; i++) { - SolrInputDocument doc2 = new SolrInputDocument(); - doc2.addField("id", "doc2" + i); - doc2.addField("text_field", "Make a Python program: " + i * i); - doc2.addField("test_vector", Floats.asList(bert.embedSequence("Make a Python program: " + i * i))); - solr.deleteDocument("doc2" + i); -// solr.addDocument(doc2); - - } - - solr.commit(); - SolrQuery query = new SolrQuery(); - float[] embeddings = bert.embedSequence("What is my name?"); - - query.setQuery("*:*"); - query.setParam("q", "{!knn f=test_vector topK=3}" + Arrays.toString(embeddings)); - query.setParam("fl", "*,score"); - - - String vector = IntStream.range(0, embeddings.length) - .mapToObj(i -> String.valueOf(embeddings[i])) - .collect(Collectors.joining(",")); -// query.setParam("vector", String.join(",", vector)); - QueryResponse response = solr.search(query); - SolrDocumentList results = response.getResults(); - for (SolrDocument docResult : results) { - String textData = ((ArrayList) docResult.getFieldValue("text_field")).toString(); - System.out.println(docResult.getFieldValue("score")); - System.out.println(textData); - } +// // Loading a BERT model that is stored in one of our Maven dependencies +// try (Bert bert = Bert.load("com/robrua/nlp/easy-bert/bert-uncased-L-12-H-768-A-12")) { +// String sentence = "Hello, my name is AP."; +//// solr.deleteDocument("doc1"); +//// solr.deleteDocument("doc2"); +// doc.addField("id", "doc1"); +// doc.addField("text_field", sentence); +// // I don't know what I should be doing here, I need a dense vector field but can't figure out the type +// doc.addField("test_vector", Floats.asList(bert.embedSequence(sentence))); +// +// +//// solr.addDocument(doc); +// +// for (int i = 0; i < 1; i++) { +// SolrInputDocument doc2 = new SolrInputDocument(); +// doc2.addField("id", "doc2" + i); +// doc2.addField("text_field", "Make a Python program: " + i * i); +// doc2.addField("test_vector", Floats.asList(bert.embedSequence("Make a Python program: " + i * i))); +//// solr.deleteDocument("doc2" + i); +//// solr.addDocument(doc2); +// +// } +// +// solr.commit(); +// SolrQuery query = new SolrQuery(); +// float[] embeddings = bert.embedSequence("What is my name."); +// +// query.setQuery("*:*"); +// query.setParam("q", "{!knn f=test_vector topK=3}" + Arrays.toString(embeddings)); +// query.setParam("fl", "*,score"); +// +// +// String vector = IntStream.range(0, embeddings.length) +// .mapToObj(i -> String.valueOf(embeddings[i])) +// .collect(Collectors.joining(",")); +//// query.setParam("vector", String.join(",", vector)); +// QueryResponse response = solr.search(query); +// SolrDocumentList results = response.getResults(); +// for (SolrDocument docResult : results) { +// String textData = ((ArrayList) docResult.getFieldValue("text_field")).toString(); +// System.out.println(docResult.getFieldValue("score")); +// System.out.println(textData); +// } - } +// } // doc = new SolrInputDocument(); // doc.setField("id", "Doc3"); @@ -1153,4 +1168,129 @@ public void releaseService() { super.releaseService(); } + + public void memorize(ChatMessage memory) { + memorize(memory, Floats.asList(bert.embedSequence(memory.message))); + } + + /** + * Commit a piece of the conversation to memory. + * Once memorized, the memory can be recalled if a request + * has high enough similarity to the memory. + * + * @param memory The turn to be remembered. + * @param embeddings + */ + @Override + public void memorize(ChatMessage memory, List embeddings) { + SolrInputDocument memoryDoc = new SolrInputDocument(); + + memoryDoc.setField("id", String.format(CONVERSATION_DOC_ID_TEMPLATE, memory.conversationId) + memory.message.hashCode()); + memoryDoc.setField("text_field", memory.message); + memoryDoc.setField("speaker_field", memory.speaker); + memoryDoc.setField("conversation_id", memory.conversationId); + memoryDoc.setField("test_vector", embeddings); + addDocument(memoryDoc); + + + SolrInputDocument newConversationDoc = new SolrInputDocument(); + newConversationDoc.addField("id", String.format(CONVERSATION_DOC_ID_TEMPLATE, memory.conversationId)); + newConversationDoc.addChildDocument(memoryDoc); +// addDocument(newConversationDoc); + commit(); + } + + + public List recallMemories(ChatMessage request) { + + + float[] embeddings = bert.embedSequence(request.message); + return recallMemories(Floats.asList(embeddings)); + + + } + + /** + * Recall a number of memorized conversation turns + * that have similarity to the request. The maximum number + * of memories recalled is set via {@link #setMaxNumMemoriesRecalled(int)}. + * This usually corresponds to the {@code top_k} parameter in vector stores. + * + * @param embeddings@return Recalled memories + */ + @Override + public List recallMemories(List embeddings) { + SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.setParam("q", "{!knn f=test_vector topK=3}" + embeddings.toString()); + query.setParam("fl", "*,score"); + QueryResponse response = search(query); + List turns = new ArrayList<>(); + for (SolrDocument result : response.getResults()) { + System.out.println("Score: " + result.getFieldValue("score")); + turns.add( + new ChatMessage( + ((ArrayList) result.getFieldValue("speaker_field")).get(0), + ((ArrayList) result.getFieldValue("text_field")).get(0), + Long.parseLong(((ArrayList) result.getFieldValue("conversation_id")).get(0)) + ) + ); + } + return turns; + } + + /** + * Upon recalling memories, they are published through this method. + * Services that are interested in recalled memories should subscribe to this method. + * + * @param memories The memories that have been recalled. + * @return The recalled memories. + */ + @Override + public List publishMemories(List memories) { + return memories; + } + + /** + * Sets the maximum number of memories to be recalled + * via {@link ChatMessageVectorStore#recallMemories(List)}. + * + * @param number The maximum number of memories that can be recalled at once + */ + @Override + public void setMaxNumMemoriesRecalled(int number) { + maxNumRecalledMemories = number; + } + + /** + * Gets the maximum number of memories to be recalled + * via {@link ChatMessageVectorStore#recallMemories(List)}. + * + * @return The maximum number of memories that can be recalled at once. + */ + @Override + public int getMaxNumMemoriesRecalled() { + return maxNumRecalledMemories; + } + + @Override + public int getEmbeddingDimensions() { + return 786; + } + + @Override + public void setEmbeddingDimensions(int dimensions) { + throw new UnsupportedOperationException( + "Cannot change embedding dimensions with Solr, manually modify the schema instead." + ); + } + + @Override + public void clearStore() { + try { + deleteEmbeddedIndex(); + } catch (SolrServerException | IOException e) { + error("Caught exception while trying to delete embedded index.", e); + } + } } diff --git a/src/main/java/org/myrobotlab/service/data/ConversationTurn.java b/src/main/java/org/myrobotlab/service/data/ChatMessage.java similarity index 67% rename from src/main/java/org/myrobotlab/service/data/ConversationTurn.java rename to src/main/java/org/myrobotlab/service/data/ChatMessage.java index fdb142e2ac..3c08e8e662 100644 --- a/src/main/java/org/myrobotlab/service/data/ConversationTurn.java +++ b/src/main/java/org/myrobotlab/service/data/ChatMessage.java @@ -17,7 +17,7 @@ * * @author AutonomicPerfectionist */ -public class ConversationTurn { +public class ChatMessage { /** * When an AI / Chatbot is speaking during @@ -38,38 +38,59 @@ public class ConversationTurn { * What the {@link #speaker} said during * their turn. */ - public final String turnContents; + public final String message; /** - * The ID of the conversation this turn was a part of. + * The ID of the conversation this message was a part of. * This ID can be generated through a number of ways, a simple * way would be to add the hashcodes of the participants' names. */ public final long conversationId; - public ConversationTurn(String speaker, String turnContents, long conversationId) { + public ChatMessage(String speaker, String message, long conversationId) { Objects.requireNonNull(speaker, "Speaker may not be null"); - Objects.requireNonNull(turnContents, "Turn contents may not be null"); + Objects.requireNonNull(message, "Turn contents may not be null"); this.speaker = speaker; - this.turnContents = turnContents; + this.message = message; this.conversationId = conversationId; } + public String getSpeaker() { + return speaker; + } + + public String getMessage() { + return message; + } + + public long getConversationId() { + return conversationId; + } + @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - ConversationTurn that = (ConversationTurn) o; + ChatMessage that = (ChatMessage) o; if (!speaker.equals(that.speaker)) return false; - return turnContents.equals(that.turnContents); + return message.equals(that.message); } @Override public int hashCode() { int result = speaker.hashCode(); - result = 31 * result + turnContents.hashCode(); + result = 31 * result + message.hashCode(); return result; } + + @Override + public String toString() { + return "ChatMessage{" + + "speaker='" + speaker + '\'' + + ", turnContents='" + message + '\'' + + ", conversationId=" + conversationId + + '}'; + } } diff --git a/src/main/java/org/myrobotlab/service/interfaces/ChatbotMemory.java b/src/main/java/org/myrobotlab/service/interfaces/ChatMessageVectorStore.java similarity index 71% rename from src/main/java/org/myrobotlab/service/interfaces/ChatbotMemory.java rename to src/main/java/org/myrobotlab/service/interfaces/ChatMessageVectorStore.java index c8dd9d3b9b..ff8ee48be4 100644 --- a/src/main/java/org/myrobotlab/service/interfaces/ChatbotMemory.java +++ b/src/main/java/org/myrobotlab/service/interfaces/ChatMessageVectorStore.java @@ -1,7 +1,7 @@ package org.myrobotlab.service.interfaces; -import org.myrobotlab.service.data.ConversationTurn; +import org.myrobotlab.service.data.ChatMessage; import java.util.List; @@ -15,26 +15,27 @@ * * @author AutonomicPerfectionist */ -public interface ChatbotMemory { +public interface ChatMessageVectorStore { /** * Commit a piece of the conversation to memory. * Once memorized, the memory can be recalled if a request * has high enough similarity to the memory. * - * @param memory The turn to be remembered. + * @param memory The turn to be remembered. + * @param embeddings */ - void memorize(ConversationTurn memory); + void memorize(ChatMessage memory, List embeddings); /** * Recall a number of memorized conversation turns * that have similarity to the request. The maximum number * of memories recalled is set via {@link #setMaxNumMemoriesRecalled(int)}. * This usually corresponds to the {@code top_k} parameter in vector stores. - * @param request The most recent conversation turn that is used to recall memories. - * @return Recalled memories + * + * @param embeddings@return Recalled memories */ - List recallMemories(ConversationTurn request); + List recallMemories(List embeddings); /** * Upon recalling memories, they are published through this method. @@ -42,19 +43,25 @@ public interface ChatbotMemory { * @param memories The memories that have been recalled. * @return The recalled memories. */ - List publishMemories(List memories); + List publishMemories(List memories); /** * Sets the maximum number of memories to be recalled - * via {@link #recallMemories(ConversationTurn)}. + * via {@link #recallMemories(List)}. * @param number The maximum number of memories that can be recalled at once */ void setMaxNumMemoriesRecalled(int number); /** * Gets the maximum number of memories to be recalled - * via {@link #recallMemories(ConversationTurn)}. + * via {@link #recallMemories(List)}. * @return The maximum number of memories that can be recalled at once. */ int getMaxNumMemoriesRecalled(); + + int getEmbeddingDimensions(); + + void setEmbeddingDimensions(int dimensions); + + void clearStore(); } From f57feade6f2b2ed0de8d0136d4dd4e27cd03678e Mon Sep 17 00:00:00 2001 From: Branden Butler Date: Sat, 13 May 2023 12:53:59 -0500 Subject: [PATCH 7/9] Minor cleanup in Document code --- src/main/java/org/myrobotlab/document/Document.java | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/main/java/org/myrobotlab/document/Document.java b/src/main/java/org/myrobotlab/document/Document.java index fbdf020c8d..7b76a80c5e 100644 --- a/src/main/java/org/myrobotlab/document/Document.java +++ b/src/main/java/org/myrobotlab/document/Document.java @@ -17,7 +17,7 @@ public class Document { private String id; - private HashMap> data; + private final HashMap> data; private ProcessingStatus status; public Document(String id) { @@ -27,11 +27,7 @@ public Document(String id) { } public ArrayList getField(String fieldName) { - if (data.containsKey(fieldName)) { - return data.get(fieldName); - } else { - return null; - } + return data.getOrDefault(fieldName, null); } public void setField(String fieldName, ArrayList value) { @@ -151,9 +147,7 @@ public boolean equals(Object obj) { return false; } else if (!id.equals(other.id)) return false; - if (status != other.status) - return false; - return true; + return status == other.status; } @Override From 8d19f0a06c4e1d043a1965a20fbf4aa84d270ef7 Mon Sep 17 00:00:00 2001 From: Branden Butler Date: Sat, 13 May 2023 13:13:12 -0500 Subject: [PATCH 8/9] Add TextEmbeddingGenerator interface and EasyBert service --- pom.xml | 6 +++ .../java/org/myrobotlab/service/EasyBert.java | 53 +++++++++++++++++++ .../interfaces/TextEmbeddingGenerator.java | 9 ++++ .../myrobotlab/service/meta/EasyBertMeta.java | 28 ++++++++++ 4 files changed, 96 insertions(+) create mode 100644 src/main/java/org/myrobotlab/service/EasyBert.java create mode 100644 src/main/java/org/myrobotlab/service/interfaces/TextEmbeddingGenerator.java create mode 100644 src/main/java/org/myrobotlab/service/meta/EasyBertMeta.java diff --git a/pom.xml b/pom.xml index 7caceb9769..579b6e2a1a 100644 --- a/pom.xml +++ b/pom.xml @@ -343,6 +343,12 @@ + + + + + + pl.allegro.tech diff --git a/src/main/java/org/myrobotlab/service/EasyBert.java b/src/main/java/org/myrobotlab/service/EasyBert.java new file mode 100644 index 0000000000..9f56068f20 --- /dev/null +++ b/src/main/java/org/myrobotlab/service/EasyBert.java @@ -0,0 +1,53 @@ +package org.myrobotlab.service; + +import com.google.common.primitives.Floats; +import com.robrua.nlp.bert.Bert; +import org.myrobotlab.framework.Service; +import org.myrobotlab.service.interfaces.TextEmbeddingGenerator; + +import java.io.File; +import java.util.List; + +public class EasyBert extends Service implements TextEmbeddingGenerator { + + private Bert bert; + public final String DEFAULT_BERT_MODEL = "com/robrua/nlp/easy-bert/bert-uncased-L-12-H-768-A-12"; + + + /** + * Constructor of service, reservedkey typically is a services name and inId + * will be its process id + * + * @param reservedKey the service name + * @param inId process id + */ + public EasyBert(String reservedKey, String inId) { + super(reservedKey, inId); + bert = Bert.load(DEFAULT_BERT_MODEL); + } + + @Override + public List generateEmbeddings(String words) { + List embeddings = Floats.asList(bert.embedSequence(words)); + invoke("publishEmbeddings", embeddings); + return embeddings; + } + + @Override + public List publishEmbeddings(List embeddings) { + return embeddings; + } + + @Override + public void onText(String text) throws Exception { + generateEmbeddings(text); + } + + public void setBertModel(String resource) { + bert = Bert.load(resource); + } + + public void setBertModel(File model) { + bert = Bert.load(model); + } +} diff --git a/src/main/java/org/myrobotlab/service/interfaces/TextEmbeddingGenerator.java b/src/main/java/org/myrobotlab/service/interfaces/TextEmbeddingGenerator.java new file mode 100644 index 0000000000..1fca923232 --- /dev/null +++ b/src/main/java/org/myrobotlab/service/interfaces/TextEmbeddingGenerator.java @@ -0,0 +1,9 @@ +package org.myrobotlab.service.interfaces; + +import java.util.List; + +public interface TextEmbeddingGenerator extends TextListener { + + List generateEmbeddings(String words); + List publishEmbeddings(List embeddings); +} diff --git a/src/main/java/org/myrobotlab/service/meta/EasyBertMeta.java b/src/main/java/org/myrobotlab/service/meta/EasyBertMeta.java new file mode 100644 index 0000000000..7805b5199b --- /dev/null +++ b/src/main/java/org/myrobotlab/service/meta/EasyBertMeta.java @@ -0,0 +1,28 @@ +package org.myrobotlab.service.meta; + +import org.myrobotlab.logging.LoggerFactory; +import org.myrobotlab.service.meta.abstracts.MetaData; +import org.slf4j.Logger; + +public class EasyBertMeta extends MetaData { + private static final long serialVersionUID = 1L; + public final static Logger log = LoggerFactory.getLogger(EasyBertMeta.class); + + /** + * This class is contains all the meta data details of a service. It's peers, + * dependencies, and all other meta data related to the service. + */ + public EasyBertMeta() { + + addDescription("EasyBert service - Java BERT sentence embeddings."); + addCategory("search"); + + addDependency("com.robrua.nlp", "easy-bert", "1.0.3"); + addDependency("com.robrua.nlp.models", "easy-bert-uncased-L-12-H-768-A-12", "1.0.0"); + addDependency("org.tensorflow", "tensorflow", "1.15.0"); + + setAvailable(true); + + } + +} From f1ab1626b983fbac68e3864f065f5fa5101ef4e7 Mon Sep 17 00:00:00 2001 From: Branden Butler Date: Sat, 20 May 2023 23:25:33 -0500 Subject: [PATCH 9/9] Switch memory embedding field to vector instead of test_vector --- src/main/java/org/myrobotlab/service/Solr.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/myrobotlab/service/Solr.java b/src/main/java/org/myrobotlab/service/Solr.java index 35c89929b9..b244f1725e 100644 --- a/src/main/java/org/myrobotlab/service/Solr.java +++ b/src/main/java/org/myrobotlab/service/Solr.java @@ -1201,7 +1201,7 @@ public void memorize(ChatMessage memory, List embeddings) { memoryDoc.setField("text_field", memory.message); memoryDoc.setField("speaker_field", memory.speaker); memoryDoc.setField("conversation_id", memory.conversationId); - memoryDoc.setField("test_vector", embeddings); + memoryDoc.setField("vector", embeddings); addDocument(memoryDoc); @@ -1234,7 +1234,7 @@ public List recallMemories(ChatMessage request) { public List recallMemories(List embeddings) { SolrQuery query = new SolrQuery(); query.setQuery("*:*"); - query.setParam("q", "{!knn f=test_vector topK=3}" + embeddings.toString()); + query.setParam("q", "{!knn f=vector topK=3}" + embeddings.toString()); query.setParam("fl", "*,score"); QueryResponse response = search(query); List turns = new ArrayList<>();