Skip to content

Commit 53085dc

Browse files
Change session_metadata_api Session from std lib dataclass to Pydantic (#313)
* Convert Session into Pydantic * Remove creation date and username * Enable field validation * Swap lines to make a bit more sequential sense
1 parent 613e807 commit 53085dc

File tree

1 file changed

+23
-23
lines changed

1 file changed

+23
-23
lines changed

llm-service/app/services/metadata_apis/session_metadata_api.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -35,43 +35,49 @@
3535
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
3636
# DATA.
3737
#
38+
import dataclasses
3839
import json
39-
from dataclasses import dataclass, field
40-
from datetime import datetime
4140
from typing import List, Any, Optional
4241

4342
import requests
43+
from pydantic import BaseModel, ConfigDict, Field, alias_generators
4444

4545
from app.config import settings
4646
from app.services.utils import raise_for_http_error, body_to_json
4747

4848

49-
@dataclass
50-
class SessionQueryConfiguration:
49+
class SessionQueryConfiguration(BaseModel):
50+
model_config = ConfigDict(
51+
alias_generator=alias_generators.to_camel,
52+
validate_by_name=True,
53+
revalidate_instances="always",
54+
)
55+
5156
enable_hyde: bool
5257
enable_summary_filter: bool
5358
enable_tool_calling: bool = False
54-
selected_tools: list[str] = field(default_factory=list)
59+
selected_tools: list[str] = Field(default_factory=list)
5560
disable_streaming: bool = False
5661

5762

58-
@dataclass
59-
class Session:
63+
class Session(BaseModel):
64+
model_config = ConfigDict(
65+
alias_generator=alias_generators.to_camel,
66+
validate_by_name=True,
67+
revalidate_instances="always",
68+
)
69+
6070
id: int
6171
name: str
62-
data_source_ids: List[int]
72+
data_source_ids: list[int]
6373
project_id: int
64-
time_created: datetime
65-
time_updated: datetime
66-
created_by_id: str
67-
updated_by_id: str
6874
inference_model: str
69-
rerank_model: str
75+
rerank_model: Optional[str]
7076
response_chunks: int
7177
query_configuration: SessionQueryConfiguration
7278
associated_data_source_id: Optional[int] = None
7379

74-
def get_all_data_source_ids(self) -> List[int]:
80+
def get_all_data_source_ids(self) -> list[int]:
7581
"""
7682
Returns all data source IDs associated with the session.
7783
If the session has an associated data source ID, it is included in the list.
@@ -81,14 +87,14 @@ def get_all_data_source_ids(self) -> List[int]:
8187
)
8288

8389

84-
@dataclass
90+
@dataclasses.dataclass
8591
class UpdatableSession:
8692
id: int
8793
name: str
8894
dataSourceIds: List[int]
8995
projectId: int
9096
inferenceModel: str
91-
rerankModel: str
97+
rerankModel: Optional[str]
9298
responseChunks: int
9399
queryConfiguration: dict[str, bool | List[str]]
94100
associatedDataSourceId: Optional[int]
@@ -114,10 +120,6 @@ def session_from_java_response(data: dict[str, Any]) -> Session:
114120
name=data["name"],
115121
data_source_ids=data["dataSourceIds"],
116122
project_id=data["projectId"],
117-
time_created=datetime.fromtimestamp(data["timeCreated"]),
118-
time_updated=datetime.fromtimestamp(data["timeUpdated"]),
119-
created_by_id=data["createdById"],
120-
updated_by_id=data["updatedById"],
121123
inference_model=data["inferenceModel"],
122124
rerank_model=data["rerankModel"],
123125
response_chunks=data["responseChunks"],
@@ -127,9 +129,7 @@ def session_from_java_response(data: dict[str, Any]) -> Session:
127129
enable_tool_calling=data["queryConfiguration"].get(
128130
"enableToolCalling", False
129131
),
130-
disable_streaming=data["queryConfiguration"].get(
131-
"disableStreaming", False
132-
),
132+
disable_streaming=data["queryConfiguration"].get("disableStreaming", False),
133133
selected_tools=data["queryConfiguration"]["selectedTools"] or [],
134134
),
135135
associated_data_source_id=data.get("associatedDataSourceId", None),

0 commit comments

Comments
 (0)