Skip to content

Commit a1c00ca

Browse files
committed
[wwb] Update reranker/embedder tests
1 parent 1c2811f commit a1c00ca

File tree

4 files changed

+108
-67
lines changed

4 files changed

+108
-67
lines changed

tools/who_what_benchmark/tests/test_cli_embeddings.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
import subprocess # nosec B404
22
import sys
33
import pytest
4+
import shutil
45
import logging
5-
from test_cli_image import run_wwb
6+
from test_cli_image import run_wwb, get_similarity
67

78

89
logging.basicConfig(level=logging.INFO)
910
logger = logging.getLogger(__name__)
1011

1112

13+
def remove_artifacts(artifacts_path, file_type="outputs"):
14+
logger.info(f"Remove {file_type}")
15+
shutil.rmtree(artifacts_path)
16+
17+
1218
@pytest.mark.parametrize(
1319
("model_id", "model_type"),
1420
[
@@ -21,6 +27,7 @@
2127
def test_embeddings_basic(model_id, model_type, tmp_path):
2228
GT_FILE = tmp_path / "gt.csv"
2329
MODEL_PATH = tmp_path / model_id.replace("/", "_")
30+
SIMILARITY_THRESHOLD = 0.99
2431

2532
result = subprocess.run(["optimum-cli", "export",
2633
"openvino", "-m", model_id,
@@ -47,8 +54,9 @@ def test_embeddings_basic(model_id, model_type, tmp_path):
4754
"--hf",
4855
])
4956

57+
outpus_path = tmp_path / "optimum"
5058
# test Optimum
51-
run_wwb([
59+
outpus = run_wwb([
5260
"--target-model",
5361
MODEL_PATH,
5462
"--num-samples",
@@ -59,10 +67,24 @@ def test_embeddings_basic(model_id, model_type, tmp_path):
5967
"CPU",
6068
"--model-type",
6169
model_type,
70+
"--output",
71+
outpus_path,
6272
])
6373

74+
assert (outpus_path / "target").exists()
75+
assert (outpus_path / "target.csv").exists()
76+
assert (outpus_path / "metrics_per_question.csv").exists()
77+
assert (outpus_path / "metrics.csv").exists()
78+
assert "Metrics for model" in outpus
79+
80+
similarity = get_similarity(outpus)
81+
assert similarity >= SIMILARITY_THRESHOLD
82+
83+
remove_artifacts(outpus_path.as_posix())
84+
85+
outpus_path = tmp_path / "genai"
6486
# test GenAI
65-
run_wwb([
87+
outpus = run_wwb([
6688
"--target-model",
6789
MODEL_PATH,
6890
"--num-samples",
@@ -75,13 +97,22 @@ def test_embeddings_basic(model_id, model_type, tmp_path):
7597
model_type,
7698
"--genai",
7799
"--output",
78-
tmp_path,
100+
outpus_path,
79101
])
80102

103+
assert (outpus_path / "target").exists()
104+
assert (outpus_path / "target.csv").exists()
105+
assert (outpus_path / "metrics_per_question.csv").exists()
106+
assert (outpus_path / "metrics.csv").exists()
107+
assert "Metrics for model" in outpus
108+
109+
similarity = get_similarity(outpus)
110+
assert similarity >= SIMILARITY_THRESHOLD
111+
81112
# test w/o models
82113
run_wwb([
83114
"--target-data",
84-
tmp_path / "target.csv",
115+
outpus_path / "target.csv",
85116
"--num-samples",
86117
"1",
87118
"--gt-data",
@@ -92,3 +123,6 @@ def test_embeddings_basic(model_id, model_type, tmp_path):
92123
model_type,
93124
"--genai",
94125
])
126+
127+
remove_artifacts(outpus_path.as_posix())
128+
remove_artifacts(MODEL_PATH.as_posix(), "model")

tools/who_what_benchmark/tests/test_cli_reranking.py

Lines changed: 59 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import shutil
55
import logging
66
import tempfile
7-
from test_cli_image import run_wwb
7+
from test_cli_image import run_wwb, get_similarity
88
from pathlib import Path
99

1010

@@ -13,40 +13,37 @@
1313
tmp_dir = tempfile.mkdtemp()
1414

1515

16-
OV_RERANK_MODELS = {
17-
("cross-encoder/ms-marco-TinyBERT-L2-v2", "text-classification"),
18-
("Qwen/Qwen3-Reranker-0.6B", "text-generation"),
19-
}
16+
def download_model(model_id, task, tmp_path):
17+
MODEL_PATH = Path(tmp_path, model_id.replace("/", "_"))
18+
subprocess.run(["optimum-cli", "export", "openvino", "--model", model_id, MODEL_PATH, "--task", task, "--trust-remote-code"],
19+
capture_output=True,
20+
text=True)
21+
return MODEL_PATH
2022

2123

22-
def setup_module():
23-
for model_info in OV_RERANK_MODELS:
24-
model_id = model_info[0]
25-
task = model_info[1]
26-
MODEL_PATH = Path(tmp_dir, model_id.replace("/", "_"))
27-
subprocess.run(["optimum-cli", "export", "openvino", "--model", model_id, MODEL_PATH, "--task", task, "--trust-remote-code"],
28-
capture_output=True,
29-
text=True)
24+
def remove_artifacts(artifacts_path, file_type="outputs"):
25+
logger.info(f"Remove {file_type}")
26+
shutil.rmtree(artifacts_path)
3027

3128

32-
def teardown_module():
33-
logger.info("Remove models")
34-
shutil.rmtree(tmp_dir)
35-
36-
37-
@pytest.mark.parametrize(("model_info"), OV_RERANK_MODELS)
38-
def test_reranking_genai(model_info, tmp_path):
39-
if sys.platform == 'darwin':
40-
pytest.xfail("Ticket 175534")
41-
29+
@pytest.mark.wwb_rerank
30+
@pytest.mark.parametrize(
31+
("model_id", "model_task", "threshold"),
32+
[
33+
("cross-encoder/ms-marco-TinyBERT-L2-v2", "text-classification", 0.6),
34+
("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", "text-classification", 0.6),
35+
("Qwen/Qwen3-Reranker-0.6B", "text-generation", 0.6),
36+
],
37+
)
38+
@pytest.mark.xfail(sys.platform == 'darwin', reason="Hangs. Ticket 175534", run=False)
39+
def test_reranking_optimum(model_id, model_task, threshold, tmp_path):
4240
GT_FILE = Path(tmp_dir) / "gt.csv"
43-
model_id = model_info[0]
44-
MODEL_PATH = Path(tmp_dir) / model_id.replace("/", "_")
41+
MODEL_PATH = download_model(model_id, model_task, tmp_path)
4542

46-
# test GenAI
43+
# Collect reference with HF model
4744
run_wwb([
4845
"--base-model",
49-
MODEL_PATH,
46+
model_id,
5047
"--num-samples",
5148
"1",
5249
"--gt-data",
@@ -55,25 +52,17 @@ def test_reranking_genai(model_info, tmp_path):
5552
"CPU",
5653
"--model-type",
5754
"text-reranking",
58-
"--genai"
55+
"--hf",
5956
])
6057

58+
assert GT_FILE.exists()
6159
assert Path(tmp_dir, "reference").exists()
6260

63-
64-
@pytest.mark.parametrize(
65-
("model_info"), OV_RERANK_MODELS
66-
)
67-
@pytest.mark.xfail(sys.platform == 'darwin', reason="Hangs. Ticket 175534", run=False)
68-
def test_reranking_optimum(model_info, tmp_path):
69-
GT_FILE = Path(tmp_dir) / "gt.csv"
70-
model_id = model_info[0]
71-
MODEL_PATH = Path(tmp_dir, model_id.replace("/", "_"))
72-
73-
# Collect reference with HF model
74-
run_wwb([
75-
"--base-model",
76-
model_id,
61+
outpus_path = tmp_path / "optimum"
62+
# test Optimum
63+
outpus_optimum = run_wwb([
64+
"--target-model",
65+
MODEL_PATH,
7766
"--num-samples",
7867
"1",
7968
"--gt-data",
@@ -82,14 +71,24 @@ def test_reranking_optimum(model_info, tmp_path):
8271
"CPU",
8372
"--model-type",
8473
"text-reranking",
85-
"--hf",
74+
"--output",
75+
outpus_path,
8676
])
8777

88-
assert GT_FILE.exists()
89-
assert Path(tmp_dir, "reference").exists()
78+
assert (outpus_path / "target").exists()
79+
assert (outpus_path / "target.csv").exists()
80+
assert (outpus_path / "metrics_per_question.csv").exists()
81+
assert (outpus_path / "metrics.csv").exists()
82+
assert "Metrics for model" in outpus_optimum
9083

91-
# test Optimum
92-
outpus = run_wwb([
84+
similarity = get_similarity(outpus_optimum)
85+
assert similarity >= threshold
86+
87+
remove_artifacts(outpus_path.as_posix())
88+
89+
outpus_path = tmp_path / "genai"
90+
# test GenAI
91+
outpus_genai = run_wwb([
9392
"--target-model",
9493
MODEL_PATH,
9594
"--num-samples",
@@ -100,20 +99,23 @@ def test_reranking_optimum(model_info, tmp_path):
10099
"CPU",
101100
"--model-type",
102101
"text-reranking",
102+
"--genai",
103103
"--output",
104-
tmp_path,
104+
outpus_path,
105105
])
106+
assert (outpus_path / "target").exists()
107+
assert (outpus_path / "target.csv").exists()
108+
assert (outpus_path / "metrics_per_question.csv").exists()
109+
assert (outpus_path / "metrics.csv").exists()
110+
assert "Metrics for model" in outpus_genai
106111

107-
assert (tmp_path / "target").exists()
108-
assert (tmp_path / "target.csv").exists()
109-
assert (tmp_path / "metrics_per_question.csv").exists()
110-
assert (tmp_path / "metrics.csv").exists()
111-
assert "Metrics for model" in outpus
112+
similarity = get_similarity(outpus_genai)
113+
assert similarity >= threshold
112114

113115
# test w/o models
114116
run_wwb([
115117
"--target-data",
116-
tmp_path / "target.csv",
118+
outpus_path / "target.csv",
117119
"--num-samples",
118120
"1",
119121
"--gt-data",
@@ -124,3 +126,6 @@ def test_reranking_optimum(model_info, tmp_path):
124126
"text-reranking",
125127
"--genai"
126128
])
129+
130+
remove_artifacts(outpus_path.as_posix())
131+
remove_artifacts(MODEL_PATH.as_posix(), "model")

tools/who_what_benchmark/whowhatbench/whowhat_metrics.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from PIL import Image
88
import torch
99
import torch.nn.functional as F
10+
from sklearn.metrics.pairwise import cosine_similarity
1011

1112
import numpy as np
1213
from sentence_transformers import SentenceTransformer, util
@@ -189,9 +190,10 @@ def evaluate(self, data_gold, data_prediction):
189190
with open(prediction, 'rb') as f:
190191
prediction_data = np.load(f)
191192

192-
cos_sim = F.cosine_similarity(torch.from_numpy(gold_data), torch.from_numpy(prediction_data))
193-
metric_per_passages.append(cos_sim.detach().numpy())
194-
metric_per_gen.append(torch.mean(cos_sim).item())
193+
cos_sim_all = cosine_similarity(gold_data, prediction_data)
194+
cos_sim = np.diag(cos_sim_all)
195+
metric_per_passages.append(cos_sim)
196+
metric_per_gen.append(np.mean(cos_sim))
195197

196198
metric_dict = {"similarity": np.mean(metric_per_gen)}
197199
return metric_dict, {"similarity": metric_per_gen, "similarity_per_passages": metric_per_passages}
@@ -222,11 +224,11 @@ def evaluate(self, data_gold, data_prediction):
222224
scores_diff = self.MISSING_DOCUMENT_PENALTY
223225
if document_idx in prediction_scores:
224226
scores_diff = abs(gold_score - prediction_scores[document_idx])
225-
per_query_text.append(scores_diff)
227+
per_query_text.append(scores_diff.item())
226228

227229
metric_per_query.append(per_query_text)
228230
dist = np.linalg.norm(per_query_text)
229231
similarity_per_query.append(1 / (1 + dist))
230232

231233
metric_dict = {"similarity": np.mean(similarity_per_query)}
232-
return metric_dict, {"similarity": similarity_per_query, "per_text_score_list": metric_per_query}
234+
return metric_dict, {"similarity": similarity_per_query, "per_text_scores_diff": metric_per_query}

tools/who_what_benchmark/whowhatbench/wwb.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,7 @@ def print_embeds_results(evaluator):
658658
)
659659
logger.info(f"Top-{i+1} example:")
660660
logger.info("## Passages num:\n%s\n", len(e["passages"]))
661-
logger.info("## Similarity:\n%s\n", e["similarity"])
661+
logger.info(f"## Similarity:\n{e['similarity']:.5}\n")
662662

663663

664664
def read_cb_config(path):
@@ -687,8 +687,8 @@ def print_rag_results(evaluator):
687687
logger.info(f"Top-{i+1} example:")
688688
logger.info("## Query:\n%s\n", e["query"])
689689
logger.info("## Passages num:\n%s\n", len(e["passages"]))
690-
logger.info("## Similarity:\n%s\n", e["similarity"])
691-
logger.info("## Top_n scores:\n%s\n", e["per_text_score_list"])
690+
logger.info(f"## Similarity:\n{e['similarity']:.5}\n")
691+
logger.info("## Difference in scores pre texts:\n%s\n", e['per_text_scores_diff'])
692692

693693

694694
def main():

0 commit comments

Comments
 (0)