diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMCacheTests.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMCacheTests.java new file mode 100644 index 0000000000000..7ec9644e9d8d7 --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMCacheTests.java @@ -0,0 +1,179 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.ccm; + +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.support.ActionTestUtils; +import org.elasticsearch.action.support.TestPlainActionFuture; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.ESSingleNodeTestCase; +import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.sameInstance; + +public class CCMCacheTests extends ESSingleNodeTestCase { + + private static final TimeValue TIMEOUT = TimeValue.THIRTY_SECONDS; + + private CCMCache ccmCache; + private CCMPersistentStorageService ccmPersistentStorageService; + + @Override + protected Collection> getPlugins() { + return List.of(LocalStateInferencePlugin.class); + } + + @Before + public void createComponents() { + ccmCache = node().injector().getInstance(CCMCache.class); + ccmPersistentStorageService = node().injector().getInstance(CCMPersistentStorageService.class); + } + + @Override + protected boolean resetNodeAfterTest() { + return true; + } + + @After + public void clearCacheAndIndex() { + try { + indicesAdmin().prepareDelete(CCMIndex.INDEX_NAME).execute().actionGet(TIMEOUT); + } catch (ResourceNotFoundException e) { + // mission complete! + } + } + + public void testCacheHit() throws IOException { + var expectedCcmModel = storeCcm(); + var actualCcmModel = getFromCache(); + assertThat(actualCcmModel, equalTo(expectedCcmModel)); + assertThat(ccmCache.stats().getHits(), equalTo(0L)); + assertThat(getFromCache(), sameInstance(actualCcmModel)); + assertThat(ccmCache.stats().getHits(), equalTo(1L)); + } + + private CCMModel storeCcm() throws IOException { + var ccmModel = CCMModel.fromXContentBytes(new BytesArray(""" + { + "api_key": "test_key" + } + """)); + var listener = new TestPlainActionFuture(); + ccmPersistentStorageService.store(ccmModel, listener); + listener.actionGet(TIMEOUT); + return ccmModel; + } + + private CCMModel getFromCache() { + var listener = new TestPlainActionFuture(); + ccmCache.get(listener); + return listener.actionGet(TIMEOUT); + } + + public void testCacheInvalidate() throws Exception { + var expectedCcmModel = storeCcm(); + var actualCcmModel = getFromCache(); + assertThat(actualCcmModel, equalTo(expectedCcmModel)); + assertThat(ccmCache.stats().getHits(), equalTo(0L)); + assertThat(ccmCache.stats().getMisses(), equalTo(1L)); + assertThat(ccmCache.cacheCount(), equalTo(1)); + + var listener = new TestPlainActionFuture(); + ccmCache.invalidate(listener); + listener.actionGet(TIMEOUT); + + assertThat(getFromCache(), not(sameInstance(actualCcmModel))); + assertThat(ccmCache.stats().getHits(), equalTo(0L)); + assertThat(ccmCache.stats().getMisses(), equalTo(2L)); + assertThat(ccmCache.stats().getEvictions(), equalTo(1L)); + assertThat(ccmCache.cacheCount(), equalTo(1)); + } + + public void testEmptyInvalidate() throws InterruptedException { + var latch = new CountDownLatch(1); + ccmCache.invalidate(ActionTestUtils.assertNoFailureListener(success -> latch.countDown())); + assertTrue(latch.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS)); + + assertThat(ccmCache.stats().getEvictions(), equalTo(0L)); + assertThat(ccmCache.cacheCount(), equalTo(0)); + } + + private boolean isPresent() { + var listener = new TestPlainActionFuture(); + ccmCache.isEnabled(listener); + return listener.actionGet(TIMEOUT); + } + + public void testIsEnabled() throws IOException { + storeCcm(); + + getFromCache(); + assertThat(ccmCache.stats().getHits(), equalTo(0L)); + assertThat(ccmCache.stats().getMisses(), equalTo(1L)); + + assertTrue(isPresent()); + assertThat(ccmCache.stats().getHits(), equalTo(1L)); + assertThat(ccmCache.stats().getMisses(), equalTo(1L)); + } + + public void testIsDisabledWithMissingIndex() { + assertFalse(isPresent()); + } + + public void testIsDisabledWithPresentIndex() { + indicesAdmin().prepareCreate(CCMIndex.INDEX_NAME).execute().actionGet(TIMEOUT); + assertFalse(isPresent()); + } + + public void testIsDisabledWithCacheHit() { + indicesAdmin().prepareCreate(CCMIndex.INDEX_NAME).execute().actionGet(TIMEOUT); + + assertFalse(isPresent()); + assertThat(ccmCache.stats().getHits(), equalTo(0L)); + assertThat(ccmCache.stats().getMisses(), equalTo(1L)); + + assertFalse(isPresent()); + assertThat(ccmCache.stats().getHits(), equalTo(1L)); + assertThat(ccmCache.stats().getMisses(), equalTo(1L)); + } + + public void testIsDisabledRefreshedWithGet() throws IOException { + indicesAdmin().prepareCreate(CCMIndex.INDEX_NAME).execute().actionGet(TIMEOUT); + + assertFalse(isPresent()); + assertThat(ccmCache.stats().getHits(), equalTo(0L)); + assertThat(ccmCache.stats().getMisses(), equalTo(1L)); + + var expectedCcmModel = storeCcm(); + + assertFalse(isPresent()); + assertThat(ccmCache.stats().getHits(), equalTo(1L)); + assertThat(ccmCache.stats().getMisses(), equalTo(1L)); + + var actualCcmModel = getFromCache(); + assertThat(actualCcmModel, equalTo(expectedCcmModel)); + assertThat(ccmCache.stats().getHits(), equalTo(2L)); + assertThat(ccmCache.stats().getMisses(), equalTo(1L)); + + assertTrue(isPresent()); + assertThat(ccmCache.stats().getHits(), equalTo(3L)); + assertThat(ccmCache.stats().getMisses(), equalTo(1L)); + } +} diff --git a/x-pack/plugin/inference/src/main/java/module-info.java b/x-pack/plugin/inference/src/main/java/module-info.java index bd200fd88a706..67a0cd11c6a4c 100644 --- a/x-pack/plugin/inference/src/main/java/module-info.java +++ b/x-pack/plugin/inference/src/main/java/module-info.java @@ -41,6 +41,7 @@ exports org.elasticsearch.xpack.inference.registry; exports org.elasticsearch.xpack.inference.rest; exports org.elasticsearch.xpack.inference.services; + exports org.elasticsearch.xpack.inference.services.elastic.ccm; exports org.elasticsearch.xpack.inference; exports org.elasticsearch.xpack.inference.action.task; exports org.elasticsearch.xpack.inference.telemetry; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java index a55a126976284..4ec2bb3280522 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java @@ -54,11 +54,12 @@ public class InferenceFeatures implements FeatureSpecification { private static final NodeFeature SEMANTIC_TEXT_FIELDS_CHUNKS_FORMAT = new NodeFeature("semantic_text.fields_chunks_format"); public static final NodeFeature INFERENCE_ENDPOINT_CACHE = new NodeFeature("inference.endpoint.cache"); + public static final NodeFeature INFERENCE_CCM_CACHE = new NodeFeature("inference.ccm.cache"); public static final NodeFeature SEARCH_USAGE_EXTENDED_DATA = new NodeFeature("search.usage.extended_data"); @Override public Set getFeatures() { - return Set.of(INFERENCE_ENDPOINT_CACHE); + return Set.of(INFERENCE_ENDPOINT_CACHE, INFERENCE_CCM_CACHE); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index c4025d0e6c4b0..f5244f2ca66a6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -149,6 +149,7 @@ import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller; import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationTaskExecutor; import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMCache; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMFeature; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMIndex; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMPersistentStorageService; @@ -276,7 +277,8 @@ public List getActions() { new ActionHandler(StoreInferenceEndpointsAction.INSTANCE, TransportStoreEndpointsAction.class), new ActionHandler(GetCCMConfigurationAction.INSTANCE, TransportGetCCMConfigurationAction.class), new ActionHandler(PutCCMConfigurationAction.INSTANCE, TransportPutCCMConfigurationAction.class), - new ActionHandler(DeleteCCMConfigurationAction.INSTANCE, TransportDeleteCCMConfigurationAction.class) + new ActionHandler(DeleteCCMConfigurationAction.INSTANCE, TransportDeleteCCMConfigurationAction.class), + new ActionHandler(CCMCache.ClearCCMCacheAction.INSTANCE, CCMCache.ClearCCMCacheAction.class) ); } @@ -453,7 +455,19 @@ public Collection createComponents(PluginServices services) { private Collection createCCMComponents(PluginServices services) { ccmFeature.set(new CCMFeature(settings)); var ccmPersistentStorageService = new CCMPersistentStorageService(services.client()); - return List.of(new CCMService(ccmPersistentStorageService), ccmFeature.get(), ccmPersistentStorageService); + return List.of( + new CCMService(ccmPersistentStorageService), + ccmFeature.get(), + ccmPersistentStorageService, + new CCMCache( + ccmPersistentStorageService, + services.clusterService(), + settings, + services.featureService(), + services.projectResolver(), + services.client() + ) + ); } @Override @@ -653,6 +667,7 @@ public static Set> getInferenceSettings() { settings.addAll(InferenceEndpointRegistry.getSettingsDefinitions()); settings.addAll(ElasticInferenceServiceSettings.getSettingsDefinitions()); settings.addAll(CCMSettings.getSettingsDefinitions()); + settings.addAll(CCMCache.getSettingsDefinitions()); return Collections.unmodifiableSet(settings); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/BroadcastMessageAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/BroadcastMessageAction.java new file mode 100644 index 0000000000000..c3b510636de40 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/BroadcastMessageAction.java @@ -0,0 +1,147 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.common; + +import org.elasticsearch.action.FailedNodeException; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.TransportAction; +import org.elasticsearch.action.support.nodes.BaseNodeResponse; +import org.elasticsearch.action.support.nodes.BaseNodesRequest; +import org.elasticsearch.action.support.nodes.BaseNodesResponse; +import org.elasticsearch.action.support.nodes.TransportNodesAction; +import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.tasks.CancellableTask; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.AbstractTransportRequest; +import org.elasticsearch.transport.TransportService; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +/** + * Broadcasts a {@link Writeable} to all nodes and responds with an empty object. + * This is intended to be used as a fire-and-forget style, where responses and failures are logged and swallowed. + */ +public abstract class BroadcastMessageAction extends TransportNodesAction< + BroadcastMessageAction.Request, + BroadcastMessageAction.Response, + BroadcastMessageAction.NodeRequest, + BroadcastMessageAction.NodeResponse, + Void> { + + protected BroadcastMessageAction( + String actionName, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + Writeable.Reader messageReader + ) { + super( + actionName, + clusterService, + transportService, + actionFilters, + in -> new NodeRequest<>(messageReader.read(in)), + clusterService.threadPool().executor(ThreadPool.Names.MANAGEMENT) + ); + } + + @Override + protected Response newResponse(Request request, List nodeResponses, List failures) { + return new Response(clusterService.getClusterName(), nodeResponses, failures); + } + + @Override + protected NodeRequest newNodeRequest(Request request) { + return new NodeRequest<>(request.message); + } + + @Override + protected NodeResponse newNodeResponse(StreamInput in, DiscoveryNode node) throws IOException { + return new NodeResponse(in, node); + } + + @Override + protected NodeResponse nodeOperation(NodeRequest request, Task task) { + receiveMessage(request.message); + return new NodeResponse(transportService.getLocalNode()); + } + + /** + * This method is run on each node in the cluster. + */ + protected abstract void receiveMessage(Message message); + + public static Request request(T message, TimeValue timeout) { + return new Request<>(message, timeout); + } + + public static class Request extends BaseNodesRequest { + private final Message message; + + protected Request(Message message, TimeValue timeout) { + super(Strings.EMPTY_ARRAY); + this.message = message; + setTimeout(timeout); + } + } + + public static class Response extends BaseNodesResponse { + + protected Response(ClusterName clusterName, List nodes, List failures) { + super(clusterName, nodes, failures); + } + + @Override + protected List readNodesFrom(StreamInput in) throws IOException { + return in.readCollectionAsList(NodeResponse::new); + } + + @Override + protected void writeNodesTo(StreamOutput out, List nodes) { + TransportAction.localOnly(); + } + } + + public static class NodeRequest extends AbstractTransportRequest { + private final Message message; + + private NodeRequest(Message message) { + this.message = message; + } + + @Override + public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + return new CancellableTask(id, type, action, "broadcasted message to an individual node", parentTaskId, headers); + } + } + + public static class NodeResponse extends BaseNodeResponse { + protected NodeResponse(StreamInput in) throws IOException { + super(in); + } + + protected NodeResponse(StreamInput in, DiscoveryNode node) throws IOException { + super(in, node); + } + + protected NodeResponse(DiscoveryNode node) { + super(node); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMCache.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMCache.java new file mode 100644 index 0000000000000..d565a8ae9cc85 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMCache.java @@ -0,0 +1,236 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.ccm; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.metadata.ProjectId; +import org.elasticsearch.cluster.project.ProjectResolver; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.cache.Cache; +import org.elasticsearch.common.cache.CacheBuilder; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.features.FeatureService; +import org.elasticsearch.injection.guice.Inject; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.inference.InferenceFeatures; +import org.elasticsearch.xpack.inference.common.BroadcastMessageAction; + +import java.io.IOException; +import java.util.Collection; +import java.util.List; +import java.util.Objects; + +/** + * Cache for whether CCM is enabled or disabled for this cluster as well as what the CCM key is for when it is enabled. + */ +public class CCMCache { + + private static final Setting INFERENCE_CCM_CACHE_WEIGHT = Setting.intSetting( + "xpack.inference.ccm.cache.weight", + 1, + Setting.Property.NodeScope + ); + + private static final Setting INFERENCE_CCM_CACHE_EXPIRY = Setting.timeSetting( + "xpack.inference.ccm.cache.expiry_time", + TimeValue.timeValueMinutes(15), + TimeValue.timeValueMinutes(1), + TimeValue.timeValueHours(1), + Setting.Property.NodeScope + ); + + public static Collection> getSettingsDefinitions() { + return List.of(INFERENCE_CCM_CACHE_WEIGHT, INFERENCE_CCM_CACHE_EXPIRY); + } + + private static final Logger logger = LogManager.getLogger(CCMCache.class); + private static final Cache.Stats EMPTY = new Cache.Stats(0, 0, 0); + private final CCMPersistentStorageService ccmPersistentStorageService; + private final Cache cache; + private final ClusterService clusterService; + private final FeatureService featureService; + private final ProjectResolver projectResolver; + private final Client client; + + public CCMCache( + CCMPersistentStorageService ccmPersistentStorageService, + ClusterService clusterService, + Settings settings, + FeatureService featureService, + ProjectResolver projectResolver, + Client client + ) { + this.ccmPersistentStorageService = ccmPersistentStorageService; + this.cache = CacheBuilder.builder() + .setMaximumWeight(INFERENCE_CCM_CACHE_WEIGHT.get(settings)) + .setExpireAfterWrite(INFERENCE_CCM_CACHE_EXPIRY.get(settings)) + .build(); + this.clusterService = clusterService; + this.featureService = featureService; + this.projectResolver = projectResolver; + this.client = client; + } + + /** + * Immediately returns the CCM key if it is cached, or goes to the index if there is no value cached or the previous call returned + * nothing. The expectation is that the caller checks if CCM is enabled via the {@link #isEnabled(ActionListener)} API, which caches + * a boolean value if the CCM key is present or absent in the underlying index. + */ + public void get(ActionListener listener) { + var projectId = projectResolver.getProjectId(); + var cachedEntry = getCacheEntry(projectId); + if (cachedEntry != null && cachedEntry.enabled()) { + listener.onResponse(cachedEntry.ccmModel()); + } else { + ccmPersistentStorageService.get(ActionListener.wrap(ccmModel -> { + putEnabledEntry(projectId, ccmModel); + listener.onResponse(ccmModel); + }, e -> { + if (e instanceof ResourceNotFoundException) { + putDisabledEntry(projectId); + } + listener.onFailure(e); + })); + } + } + + private CCMModelEntry getCacheEntry(ProjectId projectId) { + return cacheEnabled() ? cache.get(projectId) : null; + } + + private boolean cacheEnabled() { + var state = clusterService.state(); + return state.clusterRecovered() && featureService.clusterHasFeature(state, InferenceFeatures.INFERENCE_CCM_CACHE); + } + + private void putEnabledEntry(ProjectId projectId, CCMModel ccmModel) { + if (cacheEnabled()) { + cache.put(projectId, CCMModelEntry.enabled(ccmModel)); + } + } + + private void putDisabledEntry(ProjectId projectId) { + if (cacheEnabled()) { + cache.put(projectId, CCMModelEntry.DISABLED); + } + } + + /** + * Checks if the value is present or absent based on a previous call to {@link #isEnabled(ActionListener)} + * or {@link #get(ActionListener)}. If the cache entry is missing or expired, then it will call through to the backing index. + */ + public void isEnabled(ActionListener listener) { + var projectId = projectResolver.getProjectId(); + var cachedEntry = getCacheEntry(projectId); + if (cachedEntry != null) { + listener.onResponse(cachedEntry.enabled()); + } else { + ccmPersistentStorageService.get(ActionListener.wrap(ccmModel -> { + putEnabledEntry(projectId, ccmModel); + listener.onResponse(true); + }, e -> { + if (e instanceof ResourceNotFoundException) { + putDisabledEntry(projectId); + listener.onResponse(false); + } else { + listener.onFailure(e); + } + })); + } + } + + public void invalidate(ActionListener listener) { + if (cacheEnabled()) { + client.execute( + ClearCCMCacheAction.INSTANCE, + ClearCCMCacheAction.request(ClearCCMMessage.INSTANCE, null), + ActionListener.wrap(ack -> { + logger.debug("Successfully refreshed inference CCM cache for project {}.", projectResolver::getProjectId); + listener.onResponse((Void) null); + }, e -> { + logger.atDebug() + .withThrowable(e) + .log("Failed to refresh inference CCM cache for project {}.", projectResolver::getProjectId); + listener.onFailure(e); + }) + ); + } + } + + private void invalidate(ProjectId projectId) { + if (cacheEnabled()) { + var cacheKeys = cache.keys().iterator(); + while (cacheKeys.hasNext()) { + if (cacheKeys.next().equals(projectId)) { + cacheKeys.remove(); + } + } + } + } + + public Cache.Stats stats() { + return cacheEnabled() ? cache.stats() : EMPTY; + } + + public int cacheCount() { + return cacheEnabled() ? cache.count() : 0; + } + + private record CCMModelEntry(boolean enabled, @Nullable CCMModel ccmModel) { + private static final CCMModelEntry DISABLED = new CCMModelEntry(false, null); + + private static CCMModelEntry enabled(CCMModel ccmModel) { + return new CCMModelEntry(true, Objects.requireNonNull(ccmModel)); + } + } + + public static class ClearCCMCacheAction extends BroadcastMessageAction { + private static final String NAME = "cluster:internal/xpack/inference/clear_inference_ccm_cache"; + public static final ActionType INSTANCE = new ActionType<>(NAME); + + private final ProjectResolver projectResolver; + private final CCMCache ccmCache; + + @Inject + public ClearCCMCacheAction( + TransportService transportService, + ClusterService clusterService, + ActionFilters actionFilters, + ProjectResolver projectResolver, + CCMCache ccmCache + ) { + super(NAME, clusterService, transportService, actionFilters, in -> ClearCCMMessage.INSTANCE); + + this.projectResolver = projectResolver; + this.ccmCache = ccmCache; + } + + @Override + protected void receiveMessage(ClearCCMMessage clearCCMMessage) { + ccmCache.invalidate(projectResolver.getProjectId()); + } + } + + public record ClearCCMMessage() implements Writeable { + private static final ClearCCMMessage INSTANCE = new ClearCCMMessage(); + + @Override + public void writeTo(StreamOutput out) throws IOException {} + } +} diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index 0957ec55e882a..b10112c842de7 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -328,6 +328,7 @@ public class Constants { "cluster:admin/xpack/watcher/watch/put", "cluster:internal/remote_cluster/nodes", "cluster:internal/xpack/inference", + "cluster:internal/xpack/inference/clear_inference_ccm_cache", "cluster:internal/xpack/inference/clear_inference_endpoint_cache", "cluster:internal/xpack/inference/create_endpoints", "cluster:internal/xpack/inference/rerankwindowsize/get",