1212from flask import current_app
1313from flask_rest_jsonapi .data_layers .base import BaseDataLayer
1414from flask_rest_jsonapi .exceptions import RelationNotFound , RelatedObjectNotFound , JsonApiException ,\
15- InvalidSort , ObjectNotFound
15+ InvalidSort , ObjectNotFound , InvalidInclude
1616from flask_rest_jsonapi .data_layers .filtering .alchemy import create_filters
17- from flask_rest_jsonapi .schema import get_model_field , get_related_schema , get_relationships
17+ from flask_rest_jsonapi .schema import get_model_field , get_related_schema , get_relationships , get_schema_field
1818
1919
2020class SqlalchemyDataLayer (BaseDataLayer ):
@@ -43,8 +43,8 @@ def create_object(self, data, view_kwargs):
4343 """
4444 self .before_create_object (data , view_kwargs )
4545
46- relationship_fields = get_relationships (self .resource .schema )
47- obj = self .model (** {get_model_field ( self . resource . schema , key ) : value
46+ relationship_fields = get_relationships (self .resource .schema , model_field = True )
47+ obj = self .model (** {key : value
4848 for (key , value ) in data .items () if key not in relationship_fields })
4949 self .apply_relationships (data , obj )
5050
@@ -131,10 +131,10 @@ def update_object(self, obj, data, view_kwargs):
131131
132132 self .before_update_object (obj , data , view_kwargs )
133133
134- relationship_fields = get_relationships (self .resource .schema )
134+ relationship_fields = get_relationships (self .resource .schema , model_field = True )
135135 for key , value in data .items ():
136- if hasattr (obj , get_model_field ( self . resource . schema , key ) ) and key not in relationship_fields :
137- setattr (obj , get_model_field ( self . resource . schema , key ) , value )
136+ if hasattr (obj , key ) and key not in relationship_fields :
137+ setattr (obj , key , value )
138138
139139 self .apply_relationships (data , obj )
140140
@@ -395,12 +395,12 @@ def apply_relationships(self, data, obj):
395395 :return boolean: True if relationship have changed else False
396396 """
397397 relationships_to_apply = []
398- relationship_fields = get_relationships (self .resource .schema )
398+ relationship_fields = get_relationships (self .resource .schema , model_field = True )
399399 for key , value in data .items ():
400400 if key in relationship_fields :
401- related_model = getattr (obj .__class__ ,
402- get_model_field (self .resource .schema , key )). property . mapper . class_
403- related_id_field = self .resource .schema ._declared_fields [key ].id_field
401+ related_model = getattr (obj .__class__ , key ). property . mapper . class_
402+ schema_field = get_schema_field (self .resource .schema , key )
403+ related_id_field = self .resource .schema ._declared_fields [schema_field ].id_field
404404
405405 if isinstance (value , list ):
406406 related_objects = []
@@ -419,7 +419,7 @@ def apply_relationships(self, data, obj):
419419 relationships_to_apply .append ({'field' : key , 'value' : related_object })
420420
421421 for relationship in relationships_to_apply :
422- setattr (obj , get_model_field ( self . resource . schema , relationship ['field' ]) , relationship ['value' ])
422+ setattr (obj , relationship ['field' ], relationship ['value' ])
423423
424424 def filter_query (self , query , filter_info , model ):
425425 """Filter query according to jsonapi 1.0
@@ -480,7 +480,10 @@ def eagerload_includes(self, query, qs):
480480 if '.' in include :
481481 current_schema = self .resource .schema
482482 for obj in include .split ('.' ):
483- field = get_model_field (current_schema , obj )
483+ try :
484+ field = get_model_field (current_schema , obj )
485+ except Exception as e :
486+ raise InvalidInclude (str (e ))
484487
485488 if joinload_object is None :
486489 joinload_object = joinedload (field , innerjoin = True )
@@ -496,7 +499,11 @@ def eagerload_includes(self, query, qs):
496499
497500 current_schema = related_schema_cls
498501 else :
499- field = get_model_field (self .resource .schema , include )
502+ try :
503+ field = get_model_field (self .resource .schema , include )
504+ except Exception as e :
505+ raise InvalidInclude (str (e ))
506+
500507 joinload_object = joinedload (field , innerjoin = True )
501508
502509 query = query .options (joinload_object )
0 commit comments