@@ -49,9 +49,15 @@ impl SecurityLoRAClassifier {
4949 candle_core:: Error :: from ( unified_err)
5050 } ) ?;
5151
52+ // Load threshold from global config instead of hardcoding
53+ let confidence_threshold = {
54+ use crate :: core:: config_loader:: GlobalConfigLoader ;
55+ GlobalConfigLoader :: load_security_threshold ( ) . unwrap_or ( 0.7 ) // Default from config.yaml prompt_guard.threshold
56+ } ;
57+
5258 Ok ( Self {
5359 bert_classifier,
54- confidence_threshold : 0.5 ,
60+ confidence_threshold,
5561 threat_types,
5662 model_path : model_path. to_string ( ) ,
5763 } )
@@ -83,22 +89,38 @@ impl SecurityLoRAClassifier {
8389 candle_core:: Error :: from ( unified_err)
8490 } ) ?;
8591
86- // Determine if threat is detected based on predicted class
87- let is_threat = predicted_class > 0 ; // Assuming class 0 is "benign" or "safe"
92+ // Map class index to threat type label - fail if class not found
93+ let threat_type = if predicted_class < self . threat_types . len ( ) {
94+ self . threat_types [ predicted_class] . clone ( )
95+ } else {
96+ let unified_err = model_error ! (
97+ ModelErrorType :: LoRA ,
98+ "security classification" ,
99+ format!(
100+ "Invalid class index {} not found in labels (max: {})" ,
101+ predicted_class,
102+ self . threat_types. len( )
103+ ) ,
104+ text
105+ ) ;
106+ return Err ( candle_core:: Error :: from ( unified_err) ) ;
107+ } ;
88108
89- // Get detected threat types
90- let mut detected_threats = Vec :: new ( ) ;
91- if is_threat && predicted_class < self . threat_types . len ( ) {
92- detected_threats. push ( self . threat_types [ predicted_class] . clone ( ) ) ;
93- }
109+ // Determine if threat is detected based on class label (instead of hardcoded index)
110+ let is_threat = !threat_type. to_lowercase ( ) . contains ( "safe" )
111+ && !threat_type. to_lowercase ( ) . contains ( "benign" )
112+ && !threat_type. to_lowercase ( ) . contains ( "no_threat" ) ;
94113
95- // Calculate severity score based on confidence and threat type
96- let severity_score = if is_threat {
97- confidence * 0.9 // High severity for detected threats
114+ // Get detected threat types
115+ let detected_threats = if is_threat {
116+ vec ! [ threat_type ]
98117 } else {
99- 0.0 // No severity for safe content
118+ Vec :: new ( )
100119 } ;
101120
121+ // Use confidence as severity score (no artificial scaling)
122+ let severity_score = if is_threat { confidence } else { 0.0 } ;
123+
102124 let processing_time = start_time. elapsed ( ) . as_millis ( ) as u64 ;
103125
104126 Ok ( SecurityResult {
@@ -129,24 +151,41 @@ impl SecurityLoRAClassifier {
129151 let processing_time = start_time. elapsed ( ) . as_millis ( ) as u64 ;
130152
131153 let mut results = Vec :: new ( ) ;
132- for ( predicted_class, confidence) in batch_results {
133- // Determine if threat is detected
134- let is_threat = predicted_class > 0 ; // Assuming class 0 is "benign"
154+ for ( i, ( predicted_class, confidence) ) in batch_results. iter ( ) . enumerate ( ) {
155+ // Map class index to threat type label - fail if class not found
156+ let threat_type = if * predicted_class < self . threat_types . len ( ) {
157+ self . threat_types [ * predicted_class] . clone ( )
158+ } else {
159+ let unified_err = model_error ! (
160+ ModelErrorType :: LoRA ,
161+ "batch security classification" ,
162+ format!( "Invalid class index {} not found in labels (max: {}) for text at position {}" ,
163+ predicted_class, self . threat_types. len( ) , i) ,
164+ & format!( "batch[{}]" , i)
165+ ) ;
166+ return Err ( candle_core:: Error :: from ( unified_err) ) ;
167+ } ;
168+
169+ // Determine if threat is detected based on class label
170+ let is_threat = !threat_type. to_lowercase ( ) . contains ( "safe" )
171+ && !threat_type. to_lowercase ( ) . contains ( "benign" )
172+ && !threat_type. to_lowercase ( ) . contains ( "no_threat" ) ;
135173
136174 // Get detected threat types
137- let mut detected_threats = Vec :: new ( ) ;
138- if is_threat && predicted_class < self . threat_types . len ( ) {
139- detected_threats. push ( self . threat_types [ predicted_class] . clone ( ) ) ;
140- }
175+ let detected_threats = if is_threat {
176+ vec ! [ threat_type]
177+ } else {
178+ Vec :: new ( )
179+ } ;
141180
142- // Calculate severity score
143- let severity_score = if is_threat { confidence * 0.9 } else { 0.0 } ;
181+ // Use confidence as severity score (no artificial scaling)
182+ let severity_score = if is_threat { * confidence } else { 0.0 } ;
144183
145184 results. push ( SecurityResult {
146185 is_threat,
147186 threat_types : detected_threats,
148187 severity_score,
149- confidence,
188+ confidence : * confidence ,
150189 processing_time_ms : processing_time,
151190 } ) ;
152191 }
0 commit comments