diff --git a/main.py b/main.py index ca87aa0..ba25107 100644 --- a/main.py +++ b/main.py @@ -8,7 +8,7 @@ import secrets import zipfile from pathlib import Path from fastapi import FastAPI, HTTPException -from fastapi.responses import FileResponse, StreamingResponse +from fastapi.responses import FileResponse, StreamingResponse, RedirectResponse import mimetypes app = FastAPI() @@ -49,7 +49,7 @@ class FileIndexer: return None file_path = self.file_mapping[file_hash] - with open(file_path, 'rb') as f: + with open(file_path, "rb") as f: yield from f def get_filename_by_hash(self, file_hash: str) -> str: @@ -62,7 +62,7 @@ class FileIndexer: class ZipFileIndexer(FileIndexer): def _index(self): """Index all files in the zip file""" - with zipfile.ZipFile(self.path, 'r') as zip_file: + with zipfile.ZipFile(self.path, "r") as zip_file: self.file_mapping = { self._hash_path(file_info.filename): file_info for file_info in zip_file.infolist() @@ -76,7 +76,7 @@ class ZipFileIndexer(FileIndexer): file_info = self.file_mapping[file_hash] - with zipfile.ZipFile(self.path, 'r') as zip_file: + with zipfile.ZipFile(self.path, "r") as zip_file: yield from BytesIO(zip_file.read(file_info.filename)) def get_filename_by_hash(self, file_hash: str) -> str: @@ -85,9 +85,8 @@ class ZipFileIndexer(FileIndexer): return None return self.file_mapping[file_hash].filename -INDEXER_MAP = { - ".zip": ZipFileIndexer -} + +INDEXER_MAP = {".zip": ZipFileIndexer} def initialize_server(args: argparse.Namespace): @@ -105,7 +104,14 @@ def initialize_server(args: argparse.Namespace): print(f"Indexed {len(indexer.file_mapping)} files") + @app.get("/") +async def root(): + """Redirect to /random""" + return RedirectResponse(url="/random") + + +@app.get("/random") async def get_random_file(): """Serve a random file from the mapping""" if not indexer.file_mapping: @@ -125,11 +131,12 @@ async def get_random_file(): "Content-Disposition": f"inline; filename={os.path.basename(filename)}", "Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", - "Expires": "0" - } + "Expires": "0", + }, ) return response + @app.get("/{file_hash}") async def get_file_by_hash(file_hash: str): """Serve a specific file by its hash""" @@ -143,14 +150,16 @@ async def get_file_by_hash(file_hash: str): media_type=content_type, headers={ "Content-Disposition": f"inline; filename={os.path.basename(file_path)}", - } + }, ) + # Optional: Add a health check endpoint @app.get("/health") async def health_check(): return {"status": "healthy", "file_count": len(file_mapping)} + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the file server") parser.add_argument("source", type=str, help="Path to directory or ZIP archive") @@ -161,4 +170,5 @@ if __name__ == "__main__": initialize_server(args) import uvicorn + uvicorn.run(app, host=args.host, port=args.port)