mirror of
https://git.sr.ht/~tsileo/microblog.pub
synced 2025-01-22 04:44:27 +00:00
Add OAuth refresh token support
This commit is contained in:
parent
3fb36d6119
commit
ed214cf0e7
4 changed files with 94 additions and 17 deletions
|
@ -0,0 +1,36 @@
|
|||
"""Add OAuth refresh token support
|
||||
|
||||
Revision ID: a209f0333f5a
|
||||
Revises: 4ab54becec04
|
||||
Create Date: 2022-12-18 11:26:31.976348+00:00
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'a209f0333f5a'
|
||||
down_revision = '4ab54becec04'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('indieauth_access_token', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('refresh_token', sa.String(), nullable=True))
|
||||
batch_op.add_column(sa.Column('was_refreshed', sa.Boolean(), server_default='0', nullable=False))
|
||||
batch_op.create_index(batch_op.f('ix_indieauth_access_token_refresh_token'), ['refresh_token'], unique=True)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('indieauth_access_token', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('ix_indieauth_access_token_refresh_token'))
|
||||
batch_op.drop_column('was_refreshed')
|
||||
batch_op.drop_column('refresh_token')
|
||||
|
||||
# ### end Alembic commands ###
|
|
@ -270,29 +270,54 @@ async def indieauth_token_endpoint(
|
|||
form_data = await request.form()
|
||||
logger.info(f"{form_data=}")
|
||||
grant_type = form_data.get("grant_type", "authorization_code")
|
||||
if grant_type != "authorization_code":
|
||||
if grant_type not in ["authorization_code", "refresh_token"]:
|
||||
raise ValueError(f"Invalid grant_type {grant_type}")
|
||||
|
||||
code = form_data["code"]
|
||||
|
||||
# These must match the params from the first request
|
||||
client_id = form_data["client_id"]
|
||||
redirect_uri = form_data["redirect_uri"]
|
||||
# code_verifier is optional for backward compat
|
||||
code_verifier = form_data.get("code_verifier")
|
||||
|
||||
is_code_valid, auth_code_request = await _check_auth_code(
|
||||
db_session,
|
||||
code=code,
|
||||
client_id=client_id,
|
||||
redirect_uri=redirect_uri,
|
||||
code_verifier=code_verifier,
|
||||
)
|
||||
if not is_code_valid or (auth_code_request and not auth_code_request.scope):
|
||||
return JSONResponse(
|
||||
content={"error": "invalid_grant"},
|
||||
status_code=400,
|
||||
if grant_type == "authorization_code":
|
||||
code = form_data["code"]
|
||||
redirect_uri = form_data["redirect_uri"]
|
||||
# code_verifier is optional for backward compat
|
||||
is_code_valid, auth_code_request = await _check_auth_code(
|
||||
db_session,
|
||||
code=code,
|
||||
client_id=client_id,
|
||||
redirect_uri=redirect_uri,
|
||||
code_verifier=code_verifier,
|
||||
)
|
||||
if not is_code_valid or (auth_code_request and not auth_code_request.scope):
|
||||
return JSONResponse(
|
||||
content={"error": "invalid_grant"},
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
elif grant_type == "refresh_token":
|
||||
refresh_token = form_data["refresh_token"]
|
||||
access_token = (
|
||||
await db_session.scalars(
|
||||
select(models.IndieAuthAccessToken)
|
||||
.where(
|
||||
models.IndieAuthAccessToken.refresh_token == refresh_token,
|
||||
models.IndieAuthAccessToken.was_refreshed.is_(False),
|
||||
)
|
||||
.options(
|
||||
joinedload(
|
||||
models.IndieAuthAccessToken.indieauth_authorization_request
|
||||
)
|
||||
)
|
||||
)
|
||||
).one_or_none()
|
||||
if not access_token:
|
||||
raise ValueError("invalid refresh token")
|
||||
|
||||
if access_token.indieauth_authorization_request.client_id != client_id:
|
||||
raise ValueError("invalid client ID")
|
||||
|
||||
auth_code_request = access_token.indieauth_authorization_request
|
||||
access_token.was_refreshed = True
|
||||
|
||||
if not auth_code_request:
|
||||
raise ValueError("Should never happen")
|
||||
|
@ -300,6 +325,7 @@ async def indieauth_token_endpoint(
|
|||
access_token = models.IndieAuthAccessToken(
|
||||
indieauth_authorization_request_id=auth_code_request.id,
|
||||
access_token=secrets.token_urlsafe(32),
|
||||
refresh_token=secrets.token_urlsafe(32),
|
||||
expires_in=3600,
|
||||
scope=auth_code_request.scope,
|
||||
)
|
||||
|
@ -309,6 +335,7 @@ async def indieauth_token_endpoint(
|
|||
return JSONResponse(
|
||||
content={
|
||||
"access_token": access_token.access_token,
|
||||
"refresh_token": access_token.refresh_token,
|
||||
"token_type": "Bearer",
|
||||
"scope": auth_code_request.scope,
|
||||
"me": config.ID + "/",
|
||||
|
|
14
app/main.py
14
app/main.py
|
@ -631,6 +631,19 @@ async def outbox(
|
|||
)
|
||||
|
||||
|
||||
@app.post("/outbox")
|
||||
async def post_inbox(
|
||||
request: Request,
|
||||
db_session: AsyncSession = Depends(get_db_session),
|
||||
access_token_info: indieauth.AccessTokenInfo = Depends(
|
||||
indieauth.enforce_access_token
|
||||
),
|
||||
) -> ActivityPubResponse:
|
||||
payload = await request.json()
|
||||
logger.info(f"{payload=}")
|
||||
raise ValueError("TODO")
|
||||
|
||||
|
||||
@app.get("/featured")
|
||||
async def featured(
|
||||
db_session: AsyncSession = Depends(get_db_session),
|
||||
|
@ -1055,7 +1068,6 @@ async def get_inbox(
|
|||
page: bool | None = None,
|
||||
next_cursor: str | None = None,
|
||||
) -> ActivityPubResponse:
|
||||
logger.info(f"{page=}/{next_cursor=}")
|
||||
where = [
|
||||
models.InboxObject.ap_type.in_(
|
||||
["Create", "Follow", "Like", "Announce", "Undo", "Update"]
|
||||
|
|
|
@ -471,9 +471,11 @@ class IndieAuthAccessToken(Base):
|
|||
)
|
||||
|
||||
access_token = Column(String, nullable=False, unique=True, index=True)
|
||||
refresh_token = Column(String, nullable=True, unique=True, index=True)
|
||||
expires_in = Column(Integer, nullable=False)
|
||||
scope = Column(String, nullable=False)
|
||||
is_revoked = Column(Boolean, nullable=False, default=False)
|
||||
was_refreshed = Column(Boolean, nullable=False, default=False, server_default="0")
|
||||
|
||||
|
||||
class OAuthClient(Base):
|
||||
|
|
Loading…
Reference in a new issue