Skip to content

Commit 2e80804

Browse files
authored
Merge pull request #798 from superannotateai/FRIDAY-3975
added project category related functions
2 parents 40ba873 + 3f8cabb commit 2e80804

File tree

9 files changed

+372
-9
lines changed

9 files changed

+372
-9
lines changed

docs/source/api_reference/api_project.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,6 @@ Projects
2727
.. automethod:: superannotate.SAClient.get_project_steps
2828
.. automethod:: superannotate.SAClient.set_project_steps
2929
.. automethod:: superannotate.SAClient.get_component_config
30+
.. automethod:: superannotate.SAClient.create_categories
31+
.. automethod:: superannotate.SAClient.list_categories
32+
.. automethod:: superannotate.SAClient.remove_categories

src/superannotate/lib/app/interface/sdk_interface.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@
7373
from lib.app.serializers import WMProjectSerializer
7474
from lib.core.entities.work_managament import WMUserTypeEnum
7575
from lib.core.jsx_conditions import EmptyQuery
76+
from lib.core.entities.items import ProjectCategoryEntity
77+
7678

7779
logger = logging.getLogger("sa")
7880

@@ -1194,6 +1196,154 @@ def clone_project(
11941196
)
11951197
return data
11961198

1199+
def create_categories(
1200+
self, project: Union[NotEmptyStr, int], categories: List[str]
1201+
):
1202+
"""
1203+
Create one or more categories in a project.
1204+
1205+
:param project: The name or ID of the project.
1206+
:type project: Union[NotEmptyStr, int]
1207+
1208+
:param categories: A list of categories to create
1209+
:type categories: list of str
1210+
1211+
Request Example:
1212+
::
1213+
1214+
client.create_categories(
1215+
project="product-review-mm",
1216+
categories=["Shoes", "T-Shirt"]
1217+
)
1218+
"""
1219+
project = (
1220+
self.controller.get_project_by_id(project).data
1221+
if isinstance(project, int)
1222+
else self.controller.get_project(project)
1223+
)
1224+
self.controller.check_multimodal_project_categorization(project)
1225+
1226+
response = (
1227+
self.controller.service_provider.work_management.create_project_categories(
1228+
project_id=project.id, categories=categories
1229+
)
1230+
)
1231+
logger.info(
1232+
f"{len(response.data)} categories successfully added to the project."
1233+
)
1234+
1235+
def list_categories(self, project: Union[NotEmptyStr, int]):
1236+
"""
1237+
List all categories in the project.
1238+
1239+
:param project: The name or ID of the project.
1240+
:type project: Union[NotEmptyStr, int]
1241+
1242+
:return: List of categories
1243+
:rtype: list of dict
1244+
1245+
Request Example:
1246+
::
1247+
1248+
client.list_categories(
1249+
project="product-review-mm"
1250+
)
1251+
1252+
Response Example:
1253+
::
1254+
1255+
[
1256+
{
1257+
"createdAt": "2025-01-29T13:51:39.000Z",
1258+
"updatedAt": "2025-01-29T13:51:39.000Z",
1259+
"id": 328577,
1260+
"name": "category1",
1261+
"project_id": 1234
1262+
},
1263+
{
1264+
"createdAt": "2025-01-29T13:51:39.000Z",
1265+
"updatedAt": "2025-01-29T13:51:39.000Z",
1266+
"id": 328577,
1267+
"name": "category2",
1268+
"project_id": 1234
1269+
},
1270+
]
1271+
1272+
"""
1273+
project = (
1274+
self.controller.get_project_by_id(project).data
1275+
if isinstance(project, int)
1276+
else self.controller.get_project(project)
1277+
)
1278+
self.controller.check_multimodal_project_categorization(project)
1279+
1280+
response = (
1281+
self.controller.service_provider.work_management.list_project_categories(
1282+
project_id=project.id, entity=ProjectCategoryEntity
1283+
)
1284+
)
1285+
return BaseSerializer.serialize_iterable(response.data)
1286+
1287+
def remove_categories(
1288+
self,
1289+
project: Union[NotEmptyStr, int],
1290+
categories: Union[List[str], Literal["*"]],
1291+
):
1292+
"""
1293+
Remove one or more categories in a project. "*" in the category list will match all categories defined in the project.
1294+
1295+
1296+
:param project: The name or ID of the project.
1297+
:type project: Union[NotEmptyStr, int]
1298+
1299+
:param categories: A list of categories to remove, Accepts "*" to indicate all available categories in the project.
1300+
:type categories: Union[List[str], Literal["*"]]
1301+
1302+
Request Example:
1303+
::
1304+
1305+
client.remove_categories(
1306+
project="product-review-mm",
1307+
categories=["Shoes", "T-Shirt"]
1308+
)
1309+
1310+
# To remove all categories
1311+
client.remove_categories(
1312+
project="product-review-mm",
1313+
categories="*"
1314+
)
1315+
"""
1316+
project = (
1317+
self.controller.get_project_by_id(project).data
1318+
if isinstance(project, int)
1319+
else self.controller.get_project(project)
1320+
)
1321+
self.controller.check_multimodal_project_categorization(project)
1322+
1323+
query = EmptyQuery()
1324+
if categories == "*":
1325+
query &= Filter("id", [0], OperatorEnum.GT)
1326+
elif categories and isinstance(categories, list):
1327+
categories = [c.lower() for c in categories]
1328+
all_categories = self.controller.service_provider.work_management.list_project_categories(
1329+
project_id=project.id, entity=ProjectCategoryEntity
1330+
)
1331+
categories_to_remove = [
1332+
c for c in all_categories.data if c.name.lower() in categories
1333+
]
1334+
query &= Filter("id", [c.id for c in categories_to_remove], OperatorEnum.IN)
1335+
else:
1336+
raise AppException("Categories should be a list of strings or '*'.")
1337+
1338+
response = (
1339+
self.controller.service_provider.work_management.remove_project_categories(
1340+
project_id=project.id, query=query
1341+
)
1342+
)
1343+
logger.info(
1344+
f"{len(response.data)} categories successfully removed from the project."
1345+
)
1346+
11971347
def create_folder(self, project: NotEmptyStr, folder_name: NotEmptyStr):
11981348
"""
11991349
Create a new folder in the project.

src/superannotate/lib/core/entities/items.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@ class Config:
2626
extra = Extra.ignore
2727

2828

29+
class ProjectCategoryEntity(TimedBaseModel):
30+
id: int
31+
name: str
32+
project_id: int
33+
34+
class Config:
35+
extra = Extra.ignore
36+
37+
2938
class MultiModalItemCategoryEntity(TimedBaseModel):
3039
id: int = Field(None, alias="category_id")
3140
value: str = Field(None, alias="category_name")

src/superannotate/lib/core/service_types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,10 @@ class ListCategoryResponse(ServiceResponse):
234234
res_data: List[entities.CategoryEntity] = None
235235

236236

237+
class ListProjectCategoryResponse(ServiceResponse):
238+
res_data: List[entities.items.ProjectCategoryEntity] = None
239+
240+
237241
class WorkflowResponse(ServiceResponse):
238242
res_data: entities.WorkflowEntity = None
239243

src/superannotate/lib/core/serviceproviders.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@
77
from typing import List
88
from typing import Literal
99
from typing import Optional
10+
from typing import Union
1011

1112
from lib.core import entities
1213
from lib.core.conditions import Condition
14+
from lib.core.entities import CategoryEntity
15+
from lib.core.entities.project_entities import BaseEntity
1316
from lib.core.enums import CustomFieldEntityEnum
1417
from lib.core.jsx_conditions import Query
1518
from lib.core.reporter import Reporter
@@ -18,6 +21,7 @@
1821
from lib.core.service_types import FolderResponse
1922
from lib.core.service_types import IntegrationListResponse
2023
from lib.core.service_types import ListCategoryResponse
24+
from lib.core.service_types import ListProjectCategoryResponse
2125
from lib.core.service_types import ProjectListResponse
2226
from lib.core.service_types import ProjectResponse
2327
from lib.core.service_types import ServiceResponse
@@ -137,13 +141,21 @@ def search_projects(
137141
raise NotImplementedError
138142

139143
@abstractmethod
140-
def list_project_categories(self, project_id: int) -> ListCategoryResponse:
144+
def list_project_categories(
145+
self, project_id: int, entity: BaseEntity = CategoryEntity
146+
) -> Union[ListCategoryResponse, ListProjectCategoryResponse]:
147+
raise NotImplementedError
148+
149+
@abstractmethod
150+
def remove_project_categories(
151+
self, project_id: int, query: Query
152+
) -> ListProjectCategoryResponse:
141153
raise NotImplementedError
142154

143155
@abstractmethod
144156
def create_project_categories(
145157
self, project_id: int, categories: List[str]
146-
) -> ServiceResponse:
158+
) -> ListProjectCategoryResponse:
147159
raise NotImplementedError
148160

149161
@abstractmethod

src/superannotate/lib/core/usecases/annotations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2165,10 +2165,10 @@ def _attach_categories(self, folder_id: int, item_id_category_map: Dict[int, str
21652165
self._service_provider.work_management.create_project_categories(
21662166
project_id=self._project.id,
21672167
categories=categories_to_create,
2168-
).data["data"]
2168+
).data
21692169
)
21702170
for c in _categories:
2171-
self._category_name_to_id_map[c["name"]] = c["id"]
2171+
self._category_name_to_id_map[c.name] = c.id
21722172
for item_id, category_name in item_id_category_map.items():
21732173
with suppress(KeyError):
21742174
item_id_category_id_map[item_id] = self._category_name_to_id_map[

src/superannotate/lib/infrastructure/controller.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1918,3 +1918,15 @@ def get_item(
19181918
return self.get_item_by_id(item_id=item, project=project)
19191919
else:
19201920
return self.items.get_by_name(project, folder, item)
1921+
1922+
def check_multimodal_project_categorization(self, project: ProjectEntity):
1923+
if project.type != ProjectType.MULTIMODAL:
1924+
raise AppException(
1925+
"This function is only supported for Multimodal projects."
1926+
)
1927+
project_settings = self.service_provider.projects.list_settings(project).data
1928+
if not next(
1929+
(i.value for i in project_settings if i.attribute == "CategorizeItems"),
1930+
None,
1931+
):
1932+
raise AppException("Item Category not enabled for project.")

src/superannotate/lib/infrastructure/services/work_management.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
from typing import List
44
from typing import Literal
55
from typing import Optional
6+
from typing import Union
67

78
from lib.core.entities import CategoryEntity
89
from lib.core.entities import WorkflowEntity
10+
from lib.core.entities.project_entities import BaseEntity
911
from lib.core.entities.work_managament import WMProjectEntity
1012
from lib.core.entities.work_managament import WMProjectUserEntity
1113
from lib.core.entities.work_managament import WMScoreEntity
@@ -16,6 +18,7 @@
1618
from lib.core.jsx_conditions import OperatorEnum
1719
from lib.core.jsx_conditions import Query
1820
from lib.core.service_types import ListCategoryResponse
21+
from lib.core.service_types import ListProjectCategoryResponse
1922
from lib.core.service_types import ServiceResponse
2023
from lib.core.service_types import WMCustomFieldResponse
2124
from lib.core.service_types import WMProjectListResponse
@@ -74,21 +77,25 @@ def _generate_context(**kwargs):
7477
encoded_context = base64.b64encode(json.dumps(kwargs).encode("utf-8"))
7578
return encoded_context.decode("utf-8")
7679

77-
def list_project_categories(self, project_id: int) -> ListCategoryResponse:
78-
return self.client.paginate(
79-
self.URL_LIST_CATEGORIES,
80-
item_type=CategoryEntity,
80+
def list_project_categories(
81+
self, project_id: int, entity: BaseEntity = CategoryEntity
82+
) -> Union[ListCategoryResponse, ListProjectCategoryResponse]:
83+
response = self.client.paginate(
84+
url=self.URL_LIST_CATEGORIES,
85+
item_type=entity,
8186
query_params={"project_id": project_id},
8287
headers={
8388
"x-sa-entity-context": self._generate_context(
8489
team_id=self.client.team_id
8590
),
8691
},
8792
)
93+
response.raise_for_status()
94+
return response
8895

8996
def create_project_categories(
9097
self, project_id: int, categories: List[str]
91-
) -> ServiceResponse:
98+
) -> ListProjectCategoryResponse:
9299
response = self.client.request(
93100
method="post",
94101
url=self.URL_CREATE_CATEGORIES,
@@ -99,6 +106,26 @@ def create_project_categories(
99106
team_id=self.client.team_id, project_id=project_id
100107
),
101108
},
109+
content_type=ListProjectCategoryResponse,
110+
dispatcher="data",
111+
)
112+
response.raise_for_status()
113+
return response
114+
115+
def remove_project_categories(
116+
self, project_id: int, query: Query
117+
) -> ListProjectCategoryResponse:
118+
119+
response = self.client.request(
120+
method="delete",
121+
url=f"{self.URL_CREATE_CATEGORIES}?{query.build_query()}",
122+
headers={
123+
"x-sa-entity-context": self._generate_context(
124+
team_id=self.client.team_id, project_id=project_id
125+
),
126+
},
127+
content_type=ListProjectCategoryResponse,
128+
dispatcher="data",
102129
)
103130
response.raise_for_status()
104131
return response

0 commit comments

Comments
 (0)