Skip to content

Commit e59c52e

Browse files
Sophie8rootfs
andauthored
Feat: fix-issue-336: Implement In-Tree Embedding Similarity Matching (#606)
* fix-issue-336 Signed-off-by: Sophie8 <sw3237@nyu.edu> * fix-issue-336 Signed-off-by: Sophie8 <sw3237@nyu.edu> * fix-issue-336: update classifier name Signed-off-by: Sophie8 <sw3237@nyu.edu> * fix-issue-336: update unit test Signed-off-by: Sophie8 <sw3237@nyu.edu> --------- Signed-off-by: Sophie8 <sw3237@nyu.edu> Co-authored-by: Huamin Chen <rootfs@users.noreply.github.com>
1 parent dd391fd commit e59c52e

File tree

4 files changed

+323
-8
lines changed

4 files changed

+323
-8
lines changed

src/semantic-router/pkg/classification/classifier.go

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -198,13 +198,15 @@ type PIIAnalysisResult struct {
198198
// Classifier handles text classification, model selection, and jailbreak detection functionality
199199
type Classifier struct {
200200
// Dependencies - In-tree classifiers
201-
categoryInitializer CategoryInitializer
202-
categoryInference CategoryInference
203-
jailbreakInitializer JailbreakInitializer
204-
jailbreakInference JailbreakInference
205-
piiInitializer PIIInitializer
206-
piiInference PIIInference
207-
keywordClassifier *KeywordClassifier
201+
categoryInitializer CategoryInitializer
202+
categoryInference CategoryInference
203+
jailbreakInitializer JailbreakInitializer
204+
jailbreakInference JailbreakInference
205+
piiInitializer PIIInitializer
206+
piiInference PIIInference
207+
keywordClassifier *KeywordClassifier
208+
keywordEmbeddingInitializer EmbeddingClassifierInitializer
209+
keywordEmbeddingClassifier *EmbeddingClassifier
208210

209211
// Dependencies - MCP-based classifiers
210212
mcpCategoryInitializer MCPCategoryInitializer
@@ -254,6 +256,13 @@ func withKeywordClassifier(keywordClassifier *KeywordClassifier) option {
254256
}
255257
}
256258

259+
func withKeywordEmbeddingClassifier(keywordEmbeddingInitializer EmbeddingClassifierInitializer, keywordEmbeddingClassifier *EmbeddingClassifier) option {
260+
return func(c *Classifier) {
261+
c.keywordEmbeddingInitializer = keywordEmbeddingInitializer
262+
c.keywordEmbeddingClassifier = keywordEmbeddingClassifier
263+
}
264+
}
265+
257266
// initModels initializes the models for the classifier
258267
func initModels(classifier *Classifier) (*Classifier, error) {
259268
// Initialize either in-tree OR MCP-based category classifier
@@ -279,6 +288,12 @@ func initModels(classifier *Classifier) (*Classifier, error) {
279288
}
280289
}
281290

291+
if classifier.IsKeywordEmbeddingClassifierEnabled() {
292+
if err := classifier.initializeKeywordEmbeddingClassifier(); err != nil {
293+
return nil, err
294+
}
295+
}
296+
282297
return classifier, nil
283298
}
284299

@@ -320,6 +335,16 @@ func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, p
320335
options = append(options, withKeywordClassifier(keywordClassifier))
321336
}
322337

338+
// Add keyword embedding classifier if configured
339+
if len(cfg.EmbeddingRules) > 0 {
340+
keywordEmbeddingClassifier, err := NewEmbeddingClassifier(cfg.EmbeddingRules)
341+
if err != nil {
342+
logging.Errorf("Failed to create keyword embedding classifier: %v", err)
343+
return nil, err
344+
}
345+
options = append(options, withKeywordEmbeddingClassifier(createEmbeddingInitializer(), keywordEmbeddingClassifier))
346+
}
347+
323348
// Add in-tree classifier if configured
324349
if cfg.CategoryModel.ModelID != "" {
325350
options = append(options, withCategory(categoryMapping, createCategoryInitializer(cfg.UseModernBERT), createCategoryInference(cfg.UseModernBERT)))
@@ -369,7 +394,17 @@ func (c *Classifier) ClassifyCategory(text string) (string, float64, error) {
369394
return category, confidence, nil
370395
}
371396
}
372-
397+
// TODO: more sophiscated fusion engine needs to be designed and implemented to combine classifiers' results
398+
// Try embedding based similarity classification if properly configured
399+
if c.keywordEmbeddingClassifier != nil {
400+
category, confidence, err := c.keywordEmbeddingClassifier.Classify(text)
401+
if err != nil {
402+
return "", 0.0, err
403+
}
404+
if category != "" {
405+
return category, confidence, nil
406+
}
407+
}
373408
// Try in-tree first if properly configured
374409
if c.IsCategoryEnabled() && c.categoryInference != nil {
375410
return c.classifyCategoryInTree(text)

src/semantic-router/pkg/classification/classifier_test.go

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3579,3 +3579,107 @@ func BenchmarkUnifiedClassifier_BatchSizeComparison(b *testing.B) {
35793579
}
35803580
})
35813581
}
3582+
3583+
// EmbeddingClassifier unit tests
3584+
var _ = Describe("EmbeddingClassifier", func() {
3585+
var origCalculate func(string, []string, int, string, int) (*candle_binding.BatchSimilarityOutput, error)
3586+
3587+
BeforeEach(func() {
3588+
origCalculate = calculateSimilarityBatch
3589+
})
3590+
3591+
AfterEach(func() {
3592+
calculateSimilarityBatch = origCalculate
3593+
})
3594+
3595+
It("classifies with mean aggregation", func() {
3596+
calculateSimilarityBatch = func(query string, candidates []string, topK int, modelType string, targetDim int) (*candle_binding.BatchSimilarityOutput, error) {
3597+
return &candle_binding.BatchSimilarityOutput{Matches: []candle_binding.BatchSimilarityMatch{{Index: 0, Similarity: 0.9}, {Index: 1, Similarity: 0.8}, {Index: 2, Similarity: 0.7}}}, nil
3598+
}
3599+
3600+
rules := []config.EmbeddingRule{{
3601+
Category: "cat1",
3602+
Keywords: []string{"science", "math"},
3603+
AggregationMethodConfiged: config.AggregationMethodMean,
3604+
SimilarityThreshold: 0.8,
3605+
Model: "auto",
3606+
Dimension: 768,
3607+
}}
3608+
3609+
clf, err := NewEmbeddingClassifier(rules)
3610+
Expect(err).ToNot(HaveOccurred())
3611+
3612+
cat, score, err := clf.Classify("some text")
3613+
Expect(err).ToNot(HaveOccurred())
3614+
Expect(cat).To(Equal("cat1"))
3615+
Expect(score).To(BeNumerically("~", 0.8, 1e-6))
3616+
})
3617+
3618+
It("classifies with max aggregation", func() {
3619+
calculateSimilarityBatch = func(query string, candidates []string, topK int, modelType string, targetDim int) (*candle_binding.BatchSimilarityOutput, error) {
3620+
return &candle_binding.BatchSimilarityOutput{Matches: []candle_binding.BatchSimilarityMatch{{Index: 0, Similarity: 0.4}, {Index: 1, Similarity: 0.6}}}, nil
3621+
}
3622+
3623+
rules := []config.EmbeddingRule{{
3624+
Category: "cat2",
3625+
Keywords: []string{"x", "y"},
3626+
AggregationMethodConfiged: config.AggregationMethodMax,
3627+
SimilarityThreshold: 0.5,
3628+
Model: "auto",
3629+
Dimension: 512,
3630+
}}
3631+
3632+
clf, err := NewEmbeddingClassifier(rules)
3633+
Expect(err).ToNot(HaveOccurred())
3634+
3635+
cat, score, err := clf.Classify("other text")
3636+
Expect(err).ToNot(HaveOccurred())
3637+
Expect(cat).To(Equal("cat2"))
3638+
Expect(score).To(BeNumerically("~", 0.6, 1e-6))
3639+
})
3640+
3641+
It("classifies with any aggregation", func() {
3642+
calculateSimilarityBatch = func(query string, candidates []string, topK int, modelType string, targetDim int) (*candle_binding.BatchSimilarityOutput, error) {
3643+
return &candle_binding.BatchSimilarityOutput{Matches: []candle_binding.BatchSimilarityMatch{{Index: 0, Similarity: 0.2}, {Index: 1, Similarity: 0.95}}}, nil
3644+
}
3645+
3646+
rules := []config.EmbeddingRule{{
3647+
Category: "cat3",
3648+
Keywords: []string{"p", "q"},
3649+
AggregationMethodConfiged: config.AggregationMethodAny,
3650+
SimilarityThreshold: 0.7,
3651+
Model: "auto",
3652+
Dimension: 256,
3653+
}}
3654+
3655+
clf, err := NewEmbeddingClassifier(rules)
3656+
Expect(err).ToNot(HaveOccurred())
3657+
3658+
cat, score, err := clf.Classify("third text")
3659+
Expect(err).ToNot(HaveOccurred())
3660+
Expect(cat).To(Equal("cat3"))
3661+
Expect(score).To(BeNumerically("~", 0.7, 1e-6))
3662+
})
3663+
3664+
It("returns error when CalculateSimilarityBatch fails", func() {
3665+
calculateSimilarityBatch = func(query string, candidates []string, topK int, modelType string, targetDim int) (*candle_binding.BatchSimilarityOutput, error) {
3666+
return nil, errors.New("external failure")
3667+
}
3668+
3669+
rules := []config.EmbeddingRule{{
3670+
Category: "cat4",
3671+
Keywords: []string{"z"},
3672+
AggregationMethodConfiged: config.AggregationMethodMean,
3673+
SimilarityThreshold: 0.1,
3674+
Model: "auto",
3675+
Dimension: 768,
3676+
}}
3677+
3678+
clf, err := NewEmbeddingClassifier(rules)
3679+
Expect(err).ToNot(HaveOccurred())
3680+
3681+
_, _, err = clf.Classify("will error")
3682+
Expect(err).To(HaveOccurred())
3683+
Expect(err.Error()).To(ContainSubstring("failed to calculate batch similarity"))
3684+
})
3685+
})
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
package classification
2+
3+
import (
4+
"fmt"
5+
6+
candle_binding "github.com/vllm-project/semantic-router/candle-binding"
7+
"github.com/vllm-project/semantic-router/src/semantic-router/pkg/config"
8+
"github.com/vllm-project/semantic-router/src/semantic-router/pkg/observability/logging"
9+
)
10+
11+
// calculateSimilarityBatch is a package-level variable that points to the
12+
// actual implementation in the candle_binding package. It exists so tests can
13+
// override it.
14+
var calculateSimilarityBatch = candle_binding.CalculateSimilarityBatch
15+
16+
// EmbeddingClassifierInitializer initializes KeywordEmbeddingClassifier for embedding based classification
17+
type EmbeddingClassifierInitializer interface {
18+
Init(qwen3ModelPath string, gemmaModelPath string, useCPU bool) error
19+
}
20+
21+
type ExternalModelBasedEmbeddingInitializer struct{}
22+
23+
func (c *ExternalModelBasedEmbeddingInitializer) Init(qwen3ModelPath string, gemmaModelPath string, useCPU bool) error {
24+
err := candle_binding.InitEmbeddingModels(qwen3ModelPath, gemmaModelPath, useCPU)
25+
if err != nil {
26+
return err
27+
}
28+
logging.Infof("Initialized KeywordEmbedding classifier with qwen3 model path %q and gemma model path %s", qwen3ModelPath, gemmaModelPath)
29+
return nil
30+
}
31+
32+
// createEmbeddingInitializer creates the appropriate keyword embedding initializer based on configuration
33+
func createEmbeddingInitializer() EmbeddingClassifierInitializer {
34+
return &ExternalModelBasedEmbeddingInitializer{}
35+
}
36+
37+
type EmbeddingClassifier struct {
38+
rules []config.EmbeddingRule
39+
}
40+
41+
// NewKeywordClassifier creates a new KeywordEmbeddingClassifier.
42+
func NewEmbeddingClassifier(cfgRules []config.EmbeddingRule) (*EmbeddingClassifier, error) {
43+
return &EmbeddingClassifier{rules: cfgRules}, nil
44+
}
45+
46+
// IsKeywordEmbeddingClassifierEnabled checks if Keyword embedding classification rules are properly configured
47+
func (c *Classifier) IsKeywordEmbeddingClassifierEnabled() bool {
48+
return len(c.Config.EmbeddingRules) > 0
49+
}
50+
51+
// initializeKeywordEmbeddingClassifier initializes the KeywordEmbedding classification model
52+
func (c *Classifier) initializeKeywordEmbeddingClassifier() error {
53+
if !c.IsKeywordEmbeddingClassifierEnabled() || c.keywordEmbeddingInitializer == nil {
54+
return fmt.Errorf("keyword embedding similarity match is not properly configured")
55+
}
56+
return c.keywordEmbeddingInitializer.Init(c.Config.InlineModels.Qwen3ModelPath, c.Config.InlineModels.GemmaModelPath, c.Config.InlineModels.EmbeddingModels.UseCPU)
57+
}
58+
59+
// Classify performs keyword-based embedding similarity classification on the given text.
60+
func (c *EmbeddingClassifier) Classify(text string) (string, float64, error) {
61+
var bestScore float32
62+
var mostMatchedCategory string
63+
for _, rule := range c.rules {
64+
matched, aggregatedScore, err := c.matches(text, rule) // Error handled
65+
if err != nil {
66+
return "", 0.0, err // Propagate error
67+
}
68+
if matched {
69+
if len(rule.Keywords) > 0 {
70+
logging.Infof("Keyword-based embedding similarity classification matched category %q with keywords: %v, confidence score %s", rule.Category, rule.Keywords, aggregatedScore)
71+
} else {
72+
logging.Infof("Keyword-based embedding similarity classification do not match category %q with keywords: %v, confidence score %s", rule.Category, rule.Keywords, aggregatedScore)
73+
}
74+
if aggregatedScore > bestScore {
75+
bestScore = aggregatedScore
76+
mostMatchedCategory = rule.Category
77+
}
78+
}
79+
}
80+
return mostMatchedCategory, float64(bestScore), nil
81+
}
82+
83+
// matches checks if the text matches the given keyword rule.
84+
func (c *EmbeddingClassifier) matches(text string, rule config.EmbeddingRule) (bool, float32, error) {
85+
// Validate input
86+
if text == "" {
87+
return false, 0.0, fmt.Errorf("keyword-based embedding similarity classification: query must be provided")
88+
}
89+
if len(rule.Keywords) == 0 {
90+
return false, 0.0, fmt.Errorf("keyword-based embedding similarity classification: keywords must be provided")
91+
}
92+
// Set defaults
93+
if rule.Dimension == 0 {
94+
rule.Dimension = 768 // Default to full dimension
95+
}
96+
if rule.Model == "auto" && rule.QualityPriority == 0 && rule.LatencyPriority == 0 {
97+
rule.QualityPriority = 0.5
98+
rule.LatencyPriority = 0.5
99+
}
100+
101+
// Validate dimension
102+
validDimensions := map[int]bool{128: true, 256: true, 512: true, 768: true, 1024: true}
103+
if !validDimensions[rule.Dimension] {
104+
return false, 0.0, fmt.Errorf("keyword-based embedding similarity classification: dimension must be one of: 128, 256, 512, 768, 1024 (got %d)", rule.Dimension)
105+
}
106+
// Calculate batch similarity
107+
result, err := calculateSimilarityBatch(
108+
text,
109+
rule.Keywords,
110+
0, // return scores for all the keywords
111+
rule.Model,
112+
rule.Dimension,
113+
)
114+
if err != nil {
115+
return false, 0.0, fmt.Errorf("keyword-based embedding similarity classification: failed to calculate batch similarity: %w", err)
116+
}
117+
// Check for matches based on the aggregation method
118+
switch rule.AggregationMethodConfiged {
119+
case config.AggregationMethodMean:
120+
var aggregatedScore float32
121+
for _, match := range result.Matches {
122+
aggregatedScore += match.Similarity
123+
}
124+
aggregatedScore /= float32(len(result.Matches))
125+
if aggregatedScore >= rule.SimilarityThreshold {
126+
return true, aggregatedScore, nil
127+
} else {
128+
return false, aggregatedScore, nil
129+
}
130+
case config.AggregationMethodMax:
131+
var aggregatedScore float32
132+
for _, match := range result.Matches {
133+
if match.Similarity > aggregatedScore {
134+
aggregatedScore = match.Similarity
135+
}
136+
}
137+
if aggregatedScore >= rule.SimilarityThreshold {
138+
return true, aggregatedScore, nil
139+
} else {
140+
return false, aggregatedScore, nil
141+
}
142+
case config.AggregationMethodAny:
143+
for _, match := range result.Matches {
144+
if match.Similarity >= rule.SimilarityThreshold {
145+
return true, rule.SimilarityThreshold, nil
146+
}
147+
}
148+
return false, 0.0, nil
149+
150+
}
151+
return false, 0.0, fmt.Errorf("keyword-based embedding similarity classification: unsupported keyword rule aggregation method: %q", rule.AggregationMethodConfiged)
152+
}

src/semantic-router/pkg/config/config.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ type IntelligentRouting struct {
8181
// Keyword-based classification rules
8282
KeywordRules []KeywordRule `yaml:"keyword_rules,omitempty"`
8383

84+
// Embedding-based classification rules
85+
EmbeddingRules []EmbeddingRule `yaml:"embedding_rules,omitempty"`
86+
8487
// Categories for routing queries
8588
Categories []Category `yaml:"categories"`
8689

@@ -199,6 +202,27 @@ type KeywordRule struct {
199202
CaseSensitive bool `yaml:"case_sensitive"`
200203
}
201204

205+
// Aggregation method used in keyword embedding rule
206+
type AggregationMethod string
207+
208+
const (
209+
AggregationMethodMean AggregationMethod = "mean"
210+
AggregationMethodMax AggregationMethod = "max"
211+
AggregationMethodAny AggregationMethod = "any"
212+
)
213+
214+
// EmbeddingRule defines a rule for keyword embedding based similarity match rule.
215+
type EmbeddingRule struct {
216+
Category string `yaml:"category"`
217+
SimilarityThreshold float32 `yaml:"threshold"`
218+
Keywords []string `yaml:"keywords"`
219+
AggregationMethodConfiged AggregationMethod `yaml:"aggregation_mathod"`
220+
Model string `json:"model,omitempty"` // "auto" (default), "qwen3", "gemma"
221+
Dimension int `json:"dimension,omitempty"` // Target dimension: 768 (default), 512, 256, 128
222+
QualityPriority float32 `json:"quality_priority,omitempty"` // 0.0-1.0, only for "auto" model
223+
LatencyPriority float32 `json:"latency_priority,omitempty"` // 0.0-1.0, only for "auto" model
224+
}
225+
202226
// APIConfig represents configuration for API endpoints
203227
type APIConfig struct {
204228
// Batch classification configuration (zero-config auto-discovery)

0 commit comments

Comments
 (0)