In this article, we will configure user authentication for your reflex app by doing everything locally. The user will be able to register, login, and logout.
This article is based on the local_auth example on reflex_examples GitHub page: https://github.com/reflex-dev/reflex-examples/tree/main/local_auth
Outline
- Create a new folder, open it with a code editor
- Create a virtual environment and activate
- Install requirements
- reflex setup
- local_auth.py
- auth_session.py
- base_state.py
- login.py
- registration.py
- user.py
- run app
- conclusion
Create a new folder, open it with a code editor
Create a new folder and name it local_auth
then open it with a code editor like VS Code.
Create a virtual environment and activate
Open the terminal. Use the following command to create a virtual environment .venv
and activate it:
python3 -m venv .venv
source .venv/bin/activate
Install requirements
We will install reflex
to build the app, passlib
to simplify the process of securely hashing and managing passwords, and bcrypt
- the specific hashing algorithm used to hash passwords securely.
Run the following command in the terminal:
pip install reflex==0.2.9 passlib==1.7.4 bcrypt==4.0.1
reflex setup
Now, we need to create the project using reflex. Run the following command to initialize the template app in local_auth
directory.
reflex init
local_auth.py
We will build the homepage of the app. Go to the local_auth
subdirectory and open the local_auth.py
file. Add the following code to it:
"""Main app module to demo local authentication."""
import reflex as rx
from .base_state import State
from .login import require_login
from .registration import registration_page as registration_page
def index() -> rx.Component:
"""Render the index page.
Returns:
A reflex component.
"""
return rx.fragment(
rx.color_mode_button(rx.color_mode_icon(), float="right"),
rx.vstack(
rx.heading("Welcome to my homepage!", font_size="2em"),
rx.link("Protected Page", href="/protected"),
spacing="1.5em",
padding_top="10%",
),
)
@require_login
def protected() -> rx.Component:
"""Render a protected page.
The `require_login` decorator will redirect to the login page if the user is
not authenticated.
Returns:
A reflex component.
"""
return rx.vstack(
rx.heading(
"Protected Page for ", State.authenticated_user.username, font_size="2em"
),
rx.link("Home", href="/"),
rx.link("Logout", href="/", on_click=State.do_logout),
)
app = rx.App()
app.add_page(index)
app.add_page(protected)
app.compile()
index()
: This function defines the behavior for rendering the application's index page. It returns a Reflex component, representing part of the web page's user interface. The index page includes a color mode button, a greeting message, a link to the protected page, and styling for spacing and padding.
protected()
: This function is responsible for rendering a protected page that requires user authentication to access. It is decorated with @require_login
, ensuring that only authenticated users can view this page. The protected page displays a greeting message personalized for the authenticated user, links to the home page, and provides an option to log out.
The above code renders the following page:
auth_session.py
Create a new file auth_session.py
in the local_auth
subdirectory and add the following code.
import datetime
from sqlmodel import Column, DateTime, Field, func
import reflex as rx
class AuthSession(
rx.Model,
table=True, # type: ignore
):
"""Correlate a session_id with an arbitrary user_id."""
user_id: int = Field(index=True, nullable=False)
session_id: str = Field(unique=True, index=True, nullable=False)
expiration: datetime.datetime = Field(
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
nullable=False,
)
In this code, an AuthSession
class is defined. The purpose of this class is to manage authentication sessions.
The user_id
attribute is defined as an integer field, marked as indexed, and non-nullable. It is intended to store the user ID associated with a session.
The session_id
attribute is a string field, marked as unique, indexed, and non-nullable. This field ensures that each session has a distinct identifier.
The expiration attribute is of type datetime.datetime. It is defined using the Field class, associated with a SQL model column represented by a Column instance with a datetime type and timezone set to True. The server_default parameter is set to func.now(), which means the default value will be the current time when a new session is created. This field is also non-nullable, ensuring that each session has an expiration time.
base_state.py
Create a new file base_state.py
in the local_auth
subdirectory and add the following code:
import datetime
from sqlmodel import select
import reflex as rx
from .auth_session import AuthSession
from .user import User
AUTH_TOKEN_LOCAL_STORAGE_KEY = "_auth_tokens"
DEFAULT_AUTH_SESSION_EXPIRATION_DELTA = datetime.timedelta(days=7)
class State(rx.State):
# The auth_token is stored in local storage to persist across tab and browser sessions.
auth_token: str = rx.LocalStorage(name=AUTH_TOKEN_LOCAL_STORAGE_KEY)
@rx.cached_var
def authenticated_user(self) -> User:
"""The currently authenticated user, or a dummy user if not authenticated.
Returns:
A User instance with id=-1 if not authenticated, or the User instance
corresponding to the currently authenticated user.
"""
with rx.session() as session:
result = session.exec(
select(User, AuthSession).where(
AuthSession.session_id == self.auth_token,
AuthSession.expiration
>= datetime.datetime.now(datetime.timezone.utc),
User.id == AuthSession.user_id,
),
).first()
if result:
user, session = result
return user
return User(id=-1) # type: ignore
@rx.cached_var
def is_authenticated(self) -> bool:
"""Whether the current user is authenticated.
Returns:
True if the authenticated user has a positive user ID, False otherwise.
"""
return self.authenticated_user.id >= 0
def do_logout(self) -> None:
"""Destroy AuthSessions associated with the auth_token."""
with rx.session() as session:
for auth_session in session.exec(
AuthSession.select.where(AuthSession.session_id == self.auth_token)
).all():
session.delete(auth_session)
session.commit()
self.auth_token = self.auth_token
def _login(
self,
user_id: int,
expiration_delta: datetime.timedelta = DEFAULT_AUTH_SESSION_EXPIRATION_DELTA,
) -> None:
"""Create an AuthSession for the given user_id.
If the auth_token is already associated with an AuthSession, it will be
logged out first.
Args:
user_id: The user ID to associate with the AuthSession.
expiration_delta: The amount of time before the AuthSession expires.
"""
if self.is_authenticated:
self.do_logout()
if user_id < 0:
return
self.auth_token = self.auth_token or self.get_token()
with rx.session() as session:
session.add(
AuthSession( # type: ignore
user_id=user_id,
session_id=self.auth_token,
expiration=datetime.datetime.now(datetime.timezone.utc)
+ expiration_delta,
)
)
session.commit()
The above code defines a class called State
, which extends the rx.State
class. It includes several functions and properties related to authentication and user sessions.
auth_token
: This property stores the authentication token in the local storage to persist it across different browser sessions.
authenticated_user(self)
: This function returns the currently authenticated user or a dummy user if not authenticated. It uses the rx.session()
context manager to execute a SQL query that selects a User and AuthSession where the AuthSession matches the auth_token, has not expired, and corresponds to a user. If a result is found, it returns the user; otherwise, it returns a dummy user.
is_authenticated(self)
: This function returns a boolean indicating whether the current user is authenticated. It checks if the user's ID is greater than or equal to 0, and returns True if authenticated or False if not.
do_logout(self)
: This function is used to destroy AuthSessions associated with the auth_token. It begins a session and deletes all AuthSessions with a matching session_id, effectively logging the user out.
_login(self, user_id, expiration_delta)
: This is a private method used to create an AuthSession for a given user. If the user is already authenticated, it calls do_logout() to log out the current user. It then creates a new AuthSession with the provided user ID and sets an expiration time based on the expiration_delta (defaulting to 7 days). This new session is associated with the auth_token, which is generated if it doesn't exist. The new AuthSession is added to the database and committed within a session context.
login.py
Create a new file login.py
in the local_auth
subdirectory and add the following code:
"""Login page and authentication logic."""
import reflex as rx
from .base_state import State
from .user import User
LOGIN_ROUTE = "/login"
REGISTER_ROUTE = "/register"
class LoginState(State):
"""Handle login form submission and redirect to proper routes after authentication."""
error_message: str = ""
redirect_to: str = ""
def on_submit(self, form_data) -> rx.event.EventSpec:
"""Handle login form on_submit.
Args:
form_data: A dict of form fields and values.
"""
self.error_message = ""
username = form_data["username"]
password = form_data["password"]
with rx.session() as session:
user = session.exec(
User.select.where(User.username == username)
).one_or_none()
if user is not None and not user.enabled:
self.error_message = "This account is disabled."
return rx.set_value("password", "")
if user is None or not user.verify(password):
self.error_message = "There was a problem logging in, please try again."
return rx.set_value("password", "")
if (
user is not None
and user.id is not None
and user.enabled
and user.verify(password)
):
# mark the user as logged in
self._login(user.id)
self.error_message = ""
return LoginState.redir() # type: ignore
def redir(self) -> rx.event.EventSpec | None:
"""Redirect to the redirect_to route if logged in, or to the login page if not."""
if not self.is_hydrated:
# wait until after hydration to ensure auth_token is known
return LoginState.redir() # type: ignore
page = self.get_current_page()
if not self.is_authenticated and page != LOGIN_ROUTE:
self.redirect_to = page
return rx.redirect(LOGIN_ROUTE)
elif page == LOGIN_ROUTE:
return rx.redirect(self.redirect_to or "/")
@rx.page(route=LOGIN_ROUTE)
def login_page() -> rx.Component:
"""Render the login page.
Returns:
A reflex component.
"""
login_form = rx.form(
rx.input(placeholder="username", id="username"),
rx.password(placeholder="password", id="password"),
rx.button("Login", type_="submit"),
width="80vw",
on_submit=LoginState.on_submit,
)
return rx.fragment(
rx.cond(
LoginState.is_hydrated, # type: ignore
rx.vstack(
rx.cond( # conditionally show error messages
LoginState.error_message != "",
rx.text(LoginState.error_message),
),
login_form,
rx.link("Register", href=REGISTER_ROUTE),
padding_top="10vh",
),
)
)
def require_login(page: rx.app.ComponentCallable) -> rx.app.ComponentCallable:
"""Decorator to require authentication before rendering a page.
If the user is not authenticated, then redirect to the login page.
Args:
page: The page to wrap.
Returns:
The wrapped page component.
"""
def protected_page():
return rx.fragment(
rx.cond(
State.is_hydrated & State.is_authenticated, # type: ignore
page(),
rx.center(
# When this spinner mounts, it will redirect to the login page
rx.spinner(on_mount=LoginState.redir),
),
)
)
protected_page.__name__ = page.__name__
return protected_page
The above code defines a login page and authentication logic. It includes a LoginState
class that handles login form submissions, checks user credentials, and manages redirection upon successful login. Additionally, it provides a decorator function called require_login
to protect certain pages, ensuring they can only be accessed by authenticated users and redirecting unauthenticated users to the login page.
The above code renders the following page:
registration.py
Create a new file registration.py
in the local_auth
subdirectory and add the following code:
"""New user registration form and validation logic."""
from __future__ import annotations
import asyncio
from collections.abc import AsyncGenerator
import reflex as rx
from .base_state import State
from .login import LOGIN_ROUTE, REGISTER_ROUTE
from .user import User
class RegistrationState(State):
"""Handle registration form submission and redirect to login page after registration."""
success: bool = False
error_message: str = ""
async def handle_registration(
self, form_data
) -> AsyncGenerator[rx.event.EventSpec | list[rx.event.EventSpec] | None, None]:
"""Handle registration form on_submit.
Set error_message appropriately based on validation results.
Args:
form_data: A dict of form fields and values.
"""
with rx.session() as session:
username = form_data["username"]
if not username:
self.error_message = "Username cannot be empty"
yield rx.set_focus("username")
return
existing_user = session.exec(
User.select.where(User.username == username)
).one_or_none()
if existing_user is not None:
self.error_message = (
f"Username {username} is already registered. Try a different name"
)
yield [rx.set_value("username", ""), rx.set_focus("username")]
return
password = form_data["password"]
if not password:
self.error_message = "Password cannot be empty"
yield rx.set_focus("password")
return
if password != form_data["confirm_password"]:
self.error_message = "Passwords do not match"
yield [
rx.set_value("confirm_password", ""),
rx.set_focus("confirm_password"),
]
return
# Create the new user and add it to the database.
new_user = User() # type: ignore
new_user.username = username
new_user.password_hash = User.hash_password(password)
new_user.enabled = True
session.add(new_user)
session.commit()
# Set success and redirect to login page after a brief delay.
self.error_message = ""
self.success = True
yield
await asyncio.sleep(0.5)
yield [rx.redirect(LOGIN_ROUTE), RegistrationState.set_success(False)]
@rx.page(route=REGISTER_ROUTE)
def registration_page() -> rx.Component:
"""Render the registration page.
Returns:
A reflex component.
"""
register_form = rx.form(
rx.input(placeholder="username", id="username"),
rx.password(placeholder="password", id="password"),
rx.password(placeholder="confirm", id="confirm_password"),
rx.button("Register", type_="submit"),
width="80vw",
on_submit=RegistrationState.handle_registration,
)
return rx.fragment(
rx.cond(
RegistrationState.success,
rx.vstack(
rx.text("Registration successful!"),
rx.spinner(),
),
rx.vstack(
rx.cond( # conditionally show error messages
RegistrationState.error_message != "",
rx.text(RegistrationState.error_message),
),
register_form,
padding_top="10vh",
),
)
)
The above code is responsible for creating a user registration form and implementing the validation logic for user registration.
RegistrationState
class handles user registration. It includes properties for tracking the success of the registration and error messages.
The handle_registration
method asynchronously processes the registration form submission. It validates the provided username and password, checks for existing usernames in the database, ensures the passwords match, and then creates a new user and adds it to the database if all checks pass. After a brief delay, it sets the success flag and redirects to the login page.
The registration_page
function is a reflex page that renders the user registration form. It includes form fields for the username, password, and password confirmation, as well as a registration button. The form submission is handled by the handle_registration
method from RegistrationState
.
The page dynamically displays different content based on the registration's success or failure. If registration is successful, it displays a success message and a spinner. If there are validation errors or the registration has not yet succeeded, it shows error messages, the registration form, and adds some padding for spacing.
The above code renders the following page:
user.py
Create a new file user.py
in the local_auth
subdirectory and add the following code:
from passlib.context import CryptContext
from sqlmodel import Field
import reflex as rx
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
class User(
rx.Model,
table=True, # type: ignore
):
"""A local User model with bcrypt password hashing."""
username: str = Field(unique=True, nullable=False, index=True)
password_hash: str = Field(nullable=False)
enabled: bool = False
@staticmethod
def hash_password(secret: str) -> str:
"""Hash the secret using bcrypt.
Args:
secret: The password to hash.
Returns:
The hashed password.
"""
return pwd_context.hash(secret)
def verify(self, secret: str) -> bool:
"""Validate the user's password.
Args:
secret: The password to check.
Returns:
True if the hashed secret matches this user's password_hash.
"""
return pwd_context.verify(
secret,
self.password_hash,
)
The above code defines a User class for managing user data. It uses the bcrypt algorithm for securely hashing and verifying user passwords. The User class has fields for usernames, password hashes, and an "enabled" status, along with methods for hashing and verifying passwords using bcrypt. This code provides a foundation for securely handling user authentication and password storage in the application.
run app
Run the following commands in the terminal to initialize alembic and create a migration script with the current schema, to generate a script in the alembic/versions directory that will update the database schema and apply migration scripts to bring the database up to date respectively:
reflex db init
reflex db makemigrations --message 'something changed'
reflex db migrate
to start the app run the following:
reflex run
You should see an interface as follows when you go to http://localhost:3000/
When you click on the protected page link, it takes you to the login page. From there, you can either log in or register. If login is successful then you will be able to access the protected page and from there you can logout.
conclusion
You can access the code from reflex local_auth example repo: https://github.com/reflex-dev/reflex-examples/tree/main/local_auth
Top comments (0)