11import os
2- from typing import Any , Dict , List , Tuple
2+ from typing import Any , Dict , List , Tuple , Optional , Union
33import traceback
44import pandas as pd
55import pickle
1818 labeling_task ,
1919 record_label_association ,
2020 weak_supervision ,
21+ labeling_task_label ,
22+ information_source ,
2123)
2224
25+ NO_LABEL_WS_PRECISION = 0.8
26+
27+
28+ def __create_quality_metrics (
29+ project_id : str ,
30+ labeling_task_id : str ,
31+ overwrite_weak_supervision : Union [float , Dict [str , float ]],
32+ ) -> Dict [Tuple [str , str ], Dict [str , float ]]:
33+ if isinstance (overwrite_weak_supervision , float ):
34+ ws_weights = {}
35+ for heuristic_id in information_source .get_all_ids_by_labeling_task_id (
36+ project_id , labeling_task_id
37+ ):
38+ ws_weights [str (heuristic_id )] = overwrite_weak_supervision
39+ else :
40+ ws_weights = overwrite_weak_supervision
41+
42+ ws_stats = {}
43+ for heuristic_id in ws_weights :
44+ label_ids = labeling_task_label .get_all_ids (project_id , labeling_task_id )
45+ for (label_id ,) in label_ids :
46+ ws_stats [(heuristic_id , str (label_id ))] = {
47+ "precision" : ws_weights [heuristic_id ]
48+ }
49+ return ws_stats
50+
2351
2452def fit_predict (
25- project_id : str , labeling_task_id : str , user_id : str , weak_supervision_task_id : str
53+ project_id : str ,
54+ labeling_task_id : str ,
55+ user_id : str ,
56+ weak_supervision_task_id : str ,
57+ overwrite_weak_supervision : Optional [Union [float , Dict [str , float ]]] = None ,
2658):
59+ quality_metrics_overwrite = None
60+ if overwrite_weak_supervision is not None :
61+ quality_metrics_overwrite = __create_quality_metrics (
62+ project_id , labeling_task_id , overwrite_weak_supervision
63+ )
64+ elif not record_label_association .is_any_record_manually_labeled (
65+ project_id , labeling_task_id
66+ ):
67+ quality_metrics_overwrite = __create_quality_metrics (
68+ project_id , labeling_task_id , NO_LABEL_WS_PRECISION
69+ )
70+
2771 task_type , df = collect_data (project_id , labeling_task_id , True )
2872 try :
2973 if task_type == enums .LabelingTaskType .CLASSIFICATION .value :
30- results = integrate_classification (df )
31-
74+ results = integrate_classification (df , quality_metrics_overwrite )
3275 else :
33- results = integrate_extraction (df )
76+ results = integrate_extraction (df , quality_metrics_overwrite )
3477 weak_supervision .store_data (
3578 project_id ,
3679 labeling_task_id ,
@@ -52,46 +95,62 @@ def fit_predict(
5295
5396
5497def export_weak_supervision_stats (
55- project_id : str , labeling_task_id : str
98+ project_id : str ,
99+ labeling_task_id : str ,
100+ overwrite_weak_supervision : Optional [Union [float , Dict [str , float ]]] = None ,
56101) -> Tuple [int , str ]:
102+ if overwrite_weak_supervision is not None :
103+ ws_stats = __create_quality_metrics (
104+ project_id , labeling_task_id , overwrite_weak_supervision
105+ )
106+ elif not record_label_association .is_any_record_manually_labeled (
107+ project_id , labeling_task_id
108+ ):
109+ ws_stats = __create_quality_metrics (
110+ project_id , labeling_task_id , NO_LABEL_WS_PRECISION
111+ )
112+ else :
113+ task_type , df = collect_data (project_id , labeling_task_id , False )
114+ try :
115+ if task_type == enums .LabelingTaskType .CLASSIFICATION .value :
116+ cnlm = util .get_cnlm_from_df (df )
117+ stats_df = cnlm .quality_metrics ()
118+ elif task_type == enums .LabelingTaskType .INFORMATION_EXTRACTION .value :
119+ enlm = util .get_enlm_from_df (df )
120+ stats_df = enlm .quality_metrics ()
121+ else :
122+ return 404 , f"Task type { task_type } not implemented"
57123
58- task_type , df = collect_data (project_id , labeling_task_id , False )
59- try :
60- if task_type == enums .LabelingTaskType .CLASSIFICATION .value :
61- cnlm = util .get_cnlm_from_df (df )
62- stats_df = cnlm .quality_metrics ()
63- elif task_type == enums .LabelingTaskType .INFORMATION_EXTRACTION .value :
64- enlm = util .get_enlm_from_df (df )
65- stats_df = enlm .quality_metrics ()
66- else :
67- return 404 , f"Task type { task_type } not implemented"
124+ if len (stats_df ) != 0 :
125+ ws_stats = stats_df .set_index (["identifier" , "label_name" ]).to_dict (
126+ orient = "index"
127+ )
128+ else :
129+ return 404 , "Can't compute weak supervision"
68130
69- if len (stats_df ) != 0 :
70- stats_lkp = stats_df .set_index (["identifier" , "label_name" ]).to_dict (
71- orient = "index"
72- )
73- else :
74- return 404 , "Can't compute weak supervision"
131+ except Exception :
132+ print (traceback .format_exc (), flush = True )
133+ general .rollback ()
134+ return 500 , "Internal server error"
75135
76- os .makedirs (os .path .join ("/inference" , project_id ), exist_ok = True )
77- with open (
78- os .path .join (
79- "/inference" , project_id , f"weak-supervision-{ labeling_task_id } .pkl"
80- ),
81- "wb" ,
82- ) as f :
83- pickle .dump (stats_lkp , f )
136+ os .makedirs (os .path .join ("/inference" , project_id ), exist_ok = True )
137+ with open (
138+ os .path .join (
139+ "/inference" , project_id , f"weak-supervision-{ labeling_task_id } .pkl"
140+ ),
141+ "wb" ,
142+ ) as f :
143+ pickle .dump (ws_stats , f )
84144
85- except Exception :
86- print (traceback .format_exc (), flush = True )
87- general .rollback ()
88- return 500 , "Internal server error"
89145 return 200 , "OK"
90146
91147
92- def integrate_classification (df : pd .DataFrame ):
148+ def integrate_classification (
149+ df : pd .DataFrame ,
150+ quality_metrics_overwrite : Optional [Dict [Tuple [str , str ], Dict [str , float ]]] = None ,
151+ ):
93152 cnlm = util .get_cnlm_from_df (df )
94- weak_supervision_results = cnlm .weakly_supervise ()
153+ weak_supervision_results = cnlm .weakly_supervise (quality_metrics_overwrite )
95154 return_values = defaultdict (list )
96155 for record_id , (
97156 label_id ,
@@ -103,9 +162,12 @@ def integrate_classification(df: pd.DataFrame):
103162 return return_values
104163
105164
106- def integrate_extraction (df : pd .DataFrame ):
165+ def integrate_extraction (
166+ df : pd .DataFrame ,
167+ quality_metrics_overwrite : Optional [Dict [Tuple [str , str ], Dict [str , float ]]] = None ,
168+ ):
107169 enlm = util .get_enlm_from_df (df )
108- weak_supervision_results = enlm .weakly_supervise ()
170+ weak_supervision_results = enlm .weakly_supervise (quality_metrics_overwrite )
109171 return_values = defaultdict (list )
110172 for record_id , preds in weak_supervision_results .items ():
111173 for pred in preds :
@@ -128,12 +190,12 @@ def collect_data(
128190
129191 query_results = []
130192 if labeling_task_item .task_type == enums .LabelingTaskType .CLASSIFICATION .value :
131- for information_source in labeling_task_item .information_sources :
132- if only_selected and not information_source .is_selected :
193+ for information_source_item in labeling_task_item .information_sources :
194+ if only_selected and not information_source_item .is_selected :
133195 continue
134196 results = (
135197 record_label_association .get_all_classifications_for_information_source (
136- project_id , information_source .id
198+ project_id , information_source_item .id
137199 )
138200 )
139201 query_results .extend (results )
@@ -149,11 +211,11 @@ def collect_data(
149211 labeling_task_item .task_type
150212 == enums .LabelingTaskType .INFORMATION_EXTRACTION .value
151213 ):
152- for information_source in labeling_task_item .information_sources :
153- if only_selected and not information_source .is_selected :
214+ for information_source_item in labeling_task_item .information_sources :
215+ if only_selected and not information_source_item .is_selected :
154216 continue
155217 results = record_label_association .get_all_extraction_tokens_for_information_source (
156- project_id , information_source .id
218+ project_id , information_source_item .id
157219 )
158220 query_results .extend (results )
159221
0 commit comments