Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 137 additions & 4 deletions src/semantic-router/pkg/services/classification.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func GetGlobalClassificationService() *ClassificationService {

// HasClassifier returns true if the service has a real classifier (not placeholder)
func (s *ClassificationService) HasClassifier() bool {
return s.classifier != nil
return s.unifiedClassifier != nil || s.classifier != nil
}

// NewPlaceholderClassificationService creates a placeholder service for API-only mode
Expand Down Expand Up @@ -118,7 +118,12 @@ func (s *ClassificationService) ClassifyIntent(req IntentRequest) (*IntentRespon
return nil, fmt.Errorf("text cannot be empty")
}

// Check if classifier is available
// Prioritize unified classifier if available
if s.unifiedClassifier != nil {
return s.ClassifyIntentUnified(req)
}

// Check if legacy classifier is available
if s.classifier == nil {
// Return placeholder response
processingTime := time.Since(start).Milliseconds()
Expand Down Expand Up @@ -210,7 +215,12 @@ func (s *ClassificationService) DetectPII(req PIIRequest) (*PIIResponse, error)
return nil, fmt.Errorf("text cannot be empty")
}

// Check if classifier is available
// Prioritize unified classifier if available
if s.unifiedClassifier != nil {
return s.DetectPIIUnified(req)
}

// Check if legacy classifier is available
if s.classifier == nil {
// Return placeholder response
processingTime := time.Since(start).Milliseconds()
Expand Down Expand Up @@ -290,7 +300,12 @@ func (s *ClassificationService) CheckSecurity(req SecurityRequest) (*SecurityRes
return nil, fmt.Errorf("text cannot be empty")
}

// Check if classifier is available
// Prioritize unified classifier if available
if s.unifiedClassifier != nil {
return s.CheckSecurityUnified(req)
}

// Check if legacy classifier is available
if s.classifier == nil {
// Return placeholder response
processingTime := time.Since(start).Milliseconds()
Expand Down Expand Up @@ -454,6 +469,59 @@ func (s *ClassificationService) ClassifyPIIUnified(texts []string) ([]classifica
return results.PIIResults, nil
}

// DetectPIIUnified performs PII detection using unified classifier and returns PIIResponse format
func (s *ClassificationService) DetectPIIUnified(req PIIRequest) (*PIIResponse, error) {
start := time.Now()

if req.Text == "" {
return nil, fmt.Errorf("text cannot be empty")
}

// Use unified classifier for PII detection
piiResults, err := s.ClassifyPIIUnified([]string{req.Text})
if err != nil {
return nil, fmt.Errorf("PII detection failed: %w", err)
}

processingTime := time.Since(start).Milliseconds()

// Convert PIIResult to PIIResponse format
if len(piiResults) == 0 {
return &PIIResponse{
HasPII: false,
Entities: []PIIEntity{},
SecurityRecommendation: "allow",
ProcessingTimeMs: processingTime,
}, nil
}

piiResult := piiResults[0]
response := &PIIResponse{
HasPII: piiResult.HasPII,
Entities: []PIIEntity{},
ProcessingTimeMs: processingTime,
}

// Convert PII types to entities
for _, piiType := range piiResult.PIITypes {
entity := PIIEntity{
Type: piiType,
Value: "[DETECTED]", // Placeholder - unified classifier doesn't provide exact positions yet
Confidence: float64(piiResult.Confidence),
}
response.Entities = append(response.Entities, entity)
}

// Set security recommendation
if response.HasPII {
response.SecurityRecommendation = "block"
} else {
response.SecurityRecommendation = "allow"
}

return response, nil
}

// ClassifySecurityUnified performs security detection using unified classifier
func (s *ClassificationService) ClassifySecurityUnified(texts []string) ([]classification.SecurityResult, error) {
if s.unifiedClassifier == nil {
Expand All @@ -468,6 +536,71 @@ func (s *ClassificationService) ClassifySecurityUnified(texts []string) ([]class
return results.SecurityResults, nil
}

// CheckSecurityUnified performs security detection using unified classifier and returns SecurityResponse format
func (s *ClassificationService) CheckSecurityUnified(req SecurityRequest) (*SecurityResponse, error) {
start := time.Now()

if req.Text == "" {
return nil, fmt.Errorf("text cannot be empty")
}

// Use unified classifier for security detection
securityResults, err := s.ClassifySecurityUnified([]string{req.Text})
if err != nil {
return nil, fmt.Errorf("security detection failed: %w", err)
}

processingTime := time.Since(start).Milliseconds()

// Convert SecurityResult to SecurityResponse format
if len(securityResults) == 0 {
return &SecurityResponse{
IsJailbreak: false,
RiskScore: 0.1,
DetectionTypes: []string{},
Confidence: 0.9,
Recommendation: "allow",
PatternsDetected: []string{},
ProcessingTimeMs: processingTime,
}, nil
}

securityResult := securityResults[0]
response := &SecurityResponse{
IsJailbreak: securityResult.IsJailbreak,
RiskScore: float64(securityResult.Confidence),
Confidence: float64(securityResult.Confidence),
ProcessingTimeMs: processingTime,
}

// Set detection types based on threat type
if securityResult.ThreatType != "" {
response.DetectionTypes = []string{securityResult.ThreatType}
response.PatternsDetected = []string{securityResult.ThreatType}
} else {
response.DetectionTypes = []string{}
response.PatternsDetected = []string{}
}

// Set recommendation based on jailbreak detection
if response.IsJailbreak {
response.Recommendation = "block"
} else {
response.Recommendation = "allow"
}

// Add reasoning if requested
if req.Options != nil && req.Options.IncludeReasoning {
if response.IsJailbreak {
response.Reasoning = fmt.Sprintf("Detected %s with confidence %.2f", securityResult.ThreatType, securityResult.Confidence)
} else {
response.Reasoning = "No security threats detected"
}
}

return response, nil
}

// HasUnifiedClassifier returns true if the service has a unified classifier
func (s *ClassificationService) HasUnifiedClassifier() bool {
return s.unifiedClassifier != nil && s.unifiedClassifier.IsInitialized()
Expand Down
Loading