33from nltk .tokenize import sent_tokenize
44from transformers import RobertaTokenizer
55
6+ from generated_text_detector .controllers .schemas_type import Author
7+ from generated_text_detector .utils .preprocessing import preprocessing_text
68from generated_text_detector .utils .model .roberta_classifier import RobertaClassifier
79
810
@@ -21,6 +23,7 @@ def __init__(
2123 model_name_or_path : str ,
2224 device : str ,
2325 max_len : int = 512 ,
26+ preprocessing : bool = False
2427 ) -> None :
2528
2629 self .device = torch .device (device )
@@ -30,6 +33,7 @@ def __init__(
3033 self .model .eval ()
3134
3235 self .__max_len = max_len
36+ self .preprocessing = preprocessing
3337
3438 # Optimizing GPU inference
3539 if self .device .type == 'cuda' :
@@ -60,14 +64,14 @@ def __split_by_chunks(self, text: str) -> list[str]:
6064 for sentence in sent_tokenize (text ):
6165 temp_count_tokens = len (self .tokenizer .encode (sentence ))
6266 if cur_count_tokens + temp_count_tokens > self .__max_len :
63- chunks .append (cur_chunk )
67+ chunks .append (cur_chunk . strip () )
6468 cur_chunk = sentence
6569 cur_count_tokens = temp_count_tokens
6670 else :
6771 cur_count_tokens += temp_count_tokens
6872 cur_chunk += " " + sentence
6973
70- chunks .append (cur_chunk )
74+ chunks .append (cur_chunk . strip () )
7175
7276 return chunks
7377
@@ -80,9 +84,6 @@ def __model_pass(self, texts: list[str]) -> list[float]:
8084 :return: List of scores
8185 :rtype: list[float]
8286 """
83- # Preprocessing
84- texts = [" " .join (text .split ()) for text in texts ]
85-
8687 tokens = self .tokenizer .batch_encode_plus (
8788 texts ,
8889 add_special_tokens = True ,
@@ -111,6 +112,12 @@ def detect(self, text: str) -> list[tuple[str, float]]:
111112 :return: Text chunks with generated scores
112113 :rtype: list[tuple[str, float]]
113114 """
115+ # Preprocessing
116+ if self .preprocessing :
117+ text = preprocessing_text (text )
118+ else :
119+ text = " " .join (text .split ())
120+
114121 text_chunks = self .__split_by_chunks (text )
115122
116123 scores = self .__model_pass (text_chunks ).tolist ()
@@ -120,12 +127,68 @@ def detect(self, text: str) -> list[tuple[str, float]]:
120127 return res
121128
122129
130+ def detect_report (self , text : str ) -> dict :
131+ """Detects if text is generated and prepare a report.
132+
133+ :param text: Input text
134+ :type text: str
135+ :return: Text chunks with generated scores
136+ :rtype: list[tuple[str, float]]
137+ """
138+ # Preprocessing
139+ if self .preprocessing :
140+ text = preprocessing_text (text )
141+ else :
142+ text = " " .join (text .split ())
143+
144+ text_chunks = self .__split_by_chunks (text )
145+ scores = self .__model_pass (text_chunks )
146+
147+ # Average scores
148+ gen_score = sum (scores ) / len (scores )
149+ gen_score = gen_score .item ()
150+ author = self .__determine_author (gen_score ).value
151+
152+ res = {
153+ "generated_score" : gen_score ,
154+ "author" : author
155+ }
156+
157+ return res
158+
159+
160+ @staticmethod
161+ def __determine_author (generated_score : float ) -> Author :
162+ """Function for converting score for final prediction
163+ The generated score is compared with heuristics obtained from analysis on validation data
164+ As a result, we get 5 categories described in the `Author` class
165+
166+ :param text: Generated score from detector model
167+ :type text: float, should be from 0 to 1
168+ :return: Final prediction athor
169+ :rtype: Autrhor
170+ """
171+ assert 0 <= generated_score <= 1
172+
173+ if generated_score > 0.9 :
174+ return Author .LLM_GENERATED
175+ elif generated_score > 0.7 :
176+ return Author .PROBABLY_LLM_GENERATED
177+ elif generated_score > 0.3 :
178+ return Author .NOT_SURE
179+ elif generated_score > 0.1 :
180+ return Author .PROBABLY_HUMAN_WRITTEN
181+ else :
182+ return Author .HUMAN
183+
184+
123185if __name__ == "__main__" :
124186 detector = GeneratedTextDetector (
125187 "SuperAnnotate/ai-detector" ,
126- "cuda:0"
188+ "cuda:0" ,
189+ preprocessing = True
127190 )
128191
129- res = detector .detect ("Hello, world!" )
192+ res = detector .detect_report ("Hello, world!" )
130193
131194 print (res )
0 commit comments