Add basic password support
This commit is contained in:
parent
bb7730e8dc
commit
84db52c718
2
.gitignore
vendored
2
.gitignore
vendored
@ -8,3 +8,5 @@ wheels/
|
||||
|
||||
# Virtual environments
|
||||
.venv
|
||||
|
||||
.nanocoder
|
||||
|
||||
39
main.py
39
main.py
@ -6,16 +6,35 @@ import random
|
||||
import secrets
|
||||
import string
|
||||
import zipfile
|
||||
from base64 import b64encode
|
||||
from glob import glob
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import FastAPI, HTTPException, Depends
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
from fastapi.responses import FileResponse, HTMLResponse, RedirectResponse, StreamingResponse
|
||||
|
||||
app = FastAPI()
|
||||
file_mapping = {}
|
||||
indexers = []
|
||||
|
||||
AUTH_SCHEME = HTTPBasic()
|
||||
expected_password: str | None = None
|
||||
|
||||
|
||||
async def get_current_username(credentials: HTTPBasicCredentials = Depends(AUTH_SCHEME)) -> str:
|
||||
"""Verify Basic Authentication credentials"""
|
||||
if expected_password is not None and credentials.password != expected_password:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Incorrect password",
|
||||
headers={"WWW-Authenticate": AUTH_SCHEME.getobfuscation_header()},
|
||||
)
|
||||
return credentials.username
|
||||
|
||||
|
||||
def set_auth_password(password: str | None):
|
||||
"""Set the expected password for authentication"""
|
||||
global expected_password
|
||||
expected_password = password
|
||||
|
||||
|
||||
class FileIndexer:
|
||||
@ -128,7 +147,7 @@ async def health_check():
|
||||
|
||||
|
||||
@app.get("/api/{file_hash}/data")
|
||||
async def get_file_data(file_hash: str):
|
||||
async def get_file_data(file_hash: str, username: str = Depends(get_current_username)):
|
||||
"""Serve a specific file by its hash"""
|
||||
if file_hash not in file_mapping:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
@ -152,14 +171,14 @@ async def get_file_data(file_hash: str):
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
async def root(username: str = Depends(get_current_username)):
|
||||
"""Redirect to a random file hash"""
|
||||
random_hash = _get_random_hash()
|
||||
return RedirectResponse(url="/{hash}".format(hash=random_hash))
|
||||
|
||||
|
||||
@app.get("/{order}/{delay}")
|
||||
async def order_delay(order: str, delay: int):
|
||||
async def order_delay(order: str, delay: int, username: str = Depends(get_current_username)):
|
||||
"""Redirect to random file with order and delay"""
|
||||
random_hash = _get_random_hash()
|
||||
return RedirectResponse(
|
||||
@ -209,7 +228,7 @@ def _render_page(
|
||||
|
||||
|
||||
@app.get("/{file_hash}")
|
||||
async def hash_page(file_hash: str):
|
||||
async def hash_page(file_hash: str, username: str = Depends(get_current_username)):
|
||||
"""Serve a page for a specific file hash with navigation"""
|
||||
if file_hash not in file_mapping:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
@ -222,7 +241,7 @@ async def hash_page(file_hash: str):
|
||||
|
||||
|
||||
@app.get("/{order}/{delay}/{file_hash}")
|
||||
async def hash_page_with_refresh(order: str, delay: int, file_hash: str):
|
||||
async def hash_page_with_refresh(order: str, delay: int, file_hash: str, username: str = Depends(get_current_username)):
|
||||
"""Serve a page for a specific file hash with auto-refresh navigation"""
|
||||
if file_hash not in file_mapping:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
@ -278,9 +297,11 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
|
||||
parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
|
||||
parser.add_argument("--salt", type=str, default=None, help="Salt for hashing file paths")
|
||||
parser.add_argument("--password", type=str, default=None, help="Password for Basic Authentication")
|
||||
args = parser.parse_args()
|
||||
|
||||
initialize_server(args)
|
||||
set_auth_password(args.password)
|
||||
|
||||
import uvicorn
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user