Skip to content

Commit 944029a

Browse files
committed
Optimize MistralAiEmbeddingModel dimensions method
- Calculate and cache values for unknown models only if necessary - Make known embedding dimensions a mutable map attribute - Polish MistralAiEmbeddingModelTests Signed-off-by: Nicolas Krier <7557886+nicolaskrier@users.noreply.github.com>
1 parent 47e4232 commit 944029a

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package org.springframework.ai.mistralai;
1818

19+
import java.util.HashMap;
1920
import java.util.List;
2021
import java.util.Map;
2122

@@ -56,16 +57,14 @@ public class MistralAiEmbeddingModel extends AbstractEmbeddingModel {
5657

5758
private static final Logger logger = LoggerFactory.getLogger(MistralAiEmbeddingModel.class);
5859

60+
private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();
61+
5962
/**
6063
* Known embedding dimensions for Mistral AI models. Maps model names to their
6164
* respective embedding vector dimensions. This allows the dimensions() method to
6265
* return the correct value without making an API call.
6366
*/
64-
private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = Map.of(
65-
MistralAiApi.EmbeddingModel.EMBED.getValue(), 1024, MistralAiApi.EmbeddingModel.CODESTRAL_EMBED.getValue(),
66-
1536);
67-
68-
private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();
67+
private final Map<String, Integer> knownEmbeddingDimensions = createKnownEmbeddingDimensions();
6968

7069
private final MistralAiEmbeddingOptions defaultOptions;
7170

@@ -85,6 +84,14 @@ public class MistralAiEmbeddingModel extends AbstractEmbeddingModel {
8584
*/
8685
private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
8786

87+
private static Map<String, Integer> createKnownEmbeddingDimensions() {
88+
Map<String, Integer> knownEmbeddingDimensions = new HashMap<>();
89+
knownEmbeddingDimensions.put(MistralAiApi.EmbeddingModel.EMBED.getValue(), 1024);
90+
knownEmbeddingDimensions.put(MistralAiApi.EmbeddingModel.CODESTRAL_EMBED.getValue(), 1536);
91+
92+
return knownEmbeddingDimensions;
93+
}
94+
8895
public MistralAiEmbeddingModel(MistralAiApi mistralAiApi, MetadataMode metadataMode,
8996
MistralAiEmbeddingOptions options, RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {
9097
Assert.notNull(mistralAiApi, "mistralAiApi must not be null");
@@ -174,7 +181,8 @@ public float[] embed(Document document) {
174181

175182
@Override
176183
public int dimensions() {
177-
return KNOWN_EMBEDDING_DIMENSIONS.getOrDefault(this.defaultOptions.getModel(), super.dimensions());
184+
return this.knownEmbeddingDimensions.computeIfAbsent(this.defaultOptions.getModel(),
185+
model -> super.dimensions());
178186
}
179187

180188
/**

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingModelTests.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package org.springframework.ai.mistralai;
1818

19+
import java.util.Arrays;
1920
import java.util.List;
2021

2122
import org.junit.jupiter.api.Test;
@@ -94,7 +95,7 @@ void testDimensionsFallbackForUnknownModel() {
9495

9596
@Test
9697
void testAllEmbeddingModelsHaveDimensionMapping() {
97-
// This test ensures that KNOWN_EMBEDDING_DIMENSIONS map stays in sync with the
98+
// This test ensures that knownEmbeddingDimensions map stays in sync with the
9899
// EmbeddingModel enum
99100
// If a new model is added to the enum but not to the dimensions map, this test
100101
// will help catch it
@@ -138,16 +139,13 @@ private MistralAiApi createMockApiWithEmbeddingResponse(int dimensions) {
138139

139140
// Create a mock embedding response with the specified dimensions
140141
float[] embedding = new float[dimensions];
141-
for (int i = 0; i < dimensions; i++) {
142-
embedding[i] = 0.1f;
143-
}
142+
Arrays.fill(embedding, 0.1f);
144143

145144
MistralAiApi.Embedding embeddingData = new MistralAiApi.Embedding(0, embedding, "embedding");
146145

147146
MistralAiApi.Usage usage = new MistralAiApi.Usage(10, 0, 10);
148147

149-
MistralAiApi.EmbeddingList embeddingList = new MistralAiApi.EmbeddingList("object", List.of(embeddingData),
150-
"model", usage);
148+
var embeddingList = new MistralAiApi.EmbeddingList<>("object", List.of(embeddingData), "model", usage);
151149

152150
when(mockApi.embeddings(any())).thenReturn(ResponseEntity.ok(embeddingList));
153151

0 commit comments

Comments
 (0)