Architecture

Generic REST API with Flask Class-Based Views

Using Flask Pluggable Views to create a flexible and extensible API
Jonathan Machado avatar picture

By Jonathan Machado

December 29, 2020


Flask 0.7 introduces pluggable views inspired by the generic views from Django which are based on classes instead of functions. The main intention is that you can replace parts of the implementations and this way have customizable pluggable views.

In this article, we're exploring a potential use case for Flask Pluggable Views (aka. Class-Based Views) and its advantages over normal view functions.

The benefits of using classes might not be obvious when working on small projects, but as your application grows you may find that classes scale better than view functions.

For our example, let's imagine we're creating an application similar to Medium.com, where users can create and publish Stories, Publications, Bookmark content and etc.

When creating an API that contains a large number of models, we will often start repeating code, for example: fetch a particular object, fetch a list of objects, insert, delete and update. All this logic is often repeated for each endpoint we create.

Also, if we have to create a new model/entity, we will usually find ourselves copying and pasting all the methods mentioned above and updating the queries with the new model.

Flask Pluggable Views can come very handy in this scenario.

This is what the view look like:

# Flask pluggable view
class UserAPI(MethodView):

    def get(self, user_id):
        if user_id is None:
            # return a list of users
            pass
        else:
            # expose a single user
            pass

    def post(self):
        # create a new user
        pass

    def delete(self, user_id):
        # delete a single user
        pass

    def put(self, user_id):
        # update a single user
        pass

Getting started

For our example, let's suppose we have quite a number of models that we would like to expose via our API:

  • Publications
  • Stories
  • StoryCategories
  • UserBookmarks
  • Podcasts
  • EditorsChoice
  • Sections
  • ContentSections

Some of our APIs are going to be public (List and Read), some will require the user to be authenticated, and others will be restricted to admin users.

With that in mind, I have decide to use 3 different base views that we will be able to inherit from:

  • PublicListReadView - List all entries or fetch a particular entry.
  • UserAuthCUDView - Allow authenticated users to create, update and delete.
  • AdminCUDView - Allow administrators to create, update and delete.

PS: The CUD part stands for Create, Update and Delete. I didn't spend too much time thinking about the very best name for them, so you might probably come up with something a bit clearer.

The final code for this project can be found here: https://github.com/jonathanmach/flask-class-based-views

Quick recipe for creating our project environment:

python3 -m venv venv - Creates Python virtual environment

. venv/bin/activate - Activates the virtual environment

pip install Flask - Installs Flask

For this example, we will use Flask SQL Alchemy as our ORM:

pip install Flask-SQLAlchemy

And lastly, we will use Marshmallow to leverage object serialization:

pip install flask-marshmallow

Defining models

We're going to start by defining all our models. I'm creating very basic models mainly with an id and title fields, since we're not interested in data modelling itself, but in the class-based views 😄

"""
    Models and Schemas
"""

class Publications(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    title = db.Column(db.String, nullable=False)

class PublicationsSchema(ma.SQLAlchemyAutoSchema):
    class Meta:
        model = Publications

class Stories(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    title = db.Column(db.String, nullable=False)
    content = db.Column(db.Text, nullable=True)

class StoriesSchema(ma.SQLAlchemyAutoSchema):
    class Meta:
        model = Stories

class StoryCategories(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    title = db.Column(db.String, nullable=False)

class StoryCategoriesSchema(ma.SQLAlchemyAutoSchema):
    class Meta:
        model = StoryCategories

class UserBookmarks(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    # title = db.Column(db.String, nullable=False)

class UserBookmarksSchema(ma.SQLAlchemyAutoSchema):
    class Meta:
        model = UserBookmarks

class Podcasts(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    title = db.Column(db.String, nullable=False)

class PodcastsSchema(ma.SQLAlchemyAutoSchema):
    class Meta:
        model = Podcasts

class EditorsChoice(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    # title = db.Column(db.String, nullable=False)

class EditorsChoiceSchema(ma.SQLAlchemyAutoSchema):
    class Meta:
        model = EditorsChoice

class Sections(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    # title = db.Column(db.String, nullable=False)

class SectionsSchema(ma.SQLAlchemyAutoSchema):
    class Meta:
        model = Sections

class ContentSections(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    title = db.Column(db.String, nullable=False)

class ContentSectionsSchema(ma.SQLAlchemyAutoSchema):
    class Meta:
        model = ContentSections

Creating our Base Views

As mentioned earlier, we're going to have 3 different base views:

  • PublicListReadView - List all entries or fetch a particular entry.
  • UserAuthCUDView - Allow authenticated users to create, update and delete.
  • AdminCUDView - Allow administrators to create, update and delete.

We're going to check if the user is authenticated or is an administrator using Python decorators.

Now, let's write our initial base views:

class BaseView(MethodView):
    model = None
    schema = None

    def __init__(self):
        """
        If schema is not overridden when inheriting from this class, we'll try to find its respective schema class using the 'Schema' suffix.
        Ex: Stories -> StoriesSchema
        """
        if self.schema is None:
            _cls_name = f"{self.model.__name__}Schema"
            self.schema = globals()[_cls_name]

class PublicListReadView(BaseView):
    def get(self, entry_id):
                schema = self.schema()

        # Query database and return data
        if entry_id is None:
            # Return list of all entries
            result = self.model.query.all()
            return jsonify(schema.dump(result, many=True))
        else:
            # Return a single object
            result = self.model.query.get(entry_id)
            return jsonify(schema.dump(result))

We will later override the model attribute when subclassing from this view.

It basically will return all entries if we don't specify the entry_id, or a single object if we do.

For example, this is how we would create a view based on our PublicListReadView class:

# Example: inheriting from PublicListReadView
class UsersAPI(PublicListReadView):
    model = Users

And after registering it, we would have the following endpoint:

/users/ - would return all the entries

/users/123 - would return a single entry with id 123.

Cool, now to our next base view: UserAuthCUDView.

Since we want to add a decorator to it, let's declare our decorator before creating our class:

# An example decorator
def user_required(f):
    """Checks whether user is logged in or raises error 401."""
    def decorator(*args, **kwargs):
                # TODO: NotImplemented
        print("user_required decorator triggered!")
        return f(*args, **kwargs)

    return decorator

class UserAuthCUDView(BaseView):
    decorators = [user_required]

    def post(self):
        # create a new entry
        new_entry = self.model(**request.get_json())
        db.session.add(new_entry)
        db.session.commit()
        schema = self.schema()
        return jsonify(schema.dump(new_entry))

    def delete(self, entry_id):
        # delete a single entry
        entry = self.model.query.get(entry_id)
        db.session.delete(entry)
        db.session.commit()
        return {"msg": "Entry successfully deleted."}, 200

    def put(self, entry_id):
        # update a single entry
        raise NotImplementedError

There's our view that handles POST, DELETE and PUT requests. Of course, we could implement additional validations like allowing users to only delete or update entries they created.

Now, for our AdminCUDView, the main difference between the previous class is the decorator we are going to use. In this case, we can inherit from UserAuthCUDView itself and override the decorators attribute, passing a different decorator that will be responsible for checking whether the user has admin permissions or not:

# Admin decorator example
def admin_required(f):
    """Checks whether user is admin."""

    def decorator(*args, **kwargs):
        print("admin_required decorator triggered!")
        return f(*args, **kwargs)

    return decorator

class AdminCrudView(UserAuthCUDView):
    decorators = [admin_required]

That was easy and clean!

Now, we have to choose which base views our actual views are going to inherit from, taking into consideration what kind of restrictions we want: (Notice that some of them inherit from UserAuthCUDView and others from AdminCUDView)

"""
    APIs
"""

class PublicationsAPI(UserAuthCUDView, PublicListReadView):
    model = Publications

class StoriesAPI(UserAuthCUDView, PublicListReadView):
    model = Stories

class UserBookmarksAPI(UserAuthCUDView, PublicListReadView):
    model = UserBookmarks

class PodcastsAPI(AdminCUDView, PublicListReadView):
    model = Podcasts

class StoryCategoriesAPI(UserAuthCUDView, PublicListReadView):
    model = StoryCategories

class EditorsChoiceAPI(AdminCUDView, PublicListReadView):
    model = EditorsChoice

class SectionsAPI(AdminCUDView, PublicListReadView):
    model = Sections

class ContentSectionsAPI(AdminCUDView, PublicListReadView):
    model = ContentSections

That's it. The only think left to do is to register our views:

def register_api(view, endpoint, url):
    view_func = view.as_view(endpoint)
    app.add_url_rule(url, defaults={'entry_id': None}, view_func=view_func, methods=['GET',])
    app.add_url_rule(url, view_func=view_func, methods=['POST',])
    # Creates a rule like: '/podcasts/<int:entry_id>'
    app.add_url_rule(f'{url}<int:entry_id>', view_func=view_func, methods=['GET', 'PUT', 'DELETE'])

# Register APIs

register_api(PublicationsAPI, 'publications_api', '/publications/')
register_api(PodcastsAPI, 'podcasts_api', '/podcasts/')
register_api(StoriesAPI, 'stories_api', '/stories/')
register_api(UserBookmarksAPI, 'user_bookmarks_api', '/bookmarks/')
register_api(StoryCategoriesAPI, 'story_categories_api', '/categories/')
register_api(EditorsChoiceAPI, 'editors_choice_api', '/editors-choice/')
register_api(SectionsAPI, 'sections_api', '/sections/')
register_api(ContentSectionsAPI, 'content_sections_api', '/content-sections/')

All our endpoints should be functional right now!

Benefits of class-based views

Now, since we have a generic API, it's easy to implement additional features in one place and this will be reflected to all our endpoints.

Let's add the possibility to define which fields we want the API to return.

Imagine that our client only needs the id and title from the /stories/ endpoint, in order to show a list of stories on the home page. It doesn't need the content at this point.

How would we go about implementing this?

The plan is to defined the desired fields as url parameters in our request:

/stories/?fields=title,id - will return only the id and title fields, instead of all fields.

To implement this feature for all our endpoints, let's make the following changes to our PublicListReadView view:

We are using Marshmallow only= option when instantiating the schema to pass the list of fields we want:

class PublicListReadView(BaseView):
    def get(self, entry_id):
        # Handle any '?fields=' params received in the request - Ex: /stories/?fields=title,id
        fields = request.args.get("fields")
        if fields:
            schema = self.schema(only=fields.split(","))
        else:
            schema = self.schema()

        # Query database and return data
        if entry_id is None:
            # Return list of all entries
            result = self.model.query.all()
            return jsonify(schema.dump(result, many=True))
        else:
            # Return a single object
            result = self.model.query.get(entry_id)
            return jsonify(schema.dump(result))

And that's all. Now this feature will be available on all our endpoints!

If we needed to create a different logic for a specific endpoint, we could just override the method we want and implement the new logic.

Additional resources

https://flask.palletsprojects.com/en/1.1.x/views/


Stay in touch

Receive updates about new articles as well as curated and helpful content for web devs. No spam, unsubscribe at any time.