Skip to content

Commit ce45ae8

Browse files
authored
Implementing inference api (#12)
* save weak supervision stats as pickle * adds endpoint to export weak supervision statistics * adds export of weak supervision statistics for extraction tasks * pr comments
1 parent b5cbee4 commit ce45ae8

File tree

3 files changed

+79
-6
lines changed

3 files changed

+79
-6
lines changed

app.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from fastapi import FastAPI
1+
from fastapi import FastAPI, HTTPException, responses
2+
from pydantic import BaseModel
23

34
from controller import stats
45
from controller import integration
@@ -7,8 +8,6 @@
78
# API creation and description
89
app = FastAPI()
910

10-
from pydantic import BaseModel
11-
1211

1312
class WeakSupervisionRequest(BaseModel):
1413
project_id: str
@@ -29,6 +28,11 @@ class SourceStatsRequest(BaseModel):
2928
user_id: str
3029

3130

31+
class ExportWsStatsRequest(BaseModel):
32+
project_id: str
33+
labeling_task_id: str
34+
35+
3236
@app.post("/fit_predict")
3337
async def weakly_supervise(request: WeakSupervisionRequest) -> int:
3438
session_token = general.get_ctx_token()
@@ -43,7 +47,7 @@ async def weakly_supervise(request: WeakSupervisionRequest) -> int:
4347

4448

4549
@app.post("/labeling_task_statistics")
46-
async def calculate_stats(request: TaskStatsRequest):
50+
async def calculate_task_stats(request: TaskStatsRequest):
4751
session_token = general.get_ctx_token()
4852
stats.calculate_quality_statistics_for_labeling_task(
4953
request.project_id, request.labeling_task_id, request.user_id
@@ -53,7 +57,7 @@ async def calculate_stats(request: TaskStatsRequest):
5357

5458

5559
@app.post("/source_statistics")
56-
async def calculate_stats(request: SourceStatsRequest):
60+
async def calculate_source_stats(request: SourceStatsRequest):
5761
session_token = general.get_ctx_token()
5862
has_coverage = stats.calculate_quantity_statistics_for_labeling_task_from_source(
5963
request.project_id, request.source_id, request.user_id
@@ -64,3 +68,16 @@ async def calculate_stats(request: SourceStatsRequest):
6468
)
6569
general.remove_and_refresh_session(session_token)
6670
return None, 200
71+
72+
73+
@app.post("/export_ws_stats")
74+
async def export_ws_stats(request: ExportWsStatsRequest) -> responses.HTMLResponse:
75+
session_token = general.get_ctx_token()
76+
status_code, message = integration.export_weak_supervision_stats(
77+
request.project_id, request.labeling_task_id
78+
)
79+
general.remove_and_refresh_session(session_token)
80+
81+
if status_code != 200:
82+
raise HTTPException(status_code=status_code, detail=message)
83+
return responses.HTMLResponse(status_code=status_code)

controller/integration.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import os
12
from typing import Any, Dict, List, Tuple
23
import traceback
34
import pandas as pd
5+
import pickle
46
from collections import defaultdict
57

68
from submodules.model.models import (
@@ -26,6 +28,7 @@ def fit_predict(
2628
try:
2729
if task_type == enums.LabelingTaskType.CLASSIFICATION.value:
2830
results = integrate_classification(df)
31+
2932
else:
3033
results = integrate_extraction(df)
3134
weak_supervision.store_data(
@@ -37,7 +40,7 @@ def fit_predict(
3740
weak_supervision_task_id,
3841
with_commit=True,
3942
)
40-
except:
43+
except Exception:
4144
print(traceback.format_exc(), flush=True)
4245
general.rollback()
4346
weak_supervision.update_state(
@@ -48,6 +51,44 @@ def fit_predict(
4851
)
4952

5053

54+
def export_weak_supervision_stats(
55+
project_id: str, labeling_task_id: str
56+
) -> Tuple[int, str]:
57+
58+
task_type, df = collect_data(project_id, labeling_task_id, False)
59+
try:
60+
if task_type == enums.LabelingTaskType.CLASSIFICATION.value:
61+
cnlm = util.get_cnlm_from_df(df)
62+
stats_df = cnlm.quality_metrics()
63+
elif task_type == enums.LabelingTaskType.INFORMATION_EXTRACTION.value:
64+
enlm = util.get_enlm_from_df(df)
65+
stats_df = enlm.quality_metrics()
66+
else:
67+
return 404, f"Task type {task_type} not implemented"
68+
69+
if len(stats_df) != 0:
70+
stats_lkp = stats_df.set_index(["identifier", "label_name"]).to_dict(
71+
orient="index"
72+
)
73+
else:
74+
return 404, "Can't compute weak supervision"
75+
76+
os.makedirs(os.path.join("/inference", project_id), exist_ok=True)
77+
with open(
78+
os.path.join(
79+
"/inference", project_id, f"weak-supervision-{labeling_task_id}.pkl"
80+
),
81+
"wb",
82+
) as f:
83+
pickle.dump(stats_lkp, f)
84+
85+
except Exception:
86+
print(traceback.format_exc(), flush=True)
87+
general.rollback()
88+
return 500, "Internal server error"
89+
return 200, "OK"
90+
91+
5192
def integrate_classification(df: pd.DataFrame):
5293
cnlm = util.get_cnlm_from_df(df)
5394
weak_supervision_results = cnlm.weakly_supervise()

start

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,20 @@ echo -ne 'stopping old container...'
55
docker stop refinery-weak-supervisor > /dev/null 2>&1
66
echo -ne '\t [done]\n'
77

8+
INFERENCE_DIR=${PWD%/*}/dev-setup/inference/
9+
if [ ! -d "$_DIR" ]
10+
then
11+
INFERENCE_DIR=${PWD%/*/*}/dev-setup/inference/
12+
if [ ! -d "$INFERENCE_DIR" ]
13+
then
14+
# to include volume for local development, use the dev-setup inference folder:
15+
# alternative use manual logic with
16+
# -v /path/to/dev-setup/inference:/models \
17+
echo "Can't find model data directory: $INFERENCE_DIR -> stopping"
18+
exit 1
19+
fi
20+
fi
21+
822
echo -ne 'building container...'
923
docker build -t refinery-weak-supervisor-dev -f dev.Dockerfile . > /dev/null 2>&1
1024
echo -ne '\t\t [done]\n'
@@ -17,6 +31,7 @@ docker run -d --rm \
1731
-e WS_NOTIFY_ENDPOINT="http://refinery-websocket:8080" \
1832
--mount type=bind,source="$(pwd)"/,target=/app \
1933
-v /var/run/docker.sock:/var/run/docker.sock \
34+
-v "$INFERENCE_DIR":/inference \
2035
--network dev-setup_default \
2136
refinery-weak-supervisor-dev > /dev/null 2>&1
2237
echo -ne '\t\t\t [done]\n'

0 commit comments

Comments
 (0)