From e2ad2aa82ddd5cb344bfc01b6f81f855884ea3fd Mon Sep 17 00:00:00 2001 From: Max Date: Mon, 14 Jul 2025 16:52:20 +0300 Subject: [PATCH 1/3] refactor(core): Add base serializer for pagination params Introduced `BaseLimitOffsetPaginationSerializer` to handle the validation of `limit` and `offset` query parameters. --- promo_code/core/pagination.py | 14 +++++++++----- promo_code/core/serializers.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 5 deletions(-) create mode 100644 promo_code/core/serializers.py diff --git a/promo_code/core/pagination.py b/promo_code/core/pagination.py index 120f148..31a129d 100644 --- a/promo_code/core/pagination.py +++ b/promo_code/core/pagination.py @@ -1,6 +1,9 @@ +import rest_framework.exceptions import rest_framework.pagination import rest_framework.response +import core.serializers + class CustomLimitOffsetPagination( rest_framework.pagination.LimitOffsetPagination, @@ -9,12 +12,13 @@ class CustomLimitOffsetPagination( max_limit = 100 def get_limit(self, request): - raw_limit = request.query_params.get(self.limit_query_param) - - if raw_limit is None: - return self.default_limit + serializer = core.serializers.BaseLimitOffsetPaginationSerializer( + data=request.query_params, + ) + serializer.is_valid(raise_exception=True) - limit = int(raw_limit) + validated_data = serializer.validated_data + limit = validated_data.get('limit', self.default_limit) # Allow 0, otherwise cut by max_limit return 0 if limit == 0 else min(limit, self.max_limit) diff --git a/promo_code/core/serializers.py b/promo_code/core/serializers.py new file mode 100644 index 0000000..60cd85b --- /dev/null +++ b/promo_code/core/serializers.py @@ -0,0 +1,31 @@ +import rest_framework.exceptions +import rest_framework.serializers + + +class BaseLimitOffsetPaginationSerializer( + rest_framework.serializers.Serializer, +): + """ + Base serializer for common filtering and sorting parameters. + Pagination parameters (limit, offset) are handled by the pagination class. + """ + + limit = rest_framework.serializers.IntegerField( + min_value=0, + required=False, + ) + offset = rest_framework.serializers.IntegerField( + min_value=0, + required=False, + ) + + def validate(self, attrs): + errors = {} + for field in ('limit', 'offset'): + raw = self.initial_data.get(field, None) + if raw == '': + errors[field] = ['This field cannot be an empty string.'] + if errors: + raise rest_framework.exceptions.ValidationError(errors) + + return super().validate(attrs) From 42ca24364f73db375a77d3680968c9a67ea8dc92 Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 15 Jul 2025 13:25:27 +0300 Subject: [PATCH 2/3] refactor(business): Overhaul promo serializers and validation logic This commit refactors the promotion-related serializers to improve structure, validation, and maintainability, following DRF best practices. Key changes: - Removed the separate `PromoValidator` class and integrated its validation logic directly into the serializers. This co-locates validation with the data representation, making the code easier to follow. - Introduced `BasePromoSerializer` to consolidate common fields and methods, reducing code duplication across create, detail, and read-only serializers. - Implemented custom `CountryField` and `MultiCountryField` to provide robust, reusable validation for ISO 3166-1 alpha-2 country codes. - Simplified `PromoListQuerySerializer` by leveraging the new `MultiCountryField` and inheriting from the base pagination serializer, removing manual parameter validation. --- promo_code/business/serializers.py | 472 ++++++++++++++--------------- promo_code/business/validators.py | 98 ------ 2 files changed, 224 insertions(+), 346 deletions(-) delete mode 100644 promo_code/business/validators.py diff --git a/promo_code/business/serializers.py b/promo_code/business/serializers.py index 3c9b282..72827e4 100644 --- a/promo_code/business/serializers.py +++ b/promo_code/business/serializers.py @@ -1,7 +1,6 @@ import uuid import django.contrib.auth.password_validation -import django.core.validators import django.db.transaction import pycountry import rest_framework.exceptions @@ -13,7 +12,7 @@ import business.constants import business.models import business.utils.tokens -import business.validators +import core.serializers import core.utils.auth @@ -104,7 +103,7 @@ def validate(self, attrs): ) company = self.get_active_company_from_token(refresh) - company = business.utils.auth.bump_company_token_version(company) + company = core.utils.auth.bump_token_version(company) return business.utils.tokens.generate_company_tokens(company) @@ -141,6 +140,62 @@ def get_active_company_from_token(self, token): return company +class CountryField(rest_framework.serializers.CharField): + """ + Custom field for validating country codes according to ISO 3166-1 alpha-2. + """ + + def __init__(self, **kwargs): + kwargs['allow_blank'] = False + kwargs['min_length'] = business.constants.TARGET_COUNTRY_CODE_LENGTH + kwargs['max_length'] = business.constants.TARGET_COUNTRY_CODE_LENGTH + super().__init__(**kwargs) + + def to_internal_value(self, data): + code = super().to_internal_value(data) + try: + pycountry.countries.lookup(code.upper()) + except LookupError: + raise rest_framework.serializers.ValidationError( + 'Invalid ISO 3166-1 alpha-2 country code.', + ) + return code + + +class MultiCountryField(rest_framework.serializers.ListField): + """ + Custom field for handling multiple country codes, + passed either as a comma-separated list or as multiple parameters. + """ + + def __init__(self, **kwargs): + kwargs['child'] = CountryField() + kwargs['allow_empty'] = False + super().__init__(**kwargs) + + def to_internal_value(self, data): + if not data or not isinstance(data, list): + raise rest_framework.serializers.ValidationError( + 'At least one country must be specified.', + ) + + # (&country=us,fr) + if len(data) == 1 and ',' in data[0]: + countries_str = data[0] + if '' in [s.strip() for s in countries_str.split(',')]: + raise rest_framework.serializers.ValidationError( + 'Invalid country format.', + ) + data = [country.strip() for country in countries_str.split(',')] + + if any(not item for item in data): + raise rest_framework.serializers.ValidationError( + 'Empty value for country is not allowed.', + ) + + return super().to_internal_value(data) + + class TargetSerializer(rest_framework.serializers.Serializer): age_from = rest_framework.serializers.IntegerField( min_value=business.constants.TARGET_AGE_MIN, @@ -152,15 +207,13 @@ class TargetSerializer(rest_framework.serializers.Serializer): max_value=business.constants.TARGET_AGE_MAX, required=False, ) - country = rest_framework.serializers.CharField( - max_length=business.constants.TARGET_COUNTRY_CODE_LENGTH, - min_length=business.constants.TARGET_COUNTRY_CODE_LENGTH, - required=False, - ) + country = CountryField(required=False) + categories = rest_framework.serializers.ListField( child=rest_framework.serializers.CharField( min_length=business.constants.TARGET_CATEGORY_MIN_LENGTH, max_length=business.constants.TARGET_CATEGORY_MAX_LENGTH, + allow_blank=False, ), max_length=business.constants.TARGET_CATEGORY_MAX_ITEMS, required=False, @@ -170,6 +223,7 @@ class TargetSerializer(rest_framework.serializers.Serializer): def validate(self, data): age_from = data.get('age_from') age_until = data.get('age_until') + if ( age_from is not None and age_until is not None @@ -178,60 +232,47 @@ def validate(self, data): raise rest_framework.serializers.ValidationError( {'age_until': 'Must be greater than or equal to age_from.'}, ) - - country = data.get('country') - if country: - try: - pycountry.countries.lookup(country.strip().upper()) - data['country'] = country - except LookupError: - raise rest_framework.serializers.ValidationError( - {'country': 'Invalid ISO 3166-1 alpha-2 country code.'}, - ) - return data -class PromoCreateSerializer(rest_framework.serializers.ModelSerializer): +class BasePromoSerializer(rest_framework.serializers.ModelSerializer): + """ + Base serializer for promo, containing validation and representation logic. + """ + + image_url = rest_framework.serializers.URLField( + required=False, + allow_blank=False, + max_length=business.constants.PROMO_IMAGE_URL_MAX_LENGTH, + ) description = rest_framework.serializers.CharField( min_length=business.constants.PROMO_DESC_MIN_LENGTH, max_length=business.constants.PROMO_DESC_MAX_LENGTH, required=True, ) - image_url = rest_framework.serializers.CharField( - required=False, - max_length=business.constants.PROMO_IMAGE_URL_MAX_LENGTH, - validators=[ - django.core.validators.URLValidator(schemes=['http', 'https']), - ], - ) target = TargetSerializer(required=True, allow_null=True) promo_common = rest_framework.serializers.CharField( min_length=business.constants.PROMO_COMMON_CODE_MIN_LENGTH, max_length=business.constants.PROMO_COMMON_CODE_MAX_LENGTH, required=False, allow_null=True, + allow_blank=False, ) promo_unique = rest_framework.serializers.ListField( child=rest_framework.serializers.CharField( min_length=business.constants.PROMO_UNIQUE_CODE_MIN_LENGTH, max_length=business.constants.PROMO_UNIQUE_CODE_MAX_LENGTH, + allow_blank=False, ), min_length=business.constants.PROMO_UNIQUE_LIST_MIN_ITEMS, max_length=business.constants.PROMO_UNIQUE_LIST_MAX_ITEMS, required=False, allow_null=True, ) - # headers - url = rest_framework.serializers.HyperlinkedIdentityField( - view_name='api-business:promo-detail', - lookup_field='id', - ) class Meta: model = business.models.Promo fields = ( - 'url', 'description', 'image_url', 'target', @@ -244,241 +285,176 @@ class Meta: ) def validate(self, data): - data = super().validate(data) - validator = business.validators.PromoValidator(data=data) - return validator.validate() + full_data = self._get_full_data(data) - def create(self, validated_data): - target_data = validated_data.pop('target') - promo_common = validated_data.pop('promo_common', None) - promo_unique = validated_data.pop('promo_unique', None) + mode = full_data.get('mode') - return business.models.Promo.objects.create_promo( - user=self.context['request'].user, - target_data=target_data, - promo_common=promo_common, - promo_unique=promo_unique, - **validated_data, - ) + if mode == business.constants.PROMO_MODE_COMMON: + self._validate_common(full_data) + + elif mode == business.constants.PROMO_MODE_UNIQUE: + self._validate_unique(full_data) + + else: + raise rest_framework.serializers.ValidationError( + {'mode': 'Invalid mode.'}, + ) + + return data def to_representation(self, instance): + """ + Controls the display of fields in the response. + """ data = super().to_representation(instance) - data['target'] = instance.target + + if not instance.image_url: + data.pop('image_url', None) if instance.mode == business.constants.PROMO_MODE_UNIQUE: - data['promo_unique'] = [ - code.code for code in instance.unique_codes.all() - ] data.pop('promo_common', None) + if 'promo_unique' in self.fields and isinstance( + self.fields['promo_unique'], + rest_framework.serializers.SerializerMethodField, + ): + data['promo_unique'] = self.get_promo_unique(instance) + else: + data['promo_unique'] = [ + code.code for code in instance.unique_codes.all() + ] else: data.pop('promo_unique', None) return data - -class PromoListQuerySerializer(rest_framework.serializers.Serializer): - """ - Serializer for validating query parameters of promo list requests. - """ - - limit = rest_framework.serializers.CharField( - required=False, - allow_blank=True, - ) - offset = rest_framework.serializers.CharField( - required=False, - allow_blank=True, - ) - sort_by = rest_framework.serializers.ChoiceField( - choices=['active_from', 'active_until'], - required=False, - ) - country = rest_framework.serializers.CharField( - required=False, - allow_blank=True, - ) - - _allowed_params = None - - def get_allowed_params(self): - if self._allowed_params is None: - self._allowed_params = set(self.fields.keys()) - return self._allowed_params - - def validate(self, attrs): - query_params = self.initial_data - allowed_params = self.get_allowed_params() - - unexpected_params = set(query_params.keys()) - allowed_params - if unexpected_params: - raise rest_framework.exceptions.ValidationError('Invalid params.') - - field_errors = {} - - attrs = self._validate_int_field('limit', attrs, field_errors) - attrs = self._validate_int_field('offset', attrs, field_errors) - - self._validate_country(query_params, attrs, field_errors) - - if field_errors: - raise rest_framework.exceptions.ValidationError(field_errors) - - return attrs - - def _validate_int_field(self, field_name, attrs, field_errors): - value_str = self.initial_data.get(field_name) - if value_str is None: - return attrs - - if value_str == '': - raise rest_framework.exceptions.ValidationError( - f'Invalid {field_name} format.', + def _get_full_data(self, data): + """ + Build the full data dict by merging existing instance data + with new input. + """ + if self.instance: + full_data = self.to_representation(self.instance) + full_data.update(data) + else: + full_data = data + return full_data + + def _validate_common(self, full_data): + """ + Validations for COMMON promo mode. + """ + promo_common = full_data.get('promo_common') + promo_unique = full_data.get('promo_unique') + max_count = full_data.get('max_count') + + if not promo_common: + raise rest_framework.serializers.ValidationError( + {'promo_common': 'This field is required for COMMON mode.'}, ) - try: - value_int = int(value_str) - if value_int < 0: - raise rest_framework.exceptions.ValidationError( - f'{field_name.capitalize()} cannot be negative.', - ) - attrs[field_name] = value_int - except (ValueError, TypeError): - raise rest_framework.exceptions.ValidationError( - f'Invalid {field_name} format.', + if promo_unique is not None: + raise rest_framework.serializers.ValidationError( + {'promo_unique': 'This field is not allowed for COMMON mode.'}, ) - return attrs - - def _validate_country(self, query_params, attrs, field_errors): - countries_raw = query_params.getlist('country', []) - - if '' in countries_raw: - raise rest_framework.exceptions.ValidationError( - 'Invalid country format.', + min_c = business.constants.PROMO_COMMON_MIN_COUNT + max_c = business.constants.PROMO_COMMON_MAX_COUNT + if not (min_c <= max_count <= max_c): + raise rest_framework.serializers.ValidationError( + { + 'max_count': ( + f'Must be between {min_c} and {max_c} for COMMON mode.' + ), + }, ) - country_codes = [] - invalid_codes = [] + def _validate_unique(self, full_data): + """ + Validations for UNIQUE promo mode. + """ + promo_common = full_data.get('promo_common') + promo_unique = full_data.get('promo_unique') + max_count = full_data.get('max_count') - for country_group in countries_raw: - if not country_group.strip(): - continue + if not promo_unique: + raise rest_framework.serializers.ValidationError( + {'promo_unique': 'This field is required for UNIQUE mode.'}, + ) - parts = [part.strip() for part in country_group.split(',')] + if promo_common is not None: + raise rest_framework.serializers.ValidationError( + {'promo_common': 'This field is not allowed for UNIQUE mode.'}, + ) - if '' in parts: - raise rest_framework.exceptions.ValidationError( - 'Invalid country format.', - ) + if max_count != business.constants.PROMO_UNIQUE_MAX_COUNT: + raise rest_framework.serializers.ValidationError( + { + 'max_count': ( + 'Must be equal to ' + f'{business.constants.PROMO_UNIQUE_MAX_COUNT}' + 'for UNIQUE mode.' + ), + }, + ) - country_codes.extend(parts) - country_codes_upper = [c.upper() for c in country_codes] +class PromoCreateSerializer(BasePromoSerializer): + url = rest_framework.serializers.HyperlinkedIdentityField( + view_name='api-business:promo-detail', + lookup_field='id', + ) - for code in country_codes_upper: - if len(code) != 2: - invalid_codes.append(code) - continue - try: - pycountry.countries.lookup(code) - except LookupError: - invalid_codes.append(code) + class Meta(BasePromoSerializer.Meta): + fields = ('url',) + BasePromoSerializer.Meta.fields - if invalid_codes: - field_errors['country'] = ( - f'Invalid country codes: {", ".join(invalid_codes)}' - ) + def create(self, validated_data): + target_data = validated_data.pop('target') + promo_common = validated_data.pop('promo_common', None) + promo_unique = validated_data.pop('promo_unique', None) - attrs['countries'] = country_codes - attrs.pop('country', None) + return business.models.Promo.objects.create_promo( + user=self.context['request'].user, + target_data=target_data, + promo_common=promo_common, + promo_unique=promo_unique, + **validated_data, + ) -class PromoReadOnlySerializer(rest_framework.serializers.ModelSerializer): - promo_id = rest_framework.serializers.UUIDField( - source='id', - read_only=True, - ) - company_id = rest_framework.serializers.UUIDField( - source='company.id', - read_only=True, - ) - company_name = rest_framework.serializers.CharField( - source='company.name', - read_only=True, - ) - target = TargetSerializer() +class PromoListQuerySerializer( + core.serializers.BaseLimitOffsetPaginationSerializer, +): + """ + Validates query parameters for the list of promotions. + """ - promo_unique = rest_framework.serializers.SerializerMethodField() - like_count = rest_framework.serializers.IntegerField( - source='get_like_count', - read_only=True, - ) - used_count = rest_framework.serializers.IntegerField( - source='get_used_codes_count', - read_only=True, - ) - comment_count = rest_framework.serializers.IntegerField( - source='get_comment_count', - read_only=True, - ) - active = rest_framework.serializers.BooleanField( - source='is_active', - read_only=True, + sort_by = rest_framework.serializers.ChoiceField( + choices=['active_from', 'active_until'], + required=False, ) + country = MultiCountryField(required=False) - class Meta: - model = business.models.Promo - fields = ( - 'promo_id', - 'company_id', - 'company_name', - 'description', - 'image_url', - 'target', - 'max_count', - 'active_from', - 'active_until', - 'mode', - 'promo_common', - 'promo_unique', - 'like_count', - 'comment_count', - 'used_count', - 'active', - ) + def validate(self, attrs): + query_params = self.initial_data.keys() + allowed_params = self.fields.keys() + unexpected_params = set(query_params) - set(allowed_params) - def get_promo_unique(self, obj): - return obj.get_available_unique_codes + if unexpected_params: + raise rest_framework.exceptions.ValidationError( + f'Invalid parameters: {", ".join(unexpected_params)}', + ) - def to_representation(self, instance): - data = super().to_representation(instance) - if instance.mode == business.constants.PROMO_MODE_COMMON: - data.pop('promo_unique', None) - else: - data.pop('promo_common', None) + if 'country' in attrs: + attrs['countries'] = attrs.pop('country') - return data + return attrs -class PromoDetailSerializer(rest_framework.serializers.ModelSerializer): +class PromoDetailSerializer(BasePromoSerializer): promo_id = rest_framework.serializers.UUIDField( source='id', read_only=True, ) - description = rest_framework.serializers.CharField( - min_length=business.constants.PROMO_DESC_MIN_LENGTH, - max_length=business.constants.PROMO_DESC_MAX_LENGTH, - required=True, - ) - image_url = rest_framework.serializers.CharField( - required=False, - max_length=business.constants.PROMO_IMAGE_URL_MAX_LENGTH, - validators=[ - django.core.validators.URLValidator(schemes=['http', 'https']), - ], - ) - target = TargetSerializer(allow_null=True, required=False) - promo_unique = rest_framework.serializers.SerializerMethodField() company_name = rest_framework.serializers.CharField( source='company.name', read_only=True, @@ -500,31 +476,26 @@ class PromoDetailSerializer(rest_framework.serializers.ModelSerializer): read_only=True, ) - class Meta: - model = business.models.Promo - fields = ( + promo_unique = rest_framework.serializers.SerializerMethodField() + + class Meta(BasePromoSerializer.Meta): + fields = BasePromoSerializer.Meta.fields + ( 'promo_id', - 'description', - 'image_url', - 'target', - 'max_count', - 'active_from', - 'active_until', - 'mode', - 'promo_common', - 'promo_unique', 'company_name', - 'active', 'like_count', 'comment_count', 'used_count', + 'active', ) def get_promo_unique(self, obj): - return obj.get_available_unique_codes + if obj.mode == business.constants.PROMO_MODE_UNIQUE: + return obj.get_available_unique_codes + return None def update(self, instance, validated_data): target_data = validated_data.pop('target', None) + for attr, value in validated_data.items(): setattr(instance, attr, value) @@ -534,13 +505,18 @@ def update(self, instance, validated_data): instance.save() return instance - def validate(self, data): - data = super().validate(data) - validator = business.validators.PromoValidator( - data=data, - instance=self.instance, - ) - return validator.validate() + +class PromoReadOnlySerializer(PromoDetailSerializer): + """Read-only serializer for promo.""" + + company_id = rest_framework.serializers.UUIDField( + source='company.id', + read_only=True, + ) + + class Meta(PromoDetailSerializer.Meta): + fields = PromoDetailSerializer.Meta.fields + ('company_id',) + read_only_fields = fields class CountryStatSerializer(rest_framework.serializers.Serializer): diff --git a/promo_code/business/validators.py b/promo_code/business/validators.py deleted file mode 100644 index b66ce07..0000000 --- a/promo_code/business/validators.py +++ /dev/null @@ -1,98 +0,0 @@ -import rest_framework.exceptions - -import business.constants - - -class PromoValidator: - def __init__(self, data, instance=None): - self.data = data - self.instance = instance - self.full_data = self._get_full_data() - - def _get_full_data(self): - full_data = {} - if self.instance is not None: - full_data.update( - { - 'mode': self.instance.mode, - 'promo_common': self.instance.promo_common, - 'promo_unique': None, - 'max_count': self.instance.max_count, - 'active_from': self.instance.active_from, - 'active_until': self.instance.active_until, - 'used_count': self.instance.used_count, - 'target': self.instance.target - if self.instance.target - else {}, - }, - ) - - full_data.update(self.data) - return full_data - - def validate(self): - mode = self.full_data.get('mode') - promo_common = self.full_data.get('promo_common') - promo_unique = self.full_data.get('promo_unique') - max_count = self.full_data.get('max_count') - used_count = self.full_data.get('used_count') - - if mode not in [ - business.constants.PROMO_MODE_COMMON, - business.constants.PROMO_MODE_UNIQUE, - ]: - raise rest_framework.exceptions.ValidationError( - {'mode': 'Invalid mode.'}, - ) - - if used_count and used_count > max_count: - raise rest_framework.exceptions.ValidationError( - {'mode': 'Invalid max_count.'}, - ) - - if mode == business.constants.PROMO_MODE_COMMON: - if not promo_common: - raise rest_framework.exceptions.ValidationError( - { - 'promo_common': ( - 'This field is required for COMMON mode.' - ), - }, - ) - if promo_unique is not None: - raise rest_framework.exceptions.ValidationError( - { - 'promo_unique': ( - 'This field is not allowed for COMMON mode.' - ), - }, - ) - if max_count is None or not ( - business.constants.PROMO_COMMON_MIN_COUNT - <= max_count - <= business.constants.PROMO_COMMON_MAX_COUNT - ): - raise rest_framework.exceptions.ValidationError( - { - 'max_count': ( - 'Must be between 0 and 100,000,000 ' - 'for COMMON mode.' - ), - }, - ) - - elif mode == business.constants.PROMO_MODE_UNIQUE: - if promo_common is not None: - raise rest_framework.exceptions.ValidationError( - { - 'promo_common': ( - 'This field is not allowed for UNIQUE mode.' - ), - }, - ) - if max_count != business.constants.PROMO_UNIQUE_MAX_COUNT: - raise rest_framework.exceptions.ValidationError( - {'max_count': 'Must be 1 for UNIQUE mode.'}, - ) - - return self.full_data From 6d2482af839775aed797bbec09e729b03a41567b Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 15 Jul 2025 14:52:27 +0300 Subject: [PATCH 3/3] refactor(business): Enhance promo validation and data representation This commit improves the validation logic within the promo serializers and refines how promotion data is handled and displayed. Key changes: - **Improved Validation:** The validation logic now correctly handles partial updates (`PATCH`) by checking the `mode` from the instance if it's not provided in the request data. It also prevents `max_count` from being set lower than the current `used_count` for a promo. - **Accurate Used Count:** The `get_used_codes_count` method on the `Promo` model now correctly returns the `used_count` for promotions in `COMMON` mode. --- promo_code/business/models.py | 3 +- promo_code/business/serializers.py | 142 ++++++++++++++++------------- 2 files changed, 81 insertions(+), 64 deletions(-) diff --git a/promo_code/business/models.py b/promo_code/business/models.py index 9a86fe3..e64858c 100644 --- a/promo_code/business/models.py +++ b/promo_code/business/models.py @@ -119,8 +119,7 @@ def get_comment_count(self) -> int: def get_used_codes_count(self) -> int: if self.mode == business.constants.PROMO_MODE_UNIQUE: return self.unique_codes.filter(is_used=True).count() - # TODO: COMMON Promo - return 0 + return self.used_count @property def get_available_unique_codes(self) -> list[str] | None: diff --git a/promo_code/business/serializers.py b/promo_code/business/serializers.py index 72827e4..ffebcb6 100644 --- a/promo_code/business/serializers.py +++ b/promo_code/business/serializers.py @@ -285,16 +285,21 @@ class Meta: ) def validate(self, data): - full_data = self._get_full_data(data) + """ + Main validation method. + Determines the mode and calls the corresponding validation method. + """ - mode = full_data.get('mode') + mode = data.get('mode', getattr(self.instance, 'mode', None)) if mode == business.constants.PROMO_MODE_COMMON: - self._validate_common(full_data) - + self._validate_common(data) elif mode == business.constants.PROMO_MODE_UNIQUE: - self._validate_unique(full_data) - + self._validate_unique(data) + elif mode is None: + raise rest_framework.serializers.ValidationError( + {'mode': 'This field is required.'}, + ) else: raise rest_framework.serializers.ValidationError( {'mode': 'Invalid mode.'}, @@ -302,64 +307,45 @@ def validate(self, data): return data - def to_representation(self, instance): - """ - Controls the display of fields in the response. - """ - data = super().to_representation(instance) - - if not instance.image_url: - data.pop('image_url', None) - - if instance.mode == business.constants.PROMO_MODE_UNIQUE: - data.pop('promo_common', None) - if 'promo_unique' in self.fields and isinstance( - self.fields['promo_unique'], - rest_framework.serializers.SerializerMethodField, - ): - data['promo_unique'] = self.get_promo_unique(instance) - else: - data['promo_unique'] = [ - code.code for code in instance.unique_codes.all() - ] - else: - data.pop('promo_unique', None) - - return data - - def _get_full_data(self, data): - """ - Build the full data dict by merging existing instance data - with new input. - """ - if self.instance: - full_data = self.to_representation(self.instance) - full_data.update(data) - else: - full_data = data - return full_data - - def _validate_common(self, full_data): + def _validate_common(self, data): """ Validations for COMMON promo mode. """ - promo_common = full_data.get('promo_common') - promo_unique = full_data.get('promo_unique') - max_count = full_data.get('max_count') - if not promo_common: + if 'promo_unique' in data and data['promo_unique'] is not None: raise rest_framework.serializers.ValidationError( - {'promo_common': 'This field is required for COMMON mode.'}, + {'promo_unique': 'This field is not allowed for COMMON mode.'}, ) - if promo_unique is not None: + if self.instance is None and not data.get('promo_common'): raise rest_framework.serializers.ValidationError( - {'promo_unique': 'This field is not allowed for COMMON mode.'}, + {'promo_common': 'This field is required for COMMON mode.'}, ) + new_max_count = data.get('max_count') + if self.instance and new_max_count is not None: + used_count = self.instance.get_used_codes_count + if used_count > new_max_count: + raise rest_framework.serializers.ValidationError( + { + 'max_count': ( + f'max_count ({new_max_count}) cannot be less than ' + f'used_count ({used_count}).' + ), + }, + ) + + effective_max_count = ( + new_max_count + if new_max_count is not None + else getattr(self.instance, 'max_count', None) + ) + min_c = business.constants.PROMO_COMMON_MIN_COUNT max_c = business.constants.PROMO_COMMON_MAX_COUNT - if not (min_c <= max_count <= max_c): + if effective_max_count is not None and not ( + min_c <= effective_max_count <= max_c + ): raise rest_framework.serializers.ValidationError( { 'max_count': ( @@ -368,35 +354,67 @@ def _validate_common(self, full_data): }, ) - def _validate_unique(self, full_data): + def _validate_unique(self, data): """ Validations for UNIQUE promo mode. """ - promo_common = full_data.get('promo_common') - promo_unique = full_data.get('promo_unique') - max_count = full_data.get('max_count') - if not promo_unique: + if 'promo_common' in data and data['promo_common'] is not None: raise rest_framework.serializers.ValidationError( - {'promo_unique': 'This field is required for UNIQUE mode.'}, + {'promo_common': 'This field is not allowed for UNIQUE mode.'}, ) - if promo_common is not None: + if self.instance is None and not data.get('promo_unique'): raise rest_framework.serializers.ValidationError( - {'promo_common': 'This field is not allowed for UNIQUE mode.'}, + {'promo_unique': 'This field is required for UNIQUE mode.'}, ) - if max_count != business.constants.PROMO_UNIQUE_MAX_COUNT: + effective_max_count = data.get( + 'max_count', + getattr(self.instance, 'max_count', None), + ) + + if ( + effective_max_count is not None + and effective_max_count + != business.constants.PROMO_UNIQUE_MAX_COUNT + ): raise rest_framework.serializers.ValidationError( { 'max_count': ( 'Must be equal to ' - f'{business.constants.PROMO_UNIQUE_MAX_COUNT}' + f'{business.constants.PROMO_UNIQUE_MAX_COUNT} ' 'for UNIQUE mode.' ), }, ) + def to_representation(self, instance): + """ + Controls the display of fields in the response. + """ + + data = super().to_representation(instance) + + if not instance.image_url: + data.pop('image_url', None) + + if instance.mode == business.constants.PROMO_MODE_UNIQUE: + data.pop('promo_common', None) + if 'promo_unique' in self.fields and isinstance( + self.fields['promo_unique'], + rest_framework.serializers.SerializerMethodField, + ): + data['promo_unique'] = self.get_promo_unique(instance) + else: + data['promo_unique'] = [ + code.code for code in instance.unique_codes.all() + ] + else: + data.pop('promo_unique', None) + + return data + class PromoCreateSerializer(BasePromoSerializer): url = rest_framework.serializers.HyperlinkedIdentityField(