Generic REST API with Flask Class-Based Views
By Jonathan Machado
December 29, 2020
Flask 0.7 introduces pluggable views inspired by the generic views from Django which are based on classes instead of functions. The main intention is that you can replace parts of the implementations and this way have customizable pluggable views.
In this article, we're exploring a potential use case for Flask Pluggable Views (aka. Class-Based Views) and its advantages over normal view functions.
The benefits of using classes might not be obvious when working on small projects, but as your application grows you may find that classes scale better than view functions.
For our example, let's imagine we're creating an application similar to Medium.com, where users can create and publish Stories, Publications, Bookmark content and etc.
When creating an API that contains a large number of models, we will often start repeating code, for example: fetch a particular object, fetch a list of objects, insert, delete and update. All this logic is often repeated for each endpoint we create.
Also, if we have to create a new model/entity, we will usually find ourselves copying and pasting all the methods mentioned above and updating the queries with the new model.
Flask Pluggable Views can come very handy in this scenario.
This is what the view look like:
# Flask pluggable view
class UserAPI(MethodView):
def get(self, user_id):
if user_id is None:
# return a list of users
pass
else:
# expose a single user
pass
def post(self):
# create a new user
pass
def delete(self, user_id):
# delete a single user
pass
def put(self, user_id):
# update a single user
pass
Getting started
For our example, let's suppose we have quite a number of models that we would like to expose via our API:
- Publications
- Stories
- StoryCategories
- UserBookmarks
- Podcasts
- EditorsChoice
- Sections
- ContentSections
Some of our APIs are going to be public (List and Read), some will require the user to be authenticated, and others will be restricted to admin users.
With that in mind, I have decide to use 3 different base views that we will be able to inherit from:
- PublicListReadView - List all entries or fetch a particular entry.
- UserAuthCUDView - Allow authenticated users to create, update and delete.
- AdminCUDView - Allow administrators to create, update and delete.
PS: The CUD part stands for Create, Update and Delete. I didn't spend too much time thinking about the very best name for them, so you might probably come up with something a bit clearer.
The final code for this project can be found here: https://github.com/jonathanmach/flask-class-based-views
Quick recipe for creating our project environment:
python3 -m venv venv
- Creates Python virtual environment
. venv/bin/activate
- Activates the virtual environment
pip install Flask
- Installs Flask
For this example, we will use Flask SQL Alchemy as our ORM:
pip install Flask-SQLAlchemy
And lastly, we will use Marshmallow to leverage object serialization:
pip install flask-marshmallow
Defining models
We're going to start by defining all our models. I'm creating very basic models mainly with an id
and title
fields, since we're not interested in data modelling itself, but in the class-based views 😄
"""
Models and Schemas
"""
class Publications(db.Model):
id = db.Column(db.Integer, primary_key=True)
title = db.Column(db.String, nullable=False)
class PublicationsSchema(ma.SQLAlchemyAutoSchema):
class Meta:
model = Publications
class Stories(db.Model):
id = db.Column(db.Integer, primary_key=True)
title = db.Column(db.String, nullable=False)
content = db.Column(db.Text, nullable=True)
class StoriesSchema(ma.SQLAlchemyAutoSchema):
class Meta:
model = Stories
class StoryCategories(db.Model):
id = db.Column(db.Integer, primary_key=True)
title = db.Column(db.String, nullable=False)
class StoryCategoriesSchema(ma.SQLAlchemyAutoSchema):
class Meta:
model = StoryCategories
class UserBookmarks(db.Model):
id = db.Column(db.Integer, primary_key=True)
# title = db.Column(db.String, nullable=False)
class UserBookmarksSchema(ma.SQLAlchemyAutoSchema):
class Meta:
model = UserBookmarks
class Podcasts(db.Model):
id = db.Column(db.Integer, primary_key=True)
title = db.Column(db.String, nullable=False)
class PodcastsSchema(ma.SQLAlchemyAutoSchema):
class Meta:
model = Podcasts
class EditorsChoice(db.Model):
id = db.Column(db.Integer, primary_key=True)
# title = db.Column(db.String, nullable=False)
class EditorsChoiceSchema(ma.SQLAlchemyAutoSchema):
class Meta:
model = EditorsChoice
class Sections(db.Model):
id = db.Column(db.Integer, primary_key=True)
# title = db.Column(db.String, nullable=False)
class SectionsSchema(ma.SQLAlchemyAutoSchema):
class Meta:
model = Sections
class ContentSections(db.Model):
id = db.Column(db.Integer, primary_key=True)
title = db.Column(db.String, nullable=False)
class ContentSectionsSchema(ma.SQLAlchemyAutoSchema):
class Meta:
model = ContentSections
Creating our Base Views
As mentioned earlier, we're going to have 3 different base views:
- PublicListReadView - List all entries or fetch a particular entry.
- UserAuthCUDView - Allow authenticated users to create, update and delete.
- AdminCUDView - Allow administrators to create, update and delete.
We're going to check if the user is authenticated or is an administrator using Python decorators.
Now, let's write our initial base views:
class BaseView(MethodView):
model = None
schema = None
def __init__(self):
"""
If schema is not overridden when inheriting from this class, we'll try to find its respective schema class using the 'Schema' suffix.
Ex: Stories -> StoriesSchema
"""
if self.schema is None:
_cls_name = f"{self.model.__name__}Schema"
self.schema = globals()[_cls_name]
class PublicListReadView(BaseView):
def get(self, entry_id):
schema = self.schema()
# Query database and return data
if entry_id is None:
# Return list of all entries
result = self.model.query.all()
return jsonify(schema.dump(result, many=True))
else:
# Return a single object
result = self.model.query.get(entry_id)
return jsonify(schema.dump(result))
We will later override the model
attribute when subclassing from this view.
It basically will return all entries if we don't specify the entry_id, or a single object if we do.
For example, this is how we would create a view based on our PublicListReadView
class:
# Example: inheriting from PublicListReadView
class UsersAPI(PublicListReadView):
model = Users
And after registering it, we would have the following endpoint:
/users/
- would return all the entries
/users/123
- would return a single entry with id 123
.
Cool, now to our next base view: UserAuthCUDView.
Since we want to add a decorator to it, let's declare our decorator before creating our class:
# An example decorator
def user_required(f):
"""Checks whether user is logged in or raises error 401."""
def decorator(*args, **kwargs):
# TODO: NotImplemented
print("user_required decorator triggered!")
return f(*args, **kwargs)
return decorator
class UserAuthCUDView(BaseView):
decorators = [user_required]
def post(self):
# create a new entry
new_entry = self.model(**request.get_json())
db.session.add(new_entry)
db.session.commit()
schema = self.schema()
return jsonify(schema.dump(new_entry))
def delete(self, entry_id):
# delete a single entry
entry = self.model.query.get(entry_id)
db.session.delete(entry)
db.session.commit()
return {"msg": "Entry successfully deleted."}, 200
def put(self, entry_id):
# update a single entry
raise NotImplementedError
There's our view that handles POST, DELETE and PUT requests. Of course, we could implement additional validations like allowing users to only delete or update entries they created.
Now, for our AdminCUDView, the main difference between the previous class is the decorator we are going to use. In this case, we can inherit from UserAuthCUDView
itself and override the decorators
attribute, passing a different decorator that will be responsible for checking whether the user has admin permissions or not:
# Admin decorator example
def admin_required(f):
"""Checks whether user is admin."""
def decorator(*args, **kwargs):
print("admin_required decorator triggered!")
return f(*args, **kwargs)
return decorator
class AdminCrudView(UserAuthCUDView):
decorators = [admin_required]
That was easy and clean!
Now, we have to choose which base views our actual views are going to inherit from, taking into consideration what kind of restrictions we want: (Notice that some of them inherit from UserAuthCUDView and others from AdminCUDView)
"""
APIs
"""
class PublicationsAPI(UserAuthCUDView, PublicListReadView):
model = Publications
class StoriesAPI(UserAuthCUDView, PublicListReadView):
model = Stories
class UserBookmarksAPI(UserAuthCUDView, PublicListReadView):
model = UserBookmarks
class PodcastsAPI(AdminCUDView, PublicListReadView):
model = Podcasts
class StoryCategoriesAPI(UserAuthCUDView, PublicListReadView):
model = StoryCategories
class EditorsChoiceAPI(AdminCUDView, PublicListReadView):
model = EditorsChoice
class SectionsAPI(AdminCUDView, PublicListReadView):
model = Sections
class ContentSectionsAPI(AdminCUDView, PublicListReadView):
model = ContentSections
That's it. The only think left to do is to register our views:
def register_api(view, endpoint, url):
view_func = view.as_view(endpoint)
app.add_url_rule(url, defaults={'entry_id': None}, view_func=view_func, methods=['GET',])
app.add_url_rule(url, view_func=view_func, methods=['POST',])
# Creates a rule like: '/podcasts/<int:entry_id>'
app.add_url_rule(f'{url}<int:entry_id>', view_func=view_func, methods=['GET', 'PUT', 'DELETE'])
# Register APIs
register_api(PublicationsAPI, 'publications_api', '/publications/')
register_api(PodcastsAPI, 'podcasts_api', '/podcasts/')
register_api(StoriesAPI, 'stories_api', '/stories/')
register_api(UserBookmarksAPI, 'user_bookmarks_api', '/bookmarks/')
register_api(StoryCategoriesAPI, 'story_categories_api', '/categories/')
register_api(EditorsChoiceAPI, 'editors_choice_api', '/editors-choice/')
register_api(SectionsAPI, 'sections_api', '/sections/')
register_api(ContentSectionsAPI, 'content_sections_api', '/content-sections/')
All our endpoints should be functional right now!
Benefits of class-based views
Now, since we have a generic API, it's easy to implement additional features in one place and this will be reflected to all our endpoints.
Let's add the possibility to define which fields we want the API to return.
Imagine that our client only needs the id
and title
from the /stories/ endpoint, in order to show a list of stories on the home page. It doesn't need the content at this point.
How would we go about implementing this?
The plan is to defined the desired fields as url parameters in our request:
/stories/?fields=title,id
- will return only the id
and title
fields, instead of all fields.
To implement this feature for all our endpoints, let's make the following changes to our PublicListReadView view:
We are using Marshmallow only=
option when instantiating the schema to pass the list of fields we want:
class PublicListReadView(BaseView):
def get(self, entry_id):
# Handle any '?fields=' params received in the request - Ex: /stories/?fields=title,id
fields = request.args.get("fields")
if fields:
schema = self.schema(only=fields.split(","))
else:
schema = self.schema()
# Query database and return data
if entry_id is None:
# Return list of all entries
result = self.model.query.all()
return jsonify(schema.dump(result, many=True))
else:
# Return a single object
result = self.model.query.get(entry_id)
return jsonify(schema.dump(result))
And that's all. Now this feature will be available on all our endpoints!
If we needed to create a different logic for a specific endpoint, we could just override the method we want and implement the new logic.