Add OAuth refresh token support

This commit is contained in:
Thomas Sileo 2022-12-18 12:55:24 +01:00
parent 3fb36d6119
commit ed214cf0e7
4 changed files with 94 additions and 17 deletions

View file

@ -0,0 +1,36 @@
"""Add OAuth refresh token support
Revision ID: a209f0333f5a
Revises: 4ab54becec04
Create Date: 2022-12-18 11:26:31.976348+00:00
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = 'a209f0333f5a'
down_revision = '4ab54becec04'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('indieauth_access_token', schema=None) as batch_op:
batch_op.add_column(sa.Column('refresh_token', sa.String(), nullable=True))
batch_op.add_column(sa.Column('was_refreshed', sa.Boolean(), server_default='0', nullable=False))
batch_op.create_index(batch_op.f('ix_indieauth_access_token_refresh_token'), ['refresh_token'], unique=True)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('indieauth_access_token', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('ix_indieauth_access_token_refresh_token'))
batch_op.drop_column('was_refreshed')
batch_op.drop_column('refresh_token')
# ### end Alembic commands ###

View file

@ -270,29 +270,54 @@ async def indieauth_token_endpoint(
form_data = await request.form() form_data = await request.form()
logger.info(f"{form_data=}") logger.info(f"{form_data=}")
grant_type = form_data.get("grant_type", "authorization_code") grant_type = form_data.get("grant_type", "authorization_code")
if grant_type != "authorization_code": if grant_type not in ["authorization_code", "refresh_token"]:
raise ValueError(f"Invalid grant_type {grant_type}") raise ValueError(f"Invalid grant_type {grant_type}")
code = form_data["code"]
# These must match the params from the first request # These must match the params from the first request
client_id = form_data["client_id"] client_id = form_data["client_id"]
redirect_uri = form_data["redirect_uri"]
# code_verifier is optional for backward compat
code_verifier = form_data.get("code_verifier") code_verifier = form_data.get("code_verifier")
is_code_valid, auth_code_request = await _check_auth_code( if grant_type == "authorization_code":
db_session, code = form_data["code"]
code=code, redirect_uri = form_data["redirect_uri"]
client_id=client_id, # code_verifier is optional for backward compat
redirect_uri=redirect_uri, is_code_valid, auth_code_request = await _check_auth_code(
code_verifier=code_verifier, db_session,
) code=code,
if not is_code_valid or (auth_code_request and not auth_code_request.scope): client_id=client_id,
return JSONResponse( redirect_uri=redirect_uri,
content={"error": "invalid_grant"}, code_verifier=code_verifier,
status_code=400,
) )
if not is_code_valid or (auth_code_request and not auth_code_request.scope):
return JSONResponse(
content={"error": "invalid_grant"},
status_code=400,
)
elif grant_type == "refresh_token":
refresh_token = form_data["refresh_token"]
access_token = (
await db_session.scalars(
select(models.IndieAuthAccessToken)
.where(
models.IndieAuthAccessToken.refresh_token == refresh_token,
models.IndieAuthAccessToken.was_refreshed.is_(False),
)
.options(
joinedload(
models.IndieAuthAccessToken.indieauth_authorization_request
)
)
)
).one_or_none()
if not access_token:
raise ValueError("invalid refresh token")
if access_token.indieauth_authorization_request.client_id != client_id:
raise ValueError("invalid client ID")
auth_code_request = access_token.indieauth_authorization_request
access_token.was_refreshed = True
if not auth_code_request: if not auth_code_request:
raise ValueError("Should never happen") raise ValueError("Should never happen")
@ -300,6 +325,7 @@ async def indieauth_token_endpoint(
access_token = models.IndieAuthAccessToken( access_token = models.IndieAuthAccessToken(
indieauth_authorization_request_id=auth_code_request.id, indieauth_authorization_request_id=auth_code_request.id,
access_token=secrets.token_urlsafe(32), access_token=secrets.token_urlsafe(32),
refresh_token=secrets.token_urlsafe(32),
expires_in=3600, expires_in=3600,
scope=auth_code_request.scope, scope=auth_code_request.scope,
) )
@ -309,6 +335,7 @@ async def indieauth_token_endpoint(
return JSONResponse( return JSONResponse(
content={ content={
"access_token": access_token.access_token, "access_token": access_token.access_token,
"refresh_token": access_token.refresh_token,
"token_type": "Bearer", "token_type": "Bearer",
"scope": auth_code_request.scope, "scope": auth_code_request.scope,
"me": config.ID + "/", "me": config.ID + "/",

View file

@ -631,6 +631,19 @@ async def outbox(
) )
@app.post("/outbox")
async def post_inbox(
request: Request,
db_session: AsyncSession = Depends(get_db_session),
access_token_info: indieauth.AccessTokenInfo = Depends(
indieauth.enforce_access_token
),
) -> ActivityPubResponse:
payload = await request.json()
logger.info(f"{payload=}")
raise ValueError("TODO")
@app.get("/featured") @app.get("/featured")
async def featured( async def featured(
db_session: AsyncSession = Depends(get_db_session), db_session: AsyncSession = Depends(get_db_session),
@ -1055,7 +1068,6 @@ async def get_inbox(
page: bool | None = None, page: bool | None = None,
next_cursor: str | None = None, next_cursor: str | None = None,
) -> ActivityPubResponse: ) -> ActivityPubResponse:
logger.info(f"{page=}/{next_cursor=}")
where = [ where = [
models.InboxObject.ap_type.in_( models.InboxObject.ap_type.in_(
["Create", "Follow", "Like", "Announce", "Undo", "Update"] ["Create", "Follow", "Like", "Announce", "Undo", "Update"]

View file

@ -471,9 +471,11 @@ class IndieAuthAccessToken(Base):
) )
access_token = Column(String, nullable=False, unique=True, index=True) access_token = Column(String, nullable=False, unique=True, index=True)
refresh_token = Column(String, nullable=True, unique=True, index=True)
expires_in = Column(Integer, nullable=False) expires_in = Column(Integer, nullable=False)
scope = Column(String, nullable=False) scope = Column(String, nullable=False)
is_revoked = Column(Boolean, nullable=False, default=False) is_revoked = Column(Boolean, nullable=False, default=False)
was_refreshed = Column(Boolean, nullable=False, default=False, server_default="0")
class OAuthClient(Base): class OAuthClient(Base):