Skip to content

Commit c6cc86c

Browse files
Weak Supervision without manual labels (#33)
* implements overwriting and default value of ws stats * adds extraction task to weak supervision overwrite * updates submodule * pr comments, changed attribute name for clarity * pr comments, typing * updates weak_nlp version
1 parent 1c0896c commit c6cc86c

File tree

5 files changed

+114
-48
lines changed

5 files changed

+114
-48
lines changed

app.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from fastapi import FastAPI, HTTPException, responses, status
22
from pydantic import BaseModel
3+
from typing import Union, Dict, Optional
34

45
from controller import stats
56
from controller import integration
@@ -14,6 +15,7 @@ class WeakSupervisionRequest(BaseModel):
1415
labeling_task_id: str
1516
user_id: str
1617
weak_supervision_task_id: str
18+
overwrite_weak_supervision: Optional[Union[float, Dict[str, float]]]
1719

1820

1921
class TaskStatsRequest(BaseModel):
@@ -31,6 +33,7 @@ class SourceStatsRequest(BaseModel):
3133
class ExportWsStatsRequest(BaseModel):
3234
project_id: str
3335
labeling_task_id: str
36+
overwrite_weak_supervision: Optional[Union[float, Dict[str, float]]]
3437

3538

3639
@app.post("/fit_predict")
@@ -43,6 +46,7 @@ def weakly_supervise(
4346
request.labeling_task_id,
4447
request.user_id,
4548
request.weak_supervision_task_id,
49+
request.overwrite_weak_supervision,
4650
)
4751
general.remove_and_refresh_session(session_token)
4852
return responses.PlainTextResponse(status_code=status.HTTP_200_OK)
@@ -80,7 +84,7 @@ def calculate_source_stats(
8084
def export_ws_stats(request: ExportWsStatsRequest) -> responses.PlainTextResponse:
8185
session_token = general.get_ctx_token()
8286
status_code, message = integration.export_weak_supervision_stats(
83-
request.project_id, request.labeling_task_id
87+
request.project_id, request.labeling_task_id, request.overwrite_weak_supervision
8488
)
8589
general.remove_and_refresh_session(session_token)
8690

controller/integration.py

Lines changed: 106 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import Any, Dict, List, Tuple
2+
from typing import Any, Dict, List, Tuple, Optional, Union
33
import traceback
44
import pandas as pd
55
import pickle
@@ -18,19 +18,62 @@
1818
labeling_task,
1919
record_label_association,
2020
weak_supervision,
21+
labeling_task_label,
22+
information_source,
2123
)
2224

25+
NO_LABEL_WS_PRECISION = 0.8
26+
27+
28+
def __create_quality_metrics(
29+
project_id: str,
30+
labeling_task_id: str,
31+
overwrite_weak_supervision: Union[float, Dict[str, float]],
32+
) -> Dict[Tuple[str, str], Dict[str, float]]:
33+
if isinstance(overwrite_weak_supervision, float):
34+
ws_weights = {}
35+
for heuristic_id in information_source.get_all_ids_by_labeling_task_id(
36+
project_id, labeling_task_id
37+
):
38+
ws_weights[str(heuristic_id)] = overwrite_weak_supervision
39+
else:
40+
ws_weights = overwrite_weak_supervision
41+
42+
ws_stats = {}
43+
for heuristic_id in ws_weights:
44+
label_ids = labeling_task_label.get_all_ids(project_id, labeling_task_id)
45+
for (label_id,) in label_ids:
46+
ws_stats[(heuristic_id, str(label_id))] = {
47+
"precision": ws_weights[heuristic_id]
48+
}
49+
return ws_stats
50+
2351

2452
def fit_predict(
25-
project_id: str, labeling_task_id: str, user_id: str, weak_supervision_task_id: str
53+
project_id: str,
54+
labeling_task_id: str,
55+
user_id: str,
56+
weak_supervision_task_id: str,
57+
overwrite_weak_supervision: Optional[Union[float, Dict[str, float]]] = None,
2658
):
59+
quality_metrics_overwrite = None
60+
if overwrite_weak_supervision is not None:
61+
quality_metrics_overwrite = __create_quality_metrics(
62+
project_id, labeling_task_id, overwrite_weak_supervision
63+
)
64+
elif not record_label_association.is_any_record_manually_labeled(
65+
project_id, labeling_task_id
66+
):
67+
quality_metrics_overwrite = __create_quality_metrics(
68+
project_id, labeling_task_id, NO_LABEL_WS_PRECISION
69+
)
70+
2771
task_type, df = collect_data(project_id, labeling_task_id, True)
2872
try:
2973
if task_type == enums.LabelingTaskType.CLASSIFICATION.value:
30-
results = integrate_classification(df)
31-
74+
results = integrate_classification(df, quality_metrics_overwrite)
3275
else:
33-
results = integrate_extraction(df)
76+
results = integrate_extraction(df, quality_metrics_overwrite)
3477
weak_supervision.store_data(
3578
project_id,
3679
labeling_task_id,
@@ -52,46 +95,62 @@ def fit_predict(
5295

5396

5497
def export_weak_supervision_stats(
55-
project_id: str, labeling_task_id: str
98+
project_id: str,
99+
labeling_task_id: str,
100+
overwrite_weak_supervision: Optional[Union[float, Dict[str, float]]] = None,
56101
) -> Tuple[int, str]:
102+
if overwrite_weak_supervision is not None:
103+
ws_stats = __create_quality_metrics(
104+
project_id, labeling_task_id, overwrite_weak_supervision
105+
)
106+
elif not record_label_association.is_any_record_manually_labeled(
107+
project_id, labeling_task_id
108+
):
109+
ws_stats = __create_quality_metrics(
110+
project_id, labeling_task_id, NO_LABEL_WS_PRECISION
111+
)
112+
else:
113+
task_type, df = collect_data(project_id, labeling_task_id, False)
114+
try:
115+
if task_type == enums.LabelingTaskType.CLASSIFICATION.value:
116+
cnlm = util.get_cnlm_from_df(df)
117+
stats_df = cnlm.quality_metrics()
118+
elif task_type == enums.LabelingTaskType.INFORMATION_EXTRACTION.value:
119+
enlm = util.get_enlm_from_df(df)
120+
stats_df = enlm.quality_metrics()
121+
else:
122+
return 404, f"Task type {task_type} not implemented"
57123

58-
task_type, df = collect_data(project_id, labeling_task_id, False)
59-
try:
60-
if task_type == enums.LabelingTaskType.CLASSIFICATION.value:
61-
cnlm = util.get_cnlm_from_df(df)
62-
stats_df = cnlm.quality_metrics()
63-
elif task_type == enums.LabelingTaskType.INFORMATION_EXTRACTION.value:
64-
enlm = util.get_enlm_from_df(df)
65-
stats_df = enlm.quality_metrics()
66-
else:
67-
return 404, f"Task type {task_type} not implemented"
124+
if len(stats_df) != 0:
125+
ws_stats = stats_df.set_index(["identifier", "label_name"]).to_dict(
126+
orient="index"
127+
)
128+
else:
129+
return 404, "Can't compute weak supervision"
68130

69-
if len(stats_df) != 0:
70-
stats_lkp = stats_df.set_index(["identifier", "label_name"]).to_dict(
71-
orient="index"
72-
)
73-
else:
74-
return 404, "Can't compute weak supervision"
131+
except Exception:
132+
print(traceback.format_exc(), flush=True)
133+
general.rollback()
134+
return 500, "Internal server error"
75135

76-
os.makedirs(os.path.join("/inference", project_id), exist_ok=True)
77-
with open(
78-
os.path.join(
79-
"/inference", project_id, f"weak-supervision-{labeling_task_id}.pkl"
80-
),
81-
"wb",
82-
) as f:
83-
pickle.dump(stats_lkp, f)
136+
os.makedirs(os.path.join("/inference", project_id), exist_ok=True)
137+
with open(
138+
os.path.join(
139+
"/inference", project_id, f"weak-supervision-{labeling_task_id}.pkl"
140+
),
141+
"wb",
142+
) as f:
143+
pickle.dump(ws_stats, f)
84144

85-
except Exception:
86-
print(traceback.format_exc(), flush=True)
87-
general.rollback()
88-
return 500, "Internal server error"
89145
return 200, "OK"
90146

91147

92-
def integrate_classification(df: pd.DataFrame):
148+
def integrate_classification(
149+
df: pd.DataFrame,
150+
quality_metrics_overwrite: Optional[Dict[Tuple[str, str], Dict[str, float]]] = None,
151+
):
93152
cnlm = util.get_cnlm_from_df(df)
94-
weak_supervision_results = cnlm.weakly_supervise()
153+
weak_supervision_results = cnlm.weakly_supervise(quality_metrics_overwrite)
95154
return_values = defaultdict(list)
96155
for record_id, (
97156
label_id,
@@ -103,9 +162,12 @@ def integrate_classification(df: pd.DataFrame):
103162
return return_values
104163

105164

106-
def integrate_extraction(df: pd.DataFrame):
165+
def integrate_extraction(
166+
df: pd.DataFrame,
167+
quality_metrics_overwrite: Optional[Dict[Tuple[str, str], Dict[str, float]]] = None,
168+
):
107169
enlm = util.get_enlm_from_df(df)
108-
weak_supervision_results = enlm.weakly_supervise()
170+
weak_supervision_results = enlm.weakly_supervise(quality_metrics_overwrite)
109171
return_values = defaultdict(list)
110172
for record_id, preds in weak_supervision_results.items():
111173
for pred in preds:
@@ -128,12 +190,12 @@ def collect_data(
128190

129191
query_results = []
130192
if labeling_task_item.task_type == enums.LabelingTaskType.CLASSIFICATION.value:
131-
for information_source in labeling_task_item.information_sources:
132-
if only_selected and not information_source.is_selected:
193+
for information_source_item in labeling_task_item.information_sources:
194+
if only_selected and not information_source_item.is_selected:
133195
continue
134196
results = (
135197
record_label_association.get_all_classifications_for_information_source(
136-
project_id, information_source.id
198+
project_id, information_source_item.id
137199
)
138200
)
139201
query_results.extend(results)
@@ -149,11 +211,11 @@ def collect_data(
149211
labeling_task_item.task_type
150212
== enums.LabelingTaskType.INFORMATION_EXTRACTION.value
151213
):
152-
for information_source in labeling_task_item.information_sources:
153-
if only_selected and not information_source.is_selected:
214+
for information_source_item in labeling_task_item.information_sources:
215+
if only_selected and not information_source_item.is_selected:
154216
continue
155217
results = record_label_association.get_all_extraction_tokens_for_information_source(
156-
project_id, information_source.id
218+
project_id, information_source_item.id
157219
)
158220
query_results.extend(results)
159221

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,5 +107,5 @@ urllib3==1.26.16
107107
# requests
108108
uvicorn==0.22.0
109109
# via -r requirements/common-requirements.txt
110-
weak-nlp==0.0.12
110+
weak-nlp==0.0.13
111111
# via -r requirements/requirements.in

requirements/requirements.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
-r common-requirements.txt
2-
weak-nlp==0.0.12
2+
weak-nlp==0.0.13

0 commit comments

Comments
 (0)