Skip to content

Commit 4d2d053

Browse files
authored
[ML] Implement CCMCache (#137743)
Implements a lazy-loading cache in front of the CCM Inference index. By default, the cache holds a single entry for 15 minutes, and cache misses search the CCM index and load the responses into the cache. Some design decisions: - The cache maintains an "empty" entry so that the `isPresent` call can reuse the "empty" response to quickly return false. Invalidating the cache or calling get will drop this "empty" entry. - Invalidating the cache will broadcast a message to all nodes so that all caches on all nodes will invalidate their caches. - Since the broadcast message only works if all nodes are on the latest version, there is a new NodeFeature to enable the cache once all nodes in the cluster have upgraded.
1 parent 346abfc commit 4d2d053

File tree

7 files changed

+583
-3
lines changed

7 files changed

+583
-3
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.elastic.ccm;
9+
10+
import org.elasticsearch.ResourceNotFoundException;
11+
import org.elasticsearch.action.support.ActionTestUtils;
12+
import org.elasticsearch.action.support.TestPlainActionFuture;
13+
import org.elasticsearch.common.bytes.BytesArray;
14+
import org.elasticsearch.core.TimeValue;
15+
import org.elasticsearch.plugins.Plugin;
16+
import org.elasticsearch.test.ESSingleNodeTestCase;
17+
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
18+
import org.junit.After;
19+
import org.junit.Before;
20+
21+
import java.io.IOException;
22+
import java.util.Collection;
23+
import java.util.List;
24+
import java.util.concurrent.CountDownLatch;
25+
import java.util.concurrent.TimeUnit;
26+
27+
import static org.hamcrest.Matchers.equalTo;
28+
import static org.hamcrest.Matchers.not;
29+
import static org.hamcrest.Matchers.sameInstance;
30+
31+
public class CCMCacheTests extends ESSingleNodeTestCase {
32+
33+
private static final TimeValue TIMEOUT = TimeValue.THIRTY_SECONDS;
34+
35+
private CCMCache ccmCache;
36+
private CCMPersistentStorageService ccmPersistentStorageService;
37+
38+
@Override
39+
protected Collection<Class<? extends Plugin>> getPlugins() {
40+
return List.of(LocalStateInferencePlugin.class);
41+
}
42+
43+
@Before
44+
public void createComponents() {
45+
ccmCache = node().injector().getInstance(CCMCache.class);
46+
ccmPersistentStorageService = node().injector().getInstance(CCMPersistentStorageService.class);
47+
}
48+
49+
@Override
50+
protected boolean resetNodeAfterTest() {
51+
return true;
52+
}
53+
54+
@After
55+
public void clearCacheAndIndex() {
56+
try {
57+
indicesAdmin().prepareDelete(CCMIndex.INDEX_NAME).execute().actionGet(TIMEOUT);
58+
} catch (ResourceNotFoundException e) {
59+
// mission complete!
60+
}
61+
}
62+
63+
public void testCacheHit() throws IOException {
64+
var expectedCcmModel = storeCcm();
65+
var actualCcmModel = getFromCache();
66+
assertThat(actualCcmModel, equalTo(expectedCcmModel));
67+
assertThat(ccmCache.stats().getHits(), equalTo(0L));
68+
assertThat(getFromCache(), sameInstance(actualCcmModel));
69+
assertThat(ccmCache.stats().getHits(), equalTo(1L));
70+
}
71+
72+
private CCMModel storeCcm() throws IOException {
73+
var ccmModel = CCMModel.fromXContentBytes(new BytesArray("""
74+
{
75+
"api_key": "test_key"
76+
}
77+
"""));
78+
var listener = new TestPlainActionFuture<Void>();
79+
ccmPersistentStorageService.store(ccmModel, listener);
80+
listener.actionGet(TIMEOUT);
81+
return ccmModel;
82+
}
83+
84+
private CCMModel getFromCache() {
85+
var listener = new TestPlainActionFuture<CCMModel>();
86+
ccmCache.get(listener);
87+
return listener.actionGet(TIMEOUT);
88+
}
89+
90+
public void testCacheInvalidate() throws Exception {
91+
var expectedCcmModel = storeCcm();
92+
var actualCcmModel = getFromCache();
93+
assertThat(actualCcmModel, equalTo(expectedCcmModel));
94+
assertThat(ccmCache.stats().getHits(), equalTo(0L));
95+
assertThat(ccmCache.stats().getMisses(), equalTo(1L));
96+
assertThat(ccmCache.cacheCount(), equalTo(1));
97+
98+
var listener = new TestPlainActionFuture<Void>();
99+
ccmCache.invalidate(listener);
100+
listener.actionGet(TIMEOUT);
101+
102+
assertThat(getFromCache(), not(sameInstance(actualCcmModel)));
103+
assertThat(ccmCache.stats().getHits(), equalTo(0L));
104+
assertThat(ccmCache.stats().getMisses(), equalTo(2L));
105+
assertThat(ccmCache.stats().getEvictions(), equalTo(1L));
106+
assertThat(ccmCache.cacheCount(), equalTo(1));
107+
}
108+
109+
public void testEmptyInvalidate() throws InterruptedException {
110+
var latch = new CountDownLatch(1);
111+
ccmCache.invalidate(ActionTestUtils.assertNoFailureListener(success -> latch.countDown()));
112+
assertTrue(latch.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS));
113+
114+
assertThat(ccmCache.stats().getEvictions(), equalTo(0L));
115+
assertThat(ccmCache.cacheCount(), equalTo(0));
116+
}
117+
118+
private boolean isPresent() {
119+
var listener = new TestPlainActionFuture<Boolean>();
120+
ccmCache.isEnabled(listener);
121+
return listener.actionGet(TIMEOUT);
122+
}
123+
124+
public void testIsEnabled() throws IOException {
125+
storeCcm();
126+
127+
getFromCache();
128+
assertThat(ccmCache.stats().getHits(), equalTo(0L));
129+
assertThat(ccmCache.stats().getMisses(), equalTo(1L));
130+
131+
assertTrue(isPresent());
132+
assertThat(ccmCache.stats().getHits(), equalTo(1L));
133+
assertThat(ccmCache.stats().getMisses(), equalTo(1L));
134+
}
135+
136+
public void testIsDisabledWithMissingIndex() {
137+
assertFalse(isPresent());
138+
}
139+
140+
public void testIsDisabledWithPresentIndex() {
141+
indicesAdmin().prepareCreate(CCMIndex.INDEX_NAME).execute().actionGet(TIMEOUT);
142+
assertFalse(isPresent());
143+
}
144+
145+
public void testIsDisabledWithCacheHit() {
146+
indicesAdmin().prepareCreate(CCMIndex.INDEX_NAME).execute().actionGet(TIMEOUT);
147+
148+
assertFalse(isPresent());
149+
assertThat(ccmCache.stats().getHits(), equalTo(0L));
150+
assertThat(ccmCache.stats().getMisses(), equalTo(1L));
151+
152+
assertFalse(isPresent());
153+
assertThat(ccmCache.stats().getHits(), equalTo(1L));
154+
assertThat(ccmCache.stats().getMisses(), equalTo(1L));
155+
}
156+
157+
public void testIsDisabledRefreshedWithGet() throws IOException {
158+
indicesAdmin().prepareCreate(CCMIndex.INDEX_NAME).execute().actionGet(TIMEOUT);
159+
160+
assertFalse(isPresent());
161+
assertThat(ccmCache.stats().getHits(), equalTo(0L));
162+
assertThat(ccmCache.stats().getMisses(), equalTo(1L));
163+
164+
var expectedCcmModel = storeCcm();
165+
166+
assertFalse(isPresent());
167+
assertThat(ccmCache.stats().getHits(), equalTo(1L));
168+
assertThat(ccmCache.stats().getMisses(), equalTo(1L));
169+
170+
var actualCcmModel = getFromCache();
171+
assertThat(actualCcmModel, equalTo(expectedCcmModel));
172+
assertThat(ccmCache.stats().getHits(), equalTo(2L));
173+
assertThat(ccmCache.stats().getMisses(), equalTo(1L));
174+
175+
assertTrue(isPresent());
176+
assertThat(ccmCache.stats().getHits(), equalTo(3L));
177+
assertThat(ccmCache.stats().getMisses(), equalTo(1L));
178+
}
179+
}

x-pack/plugin/inference/src/main/java/module-info.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
exports org.elasticsearch.xpack.inference.registry;
4242
exports org.elasticsearch.xpack.inference.rest;
4343
exports org.elasticsearch.xpack.inference.services;
44+
exports org.elasticsearch.xpack.inference.services.elastic.ccm;
4445
exports org.elasticsearch.xpack.inference;
4546
exports org.elasticsearch.xpack.inference.action.task;
4647
exports org.elasticsearch.xpack.inference.telemetry;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,12 @@ public class InferenceFeatures implements FeatureSpecification {
5454
private static final NodeFeature SEMANTIC_TEXT_FIELDS_CHUNKS_FORMAT = new NodeFeature("semantic_text.fields_chunks_format");
5555

5656
public static final NodeFeature INFERENCE_ENDPOINT_CACHE = new NodeFeature("inference.endpoint.cache");
57+
public static final NodeFeature INFERENCE_CCM_CACHE = new NodeFeature("inference.ccm.cache");
5758
public static final NodeFeature SEARCH_USAGE_EXTENDED_DATA = new NodeFeature("search.usage.extended_data");
5859

5960
@Override
6061
public Set<NodeFeature> getFeatures() {
61-
return Set.of(INFERENCE_ENDPOINT_CACHE);
62+
return Set.of(INFERENCE_ENDPOINT_CACHE, INFERENCE_CCM_CACHE);
6263
}
6364

6465
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@
149149
import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller;
150150
import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationTaskExecutor;
151151
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler;
152+
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMCache;
152153
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMFeature;
153154
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMIndex;
154155
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMPersistentStorageService;
@@ -276,7 +277,8 @@ public List<ActionHandler> getActions() {
276277
new ActionHandler(StoreInferenceEndpointsAction.INSTANCE, TransportStoreEndpointsAction.class),
277278
new ActionHandler(GetCCMConfigurationAction.INSTANCE, TransportGetCCMConfigurationAction.class),
278279
new ActionHandler(PutCCMConfigurationAction.INSTANCE, TransportPutCCMConfigurationAction.class),
279-
new ActionHandler(DeleteCCMConfigurationAction.INSTANCE, TransportDeleteCCMConfigurationAction.class)
280+
new ActionHandler(DeleteCCMConfigurationAction.INSTANCE, TransportDeleteCCMConfigurationAction.class),
281+
new ActionHandler(CCMCache.ClearCCMCacheAction.INSTANCE, CCMCache.ClearCCMCacheAction.class)
280282
);
281283
}
282284

@@ -453,7 +455,19 @@ public Collection<?> createComponents(PluginServices services) {
453455
private Collection<?> createCCMComponents(PluginServices services) {
454456
ccmFeature.set(new CCMFeature(settings));
455457
var ccmPersistentStorageService = new CCMPersistentStorageService(services.client());
456-
return List.of(new CCMService(ccmPersistentStorageService), ccmFeature.get(), ccmPersistentStorageService);
458+
return List.of(
459+
new CCMService(ccmPersistentStorageService),
460+
ccmFeature.get(),
461+
ccmPersistentStorageService,
462+
new CCMCache(
463+
ccmPersistentStorageService,
464+
services.clusterService(),
465+
settings,
466+
services.featureService(),
467+
services.projectResolver(),
468+
services.client()
469+
)
470+
);
457471
}
458472

459473
@Override
@@ -653,6 +667,7 @@ public static Set<Setting<?>> getInferenceSettings() {
653667
settings.addAll(InferenceEndpointRegistry.getSettingsDefinitions());
654668
settings.addAll(ElasticInferenceServiceSettings.getSettingsDefinitions());
655669
settings.addAll(CCMSettings.getSettingsDefinitions());
670+
settings.addAll(CCMCache.getSettingsDefinitions());
656671
return Collections.unmodifiableSet(settings);
657672
}
658673

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.common;
9+
10+
import org.elasticsearch.action.FailedNodeException;
11+
import org.elasticsearch.action.support.ActionFilters;
12+
import org.elasticsearch.action.support.TransportAction;
13+
import org.elasticsearch.action.support.nodes.BaseNodeResponse;
14+
import org.elasticsearch.action.support.nodes.BaseNodesRequest;
15+
import org.elasticsearch.action.support.nodes.BaseNodesResponse;
16+
import org.elasticsearch.action.support.nodes.TransportNodesAction;
17+
import org.elasticsearch.cluster.ClusterName;
18+
import org.elasticsearch.cluster.node.DiscoveryNode;
19+
import org.elasticsearch.cluster.service.ClusterService;
20+
import org.elasticsearch.common.Strings;
21+
import org.elasticsearch.common.io.stream.StreamInput;
22+
import org.elasticsearch.common.io.stream.StreamOutput;
23+
import org.elasticsearch.common.io.stream.Writeable;
24+
import org.elasticsearch.core.TimeValue;
25+
import org.elasticsearch.tasks.CancellableTask;
26+
import org.elasticsearch.tasks.Task;
27+
import org.elasticsearch.tasks.TaskId;
28+
import org.elasticsearch.threadpool.ThreadPool;
29+
import org.elasticsearch.transport.AbstractTransportRequest;
30+
import org.elasticsearch.transport.TransportService;
31+
32+
import java.io.IOException;
33+
import java.util.List;
34+
import java.util.Map;
35+
36+
/**
37+
* Broadcasts a {@link Writeable} to all nodes and responds with an empty object.
38+
* This is intended to be used as a fire-and-forget style, where responses and failures are logged and swallowed.
39+
*/
40+
public abstract class BroadcastMessageAction<Message extends Writeable> extends TransportNodesAction<
41+
BroadcastMessageAction.Request<Message>,
42+
BroadcastMessageAction.Response,
43+
BroadcastMessageAction.NodeRequest<Message>,
44+
BroadcastMessageAction.NodeResponse,
45+
Void> {
46+
47+
protected BroadcastMessageAction(
48+
String actionName,
49+
ClusterService clusterService,
50+
TransportService transportService,
51+
ActionFilters actionFilters,
52+
Writeable.Reader<Message> messageReader
53+
) {
54+
super(
55+
actionName,
56+
clusterService,
57+
transportService,
58+
actionFilters,
59+
in -> new NodeRequest<>(messageReader.read(in)),
60+
clusterService.threadPool().executor(ThreadPool.Names.MANAGEMENT)
61+
);
62+
}
63+
64+
@Override
65+
protected Response newResponse(Request<Message> request, List<NodeResponse> nodeResponses, List<FailedNodeException> failures) {
66+
return new Response(clusterService.getClusterName(), nodeResponses, failures);
67+
}
68+
69+
@Override
70+
protected NodeRequest<Message> newNodeRequest(Request<Message> request) {
71+
return new NodeRequest<>(request.message);
72+
}
73+
74+
@Override
75+
protected NodeResponse newNodeResponse(StreamInput in, DiscoveryNode node) throws IOException {
76+
return new NodeResponse(in, node);
77+
}
78+
79+
@Override
80+
protected NodeResponse nodeOperation(NodeRequest<Message> request, Task task) {
81+
receiveMessage(request.message);
82+
return new NodeResponse(transportService.getLocalNode());
83+
}
84+
85+
/**
86+
* This method is run on each node in the cluster.
87+
*/
88+
protected abstract void receiveMessage(Message message);
89+
90+
public static <T extends Writeable> Request<T> request(T message, TimeValue timeout) {
91+
return new Request<>(message, timeout);
92+
}
93+
94+
public static class Request<Message extends Writeable> extends BaseNodesRequest {
95+
private final Message message;
96+
97+
protected Request(Message message, TimeValue timeout) {
98+
super(Strings.EMPTY_ARRAY);
99+
this.message = message;
100+
setTimeout(timeout);
101+
}
102+
}
103+
104+
public static class Response extends BaseNodesResponse<NodeResponse> {
105+
106+
protected Response(ClusterName clusterName, List<NodeResponse> nodes, List<FailedNodeException> failures) {
107+
super(clusterName, nodes, failures);
108+
}
109+
110+
@Override
111+
protected List<NodeResponse> readNodesFrom(StreamInput in) throws IOException {
112+
return in.readCollectionAsList(NodeResponse::new);
113+
}
114+
115+
@Override
116+
protected void writeNodesTo(StreamOutput out, List<NodeResponse> nodes) {
117+
TransportAction.localOnly();
118+
}
119+
}
120+
121+
public static class NodeRequest<Message extends Writeable> extends AbstractTransportRequest {
122+
private final Message message;
123+
124+
private NodeRequest(Message message) {
125+
this.message = message;
126+
}
127+
128+
@Override
129+
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
130+
return new CancellableTask(id, type, action, "broadcasted message to an individual node", parentTaskId, headers);
131+
}
132+
}
133+
134+
public static class NodeResponse extends BaseNodeResponse {
135+
protected NodeResponse(StreamInput in) throws IOException {
136+
super(in);
137+
}
138+
139+
protected NodeResponse(StreamInput in, DiscoveryNode node) throws IOException {
140+
super(in, node);
141+
}
142+
143+
protected NodeResponse(DiscoveryNode node) {
144+
super(node);
145+
}
146+
}
147+
}

0 commit comments

Comments
 (0)