diff --git a/server/src/main/java/org/elasticsearch/inference/ChunkInferenceInput.java b/server/src/main/java/org/elasticsearch/inference/ChunkInferenceInput.java index 8e25e0e55f08c..886126a79901c 100644 --- a/server/src/main/java/org/elasticsearch/inference/ChunkInferenceInput.java +++ b/server/src/main/java/org/elasticsearch/inference/ChunkInferenceInput.java @@ -13,13 +13,20 @@ import java.util.List; -public record ChunkInferenceInput(String input, @Nullable ChunkingSettings chunkingSettings) { +import static org.elasticsearch.inference.InferenceString.DataType.TEXT; + +public record ChunkInferenceInput(InferenceString input, @Nullable ChunkingSettings chunkingSettings) { public ChunkInferenceInput(String input) { - this(input, null); + this(new InferenceString(input, TEXT), null); } - public static List inputs(List chunkInferenceInputs) { + public static List inputs(List chunkInferenceInputs) { return chunkInferenceInputs.stream().map(ChunkInferenceInput::input).toList(); } + + public String inputText() { + assert input.isText(); + return input.value(); + } } diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceString.java b/server/src/main/java/org/elasticsearch/inference/InferenceString.java new file mode 100644 index 0000000000000..ed01b5a4863f5 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/inference/InferenceString.java @@ -0,0 +1,53 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.inference; + +import java.util.EnumSet; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +/** + * This class represents a String which may be raw text, or the String representation of some other data such as an image in base64 + */ +public record InferenceString(String value, DataType dataType) { + /** + * Describes the type of data represented by an {@link InferenceString} + */ + public enum DataType { + TEXT, + IMAGE_BASE64 + } + + private static final EnumSet IMAGE_TYPES = EnumSet.of(DataType.IMAGE_BASE64); + + /** + * Constructs an {@link InferenceString} with the given value and {@link DataType} + * @param value the String value + * @param dataType the type of data that the String represents + */ + public InferenceString(String value, DataType dataType) { + this.value = Objects.requireNonNull(value); + this.dataType = Objects.requireNonNull(dataType); + } + + public boolean isImage() { + return IMAGE_TYPES.contains(dataType); + } + + public boolean isText() { + return DataType.TEXT.equals(dataType); + } + + public static List toStringList(List inferenceStrings) { + return inferenceStrings.stream().map(InferenceString::value).collect(Collectors.toList()); + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/chunking/EmbeddingRequestChunker.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/chunking/EmbeddingRequestChunker.java index 7d1fb251445d5..8ba5e6be0938c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/chunking/EmbeddingRequestChunker.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/chunking/EmbeddingRequestChunker.java @@ -15,6 +15,7 @@ import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.ChunkingStrategy; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InferenceString; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.inference.chunking.Chunker.ChunkOffset; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; @@ -36,7 +37,7 @@ * chunks. Multiple inputs may be fit into a single batch or * a single large input that has been chunked may spread over * multiple batches. - * + *

* The final aspect is to gather the responses from the batch * processing and map the results back to the original element * in the input list. @@ -44,14 +45,18 @@ public class EmbeddingRequestChunker> { // Visible for testing - record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, String input) { - public String chunkText() { - return input.substring(chunk.start(), chunk.end()); + record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, InferenceString input) { + public InferenceString chunkText() { + if (chunk.start() == 0 && chunk.end() == input.value().length()) { + return input; + } else { + return new InferenceString(input.value().substring(chunk.start(), chunk.end()), input.dataType()); + } } } public record BatchRequest(List requests) { - public Supplier> inputs() { + public Supplier> inputs() { return () -> requests.stream().map(Request::chunkText).collect(Collectors.toList()); } } @@ -107,13 +112,21 @@ public EmbeddingRequestChunker( List allRequests = new ArrayList<>(); for (int inputIndex = 0; inputIndex < inputs.size(); inputIndex++) { - ChunkingSettings chunkingSettings = inputs.get(inputIndex).chunkingSettings(); + ChunkInferenceInput chunkInferenceInput = inputs.get(inputIndex); + ChunkingSettings chunkingSettings = chunkInferenceInput.chunkingSettings(); if (chunkingSettings == null) { chunkingSettings = defaultChunkingSettings; } - Chunker chunker = chunkers.getOrDefault(chunkingSettings.getChunkingStrategy(), defaultChunker); - String inputString = inputs.get(inputIndex).input(); - List chunks = chunker.chunk(inputString, chunkingSettings); + Chunker chunker; + if (chunkInferenceInput.input().isText()) { + chunker = chunkers.getOrDefault(chunkingSettings.getChunkingStrategy(), defaultChunker); + } else { + // Do not chunk non-text inputs + chunker = NoopChunker.INSTANCE; + chunkingSettings = NoneChunkingSettings.INSTANCE; + } + InferenceString inputString = chunkInferenceInput.input(); + List chunks = chunker.chunk(inputString.value(), chunkingSettings); int resultCount = Math.min(chunks.size(), MAX_CHUNKS); resultEmbeddings.add(new AtomicReferenceArray<>(resultCount)); resultOffsetStarts.add(new ArrayList<>(resultCount)); @@ -121,7 +134,7 @@ public EmbeddingRequestChunker( for (int chunkIndex = 0; chunkIndex < chunks.size(); chunkIndex++) { // If the number of chunks is larger than the maximum allowed value, - // scale the indices to [0, MAX) with similar number of original + // scale the indices to [0, MAX] with similar number of original // chunks in the final chunks. int targetChunkIndex = chunks.size() <= MAX_CHUNKS ? chunkIndex : chunkIndex * MAX_CHUNKS / chunks.size(); if (resultOffsetStarts.getLast().size() <= targetChunkIndex) { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbedding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbedding.java index 75aee7230be57..12575210dbc1f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbedding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbedding.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.inference.results; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.InferenceString; import org.elasticsearch.xcontent.XContent; import java.io.IOException; @@ -19,15 +20,18 @@ public record ChunkedInferenceEmbedding(List chunks) implements ChunkedInference { - public static List listOf(List inputs, SparseEmbeddingResults sparseEmbeddingResults) { - validateInputSizeAgainstEmbeddings(inputs, sparseEmbeddingResults.embeddings().size()); + public static List listOf(List inputs, SparseEmbeddingResults sparseEmbeddingResults) { + validateInputSizeAgainstEmbeddings(inputs.size(), sparseEmbeddingResults.embeddings().size()); var results = new ArrayList(inputs.size()); for (int i = 0; i < inputs.size(); i++) { results.add( new ChunkedInferenceEmbedding( List.of( - new EmbeddingResults.Chunk(sparseEmbeddingResults.embeddings().get(i), new TextOffset(0, inputs.get(i).length())) + new EmbeddingResults.Chunk( + sparseEmbeddingResults.embeddings().get(i), + new TextOffset(0, inputs.get(i).value().length()) + ) ) ) ); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingUtils.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingUtils.java index 0ba8102b1dab7..56406b4f9def1 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingUtils.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingUtils.java @@ -9,18 +9,16 @@ import org.elasticsearch.common.Strings; -import java.util.List; - public class TextEmbeddingUtils { /** * Throws an exception if the number of elements in the input text list is different than the results in text embedding * response. */ - public static void validateInputSizeAgainstEmbeddings(List inputs, int embeddingSize) { - if (inputs.size() != embeddingSize) { + public static void validateInputSizeAgainstEmbeddings(int inputsSize, int embeddingSize) { + if (inputsSize != embeddingSize) { throw new IllegalArgumentException( - Strings.format("The number of inputs [%s] does not match the embeddings [%s]", inputs.size(), embeddingSize) + Strings.format("The number of inputs [%s] does not match the embeddings [%s]", inputsSize, embeddingSize) ); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/chunking/EmbeddingRequestChunkerTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/chunking/EmbeddingRequestChunkerTests.java index 80b887428f147..7fd3518a1e4e5 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/chunking/EmbeddingRequestChunkerTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/chunking/EmbeddingRequestChunkerTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.InferenceString; import org.elasticsearch.inference.WeightedToken; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; @@ -18,18 +19,21 @@ import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.hamcrest.Matchers; import java.util.ArrayList; import java.util.List; import java.util.concurrent.atomic.AtomicReference; +import static org.elasticsearch.inference.InferenceString.DataType.TEXT; +import static org.elasticsearch.inference.InferenceString.toStringList; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.endsWith; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; import static org.hamcrest.Matchers.startsWith; public class EmbeddingRequestChunkerTests extends ESTestCase { @@ -58,8 +62,8 @@ public void testAnyInput_NoopChunker() { var batches = new EmbeddingRequestChunker<>(List.of(new ChunkInferenceInput(randomInput)), 10, NoneChunkingSettings.INSTANCE) .batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); - assertThat(batches.get(0).batch().inputs().get(), hasSize(1)); - assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is(randomInput)); + assertThat(batches.getFirst().batch().inputs().get(), hasSize(1)); + assertThat(batches.getFirst().batch().inputs().get().getFirst().value(), is(randomInput)); } public void testWhitespaceInput_SentenceChunker() { @@ -69,8 +73,8 @@ public void testWhitespaceInput_SentenceChunker() { new SentenceBoundaryChunkingSettings(250, 1) ).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); - assertThat(batches.get(0).batch().inputs().get(), hasSize(1)); - assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is(" ")); + assertThat(batches.getFirst().batch().inputs().get(), hasSize(1)); + assertThat(batches.getFirst().batch().inputs().get().getFirst().value(), is(" ")); } public void testBlankInput_WordChunker() { @@ -78,16 +82,16 @@ public void testBlankInput_WordChunker() { testListener() ); assertThat(batches, hasSize(1)); - assertThat(batches.get(0).batch().inputs().get(), hasSize(1)); - assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is("")); + assertThat(batches.getFirst().batch().inputs().get(), hasSize(1)); + assertThat(batches.getFirst().batch().inputs().get().getFirst().value(), is("")); } public void testBlankInput_SentenceChunker() { var batches = new EmbeddingRequestChunker<>(List.of(new ChunkInferenceInput("")), 10, new SentenceBoundaryChunkingSettings(250, 1)) .batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); - assertThat(batches.get(0).batch().inputs().get(), hasSize(1)); - assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is("")); + assertThat(batches.getFirst().batch().inputs().get(), hasSize(1)); + assertThat(batches.getFirst().batch().inputs().get().getFirst().value(), is("")); } public void testInputThatDoesNotChunk_WordChunker() { @@ -95,8 +99,8 @@ public void testInputThatDoesNotChunk_WordChunker() { testListener() ); assertThat(batches, hasSize(1)); - assertThat(batches.get(0).batch().inputs().get(), hasSize(1)); - assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is("ABBAABBA")); + assertThat(batches.getFirst().batch().inputs().get(), hasSize(1)); + assertThat(batches.getFirst().batch().inputs().get().getFirst().value(), is("ABBAABBA")); } public void testInputThatDoesNotChunk_SentenceChunker() { @@ -106,15 +110,15 @@ public void testInputThatDoesNotChunk_SentenceChunker() { new SentenceBoundaryChunkingSettings(250, 1) ).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); - assertThat(batches.get(0).batch().inputs().get(), hasSize(1)); - assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is("ABBAABBA")); + assertThat(batches.getFirst().batch().inputs().get(), hasSize(1)); + assertThat(batches.getFirst().batch().inputs().get().getFirst().value(), is("ABBAABBA")); } public void testShortInputsAreSingleBatch() { ChunkInferenceInput input = new ChunkInferenceInput("one chunk"); var batches = new EmbeddingRequestChunker<>(List.of(input), 100, 100, 10).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); - assertThat(batches.get(0).batch().inputs().get(), contains(input.input())); + assertThat(toStringList(batches.getFirst().batch().inputs().get()), contains(input.inputText())); } public void testMultipleShortInputsAreSingleBatch() { @@ -129,7 +133,7 @@ public void testMultipleShortInputsAreSingleBatch() { assertEquals(batch.inputs().get(), ChunkInferenceInput.inputs(inputs)); for (int i = 0; i < inputs.size(); i++) { var request = batch.requests().get(i); - assertThat(request.chunkText(), equalTo(inputs.get(i).input())); + assertThat(request.chunkText().value(), equalTo(inputs.get(i).inputText())); assertEquals(i, request.inputIndex()); assertEquals(0, request.chunkIndex()); } @@ -151,20 +155,20 @@ public void testManyInputsMakeManyBatches() { assertThat(batches.get(2).batch().inputs().get(), hasSize(maxNumInputsPerBatch)); assertThat(batches.get(3).batch().inputs().get(), hasSize(1)); - assertEquals("input 0", batches.get(0).batch().inputs().get().get(0)); - assertEquals("input 9", batches.get(0).batch().inputs().get().get(9)); + assertEquals("input 0", batches.get(0).batch().inputs().get().get(0).value()); + assertEquals("input 9", batches.get(0).batch().inputs().get().get(9).value()); assertThat( - batches.get(1).batch().inputs().get(), + toStringList(batches.get(1).batch().inputs().get()), contains("input 10", "input 11", "input 12", "input 13", "input 14", "input 15", "input 16", "input 17", "input 18", "input 19") ); - assertEquals("input 20", batches.get(2).batch().inputs().get().get(0)); - assertEquals("input 29", batches.get(2).batch().inputs().get().get(9)); - assertThat(batches.get(3).batch().inputs().get(), contains("input 30")); + assertEquals("input 20", batches.get(2).batch().inputs().get().get(0).value()); + assertEquals("input 29", batches.get(2).batch().inputs().get().get(9).value()); + assertThat(toStringList(batches.get(3).batch().inputs().get()), contains("input 30")); List requests = batches.get(0).batch().requests(); for (int i = 0; i < requests.size(); i++) { EmbeddingRequestChunker.Request request = requests.get(i); - assertThat(request.chunkText(), equalTo(inputs.get(i).input())); + assertThat(request.chunkText().value(), equalTo(inputs.get(i).inputText())); assertThat(request.inputIndex(), equalTo(i)); assertThat(request.chunkIndex(), equalTo(0)); } @@ -187,20 +191,20 @@ public void testChunkingSettingsProvided() { assertThat(batches.get(2).batch().inputs().get(), hasSize(maxNumInputsPerBatch)); assertThat(batches.get(3).batch().inputs().get(), hasSize(1)); - assertEquals("input 0", batches.get(0).batch().inputs().get().get(0)); - assertEquals("input 9", batches.get(0).batch().inputs().get().get(9)); + assertEquals("input 0", batches.get(0).batch().inputs().get().get(0).value()); + assertEquals("input 9", batches.get(0).batch().inputs().get().get(9).value()); assertThat( - batches.get(1).batch().inputs().get(), + toStringList(batches.get(1).batch().inputs().get()), contains("input 10", "input 11", "input 12", "input 13", "input 14", "input 15", "input 16", "input 17", "input 18", "input 19") ); - assertEquals("input 20", batches.get(2).batch().inputs().get().get(0)); - assertEquals("input 29", batches.get(2).batch().inputs().get().get(9)); - assertThat(batches.get(3).batch().inputs().get(), contains("input 30")); + assertEquals("input 20", batches.get(2).batch().inputs().get().get(0).value()); + assertEquals("input 29", batches.get(2).batch().inputs().get().get(9).value()); + assertThat(toStringList(batches.get(3).batch().inputs().get()), contains("input 30")); List requests = batches.get(0).batch().requests(); for (int i = 0; i < requests.size(); i++) { EmbeddingRequestChunker.Request request = requests.get(i); - assertThat(request.chunkText(), equalTo(inputs.get(i).input())); + assertThat(request.chunkText().value(), equalTo(inputs.get(i).inputText())); assertThat(request.inputIndex(), equalTo(i)); assertThat(request.chunkIndex(), equalTo(0)); } @@ -230,21 +234,21 @@ public void testLongInputChunkedOverMultipleBatches() { assertThat(batches, hasSize(2)); - var batch = batches.get(0).batch(); + var batch = batches.getFirst().batch(); assertThat(batch.inputs().get(), hasSize(batchSize)); assertThat(batch.requests(), hasSize(batchSize)); - EmbeddingRequestChunker.Request request = batch.requests().get(0); + EmbeddingRequestChunker.Request request = batch.requests().getFirst(); assertThat(request.inputIndex(), equalTo(0)); assertThat(request.chunkIndex(), equalTo(0)); - assertThat(request.chunkText(), equalTo("1st small")); + assertThat(request.chunkText().value(), equalTo("1st small")); for (int requestIndex = 1; requestIndex < 5; requestIndex++) { request = batch.requests().get(requestIndex); assertThat(request.inputIndex(), equalTo(1)); int chunkIndex = requestIndex - 1; assertThat(request.chunkIndex(), equalTo(chunkIndex)); - assertThat(request.chunkText(), startsWith((chunkIndex == 0 ? "" : " ") + "passage_input" + 20 * chunkIndex)); + assertThat(request.chunkText().value(), startsWith((chunkIndex == 0 ? "" : " ") + "passage_input" + 20 * chunkIndex)); } batch = batches.get(1).batch(); @@ -256,18 +260,18 @@ public void testLongInputChunkedOverMultipleBatches() { assertThat(request.inputIndex(), equalTo(1)); int chunkIndex = requestIndex + 4; assertThat(request.chunkIndex(), equalTo(chunkIndex)); - assertThat(request.chunkText(), startsWith(" passage_input" + 20 * chunkIndex)); + assertThat(request.chunkText().value(), startsWith(" passage_input" + 20 * chunkIndex)); } request = batch.requests().get(2); assertThat(request.inputIndex(), equalTo(2)); assertThat(request.chunkIndex(), equalTo(0)); - assertThat(request.chunkText(), equalTo("2nd small")); + assertThat(request.chunkText().value(), equalTo("2nd small")); request = batch.requests().get(3); assertThat(request.inputIndex(), equalTo(3)); assertThat(request.chunkIndex(), equalTo(0)); - assertThat(request.chunkText(), equalTo("3rd small")); + assertThat(request.chunkText().value(), equalTo("3rd small")); } public void testVeryLongInput_Sparse() { @@ -314,13 +318,13 @@ public void testVeryLongInput_Sparse() { assertThat(finalListener.results, hasSize(3)); // The first input has the token with weight 1/16384f. - ChunkedInference inference = finalListener.results.get(0); + ChunkedInference inference = finalListener.results.getFirst(); assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); ChunkedInferenceEmbedding chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); - assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); - SparseEmbeddingResults.Embedding embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); + assertThat(getMatchedText(inputs.getFirst().inputText(), chunkedEmbedding.chunks().getFirst().offset()), equalTo("1st small")); + assertThat(chunkedEmbedding.chunks().getFirst().embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); + SparseEmbeddingResults.Embedding embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().getFirst().embedding(); assertThat(embedding.tokens(), contains(new WeightedToken("word", 1 / 16384f))); // The very long passage "word0 word1 ... word199999" is split into 10000 chunks for @@ -334,19 +338,22 @@ public void testVeryLongInput_Sparse() { // The first merged chunk consists of 20 small chunks (so 400 words) and the max // weight is the weight of the 20th small chunk (so 21/16384). - assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); - assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); - embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); + assertThat(getMatchedText(inputs.get(1).inputText(), chunkedEmbedding.chunks().getFirst().offset()), startsWith("word0 word1 ")); + assertThat(getMatchedText(inputs.get(1).inputText(), chunkedEmbedding.chunks().getFirst().offset()), endsWith(" word398 word399")); + assertThat(chunkedEmbedding.chunks().getFirst().embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); + embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().getFirst().embedding(); assertThat(embedding.tokens(), contains(new WeightedToken("word", 21 / 16384f))); // The last merged chunk consists of 19 small chunks (so 380 words) and the max // weight is the weight of the 10000th small chunk (so 10001/16384). assertThat( - getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), + getMatchedText(inputs.get(1).inputText(), chunkedEmbedding.chunks().get(511).offset()), startsWith(" word199620 word199621 ") ); - assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999")); + assertThat( + getMatchedText(inputs.get(1).inputText(), chunkedEmbedding.chunks().get(511).offset()), + endsWith(" word199998 word199999") + ); assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(511).embedding(); assertThat(embedding.tokens(), contains(new WeightedToken("word", 10001 / 16384f))); @@ -356,9 +363,9 @@ public void testVeryLongInput_Sparse() { assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); - assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); - embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); + assertThat(getMatchedText(inputs.get(2).inputText(), chunkedEmbedding.chunks().getFirst().offset()), equalTo("2nd small")); + assertThat(chunkedEmbedding.chunks().getFirst().embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); + embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().getFirst().embedding(); assertThat(embedding.tokens(), contains(new WeightedToken("word", 10002 / 16384f))); } @@ -405,14 +412,14 @@ public void testVeryLongInput_Float() { assertThat(finalListener.results, hasSize(3)); // The first input has the embedding with weight 1/16384. - ChunkedInference inference = finalListener.results.get(0); + ChunkedInference inference = finalListener.results.getFirst(); assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); ChunkedInferenceEmbedding chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); - assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); + assertThat(getMatchedText(inputs.getFirst().inputText(), chunkedEmbedding.chunks().getFirst().offset()), equalTo("1st small")); + assertThat(chunkedEmbedding.chunks().getFirst().embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); DenseEmbeddingFloatResults.Embedding embedding = (DenseEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks() - .get(0) + .getFirst() .embedding(); assertThat(embedding.values(), equalTo(new float[] { 1 / 16384f })); @@ -427,19 +434,22 @@ public void testVeryLongInput_Float() { // The first merged chunk consists of 20 small chunks (so 400 words) and the weight // is the average of the weights 2/16384 ... 21/16384. - assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); - assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); - embedding = (DenseEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); + assertThat(getMatchedText(inputs.get(1).inputText(), chunkedEmbedding.chunks().getFirst().offset()), startsWith("word0 word1 ")); + assertThat(getMatchedText(inputs.get(1).inputText(), chunkedEmbedding.chunks().getFirst().offset()), endsWith(" word398 word399")); + assertThat(chunkedEmbedding.chunks().getFirst().embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); + embedding = (DenseEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().getFirst().embedding(); assertThat(embedding.values(), equalTo(new float[] { (2 + 21) / (2 * 16384f) })); // The last merged chunk consists of 19 small chunks (so 380 words) and the weight // is the average of the weights 9983/16384 ... 10001/16384. assertThat( - getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), + getMatchedText(inputs.get(1).inputText(), chunkedEmbedding.chunks().get(511).offset()), startsWith(" word199620 word199621 ") ); - assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999")); + assertThat( + getMatchedText(inputs.get(1).inputText(), chunkedEmbedding.chunks().get(511).offset()), + endsWith(" word199998 word199999") + ); assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); embedding = (DenseEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(511).embedding(); assertThat(embedding.values(), equalTo(new float[] { (9983 + 10001) / (2 * 16384f) })); @@ -449,9 +459,9 @@ public void testVeryLongInput_Float() { assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); - assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); - embedding = (DenseEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); + assertThat(getMatchedText(inputs.get(2).inputText(), chunkedEmbedding.chunks().getFirst().offset()), equalTo("2nd small")); + assertThat(chunkedEmbedding.chunks().getFirst().embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); + embedding = (DenseEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().getFirst().embedding(); assertThat(embedding.values(), equalTo(new float[] { 10002 / 16384f })); } @@ -498,13 +508,15 @@ public void testVeryLongInput_Byte() { assertThat(finalListener.results, hasSize(3)); // The first input has the embedding with weight 1. - ChunkedInference inference = finalListener.results.get(0); + ChunkedInference inference = finalListener.results.getFirst(); assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); ChunkedInferenceEmbedding chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); - assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(DenseEmbeddingByteResults.Embedding.class)); - DenseEmbeddingByteResults.Embedding embedding = (DenseEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); + assertThat(getMatchedText(inputs.getFirst().inputText(), chunkedEmbedding.chunks().getFirst().offset()), equalTo("1st small")); + assertThat(chunkedEmbedding.chunks().getFirst().embedding(), instanceOf(DenseEmbeddingByteResults.Embedding.class)); + DenseEmbeddingByteResults.Embedding embedding = (DenseEmbeddingByteResults.Embedding) chunkedEmbedding.chunks() + .getFirst() + .embedding(); assertThat(embedding.values(), equalTo(new byte[] { 1 })); // The very long passage "word0 word1 ... word199999" is split into 10000 chunks for @@ -518,20 +530,23 @@ public void testVeryLongInput_Byte() { // The first merged chunk consists of 20 small chunks (so 400 words) and the weight // is the average of the weights 2 ... 21, so 11.5, which is rounded to 12. - assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); - assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(DenseEmbeddingByteResults.Embedding.class)); - embedding = (DenseEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); + assertThat(getMatchedText(inputs.get(1).inputText(), chunkedEmbedding.chunks().getFirst().offset()), startsWith("word0 word1 ")); + assertThat(getMatchedText(inputs.get(1).inputText(), chunkedEmbedding.chunks().getFirst().offset()), endsWith(" word398 word399")); + assertThat(chunkedEmbedding.chunks().getFirst().embedding(), instanceOf(DenseEmbeddingByteResults.Embedding.class)); + embedding = (DenseEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().getFirst().embedding(); assertThat(embedding.values(), equalTo(new byte[] { 12 })); // The last merged chunk consists of 19 small chunks (so 380 words) and the weight // is the average of the weights 9983 ... 10001 modulo 256 (bytes overflowing), so // the average of -1, 0, 1, ... , 17, so 8. assertThat( - getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), + getMatchedText(inputs.get(1).inputText(), chunkedEmbedding.chunks().get(511).offset()), startsWith(" word199620 word199621 ") ); - assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999")); + assertThat( + getMatchedText(inputs.get(1).inputText(), chunkedEmbedding.chunks().get(511).offset()), + endsWith(" word199998 word199999") + ); assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(DenseEmbeddingByteResults.Embedding.class)); embedding = (DenseEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(511).embedding(); assertThat(embedding.values(), equalTo(new byte[] { 8 })); @@ -541,9 +556,9 @@ public void testVeryLongInput_Byte() { assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); - assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(DenseEmbeddingByteResults.Embedding.class)); - embedding = (DenseEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); + assertThat(getMatchedText(inputs.get(2).inputText(), chunkedEmbedding.chunks().getFirst().offset()), equalTo("2nd small")); + assertThat(chunkedEmbedding.chunks().getFirst().embedding(), instanceOf(DenseEmbeddingByteResults.Embedding.class)); + embedding = (DenseEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().getFirst().embedding(); assertThat(embedding.values(), equalTo(new byte[] { 18 })); } @@ -576,7 +591,7 @@ public void testMergingListener_Float() { for (int i = 0; i < batchSize; i++) { embeddings.add(new DenseEmbeddingFloatResults.Embedding(new float[] { randomFloat() })); } - batches.get(0).listener().onResponse(new DenseEmbeddingFloatResults(embeddings)); + batches.getFirst().listener().onResponse(new DenseEmbeddingFloatResults(embeddings)); } { var embeddings = new ArrayList(); @@ -589,11 +604,14 @@ public void testMergingListener_Float() { assertNotNull(finalListener.results); assertThat(finalListener.results, hasSize(4)); { - var chunkedResult = finalListener.results.get(0); + var chunkedResult = finalListener.results.getFirst(); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0).input(), chunkedFloatResult.chunks().get(0).offset()), equalTo("1st small")); + assertThat( + getMatchedText(inputs.getFirst().inputText(), chunkedFloatResult.chunks().getFirst().offset()), + equalTo("1st small") + ); } { // this is the large input split in multiple chunks @@ -601,13 +619,28 @@ public void testMergingListener_Float() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(6)); - assertThat(getMatchedText(inputs.get(1).input(), chunkedFloatResult.chunks().get(0).offset()), startsWith("passage_input0 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedFloatResult.chunks().get(1).offset()), startsWith(" passage_input20 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedFloatResult.chunks().get(2).offset()), startsWith(" passage_input40 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedFloatResult.chunks().get(3).offset()), startsWith(" passage_input60 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedFloatResult.chunks().get(4).offset()), startsWith(" passage_input80 ")); assertThat( - getMatchedText(inputs.get(1).input(), chunkedFloatResult.chunks().get(5).offset()), + getMatchedText(inputs.get(1).inputText(), chunkedFloatResult.chunks().get(0).offset()), + startsWith("passage_input0 ") + ); + assertThat( + getMatchedText(inputs.get(1).inputText(), chunkedFloatResult.chunks().get(1).offset()), + startsWith(" passage_input20 ") + ); + assertThat( + getMatchedText(inputs.get(1).inputText(), chunkedFloatResult.chunks().get(2).offset()), + startsWith(" passage_input40 ") + ); + assertThat( + getMatchedText(inputs.get(1).inputText(), chunkedFloatResult.chunks().get(3).offset()), + startsWith(" passage_input60 ") + ); + assertThat( + getMatchedText(inputs.get(1).inputText(), chunkedFloatResult.chunks().get(4).offset()), + startsWith(" passage_input80 ") + ); + assertThat( + getMatchedText(inputs.get(1).inputText(), chunkedFloatResult.chunks().get(5).offset()), startsWith(" passage_input100 ") ); } @@ -616,14 +649,14 @@ public void testMergingListener_Float() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2).input(), chunkedFloatResult.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(getMatchedText(inputs.get(2).inputText(), chunkedFloatResult.chunks().getFirst().offset()), equalTo("2nd small")); } { var chunkedResult = finalListener.results.get(3); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(3).input(), chunkedFloatResult.chunks().get(0).offset()), equalTo("3rd small")); + assertThat(getMatchedText(inputs.get(3).inputText(), chunkedFloatResult.chunks().getFirst().offset()), equalTo("3rd small")); } } @@ -656,7 +689,7 @@ public void testMergingListener_Byte() { for (int i = 0; i < batchSize; i++) { embeddings.add(new DenseEmbeddingByteResults.Embedding(new byte[] { randomByte() })); } - batches.get(0).listener().onResponse(new DenseEmbeddingByteResults(embeddings)); + batches.getFirst().listener().onResponse(new DenseEmbeddingByteResults(embeddings)); } { var embeddings = new ArrayList(); @@ -669,11 +702,11 @@ public void testMergingListener_Byte() { assertNotNull(finalListener.results); assertThat(finalListener.results, hasSize(4)); { - var chunkedResult = finalListener.results.get(0); + var chunkedResult = finalListener.results.getFirst(); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0).input(), chunkedByteResult.chunks().get(0).offset()), equalTo("1st small")); + assertThat(getMatchedText(inputs.getFirst().inputText(), chunkedByteResult.chunks().getFirst().offset()), equalTo("1st small")); } { // this is the large input split in multiple chunks @@ -681,26 +714,44 @@ public void testMergingListener_Byte() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(6)); - assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(0).offset()), startsWith("passage_input0 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(1).offset()), startsWith(" passage_input20 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(2).offset()), startsWith(" passage_input40 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(3).offset()), startsWith(" passage_input60 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(4).offset()), startsWith(" passage_input80 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(5).offset()), startsWith(" passage_input100 ")); + assertThat( + getMatchedText(inputs.get(1).inputText(), chunkedByteResult.chunks().get(0).offset()), + startsWith("passage_input0 ") + ); + assertThat( + getMatchedText(inputs.get(1).inputText(), chunkedByteResult.chunks().get(1).offset()), + startsWith(" passage_input20 ") + ); + assertThat( + getMatchedText(inputs.get(1).inputText(), chunkedByteResult.chunks().get(2).offset()), + startsWith(" passage_input40 ") + ); + assertThat( + getMatchedText(inputs.get(1).inputText(), chunkedByteResult.chunks().get(3).offset()), + startsWith(" passage_input60 ") + ); + assertThat( + getMatchedText(inputs.get(1).inputText(), chunkedByteResult.chunks().get(4).offset()), + startsWith(" passage_input80 ") + ); + assertThat( + getMatchedText(inputs.get(1).inputText(), chunkedByteResult.chunks().get(5).offset()), + startsWith(" passage_input100 ") + ); } { var chunkedResult = finalListener.results.get(2); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2).input(), chunkedByteResult.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(getMatchedText(inputs.get(2).inputText(), chunkedByteResult.chunks().getFirst().offset()), equalTo("2nd small")); } { var chunkedResult = finalListener.results.get(3); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(3).input(), chunkedByteResult.chunks().get(0).offset()), equalTo("3rd small")); + assertThat(getMatchedText(inputs.get(3).inputText(), chunkedByteResult.chunks().getFirst().offset()), equalTo("3rd small")); } } @@ -733,7 +784,7 @@ public void testMergingListener_Bit() { for (int i = 0; i < batchSize; i++) { embeddings.add(new DenseEmbeddingByteResults.Embedding(new byte[] { randomByte() })); } - batches.get(0).listener().onResponse(new DenseEmbeddingBitResults(embeddings)); + batches.getFirst().listener().onResponse(new DenseEmbeddingBitResults(embeddings)); } { var embeddings = new ArrayList(); @@ -746,11 +797,11 @@ public void testMergingListener_Bit() { assertNotNull(finalListener.results); assertThat(finalListener.results, hasSize(4)); { - var chunkedResult = finalListener.results.get(0); + var chunkedResult = finalListener.results.getFirst(); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0).input(), chunkedByteResult.chunks().get(0).offset()), equalTo("1st small")); + assertThat(getMatchedText(inputs.getFirst().inputText(), chunkedByteResult.chunks().getFirst().offset()), equalTo("1st small")); } { // this is the large input split in multiple chunks @@ -758,26 +809,44 @@ public void testMergingListener_Bit() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(6)); - assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(0).offset()), startsWith("passage_input0 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(1).offset()), startsWith(" passage_input20 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(2).offset()), startsWith(" passage_input40 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(3).offset()), startsWith(" passage_input60 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(4).offset()), startsWith(" passage_input80 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(5).offset()), startsWith(" passage_input100 ")); + assertThat( + getMatchedText(inputs.get(1).inputText(), chunkedByteResult.chunks().get(0).offset()), + startsWith("passage_input0 ") + ); + assertThat( + getMatchedText(inputs.get(1).inputText(), chunkedByteResult.chunks().get(1).offset()), + startsWith(" passage_input20 ") + ); + assertThat( + getMatchedText(inputs.get(1).inputText(), chunkedByteResult.chunks().get(2).offset()), + startsWith(" passage_input40 ") + ); + assertThat( + getMatchedText(inputs.get(1).inputText(), chunkedByteResult.chunks().get(3).offset()), + startsWith(" passage_input60 ") + ); + assertThat( + getMatchedText(inputs.get(1).inputText(), chunkedByteResult.chunks().get(4).offset()), + startsWith(" passage_input80 ") + ); + assertThat( + getMatchedText(inputs.get(1).inputText(), chunkedByteResult.chunks().get(5).offset()), + startsWith(" passage_input100 ") + ); } { var chunkedResult = finalListener.results.get(2); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2).input(), chunkedByteResult.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(getMatchedText(inputs.get(2).inputText(), chunkedByteResult.chunks().getFirst().offset()), equalTo("2nd small")); } { var chunkedResult = finalListener.results.get(3); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(3).input(), chunkedByteResult.chunks().get(0).offset()), equalTo("3rd small")); + assertThat(getMatchedText(inputs.get(3).inputText(), chunkedByteResult.chunks().getFirst().offset()), equalTo("3rd small")); } } @@ -810,7 +879,7 @@ public void testMergingListener_Sparse() { for (int i = 0; i < batchSize; i++) { embeddings.add(new SparseEmbeddingResults.Embedding(List.of(new WeightedToken(randomAlphaOfLength(4), 1.0f)), false)); } - batches.get(0).listener().onResponse(new SparseEmbeddingResults(embeddings)); + batches.getFirst().listener().onResponse(new SparseEmbeddingResults(embeddings)); } { var embeddings = new ArrayList(); @@ -830,25 +899,28 @@ public void testMergingListener_Sparse() { assertNotNull(finalListener.results); assertThat(finalListener.results, hasSize(4)); { - var chunkedResult = finalListener.results.get(0); + var chunkedResult = finalListener.results.getFirst(); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedSparseResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0).input(), chunkedSparseResult.chunks().get(0).offset()), equalTo("1st small")); + assertThat( + getMatchedText(inputs.getFirst().inputText(), chunkedSparseResult.chunks().getFirst().offset()), + equalTo("1st small") + ); } { var chunkedResult = finalListener.results.get(1); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedSparseResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(1).input(), chunkedSparseResult.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(getMatchedText(inputs.get(1).inputText(), chunkedSparseResult.chunks().getFirst().offset()), equalTo("2nd small")); } { var chunkedResult = finalListener.results.get(2); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedSparseResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2).input(), chunkedSparseResult.chunks().get(0).offset()), equalTo("3rd small")); + assertThat(getMatchedText(inputs.get(2).inputText(), chunkedSparseResult.chunks().getFirst().offset()), equalTo("3rd small")); } { // this is the large input split in multiple chunks @@ -856,13 +928,16 @@ public void testMergingListener_Sparse() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedSparseResult.chunks(), hasSize(9)); // passage is split into 9 chunks, 10 words each - assertThat(getMatchedText(inputs.get(3).input(), chunkedSparseResult.chunks().get(0).offset()), startsWith("passage_input0 ")); assertThat( - getMatchedText(inputs.get(3).input(), chunkedSparseResult.chunks().get(1).offset()), + getMatchedText(inputs.get(3).inputText(), chunkedSparseResult.chunks().get(0).offset()), + startsWith("passage_input0 ") + ); + assertThat( + getMatchedText(inputs.get(3).inputText(), chunkedSparseResult.chunks().get(1).offset()), startsWith(" passage_input10 ") ); assertThat( - getMatchedText(inputs.get(3).input(), chunkedSparseResult.chunks().get(8).offset()), + getMatchedText(inputs.get(3).inputText(), chunkedSparseResult.chunks().get(8).offset()), startsWith(" passage_input80 ") ); } @@ -880,8 +955,8 @@ public void testListenerErrorsWithWrongNumberOfResponses() { @Override public void onResponse(List chunkedResults) { - assertThat(chunkedResults.get(0), instanceOf(ChunkedInferenceError.class)); - var error = (ChunkedInferenceError) chunkedResults.get(0); + assertThat(chunkedResults.getFirst(), instanceOf(ChunkedInferenceError.class)); + var error = (ChunkedInferenceError) chunkedResults.getFirst(); failureMessage.set(error.exception().getMessage()); } @@ -897,10 +972,43 @@ public void onFailure(Exception e) { var embeddings = new ArrayList(); embeddings.add(new DenseEmbeddingFloatResults.Embedding(new float[] { randomFloat() })); embeddings.add(new DenseEmbeddingFloatResults.Embedding(new float[] { randomFloat() })); - batches.get(0).listener().onResponse(new DenseEmbeddingFloatResults(embeddings)); + batches.getFirst().listener().onResponse(new DenseEmbeddingFloatResults(embeddings)); assertEquals("Error the number of embedding responses [2] does not equal the number of requests [3]", failureMessage.get()); } + public void testDoesNotChunkNonTextInputs_whenChunkingSettingsAreNullInInput() { + InferenceString imageString = new InferenceString( + "image chunks", + randomValueOtherThan(TEXT, () -> randomFrom(InferenceString.DataType.values())) + ); + ChunkInferenceInput imageInput = new ChunkInferenceInput(imageString, null); + ChunkInferenceInput textInput = new ChunkInferenceInput(new InferenceString("text chunks", TEXT), null); + + var batches = new EmbeddingRequestChunker<>(List.of(imageInput, textInput), 100, 1, 0).batchRequestsWithListeners(testListener()); + + assertThat(batches, hasSize(1)); + var expectedOutput = List.of(imageString, new InferenceString("text", TEXT), new InferenceString(" chunks", TEXT)); + assertThat(batches.getFirst().batch().inputs().get(), is(expectedOutput)); + assertThat(batches.getFirst().batch().inputs().get().getFirst(), is(sameInstance(imageString))); + } + + public void testDoesNotChunkNonTextInputs_whenChunkingSettingsAreSpecifiedInInput() { + InferenceString imageString = new InferenceString( + "image chunks", + randomValueOtherThan(TEXT, () -> randomFrom(InferenceString.DataType.values())) + ); + WordBoundaryChunkingSettings chunkingSettings = new WordBoundaryChunkingSettings(1, 0); + ChunkInferenceInput imageInput = new ChunkInferenceInput(imageString, chunkingSettings); + ChunkInferenceInput textInput = new ChunkInferenceInput(new InferenceString("text chunks", TEXT), chunkingSettings); + + var batches = new EmbeddingRequestChunker<>(List.of(imageInput, textInput), 100).batchRequestsWithListeners(testListener()); + + assertThat(batches, hasSize(1)); + var expectedOutput = List.of(imageString, new InferenceString("text", TEXT), new InferenceString(" chunks", TEXT)); + assertThat(batches.getFirst().batch().inputs().get(), is(expectedOutput)); + assertThat(batches.getFirst().batch().inputs().get().getFirst(), is(sameInstance(imageString))); + } + private ChunkedResultsListener testListener() { return new ChunkedResultsListener(); } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java index 35c317e3ca1a7..da9ec1a808cee 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java @@ -121,29 +121,29 @@ public void close() throws IOException {} protected List chunkInputs(ChunkInferenceInput input) { ChunkingSettings chunkingSettings = input.chunkingSettings(); - String inputText = input.input(); - if (chunkingSettings == null) { - return List.of(new ChunkedInput(inputText, 0, inputText.length())); + String inputString = input.input().value(); + if (chunkingSettings == null || input.input().isText() == false) { + return List.of(new ChunkedInput(inputString, 0, inputString.length())); } List chunkedInputs = new ArrayList<>(); if (chunkingSettings.getChunkingStrategy() == ChunkingStrategy.NONE) { - var offsets = NoopChunker.INSTANCE.chunk(input.input(), chunkingSettings); + var offsets = NoopChunker.INSTANCE.chunk(inputString, chunkingSettings); List ret = new ArrayList<>(); for (var offset : offsets) { - ret.add(new ChunkedInput(inputText.substring(offset.start(), offset.end()), offset.start(), offset.end())); + ret.add(new ChunkedInput(inputString.substring(offset.start(), offset.end()), offset.start(), offset.end())); } return ret; } else if (chunkingSettings.getChunkingStrategy() == ChunkingStrategy.WORD) { WordBoundaryChunker chunker = new WordBoundaryChunker(); WordBoundaryChunkingSettings wordBoundaryChunkingSettings = (WordBoundaryChunkingSettings) chunkingSettings; List offsets = chunker.chunk( - inputText, + inputString, wordBoundaryChunkingSettings.maxChunkSize(), wordBoundaryChunkingSettings.overlap() ); for (WordBoundaryChunker.ChunkOffset offset : offsets) { - chunkedInputs.add(new ChunkedInput(inputText.substring(offset.start(), offset.end()), offset.start(), offset.end())); + chunkedInputs.add(new ChunkedInput(inputString.substring(offset.start(), offset.end()), offset.start(), offset.end())); } } else { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 240d4863c7a18..4c286d56d839a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -43,6 +43,7 @@ import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.InferenceString; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; @@ -76,6 +77,7 @@ import java.util.Map; import java.util.stream.Collectors; +import static org.elasticsearch.inference.InferenceString.DataType.TEXT; import static org.elasticsearch.inference.telemetry.InferenceStats.serviceAndResponseAttributes; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunksLegacy; @@ -393,7 +395,7 @@ public void onFailure(Exception exc) { } final List inputs = requests.stream() - .map(r -> new ChunkInferenceInput(r.input, r.chunkingSettings)) + .map(r -> new ChunkInferenceInput(new InferenceString(r.input, TEXT), r.chunkingSettings)) .collect(Collectors.toList()); ActionListener> completionListener = new ActionListener<>() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java index f9fd3a2011ee0..5bff84920c468 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java @@ -8,31 +8,35 @@ package org.elasticsearch.xpack.inference.external.http.sender; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InferenceString; import org.elasticsearch.inference.InputType; import java.util.List; import java.util.Objects; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Supplier; +import java.util.stream.Collectors; + +import static org.elasticsearch.inference.InferenceString.DataType.TEXT; public class EmbeddingsInput extends InferenceInputs { - private final Supplier> inputListSupplier; + private final Supplier> inputListSupplier; private final InputType inputType; private final AtomicBoolean supplierInvoked = new AtomicBoolean(); public EmbeddingsInput(List input, @Nullable InputType inputType) { - this(() -> input, inputType, false); + this(() -> input.stream().map(s -> new InferenceString(s, TEXT)).collect(Collectors.toList()), inputType, false); } public EmbeddingsInput(List input, @Nullable InputType inputType, boolean stream) { - this(() -> input, inputType, stream); + this(() -> input.stream().map(s -> new InferenceString(s, TEXT)).collect(Collectors.toList()), inputType, stream); } - public EmbeddingsInput(Supplier> inputSupplier, @Nullable InputType inputType) { + public EmbeddingsInput(Supplier> inputSupplier, @Nullable InputType inputType) { this(inputSupplier, inputType, false); } - private EmbeddingsInput(Supplier> inputSupplier, @Nullable InputType inputType, boolean stream) { + private EmbeddingsInput(Supplier> inputSupplier, @Nullable InputType inputType, boolean stream) { super(stream); this.inputListSupplier = Objects.requireNonNull(inputSupplier); this.inputType = inputType; @@ -44,13 +48,32 @@ private EmbeddingsInput(Supplier> inputSupplier, @Nullable InputTyp * Strings when invoked is passed into the constructor, this will result in multiple copies of the input Strings being created. Calling * this method twice in a non-production environment will cause an {@link AssertionError} to be thrown. * - * @return a list of String embedding inputs + * @return a list of {@link InferenceString} embedding inputs */ - public List getInputs() { + public List getInputs() { assert supplierInvoked.compareAndSet(false, true) : "EmbeddingsInput supplier invoked twice"; return inputListSupplier.get(); } + /** + * This method should only be called in code paths that do not deal with multimodal embeddings; where all inputs are guaranteed to be + * raw text, since it discards the {@link org.elasticsearch.inference.InferenceString.DataType} associated with each input. + *

+ * Calling this method twice will result in the {@link #inputListSupplier} being invoked twice. In the case where the supplier simply + * returns the list passed into the constructor, this is not a problem, but in the case where a supplier that will chunk the input + * Strings when invoked is passed into the constructor, this will result in multiple copies of the input Strings being created. Calling + * this method twice in a non-production environment will cause an {@link AssertionError} to be thrown. + * + * @return a list of String embedding inputs that do not contain any non-text inputs + */ + public List getTextInputs() { + assert supplierInvoked.compareAndSet(false, true) : "EmbeddingsInput supplier invoked twice"; + return inputListSupplier.get().stream().map(i -> { + assert i.isText() : "Non-text input returned from EmbeddingsInput.getTextInputs"; + return i.value(); + }).toList(); + } + public InputType getInputType() { return this.inputType; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java index 4a485f87858aa..5753199b239f3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java @@ -52,7 +52,7 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getInputs(); + var docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getTextInputs(); var truncatedInput = truncate(docsInput, maxInputTokens); var request = requestCreator.apply(truncatedInput); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestManager.java index 8a77f65592226..aa7794e60414a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestManager.java @@ -73,7 +73,7 @@ public void execute( ) { EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); - List docsInput = input.getInputs(); + List docsInput = input.getTextInputs(); InputType inputType = input.getInputType(); AlibabaCloudSearchEmbeddingsRequest request = new AlibabaCloudSearchEmbeddingsRequest(account, docsInput, inputType, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchSparseRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchSparseRequestManager.java index 20ff8ce58b550..11d906d821ec4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchSparseRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchSparseRequestManager.java @@ -72,7 +72,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); - List docsInput = input.getInputs(); + List docsInput = input.getTextInputs(); InputType inputType = input.getInputType(); AlibabaCloudSearchSparseRequest request = new AlibabaCloudSearchSparseRequest(account, docsInput, inputType, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockEmbeddingsRequestManager.java index 387d8b65f40d6..390031e741295 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockEmbeddingsRequestManager.java @@ -57,7 +57,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); - List docsInput = input.getInputs(); + List docsInput = input.getTextInputs(); InputType inputType = input.getInputType(); var serviceSettings = embeddingsModel.getServiceSettings(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioEmbeddingsRequestManager.java index afb268ab499a9..e405df3c125a1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioEmbeddingsRequestManager.java @@ -51,7 +51,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); - List docsInput = input.getInputs(); + List docsInput = input.getTextInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiEmbeddingsRequestManager.java index db38b3fb0def3..790abb909ddcd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiEmbeddingsRequestManager.java @@ -64,7 +64,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); - List docsInput = input.getInputs(); + List docsInput = input.getTextInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java index 121f0e1e80a96..440e401fbdb5f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java @@ -81,8 +81,8 @@ public ExecutableAction create(CohereEmbeddingsModel model, Map : overriddenModel.getTaskSettings().getInputType(); return switch (overriddenModel.getServiceSettings().getCommonSettings().apiVersion()) { - case V1 -> new CohereV1EmbeddingsRequest(inferenceInputs.getInputs(), requestInputType, overriddenModel); - case V2 -> new CohereV2EmbeddingsRequest(inferenceInputs.getInputs(), requestInputType, overriddenModel); + case V1 -> new CohereV1EmbeddingsRequest(inferenceInputs.getTextInputs(), requestInputType, overriddenModel); + case V2 -> new CohereV2EmbeddingsRequest(inferenceInputs.getTextInputs(), requestInputType, overriddenModel); }; }; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index 45fcc7790c55e..be3f8bc85e031 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -22,6 +22,7 @@ import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InferenceString; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; @@ -56,6 +57,7 @@ import java.util.List; import java.util.Map; +import static org.elasticsearch.inference.InferenceString.DataType.TEXT; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; @@ -154,7 +156,7 @@ private static RequestParameters createParameters(CustomModel model) { case RERANK -> RerankParameters.of(new QueryAndDocsInputs("test query", List.of("test input"))); case COMPLETION -> CompletionParameters.of(new ChatCompletionInput(List.of("test input"))); case TEXT_EMBEDDING, SPARSE_EMBEDDING -> EmbeddingParameters.of( - new EmbeddingsInput(List.of("test input"), null), + new EmbeddingsInput(() -> List.of(new InferenceString("test input", TEXT)), null), model.getServiceSettings().getInputTypeTranslator() ); default -> throw new IllegalStateException( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParameters.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParameters.java index ef91045eb1dab..281432d0cd020 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParameters.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParameters.java @@ -28,7 +28,7 @@ public static EmbeddingParameters of(EmbeddingsInput embeddingsInput, InputTypeT private final InputTypeTranslator translator; private EmbeddingParameters(EmbeddingsInput embeddingsInput, InputTypeTranslator translator) { - super(embeddingsInput.getInputs()); + super(embeddingsInput.getTextInputs()); this.inputType = embeddingsInput.getInputType(); this.translator = translator; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java index f34d538c413dc..a8b1cd8b11b7e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java @@ -70,7 +70,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); - List docsInput = input.getInputs(); + List docsInput = input.getTextInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java index 749a3277929f8..e980d7f713495 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java @@ -97,7 +97,7 @@ public ExecutableAction create(ElasticInferenceServiceDenseTextEmbeddingsModel m DENSE_TEXT_EMBEDDINGS_HANDLER, (embeddingsInput) -> new ElasticInferenceServiceDenseTextEmbeddingsRequest( model, - embeddingsInput.getInputs(), + embeddingsInput.getTextInputs(), traceContext, extractRequestMetadataFromThreadContext(threadPool.getThreadContext()), embeddingsInput.getInputType() diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 2d2a04ba3f7fe..55c2f7623c157 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -74,6 +74,7 @@ import java.util.function.Consumer; import java.util.function.Function; +import static org.elasticsearch.inference.InferenceString.toStringList; import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; @@ -1139,7 +1140,7 @@ private void executeRequest(int batchIndex, boolean maybeDeploy, Runnable runAft var inferenceRequest = buildInferenceRequest( esModel.mlNodeDeploymentId(), EmptyConfigUpdate.INSTANCE, - batch.batch().inputs().get(), + toStringList(batch.batch().inputs().get()), inputType, timeout ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioEmbeddingsRequestManager.java index e65bfda857c1b..84485fb9f8c77 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioEmbeddingsRequestManager.java @@ -57,7 +57,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); - List docsInput = input.getInputs(); + List docsInput = input.getTextInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiEmbeddingsRequestManager.java index 90d7a0b1b0a11..3ace1f9c0aab0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiEmbeddingsRequestManager.java @@ -65,7 +65,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); - List docsInput = input.getInputs(); + List docsInput = input.getTextInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java index 833f9bc6b347b..1ae1ff599136b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java @@ -62,7 +62,7 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List inputs = inferenceInputs.castTo(EmbeddingsInput.class).getInputs(); + List inputs = inferenceInputs.castTo(EmbeddingsInput.class).getTextInputs(); var truncatedInput = truncate(inputs, model.getTokenLimit()); var request = new HuggingFaceEmbeddingsRequest(truncator, truncatedInput, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index 681058d8c21a5..42f5e4591a7bc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -116,7 +116,7 @@ protected void doChunkedInfer( // TODO chunking sparse embeddings not implemented doInfer( model, - new EmbeddingsInput(inputs.stream().map(ChunkInferenceInput::input).toList(), inputType), + new EmbeddingsInput(() -> inputs.stream().map(ChunkInferenceInput::input).toList(), inputType), taskSettings, timeout, inferListener @@ -128,7 +128,7 @@ private static List translateToChunkedResults( InferenceServiceResults inferenceResults ) { if (inferenceResults instanceof DenseEmbeddingFloatResults denseEmbeddingResults) { - validateInputSizeAgainstEmbeddings(ChunkInferenceInput.inputs(inputs), denseEmbeddingResults.embeddings().size()); + validateInputSizeAgainstEmbeddings(ChunkInferenceInput.inputs(inputs).size(), denseEmbeddingResults.embeddings().size()); var results = new ArrayList(inputs.size()); @@ -138,7 +138,7 @@ private static List translateToChunkedResults( List.of( new EmbeddingResults.Chunk( denseEmbeddingResults.embeddings().get(i), - new ChunkedInference.TextOffset(0, inputs.get(i).input().length()) + new ChunkedInference.TextOffset(0, inputs.get(i).input().value().length()) ) ) ) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxEmbeddingsRequestManager.java index 92434d371c7e8..5a904c3420c01 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxEmbeddingsRequestManager.java @@ -55,7 +55,7 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getInputs(); + List docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getTextInputs(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); execute( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIEmbeddingsRequestManager.java index 3e3918acb78dc..ab72aab578365 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIEmbeddingsRequestManager.java @@ -53,7 +53,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); - List docsInput = input.getInputs(); + List docsInput = input.getTextInputs(); InputType inputType = input.getInputType(); JinaAIEmbeddingsRequest request = new JinaAIEmbeddingsRequest(docsInput, inputType, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreator.java index f647338ba3110..3e071d46c0cc4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreator.java @@ -73,7 +73,7 @@ public ExecutableAction create(LlamaEmbeddingsModel model) { EMBEDDINGS_HANDLER, embeddingsInput -> new LlamaEmbeddingsRequest( serviceComponents.truncator(), - truncate(embeddingsInput.getInputs(), model.getServiceSettings().maxInputTokens()), + truncate(embeddingsInput.getTextInputs(), model.getServiceSettings().maxInputTokens()), model ), EmbeddingsInput.class diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralEmbeddingsRequestManager.java index ea31435780b96..aef78f015294e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralEmbeddingsRequestManager.java @@ -61,7 +61,7 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getInputs(); + List docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getTextInputs(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); MistralEmbeddingsRequest request = new MistralEmbeddingsRequest(truncator, truncatedInput, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java index 88cc21d43649b..5f279e4f60b6b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java @@ -44,6 +44,7 @@ import java.util.Set; import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.inference.InferenceString.toStringList; import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails; @@ -288,7 +289,7 @@ public void chunkedInfer( query, null, // no return docs while chunking? null, // no topN while chunking? - request.batch().inputs().get(), + toStringList(request.batch().inputs().get()), false, // we never stream when chunking null, // since we pass sageMakerModel as the model, we already overwrote the model with the task settings inputType, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreator.java index 5bf9bd66def2f..396bbac39ad0b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreator.java @@ -57,7 +57,7 @@ public ExecutableAction create(VoyageAIEmbeddingsModel model, Map new VoyageAIEmbeddingsRequest( - embeddingsInput.getInputs(), + embeddingsInput.getTextInputs(), embeddingsInput.getInputType(), overriddenModel ), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 070a6dc8a9538..04997af32b125 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -1068,7 +1068,7 @@ private static ShardBulkInferenceActionFilter createFilter( Runnable runnable = () -> { List results = new ArrayList<>(); for (ChunkInferenceInput input : inputs) { - results.add(model.getResults(input.input())); + results.add(model.getResults(input.inputText())); } listener.onResponse(results); }; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInputTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInputTests.java index d6ba10b1932dc..86f12d30950bd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInputTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInputTests.java @@ -7,19 +7,22 @@ package org.elasticsearch.xpack.inference.external.http.sender; +import org.elasticsearch.inference.InferenceString; import org.elasticsearch.test.ESTestCase; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Supplier; +import static org.elasticsearch.inference.InferenceString.DataType.IMAGE_BASE64; +import static org.elasticsearch.inference.InferenceString.DataType.TEXT; import static org.hamcrest.Matchers.is; public class EmbeddingsInputTests extends ESTestCase { public void testCallingGetInputs_invokesSupplier() { AtomicBoolean invoked = new AtomicBoolean(); - final List list = List.of("input1", "input2"); - Supplier> supplier = () -> { + final List list = List.of(new InferenceString("input1", TEXT), new InferenceString("image_url", IMAGE_BASE64)); + Supplier> supplier = () -> { invoked.set(true); return list; }; @@ -31,11 +34,53 @@ public void testCallingGetInputs_invokesSupplier() { assertThat(invoked.get(), is(true)); } + public void testCallingGetTextInputs_invokesSupplier() { + AtomicBoolean invoked = new AtomicBoolean(); + var textInputs = List.of("input1", "input2"); + final List list = textInputs.stream().map(i -> new InferenceString(i, TEXT)).toList(); + Supplier> supplier = () -> { + invoked.set(true); + return list; + }; + EmbeddingsInput input = new EmbeddingsInput(supplier, null); + // Ensure we don't invoke the supplier until we call getTextInputs() + assertThat(invoked.get(), is(false)); + + assertThat(input.getTextInputs(), is(textInputs)); + assertThat(invoked.get(), is(true)); + } + + public void testCallingGetTextInputs_withNonTextInput_throws() { + Supplier> supplier = () -> List.of( + new InferenceString("input1", TEXT), + new InferenceString("image_url", IMAGE_BASE64) + ); + EmbeddingsInput input = new EmbeddingsInput(supplier, null); + var exception = expectThrows(AssertionError.class, input::getTextInputs); + assertThat(exception.getMessage(), is("Non-text input returned from EmbeddingsInput.getTextInputs")); + } + public void testCallingGetInputsTwice_throws() { - Supplier> supplier = () -> List.of("input"); + Supplier> supplier = () -> List.of(new InferenceString("input1", TEXT)); EmbeddingsInput input = new EmbeddingsInput(supplier, null); input.getInputs(); var exception = expectThrows(AssertionError.class, input::getInputs); assertThat(exception.getMessage(), is("EmbeddingsInput supplier invoked twice")); } + + public void testCallingGetTextInputsTwice_throws() { + Supplier> supplier = () -> List.of(new InferenceString("input1", TEXT)); + EmbeddingsInput input = new EmbeddingsInput(supplier, null); + input.getTextInputs(); + var exception = expectThrows(AssertionError.class, input::getTextInputs); + assertThat(exception.getMessage(), is("EmbeddingsInput supplier invoked twice")); + } + + public void testCallingEitherGetInputsMethodTwice_throws() { + Supplier> supplier = () -> List.of(new InferenceString("input1", TEXT)); + EmbeddingsInput input = new EmbeddingsInput(supplier, null); + input.getInputs(); + var exception = expectThrows(AssertionError.class, input::getTextInputs); + assertThat(exception.getMessage(), is("EmbeddingsInput supplier invoked twice")); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java index f77553d3abf4c..571e18e887727 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java @@ -83,7 +83,7 @@ public void send( ) { sendCounter++; if (inferenceInputs instanceof EmbeddingsInput embeddingsInput) { - inputs.add(embeddingsInput.getInputs()); + inputs.add(embeddingsInput.getTextInputs()); if (embeddingsInput.getInputType() != null) { inputTypes.add(embeddingsInput.getInputType()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index ac5ce33f9a02b..68538608513bb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -1443,7 +1443,7 @@ public void testChunkingLargeDocument() throws InterruptedException { // build a doc with enough words to make numChunks of chunks int wordsPerChunk = 10; int numWords = numChunks * wordsPerChunk; - var input = new ChunkInferenceInput("word ".repeat(numWords), null); + var input = new ChunkInferenceInput("word ".repeat(numWords)); Client client = mock(Client.class); when(client.threadPool()).thenReturn(threadPool); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index 113d874eb7280..d6e938ac17fdf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -940,7 +940,7 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed assertThat(results.get(0), instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, input.get(0).input().length()), floatResult.chunks().get(0).offset()); + assertEquals(new ChunkedInference.TextOffset(0, input.get(0).inputText().length()), floatResult.chunks().get(0).offset()); assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( @@ -955,7 +955,7 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed assertThat(results.get(1), instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, input.get(1).input().length()), floatResult.chunks().get(0).offset()); + assertEquals(new ChunkedInference.TextOffset(0, input.get(1).inputText().length()), floatResult.chunks().get(0).offset()); assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( @@ -979,7 +979,7 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed "model", Strings.format("%s/%s", "models", modelId), "content", - Map.of("parts", List.of(Map.of("text", input.get(0).input()))), + Map.of("parts", List.of(Map.of("text", input.get(0).inputText()))), "taskType", "RETRIEVAL_DOCUMENT" ), @@ -987,7 +987,7 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed "model", Strings.format("%s/%s", "models", modelId), "content", - Map.of("parts", List.of(Map.of("text", input.get(1).input()))), + Map.of("parts", List.of(Map.of("text", input.get(1).inputText()))), "taskType", "RETRIEVAL_DOCUMENT" ) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index 14e5cde764623..af935cf898927 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -788,7 +788,7 @@ private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws assertThat(results.get(0), instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, input.get(0).input().length()), floatResult.chunks().get(0).offset()); + assertEquals(new ChunkedInference.TextOffset(0, input.get(0).inputText().length()), floatResult.chunks().get(0).offset()); assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( @@ -803,7 +803,7 @@ private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws assertThat(results.get(1), instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, input.get(1).input().length()), floatResult.chunks().get(0).offset()); + assertEquals(new ChunkedInference.TextOffset(0, input.get(1).inputText().length()), floatResult.chunks().get(0).offset()); assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java index b326664c527c1..e208b53f60763 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java @@ -448,7 +448,7 @@ private ExecutableAction createAction( threadPool, model, EMBEDDINGS_HANDLER, - (embeddingsInput) -> new VoyageAIEmbeddingsRequest(embeddingsInput.getInputs(), embeddingsInput.getInputType(), model), + (embeddingsInput) -> new VoyageAIEmbeddingsRequest(embeddingsInput.getTextInputs(), embeddingsInput.getInputType(), model), EmbeddingsInput.class );