Skip to content

Commit 2e37034

Browse files
committed
bug fixes and fix tests
1 parent 31264d9 commit 2e37034

File tree

3 files changed

+75
-45
lines changed

3 files changed

+75
-45
lines changed

flask_rest_jsonapi/data_layers/alchemy.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
from flask import current_app
1313
from flask_rest_jsonapi.data_layers.base import BaseDataLayer
1414
from flask_rest_jsonapi.exceptions import RelationNotFound, RelatedObjectNotFound, JsonApiException,\
15-
InvalidSort, ObjectNotFound
15+
InvalidSort, ObjectNotFound, InvalidInclude
1616
from flask_rest_jsonapi.data_layers.filtering.alchemy import create_filters
17-
from flask_rest_jsonapi.schema import get_model_field, get_related_schema, get_relationships
17+
from flask_rest_jsonapi.schema import get_model_field, get_related_schema, get_relationships, get_schema_field
1818

1919

2020
class SqlalchemyDataLayer(BaseDataLayer):
@@ -43,8 +43,8 @@ def create_object(self, data, view_kwargs):
4343
"""
4444
self.before_create_object(data, view_kwargs)
4545

46-
relationship_fields = get_relationships(self.resource.schema)
47-
obj = self.model(**{get_model_field(self.resource.schema, key): value
46+
relationship_fields = get_relationships(self.resource.schema, model_field=True)
47+
obj = self.model(**{key: value
4848
for (key, value) in data.items() if key not in relationship_fields})
4949
self.apply_relationships(data, obj)
5050

@@ -131,10 +131,10 @@ def update_object(self, obj, data, view_kwargs):
131131

132132
self.before_update_object(obj, data, view_kwargs)
133133

134-
relationship_fields = get_relationships(self.resource.schema)
134+
relationship_fields = get_relationships(self.resource.schema, model_field=True)
135135
for key, value in data.items():
136-
if hasattr(obj, get_model_field(self.resource.schema, key)) and key not in relationship_fields:
137-
setattr(obj, get_model_field(self.resource.schema, key), value)
136+
if hasattr(obj, key) and key not in relationship_fields:
137+
setattr(obj, key, value)
138138

139139
self.apply_relationships(data, obj)
140140

@@ -395,12 +395,12 @@ def apply_relationships(self, data, obj):
395395
:return boolean: True if relationship have changed else False
396396
"""
397397
relationships_to_apply = []
398-
relationship_fields = get_relationships(self.resource.schema)
398+
relationship_fields = get_relationships(self.resource.schema, model_field=True)
399399
for key, value in data.items():
400400
if key in relationship_fields:
401-
related_model = getattr(obj.__class__,
402-
get_model_field(self.resource.schema, key)).property.mapper.class_
403-
related_id_field = self.resource.schema._declared_fields[key].id_field
401+
related_model = getattr(obj.__class__, key).property.mapper.class_
402+
schema_field = get_schema_field(self.resource.schema, key)
403+
related_id_field = self.resource.schema._declared_fields[schema_field].id_field
404404

405405
if isinstance(value, list):
406406
related_objects = []
@@ -419,7 +419,7 @@ def apply_relationships(self, data, obj):
419419
relationships_to_apply.append({'field': key, 'value': related_object})
420420

421421
for relationship in relationships_to_apply:
422-
setattr(obj, get_model_field(self.resource.schema, relationship['field']), relationship['value'])
422+
setattr(obj, relationship['field'], relationship['value'])
423423

424424
def filter_query(self, query, filter_info, model):
425425
"""Filter query according to jsonapi 1.0
@@ -480,7 +480,10 @@ def eagerload_includes(self, query, qs):
480480
if '.' in include:
481481
current_schema = self.resource.schema
482482
for obj in include.split('.'):
483-
field = get_model_field(current_schema, obj)
483+
try:
484+
field = get_model_field(current_schema, obj)
485+
except Exception as e:
486+
raise InvalidInclude(str(e))
484487

485488
if joinload_object is None:
486489
joinload_object = joinedload(field, innerjoin=True)
@@ -496,7 +499,11 @@ def eagerload_includes(self, query, qs):
496499

497500
current_schema = related_schema_cls
498501
else:
499-
field = get_model_field(self.resource.schema, include)
502+
try:
503+
field = get_model_field(self.resource.schema, include)
504+
except Exception as e:
505+
raise InvalidInclude(str(e))
506+
500507
joinload_object = joinedload(field, innerjoin=True)
501508

502509
query = query.options(joinload_object)

flask_rest_jsonapi/schema.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,18 @@ def get_model_field(schema, field):
9595
return field
9696

9797

98-
def get_relationships(schema):
98+
def get_relationships(schema, model_field=False):
9999
"""Return relationship fields of a schema
100100
101101
:param Schema schema: a marshmallow schema
102102
:param list: list of relationship fields of a schema
103103
"""
104-
return [key for (key, value) in schema._declared_fields.items() if isinstance(value, Relationship)]
104+
relationships = [key for (key, value) in schema._declared_fields.items() if isinstance(value, Relationship)]
105+
106+
if model_field is True:
107+
relationships = [get_model_field(schema, key) for key in relationships]
108+
109+
return relationships
105110

106111

107112
def get_related_schema(schema, field):
@@ -128,3 +133,18 @@ def get_schema_from_type(resource_type):
128133
pass
129134

130135
raise Exception("Couldn't find schema for type: {}".format(resource_type))
136+
137+
138+
def get_schema_field(schema, field):
139+
"""Get the schema field of a model field
140+
141+
:param Schema schema: a marshmallow schema
142+
:param str field: the name of the model field
143+
:return str: the name of the field in the schema
144+
"""
145+
schema_fields_to_model = {key: get_model_field(schema, key) for (key, value) in schema._declared_fields.items()}
146+
for key, value in schema_fields_to_model.items():
147+
if value == field:
148+
return key
149+
150+
raise Exception("Couldn't find schema field from {}".format(field))

tests/test_sqlalchemy_data_layer.py

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -328,10 +328,11 @@ def __init__(self, kwargs):
328328
return get_object
329329

330330

331-
def test_add_pagination_links():
332-
qs = {'page[number]': '15', 'page[size]': '10'}
333-
qsm = QSManager(qs, None)
334-
add_pagination_links(dict(), 1000, qsm, str())
331+
def test_add_pagination_links(app):
332+
with app.app_context():
333+
qs = {'page[number]': '15', 'page[size]': '10'}
334+
qsm = QSManager(qs, None)
335+
add_pagination_links(dict(), 1000, qsm, str())
335336

336337

337338
def test_Node(person_model, person_schema, monkeypatch):
@@ -390,33 +391,35 @@ def test_query_string_manager(person_schema):
390391
qsm.sorting
391392

392393

393-
def test_resource(person_model, person_schema, session, monkeypatch):
394+
def test_resource(app, person_model, person_schema, session, monkeypatch):
394395
def schema_load_mock(*args):
395396
raise ValidationError(dict(errors=[dict(status=None, title=None)]))
396-
query_string = {'page[slumber]': '3'}
397-
app = type('app', (object,), dict(config=dict(DEBUG=True)))
398-
headers = {'Content-Type': 'application/vnd.api+json'}
399-
request = type('request', (object,), dict(method='POST',
400-
headers=headers,
401-
get_json=dict,
402-
args=query_string))
403-
dl = SqlalchemyDataLayer(dict(session=session, model=person_model))
404-
rl = ResourceList()
405-
rd = ResourceDetail()
406-
rl._data_layer = dl
407-
rl.schema = person_schema
408-
rd._data_layer = dl
409-
rd.schema = person_schema
410-
monkeypatch.setattr(flask_rest_jsonapi.resource, 'request', request)
411-
monkeypatch.setattr(flask_rest_jsonapi.resource, 'current_app', app)
412-
monkeypatch.setattr(flask_rest_jsonapi.decorators, 'request', request)
413-
monkeypatch.setattr(rl.schema, 'load', schema_load_mock)
414-
r = super(flask_rest_jsonapi.resource.Resource, ResourceList)\
415-
.__new__(ResourceList)
416-
with pytest.raises(Exception):
417-
r.dispatch_request()
418-
rl.post()
419-
rd.patch()
397+
398+
with app.app_context():
399+
query_string = {'page[slumber]': '3'}
400+
app = type('app', (object,), dict(config=dict(DEBUG=True)))
401+
headers = {'Content-Type': 'application/vnd.api+json'}
402+
request = type('request', (object,), dict(method='POST',
403+
headers=headers,
404+
get_json=dict,
405+
args=query_string))
406+
dl = SqlalchemyDataLayer(dict(session=session, model=person_model))
407+
rl = ResourceList()
408+
rd = ResourceDetail()
409+
rl._data_layer = dl
410+
rl.schema = person_schema
411+
rd._data_layer = dl
412+
rd.schema = person_schema
413+
monkeypatch.setattr(flask_rest_jsonapi.resource, 'request', request)
414+
monkeypatch.setattr(flask_rest_jsonapi.resource, 'current_app', app)
415+
monkeypatch.setattr(flask_rest_jsonapi.decorators, 'request', request)
416+
monkeypatch.setattr(rl.schema, 'load', schema_load_mock)
417+
r = super(flask_rest_jsonapi.resource.Resource, ResourceList)\
418+
.__new__(ResourceList)
419+
with pytest.raises(Exception):
420+
r.dispatch_request()
421+
rl.post()
422+
rd.patch()
420423

421424

422425
def test_compute_schema(person_schema):

0 commit comments

Comments
 (0)