Skip to content

Commit 73f9908

Browse files
authored
feat(python): support index for nested field (#5027)
Fixes #5026 Also fixes failures in vector index on nested fields failures at rust level. Previous fix did not really work all the way.
1 parent add6fa6 commit 73f9908

File tree

10 files changed

+739
-41
lines changed

10 files changed

+739
-41
lines changed

python/python/lance/dataset.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2341,7 +2341,8 @@ def create_scalar_index(
23412341
)
23422342

23432343
column = column[0]
2344-
if column not in self.schema.names:
2344+
lance_field = self._ds.lance_schema.field(column)
2345+
if lance_field is None:
23452346
raise KeyError(f"{column} not found in schema")
23462347

23472348
# TODO: Add documentation of IndexConfig approach for creating
@@ -2365,7 +2366,7 @@ def create_scalar_index(
23652366
)
23662367
)
23672368

2368-
field = self.schema.field(column)
2369+
field = lance_field.to_arrow()
23692370

23702371
field_type = field.type
23712372
if hasattr(field_type, "storage_type"):
@@ -2618,9 +2619,10 @@ def create_index(
26182619

26192620
# validate args
26202621
for c in column:
2621-
if c not in self.schema.names:
2622+
lance_field = self._ds.lance_schema.field(c)
2623+
if lance_field is None:
26222624
raise KeyError(f"{c} not found in schema")
2623-
field = self.schema.field(c)
2625+
field = lance_field.to_arrow()
26242626
is_multivec = False
26252627
if pa.types.is_fixed_size_list(field.type):
26262628
dimension = field.type.list_size
@@ -4347,10 +4349,11 @@ def nearest(
43474349
) -> ScannerBuilder:
43484350
q, q_dim = _coerce_query_vector(q)
43494351

4350-
if self.ds.schema.get_field_index(column) < 0:
4352+
lance_field = self.ds._ds.lance_schema.field(column)
4353+
if lance_field is None:
43514354
raise ValueError(f"Embedding column {column} is not in the dataset")
43524355

4353-
column_field = self.ds.schema.field(column)
4356+
column_field = lance_field.to_arrow()
43544357
column_type = column_field.type
43554358
if hasattr(column_type, "storage_type"):
43564359
column_type = column_type.storage_type

python/python/tests/test_scalar_index.py

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3690,3 +3690,258 @@ def scan_stats_callback(stats: lance.ScanStatistics):
36903690
for key, value in scan_stats.all_counts.items():
36913691
assert isinstance(key, str)
36923692
assert isinstance(value, int)
3693+
3694+
3695+
def test_nested_field_btree_index(tmp_path):
3696+
"""Test BTREE index creation and querying on nested fields"""
3697+
# Create a dataset with nested structure
3698+
schema = pa.schema(
3699+
[
3700+
pa.field("id", pa.int64()),
3701+
pa.field(
3702+
"meta",
3703+
pa.struct(
3704+
[pa.field("lang", pa.string()), pa.field("version", pa.int32())]
3705+
),
3706+
),
3707+
]
3708+
)
3709+
3710+
data = pa.table(
3711+
{
3712+
"id": [1, 2, 3, 4, 5],
3713+
"meta": [
3714+
{"lang": "en", "version": 1},
3715+
{"lang": "fr", "version": 2},
3716+
{"lang": "en", "version": 1},
3717+
{"lang": "es", "version": 3},
3718+
{"lang": "fr", "version": 2},
3719+
],
3720+
},
3721+
schema=schema,
3722+
)
3723+
3724+
# Create dataset
3725+
uri = tmp_path / "test_nested_btree"
3726+
dataset = lance.write_dataset(data, uri)
3727+
3728+
# Create BTREE index on nested string column
3729+
dataset.create_scalar_index(column="meta.lang", index_type="BTREE")
3730+
3731+
# Verify index was created
3732+
indices = dataset.list_indices()
3733+
assert len(indices) == 1
3734+
assert indices[0]["fields"] == ["meta.lang"]
3735+
assert indices[0]["type"] == "BTree"
3736+
3737+
# Test query using the index - filter for English language
3738+
result = dataset.scanner(filter="meta.lang = 'en'").to_table()
3739+
assert len(result) == 2
3740+
for i in range(len(result)):
3741+
assert result["meta"][i]["lang"].as_py() == "en"
3742+
3743+
# Test query for French language
3744+
result = dataset.scanner(filter="meta.lang = 'fr'").to_table()
3745+
assert len(result) == 2
3746+
for i in range(len(result)):
3747+
assert result["meta"][i]["lang"].as_py() == "fr"
3748+
3749+
# Verify the index is being used
3750+
plan = dataset.scanner(filter="meta.lang = 'en'").explain_plan()
3751+
assert "ScalarIndexQuery" in plan
3752+
3753+
# Write additional data to the dataset
3754+
new_data = pa.table(
3755+
{
3756+
"id": [6, 7, 8],
3757+
"meta": [
3758+
{"lang": "de", "version": 4},
3759+
{"lang": "en", "version": 2},
3760+
{"lang": "de", "version": 4},
3761+
],
3762+
},
3763+
schema=schema,
3764+
)
3765+
3766+
dataset = lance.write_dataset(new_data, uri, mode="append")
3767+
3768+
# Verify query still works after appending data
3769+
result = dataset.scanner(filter="meta.lang = 'en'").to_table()
3770+
assert len(result) == 3, f"Expected 3 English records, got {len(result)}"
3771+
for i in range(len(result)):
3772+
assert result["meta"][i]["lang"].as_py() == "en"
3773+
3774+
# Test query for new German language entries
3775+
result = dataset.scanner(filter="meta.lang = 'de'").to_table()
3776+
assert len(result) == 2
3777+
for i in range(len(result)):
3778+
assert result["meta"][i]["lang"].as_py() == "de"
3779+
3780+
# Test optimize_indices with nested field BTREE index
3781+
dataset.optimize.optimize_indices()
3782+
3783+
# Verify query still works after optimization
3784+
result = dataset.scanner(filter="meta.lang = 'en'").to_table()
3785+
assert len(result) == 3
3786+
result = dataset.scanner(filter="meta.lang = 'de'").to_table()
3787+
assert len(result) == 2
3788+
3789+
# Create BTREE index on nested integer column
3790+
dataset.create_scalar_index(column="meta.version", index_type="BTREE", replace=True)
3791+
3792+
# Test query using the version index
3793+
result = dataset.scanner(filter="meta.version = 1").to_table()
3794+
assert len(result) == 2
3795+
for i in range(len(result)):
3796+
assert result["meta"][i]["version"].as_py() == 1
3797+
3798+
# Test query for version 4 (new data)
3799+
result = dataset.scanner(filter="meta.version = 4").to_table()
3800+
assert len(result) == 2
3801+
for i in range(len(result)):
3802+
assert result["meta"][i]["version"].as_py() == 4
3803+
3804+
# Verify total row count
3805+
total = dataset.count_rows()
3806+
assert total == 8, f"Expected 8 total rows, got {total}"
3807+
3808+
3809+
def test_nested_field_fts_index(tmp_path):
3810+
"""Test FTS index creation and querying on nested fields"""
3811+
# Create dataset with nested text field
3812+
data = pa.table(
3813+
{
3814+
"id": range(100),
3815+
"data": pa.StructArray.from_arrays(
3816+
[
3817+
pa.array(
3818+
[f"document {i} about lance database" for i in range(100)]
3819+
),
3820+
pa.array([f"label_{i}" for i in range(100)]),
3821+
],
3822+
names=["text", "label"],
3823+
),
3824+
}
3825+
)
3826+
3827+
ds = lance.write_dataset(data, tmp_path)
3828+
3829+
# Create FTS index on nested field
3830+
ds.create_scalar_index("data.text", index_type="INVERTED", with_position=False)
3831+
3832+
# Verify index was created
3833+
indices = ds.list_indices()
3834+
assert len(indices) == 1
3835+
assert indices[0]["fields"] == ["data.text"]
3836+
assert indices[0]["type"] == "Inverted"
3837+
3838+
# Test full text search on nested field
3839+
results = ds.to_table(full_text_query="lance")
3840+
assert results.num_rows == 100
3841+
3842+
# Verify the results contain the expected text
3843+
for i in range(results.num_rows):
3844+
text = results["data"][i]["text"].as_py()
3845+
assert "lance" in text
3846+
3847+
# Test with prefilter using another nested field
3848+
results = ds.to_table(
3849+
full_text_query="database",
3850+
filter="data.label = 'label_5'",
3851+
prefilter=True,
3852+
)
3853+
assert results.num_rows == 1
3854+
assert results["id"][0].as_py() == 5
3855+
3856+
# Test optimize_indices with nested field FTS index
3857+
# Append more data
3858+
new_data = pa.table(
3859+
{
3860+
"id": range(100, 150),
3861+
"data": pa.StructArray.from_arrays(
3862+
[
3863+
pa.array(
3864+
[f"document {i} about lance search" for i in range(100, 150)]
3865+
),
3866+
pa.array([f"label_{i}" for i in range(100, 150)]),
3867+
],
3868+
names=["text", "label"],
3869+
),
3870+
}
3871+
)
3872+
ds = lance.write_dataset(new_data, tmp_path, mode="append")
3873+
3874+
# Optimize indices
3875+
ds.optimize.optimize_indices()
3876+
3877+
# Verify search still works after optimization
3878+
results = ds.to_table(full_text_query="lance")
3879+
assert results.num_rows == 150
3880+
3881+
results = ds.to_table(full_text_query="search")
3882+
assert results.num_rows == 50
3883+
3884+
3885+
def test_nested_field_bitmap_index(tmp_path):
3886+
"""Test BITMAP index creation and querying on nested fields"""
3887+
# Create dataset with nested categorical field
3888+
data = pa.table(
3889+
{
3890+
"id": range(100),
3891+
"attributes": pa.StructArray.from_arrays(
3892+
[
3893+
pa.array(["red", "green", "blue"][i % 3] for i in range(100)),
3894+
pa.array([f"size_{i % 5}" for i in range(100)]),
3895+
],
3896+
names=["color", "size"],
3897+
),
3898+
}
3899+
)
3900+
3901+
ds = lance.write_dataset(data, tmp_path)
3902+
3903+
# Create BITMAP index on nested field
3904+
ds.create_scalar_index("attributes.color", index_type="BITMAP")
3905+
3906+
# Verify index was created
3907+
indices = ds.list_indices()
3908+
assert len(indices) == 1
3909+
assert indices[0]["fields"] == ["attributes.color"]
3910+
assert indices[0]["type"] == "Bitmap"
3911+
3912+
# Test equality query
3913+
results = ds.to_table(filter="attributes.color = 'red'", prefilter=True)
3914+
assert results.num_rows == 34 # 0, 3, 6, 9, ... 99 (34 values)
3915+
3916+
# Verify the index is being used
3917+
plan = ds.scanner(filter="attributes.color = 'red'", prefilter=True).explain_plan()
3918+
assert "ScalarIndexQuery" in plan
3919+
3920+
# Test with different color
3921+
results = ds.to_table(filter="attributes.color = 'green'", prefilter=True)
3922+
assert results.num_rows == 33 # 1, 4, 7, 10, ... 97 (33 values)
3923+
3924+
results = ds.to_table(filter="attributes.color = 'blue'", prefilter=True)
3925+
assert results.num_rows == 33 # 2, 5, 8, 11, ... 98 (33 values)
3926+
3927+
# Test optimize_indices with nested field BITMAP index
3928+
new_data = pa.table(
3929+
{
3930+
"id": range(100, 150),
3931+
"attributes": pa.StructArray.from_arrays(
3932+
[
3933+
pa.array(["red", "green", "blue"][i % 3] for i in range(50)),
3934+
pa.array([f"size_{i % 5}" for i in range(50)]),
3935+
],
3936+
names=["color", "size"],
3937+
),
3938+
}
3939+
)
3940+
ds = lance.write_dataset(new_data, tmp_path, mode="append")
3941+
3942+
# Optimize indices
3943+
ds.optimize.optimize_indices()
3944+
3945+
# Verify query still works after optimization
3946+
results = ds.to_table(filter="attributes.color = 'red'", prefilter=True)
3947+
assert results.num_rows == 51 # 34 + 17 from new data

0 commit comments

Comments
 (0)