Skip to content

Commit e23b925

Browse files
committed
Added ability to get project users
tod
1 parent a08e04f commit e23b925

File tree

12 files changed

+311
-64
lines changed

12 files changed

+311
-64
lines changed

src/superannotate/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import sys
44

55

6-
__version__ = "4.4.32"
6+
__version__ = "4.4.33dev1"
77

88
os.environ.update({"sa_version": __version__})
99
sys.path.append(os.path.split(os.path.realpath(__file__))[0])

src/superannotate/lib/app/interface/sdk_interface.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -372,9 +372,18 @@ def set_user_custom_field(
372372
parent_entity=CustomFieldEntityEnum.TEAM,
373373
)
374374

375-
def list_users(self, *, include: List[Literal["custom_fields"]] = None, **filters):
375+
def list_users(
376+
self,
377+
*,
378+
project: Union[int, str] = None,
379+
include: List[Literal["custom_fields"]] = None,
380+
**filters,
381+
):
376382
"""
377383
Search users by filtering criteria
384+
:param project: Project name or ID, if provided, results will be for project-level,
385+
otherwise results will be for team level.
386+
:type project: str or int
378387
379388
:param include: Specifies additional fields to be included in the response.
380389
@@ -454,9 +463,22 @@ def list_users(self, *, include: List[Literal["custom_fields"]] = None, **filter
454463
}
455464
]
456465
"""
457-
return BaseSerializer.serialize_iterable(
458-
self.controller.work_management.list_users(include=include, **filters)
466+
if project is not None:
467+
if isinstance(project, int):
468+
project = self.controller.get_project_by_id(project)
469+
else:
470+
project = self.controller.get_project(project)
471+
response = BaseSerializer.serialize_iterable(
472+
self.controller.work_management.list_users(
473+
project=project, include=include, **filters
474+
)
459475
)
476+
if project:
477+
for user in response:
478+
user["role"] = self.controller.service_provider.get_role_name(
479+
project, user["role"]
480+
)
481+
return response
460482

461483
def pause_user_activity(
462484
self, pk: Union[int, str], projects: Union[List[int], List[str], Literal["*"]]

src/superannotate/lib/core/entities/work_managament.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,33 @@ def json(self, **kwargs):
119119
if "exclude" not in kwargs:
120120
kwargs["exclude"] = {"custom_fields"}
121121
return super().json(**kwargs)
122+
123+
124+
class WMProjectUserEntity(TimedBaseModel):
125+
id: Optional[int]
126+
team_id: Optional[int]
127+
role: int
128+
email: Optional[str]
129+
state: Optional[WMUserStateEnum]
130+
custom_fields: Optional[dict] = Field(dict(), alias="customField")
131+
132+
class Config:
133+
extra = Extra.ignore
134+
use_enum_names = True
135+
136+
json_encoders = {
137+
Enum: lambda v: v.value,
138+
datetime.date: lambda v: v.isoformat(),
139+
datetime.datetime: lambda v: v.isoformat(),
140+
}
141+
142+
@validator("custom_fields")
143+
def custom_fields_transformer(cls, v):
144+
if v and "custom_field_values" in v:
145+
return v.get("custom_field_values", {})
146+
return {}
147+
148+
def json(self, **kwargs):
149+
if "exclude" not in kwargs:
150+
kwargs["exclude"] = {"custom_fields"}
151+
return super().json(**kwargs)

src/superannotate/lib/core/serviceproviders.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,12 @@ def create_project_categories(
147147

148148
@abstractmethod
149149
def list_users(
150-
self, body_query: Query, chunk_size=100, include_custom_fields=False
150+
self,
151+
body_query: Query,
152+
parent_entity: str = "Team",
153+
chunk_size=100,
154+
project_id: int = None,
155+
include_custom_fields=False,
151156
) -> WMUserListResponse:
152157
raise NotImplementedError
153158

@@ -804,23 +809,40 @@ def invite_contributors(
804809
raise NotImplementedError
805810

806811
@abstractmethod
807-
def list_custom_field_names(self, entity: CustomFieldEntityEnum) -> List[str]:
812+
def list_custom_field_names(
813+
self, pk, entity: CustomFieldEntityEnum, parent: CustomFieldEntityEnum
814+
) -> List[str]:
808815
raise NotImplementedError
809816

810817
@abstractmethod
811818
def get_custom_field_id(
812-
self, field_name: str, entity: CustomFieldEntityEnum
819+
self,
820+
field_name: str,
821+
entity: CustomFieldEntityEnum,
822+
parent: CustomFieldEntityEnum,
813823
) -> int:
814824
raise NotImplementedError
815825

816826
@abstractmethod
817827
def get_custom_field_name(
818-
self, field_id: int, entity: CustomFieldEntityEnum
828+
self,
829+
field_id: int,
830+
entity: CustomFieldEntityEnum,
831+
parent: CustomFieldEntityEnum,
819832
) -> str:
820833
raise NotImplementedError
821834

822835
@abstractmethod
823836
def get_custom_field_component_id(
824-
self, field_id: int, entity: CustomFieldEntityEnum
837+
self,
838+
field_id: int,
839+
entity: CustomFieldEntityEnum,
840+
parent: CustomFieldEntityEnum,
825841
) -> str:
826842
raise NotImplementedError
843+
844+
@abstractmethod
845+
def get_custom_fields_templates(
846+
self, entity: CustomFieldEntityEnum, parent: CustomFieldEntityEnum
847+
):
848+
raise NotImplementedError

src/superannotate/lib/core/usecases/projects.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,9 @@ def execute(self):
155155
project.users = []
156156
if self._include_custom_fields:
157157
custom_fields_names = self._service_provider.list_custom_field_names(
158-
entity=CustomFieldEntityEnum.PROJECT
158+
self._project.team_id,
159+
entity=CustomFieldEntityEnum.PROJECT,
160+
parent=CustomFieldEntityEnum.TEAM,
159161
)
160162
if custom_fields_names:
161163
project_custom_fields = (
@@ -171,7 +173,9 @@ def execute(self):
171173
custom_fields_name_value_map = {}
172174
for name in custom_fields_names:
173175
field_id = self._service_provider.get_custom_field_id(
174-
name, entity=CustomFieldEntityEnum.PROJECT
176+
name,
177+
entity=CustomFieldEntityEnum.PROJECT,
178+
parent=CustomFieldEntityEnum.TEAM,
175179
)
176180
field_value = (
177181
custom_fields_id_value_map[str(field_id)]
@@ -180,7 +184,9 @@ def execute(self):
180184
)
181185
# timestamp: convert milliseconds to seconds
182186
component_id = self._service_provider.get_custom_field_component_id(
183-
field_id, entity=CustomFieldEntityEnum.PROJECT
187+
field_id,
188+
entity=CustomFieldEntityEnum.PROJECT,
189+
parent=CustomFieldEntityEnum.TEAM,
184190
)
185191
if (
186192
field_value

src/superannotate/lib/infrastructure/annotation_adapter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ def get_component_value(self, component_id: str):
4444
return None
4545

4646
def set_component_value(self, component_id: str, value: Any):
47-
self.annotation.setdefault("data", {}).setdefault(component_id, {})["value"] = value
47+
self.annotation.setdefault("data", {}).setdefault(component_id, {})[
48+
"value"
49+
] = value
4850
return self
4951

5052

src/superannotate/lib/infrastructure/controller.py

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,22 @@ def build_condition(**kwargs) -> Condition:
7373

7474

7575
def serialize_custom_fields(
76-
service_provider: ServiceProvider, data: List[dict], entity: CustomFieldEntityEnum
76+
team_id: int,
77+
project_id: int,
78+
service_provider: ServiceProvider,
79+
data: List[dict],
80+
entity: CustomFieldEntityEnum,
81+
parent_entity: CustomFieldEntityEnum,
7782
) -> List[dict]:
78-
existing_custom_fields = service_provider.list_custom_field_names(entity)
83+
pk = (
84+
project_id
85+
if entity == CustomFieldEntityEnum.PROJECT
86+
else (team_id if parent_entity == CustomFieldEntityEnum.TEAM else project_id)
87+
)
88+
89+
existing_custom_fields = service_provider.list_custom_field_names(
90+
pk, entity, parent=parent_entity
91+
)
7992
for i in range(len(data)):
8093
if not data[i]:
8194
data[i] = {}
@@ -85,7 +98,7 @@ def serialize_custom_fields(
8598
field_id = int(custom_field_name)
8699
try:
87100
component_id = service_provider.get_custom_field_component_id(
88-
field_id, entity=entity
101+
field_id, entity=entity, parent=parent_entity
89102
)
90103
except AppException:
91104
# The component template can be deleted, but not from the entity, so it will be skipped.
@@ -95,7 +108,7 @@ def serialize_custom_fields(
95108
field_value /= 1000 # Convert timestamp
96109

97110
new_field_name = service_provider.get_custom_field_name(
98-
field_id, entity=entity
111+
field_id, entity=entity, parent=parent_entity
99112
)
100113
updated_fields[new_field_name] = field_value
101114

@@ -139,10 +152,10 @@ def set_custom_field_value(
139152
if entity == CustomFieldEntityEnum.PROJECT:
140153
_context["project_id"] = entity_id
141154
template_id = self.service_provider.get_custom_field_id(
142-
field_name, entity=entity
155+
field_name, entity=entity, parent=parent_entity
143156
)
144157
component_id = self.service_provider.get_custom_field_component_id(
145-
template_id, entity=entity
158+
template_id, entity=entity, parent=parent_entity
146159
)
147160
# timestamp: convert seconds to milliseconds
148161
if component_id == CustomFieldType.DATE_PICKER.value and value is not None:
@@ -159,40 +172,59 @@ def set_custom_field_value(
159172
context=_context,
160173
)
161174

162-
def list_users(self, include: List[Literal["custom_fields"]] = None, **filters):
175+
def list_users(
176+
self, include: List[Literal["custom_fields"]] = None, project=None, **filters
177+
):
178+
if project:
179+
parent_entity = CustomFieldEntityEnum.PROJECT
180+
project_id = project.id
181+
else:
182+
parent_entity = CustomFieldEntityEnum.TEAM
183+
project_id = None
163184
valid_fields = generate_schema(
164185
UserFilters.__annotations__,
165186
self.service_provider.get_custom_fields_templates(
166-
CustomFieldEntityEnum.CONTRIBUTOR
187+
CustomFieldEntityEnum.CONTRIBUTOR, parent=parent_entity
167188
),
168189
)
169190
chain = QueryBuilderChain(
170191
[
171192
FieldValidationHandler(valid_fields.keys()),
172193
UserFilterHandler(
194+
team_id=self.service_provider.client.team_id,
195+
project_id=project_id,
173196
service_provider=self.service_provider,
174197
entity=CustomFieldEntityEnum.CONTRIBUTOR,
198+
parent=parent_entity,
175199
),
176200
]
177201
)
178202
query = chain.handle(filters, EmptyQuery())
179203
if include and "custom_fields" in include:
180204
response = self.service_provider.work_management.list_users(
181-
query, include_custom_fields=True
205+
query,
206+
include_custom_fields=True,
207+
parent_entity=parent_entity,
208+
project_id=project_id,
182209
)
183210
if not response.ok:
184211
raise AppException(response.error)
185212
users = response.data
186213
custom_fields_list = [user.custom_fields for user in users]
187214
serialized_fields = serialize_custom_fields(
215+
self.service_provider.client.team_id,
216+
project_id,
188217
self.service_provider,
189218
custom_fields_list,
190-
CustomFieldEntityEnum.CONTRIBUTOR,
219+
entity=CustomFieldEntityEnum.CONTRIBUTOR,
220+
parent_entity=parent_entity,
191221
)
192222
for users, serialized_custom_fields in zip(users, serialized_fields):
193223
users.custom_fields = serialized_custom_fields
194224
return response.data
195-
return self.service_provider.work_management.list_users(query).data
225+
return self.service_provider.work_management.list_users(
226+
query, parent_entity=parent_entity, project_id=project_id
227+
).data
196228

197229
def update_user_activity(
198230
self,
@@ -406,14 +438,18 @@ def list_projects(
406438
valid_fields = generate_schema(
407439
ProjectFilters.__annotations__,
408440
self.service_provider.get_custom_fields_templates(
409-
CustomFieldEntityEnum.PROJECT
441+
CustomFieldEntityEnum.PROJECT, parent=CustomFieldEntityEnum.TEAM
410442
),
411443
)
412444
chain = QueryBuilderChain(
413445
[
414446
FieldValidationHandler(valid_fields.keys()),
415447
ProjectFilterHandler(
416-
self.service_provider, entity=CustomFieldEntityEnum.PROJECT
448+
team_id=self.service_provider.client.team_id,
449+
project_id=None,
450+
service_provider=self.service_provider,
451+
entity=CustomFieldEntityEnum.PROJECT,
452+
parent=CustomFieldEntityEnum.TEAM,
417453
),
418454
]
419455
)
@@ -435,7 +471,11 @@ def list_projects(
435471
if include_custom_fields:
436472
custom_fields_list = [project.custom_fields for project in projects]
437473
serialized_fields = serialize_custom_fields(
438-
self.service_provider, custom_fields_list, CustomFieldEntityEnum.PROJECT
474+
self.service_provider.client.team_id,
475+
None,
476+
self.service_provider,
477+
custom_fields_list,
478+
CustomFieldEntityEnum.PROJECT,
439479
)
440480
for project, serialized_custom_fields in zip(projects, serialized_fields):
441481
project.custom_fields = serialized_custom_fields

src/superannotate/lib/infrastructure/query_builder.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,23 +152,31 @@ def _handle_special_fields(self, keys: List[str], val):
152152

153153
class BaseCustomFieldHandler(AbstractQueryHandler):
154154
def __init__(
155-
self, service_provider: BaseServiceProvider, entity: CustomFieldEntityEnum
155+
self,
156+
team_id: int,
157+
project_id: Optional[int],
158+
service_provider: BaseServiceProvider,
159+
entity: CustomFieldEntityEnum,
160+
parent: CustomFieldEntityEnum,
156161
):
157162
self._service_provider = service_provider
158163
self._entity = entity
164+
self._parent = parent
159165

160166
def _handle_custom_field_key(self, key) -> Tuple[str, str, Optional[str]]:
161167
for custom_field in sorted(
162-
self._service_provider.list_custom_field_names(entity=self._entity),
168+
self._service_provider.list_custom_field_names(
169+
entity=self._entity, parent=self._parent
170+
),
163171
key=len,
164172
reverse=True,
165173
):
166174
if custom_field in key:
167175
custom_field_id = self._service_provider.get_custom_field_id(
168-
custom_field, entity=self._entity
176+
custom_field, entity=self._entity, parent=self._parent
169177
)
170178
component_id = self._service_provider.get_custom_field_component_id(
171-
custom_field_id, entity=self._entity
179+
custom_field_id, entity=self._entity, parent=self._parent
172180
)
173181
key = key.replace(
174182
custom_field,
@@ -209,7 +217,7 @@ def _determine_condition_and_key(keys: List[str]) -> Tuple[OperatorEnum, str]:
209217
def _handle_special_fields(self, keys: List[str], val):
210218
if keys[0] == "custom_field":
211219
component_id = self._service_provider.get_custom_field_component_id(
212-
field_id=int(keys[1]), entity=self._entity
220+
field_id=int(keys[1]), entity=self._entity, parent=self._parent
213221
)
214222
if component_id == CustomFieldType.DATE_PICKER.value and val is not None:
215223
try:

0 commit comments

Comments
 (0)