11import uuid
22
33import django .contrib .auth .password_validation
4- import django .core .exceptions
54import django .core .validators
5+ import django .db .transaction
66import pycountry
77import rest_framework .exceptions
88import rest_framework .serializers
1111import rest_framework_simplejwt .tokens
1212
1313import business .constants
14+ import business .models
1415import business .models as business_models
16+ import business .utils .auth
17+ import business .utils .tokens
1518import business .validators
1619
1720
@@ -21,9 +24,9 @@ class CompanySignUpSerializer(rest_framework.serializers.ModelSerializer):
2124 write_only = True ,
2225 required = True ,
2326 validators = [django .contrib .auth .password_validation .validate_password ],
27+ style = {'input_type' : 'password' },
2428 min_length = business .constants .COMPANY_PASSWORD_MIN_LENGTH ,
2529 max_length = business .constants .COMPANY_PASSWORD_MAX_LENGTH ,
26- style = {'input_type' : 'password' },
2730 )
2831 name = rest_framework .serializers .CharField (
2932 required = True ,
@@ -44,30 +47,18 @@ class CompanySignUpSerializer(rest_framework.serializers.ModelSerializer):
4447
4548 class Meta :
4649 model = business_models .Company
47- fields = (
48- 'id' ,
49- 'name' ,
50- 'email' ,
51- 'password' ,
52- )
50+ fields = ('id' , 'name' , 'email' , 'password' )
5351
52+ @django .db .transaction .atomic
5453 def create (self , validated_data ):
55- try :
56- company = business_models .Company .objects .create_company (
57- email = validated_data ['email' ],
58- name = validated_data ['name' ],
59- password = validated_data ['password' ],
60- )
61- company .token_version += 1
62- company .save ()
63- return company
64- except django .core .exceptions .ValidationError as e :
65- raise rest_framework .serializers .ValidationError (e .messages )
54+ company = business_models .Company .objects .create_company (
55+ ** validated_data ,
56+ )
6657
58+ return business .utils .auth .bump_company_token_version (company )
6759
68- class CompanySignInSerializer (
69- rest_framework .serializers .Serializer ,
70- ):
60+
61+ class CompanySignInSerializer (rest_framework .serializers .Serializer ):
7162 email = rest_framework .serializers .EmailField (required = True )
7263 password = rest_framework .serializers .CharField (
7364 required = True ,
@@ -80,16 +71,15 @@ def validate(self, attrs):
8071 password = attrs .get ('password' )
8172
8273 if not email or not password :
83- raise rest_framework .exceptions .ValidationError (
84- {'detail' : 'Both email and password are required' },
85- code = 'required' ,
74+ raise rest_framework .serializers .ValidationError (
75+ 'Both email and password are required.' ,
8676 )
8777
8878 try :
8979 company = business_models .Company .objects .get (email = email )
9080 except business_models .Company .DoesNotExist :
9181 raise rest_framework .serializers .ValidationError (
92- 'Invalid credentials' ,
82+ 'Invalid credentials. ' ,
9383 )
9484
9585 if not company .is_active or not company .check_password (password ):
@@ -98,53 +88,55 @@ def validate(self, attrs):
9888 code = 'authentication_failed' ,
9989 )
10090
91+ attrs ['company' ] = company
10192 return attrs
10293
10394
10495class CompanyTokenRefreshSerializer (
10596 rest_framework_simplejwt .serializers .TokenRefreshSerializer ,
10697):
10798 def validate (self , attrs ):
99+ attrs = super ().validate (attrs )
108100 refresh = rest_framework_simplejwt .tokens .RefreshToken (
109101 attrs ['refresh' ],
110102 )
111- user_type = refresh .payload .get ('user_type' , 'user' )
103+ company = self .get_active_company_from_token (refresh )
104+
105+ company = business .utils .auth .bump_company_token_version (company )
112106
113- if user_type != 'company' :
107+ return business .utils .tokens .generate_company_tokens (company )
108+
109+ def get_active_company_from_token (self , token ):
110+ if token .payload .get ('user_type' ) != 'company' :
114111 raise rest_framework_simplejwt .exceptions .InvalidToken (
115112 'This refresh endpoint is for company tokens only' ,
116113 )
117114
118- company_id = refresh .payload .get ('company_id' )
119- if not company_id :
115+ company_id = token .payload .get ('company_id' )
116+ try :
117+ company_uuid = uuid .UUID (company_id )
118+ except (TypeError , ValueError ):
120119 raise rest_framework_simplejwt .exceptions .InvalidToken (
121- 'Company ID missing in token' ,
120+ 'Invalid or missing company_id in token' ,
122121 )
123122
124123 try :
125- company = business_models .Company .objects .get (
126- id = uuid .UUID (company_id ),
124+ company = business .models .Company .objects .get (
125+ id = company_uuid ,
126+ is_active = True ,
127127 )
128- except business_models .Company .DoesNotExist :
128+ except business . models .Company .DoesNotExist :
129129 raise rest_framework_simplejwt .exceptions .InvalidToken (
130- 'Company not found' ,
130+ 'Company not found or inactive ' ,
131131 )
132132
133- token_version = refresh .payload .get ('token_version' , 0 )
133+ token_version = token .payload .get ('token_version' , 0 )
134134 if company .token_version != token_version :
135135 raise rest_framework_simplejwt .exceptions .InvalidToken (
136136 'Token is blacklisted' ,
137137 )
138138
139- new_refresh = rest_framework_simplejwt .tokens .RefreshToken ()
140- new_refresh ['user_type' ] = 'company'
141- new_refresh ['company_id' ] = str (company .id )
142- new_refresh ['token_version' ] = company .token_version
143-
144- return {
145- 'access' : str (new_refresh .access_token ),
146- 'refresh' : str (new_refresh ),
147- }
139+ return company
148140
149141
150142class TargetSerializer (rest_framework .serializers .Serializer ):
0 commit comments