3535# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
3636# DATA.
3737#
38+ import dataclasses
3839import json
39- from dataclasses import dataclass , field
40- from datetime import datetime
4140from typing import List , Any , Optional
4241
4342import requests
43+ from pydantic import BaseModel , ConfigDict , Field , alias_generators
4444
4545from app .config import settings
4646from app .services .utils import raise_for_http_error , body_to_json
4747
4848
49- @dataclass
50- class SessionQueryConfiguration :
49+ class SessionQueryConfiguration (BaseModel ):
50+ model_config = ConfigDict (
51+ alias_generator = alias_generators .to_camel ,
52+ validate_by_name = True ,
53+ revalidate_instances = "always" ,
54+ )
55+
5156 enable_hyde : bool
5257 enable_summary_filter : bool
5358 enable_tool_calling : bool = False
54- selected_tools : list [str ] = field (default_factory = list )
59+ selected_tools : list [str ] = Field (default_factory = list )
5560 disable_streaming : bool = False
5661
5762
58- @dataclass
59- class Session :
63+ class Session (BaseModel ):
64+ model_config = ConfigDict (
65+ alias_generator = alias_generators .to_camel ,
66+ validate_by_name = True ,
67+ revalidate_instances = "always" ,
68+ )
69+
6070 id : int
6171 name : str
62- data_source_ids : List [int ]
72+ data_source_ids : list [int ]
6373 project_id : int
64- time_created : datetime
65- time_updated : datetime
66- created_by_id : str
67- updated_by_id : str
6874 inference_model : str
69- rerank_model : str
75+ rerank_model : Optional [ str ]
7076 response_chunks : int
7177 query_configuration : SessionQueryConfiguration
7278 associated_data_source_id : Optional [int ] = None
7379
74- def get_all_data_source_ids (self ) -> List [int ]:
80+ def get_all_data_source_ids (self ) -> list [int ]:
7581 """
7682 Returns all data source IDs associated with the session.
7783 If the session has an associated data source ID, it is included in the list.
@@ -81,14 +87,14 @@ def get_all_data_source_ids(self) -> List[int]:
8187 )
8288
8389
84- @dataclass
90+ @dataclasses . dataclass
8591class UpdatableSession :
8692 id : int
8793 name : str
8894 dataSourceIds : List [int ]
8995 projectId : int
9096 inferenceModel : str
91- rerankModel : str
97+ rerankModel : Optional [ str ]
9298 responseChunks : int
9399 queryConfiguration : dict [str , bool | List [str ]]
94100 associatedDataSourceId : Optional [int ]
@@ -114,10 +120,6 @@ def session_from_java_response(data: dict[str, Any]) -> Session:
114120 name = data ["name" ],
115121 data_source_ids = data ["dataSourceIds" ],
116122 project_id = data ["projectId" ],
117- time_created = datetime .fromtimestamp (data ["timeCreated" ]),
118- time_updated = datetime .fromtimestamp (data ["timeUpdated" ]),
119- created_by_id = data ["createdById" ],
120- updated_by_id = data ["updatedById" ],
121123 inference_model = data ["inferenceModel" ],
122124 rerank_model = data ["rerankModel" ],
123125 response_chunks = data ["responseChunks" ],
@@ -127,9 +129,7 @@ def session_from_java_response(data: dict[str, Any]) -> Session:
127129 enable_tool_calling = data ["queryConfiguration" ].get (
128130 "enableToolCalling" , False
129131 ),
130- disable_streaming = data ["queryConfiguration" ].get (
131- "disableStreaming" , False
132- ),
132+ disable_streaming = data ["queryConfiguration" ].get ("disableStreaming" , False ),
133133 selected_tools = data ["queryConfiguration" ]["selectedTools" ] or [],
134134 ),
135135 associated_data_source_id = data .get ("associatedDataSourceId" , None ),
0 commit comments