diff --git a/.gitignore b/.gitignore index 505a3b1..adf464f 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,5 @@ wheels/ # Virtual environments .venv + +.nanocoder diff --git a/main.py b/main.py index 88e9f0c..8a68052 100644 --- a/main.py +++ b/main.py @@ -6,16 +6,35 @@ import random import secrets import string import zipfile +from base64 import b64encode from glob import glob from io import BytesIO from pathlib import Path -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, Depends +from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.responses import FileResponse, HTMLResponse, RedirectResponse, StreamingResponse -app = FastAPI() -file_mapping = {} -indexers = [] + +AUTH_SCHEME = HTTPBasic() +expected_password: str | None = None + + +async def get_current_username(credentials: HTTPBasicCredentials = Depends(AUTH_SCHEME)) -> str: + """Verify Basic Authentication credentials""" + if expected_password is not None and credentials.password != expected_password: + raise HTTPException( + status_code=401, + detail="Incorrect password", + headers={"WWW-Authenticate": AUTH_SCHEME.getobfuscation_header()}, + ) + return credentials.username + + +def set_auth_password(password: str | None): + """Set the expected password for authentication""" + global expected_password + expected_password = password class FileIndexer: @@ -128,7 +147,7 @@ async def health_check(): @app.get("/api/{file_hash}/data") -async def get_file_data(file_hash: str): +async def get_file_data(file_hash: str, username: str = Depends(get_current_username)): """Serve a specific file by its hash""" if file_hash not in file_mapping: raise HTTPException(status_code=404, detail="File not found") @@ -152,14 +171,14 @@ async def get_file_data(file_hash: str): @app.get("/") -async def root(): +async def root(username: str = Depends(get_current_username)): """Redirect to a random file hash""" random_hash = _get_random_hash() return RedirectResponse(url="/{hash}".format(hash=random_hash)) @app.get("/{order}/{delay}") -async def order_delay(order: str, delay: int): +async def order_delay(order: str, delay: int, username: str = Depends(get_current_username)): """Redirect to random file with order and delay""" random_hash = _get_random_hash() return RedirectResponse( @@ -209,7 +228,7 @@ def _render_page( @app.get("/{file_hash}") -async def hash_page(file_hash: str): +async def hash_page(file_hash: str, username: str = Depends(get_current_username)): """Serve a page for a specific file hash with navigation""" if file_hash not in file_mapping: raise HTTPException(status_code=404, detail="File not found") @@ -222,7 +241,7 @@ async def hash_page(file_hash: str): @app.get("/{order}/{delay}/{file_hash}") -async def hash_page_with_refresh(order: str, delay: int, file_hash: str): +async def hash_page_with_refresh(order: str, delay: int, file_hash: str, username: str = Depends(get_current_username)): """Serve a page for a specific file hash with auto-refresh navigation""" if file_hash not in file_mapping: raise HTTPException(status_code=404, detail="File not found") @@ -278,9 +297,11 @@ if __name__ == "__main__": parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to") parser.add_argument("--port", type=int, default=8000, help="Port to bind to") parser.add_argument("--salt", type=str, default=None, help="Salt for hashing file paths") + parser.add_argument("--password", type=str, default=None, help="Password for Basic Authentication") args = parser.parse_args() initialize_server(args) + set_auth_password(args.password) import uvicorn