From 5774ef77790cc2f770db9d33f476a42eb358943e Mon Sep 17 00:00:00 2001 From: Timothy Farrell Date: Mon, 9 Mar 2026 09:46:43 +0000 Subject: [PATCH] Allow a glob source --- main.py | 124 ++++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 80 insertions(+), 44 deletions(-) diff --git a/main.py b/main.py index 4bed5c7..8543b14 100644 --- a/main.py +++ b/main.py @@ -5,6 +5,7 @@ import os import random import secrets import zipfile +from glob import glob from io import BytesIO from pathlib import Path @@ -12,15 +13,15 @@ from fastapi import FastAPI, HTTPException from fastapi.responses import FileResponse, StreamingResponse app = FastAPI() -indexer = None +file_mapping = {} +indexers = [] class FileIndexer: def __init__(self, path: str, salt: str | None = None): self.path = Path(path) - self.file_mapping = {} self._salt = salt - self._index() + self._file_mapping = self._index() @property def salt(self) -> str: @@ -33,77 +34,91 @@ class FileIndexer: """Generate a salted hash of the file path""" return hashlib.sha256((filepath + self.salt).encode()).hexdigest() - def _index(self): - """Index all files in the zip file""" + def _index(self) -> dict[str, str]: + """Index all files in the directory""" + mapping = {} for root, _, files in os.walk(self.path): for file in files: filepath = os.path.join(root, file) - # Generate hash for the file path file_hash = self._hash_path(filepath) - # Store mapping - self.file_mapping[file_hash] = filepath + mapping[file_hash] = filepath + return mapping def get_file_by_hash(self, file_hash: str): """Get file content by hash""" - if file_hash not in self.file_mapping: + if file_hash not in self._file_mapping: return None - file_path = self.file_mapping[file_hash] + file_path = self._file_mapping[file_hash] with open(file_path, "rb") as f: yield from f - def get_filename_by_hash(self, file_hash: str) -> str: + def get_filename_by_hash(self, file_hash: str) -> str | None: """Get filename by hash""" - if file_hash not in self.file_mapping: + if file_hash not in self._file_mapping: return None - return self.file_mapping[file_hash] + return self._file_mapping[file_hash] class ZipFileIndexer(FileIndexer): - def _index(self): + def _index(self) -> dict[str, str]: """Index all files in the zip file""" + mapping = {} with zipfile.ZipFile(self.path, "r") as zip_file: - self.file_mapping = { - self._hash_path(file_info.filename): file_info - for file_info in zip_file.infolist() - if not file_info.is_dir() - } + for file_info in zip_file.infolist(): + if not file_info.is_dir(): + file_hash = self._hash_path(file_info.filename) + mapping[file_hash] = file_info.filename + return mapping def get_file_by_hash(self, file_hash: str): """Get file content by hash""" - if file_hash not in self.file_mapping: + if file_hash not in self._file_mapping: return None - file_info = self.file_mapping[file_hash] + filename = self._file_mapping[file_hash] with zipfile.ZipFile(self.path, "r") as zip_file: - yield from BytesIO(zip_file.read(file_info.filename)) + yield from BytesIO(zip_file.read(filename)) - def get_filename_by_hash(self, file_hash: str) -> str: + def get_filename_by_hash(self, file_hash: str) -> str | None: """Get filename by hash""" - if file_hash not in self.file_mapping: + if file_hash not in self._file_mapping: return None - return self.file_mapping[file_hash].filename + return self._file_mapping[file_hash] INDEXER_MAP = {".zip": ZipFileIndexer} def initialize_server(args: argparse.Namespace): - """Initialize the server with directory indexing""" - global indexer + """Initialize the server with directory or glob indexing""" + global file_mapping, indexers src_path = Path(args.source) - if not src_path.exists(): - raise SystemExit(f"Source path {src_path} does not exist") - indexer = ( - FileIndexer(src_path, args.salt) - if src_path.is_dir() - else INDEXER_MAP[src_path.suffix](src_path, args.salt) - ) + shared_salt = args.salt + if shared_salt is None: + shared_salt = secrets.token_hex(16) - print(f"Indexed {len(indexer.file_mapping)} files") + if src_path.is_dir(): + indexer = FileIndexer(str(src_path), shared_salt) + indexers.append(indexer) + file_mapping.update(indexer._file_mapping) + else: + pattern = args.source + matching_files = glob(pattern) + if not matching_files: + raise SystemExit(f"No files match pattern {pattern}") + + for file_path in matching_files: + file_ext = Path(file_path).suffix + if file_ext in INDEXER_MAP: + indexer = INDEXER_MAP[file_ext](file_path, shared_salt) + indexers.append(indexer) + file_mapping.update(indexer._file_mapping) + + print(f"Indexed {len(file_mapping)} files from {len(indexers)} source(s)") @app.get("/") @@ -115,10 +130,10 @@ async def root(): @app.get("/random") async def get_random_file(): """Get random file hashes from the mapping""" - if not indexer.file_mapping: + if not file_mapping: raise HTTPException(status_code=404, detail="No files indexed") - keys = list(indexer.file_mapping.keys()) + keys = list(file_mapping.keys()) random_idx = random.randint(0, len(keys) - 1) current = keys[random_idx] next_hash = keys[(random_idx + 1) % len(keys)] @@ -126,14 +141,26 @@ async def get_random_file(): return {"img": current, "next": next_hash, "previous": prev_hash} +def _find_indexer_for_hash(file_hash: str): + """Find the indexer that contains the file with the given hash""" + for idx in indexers: + if file_hash in idx._file_mapping: + return idx + return None + + @app.get("/{file_hash}/data") async def get_file_data(file_hash: str): """Serve a specific file by its hash""" - if file_hash not in indexer.file_mapping: + if file_hash not in file_mapping: + raise HTTPException(status_code=404, detail="File not found") + + indexer = _find_indexer_for_hash(file_hash) + if not indexer: raise HTTPException(status_code=404, detail="File not found") filename = indexer.get_filename_by_hash(file_hash) - content_type, _ = mimetypes.guess_type(filename) + content_type, _ = mimetypes.guess_type(filename or "") if not content_type: content_type = "application/octet-stream" @@ -141,7 +168,7 @@ async def get_file_data(file_hash: str): indexer.get_file_by_hash(file_hash), media_type=content_type, headers={ - "Content-Disposition": f"inline; filename={os.path.basename(filename)}", + "Content-Disposition": f"inline; filename={os.path.basename(filename or '')}", }, ) @@ -149,13 +176,18 @@ async def get_file_data(file_hash: str): @app.get("/{file_hash}") async def get_file_info(file_hash: str): """Get file info by hash""" - if file_hash not in indexer.file_mapping: + if file_hash not in file_mapping: raise HTTPException(status_code=404, detail="File not found") - keys = list(indexer.file_mapping.keys()) + keys = list(file_mapping.keys()) idx = keys.index(file_hash) next_hash = keys[(idx + 1) % len(keys)] prev_hash = keys[idx - 1] if idx > 0 else keys[-1] + + indexer = _find_indexer_for_hash(file_hash) + if not indexer: + raise HTTPException(status_code=404, detail="File not found") + filename = indexer.get_filename_by_hash(file_hash) return {"img": file_hash, "next": next_hash, "previous": prev_hash, "filename": filename} @@ -163,12 +195,16 @@ async def get_file_info(file_hash: str): # Optional: Add a health check endpoint @app.get("/health") async def health_check(): - return {"status": "healthy", "file_count": len(indexer.file_mapping)} + return {"status": "healthy", "file_count": len(file_mapping)} if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the file server") - parser.add_argument("source", type=str, help="Path to directory or ZIP archive") + parser.add_argument( + "source", + type=str, + help="Path to directory, ZIP archive, or glob pattern (e.g., *.zip, path/to/zips/*.zip)", + ) parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to") parser.add_argument("--port", type=int, default=8000, help="Port to bind to") parser.add_argument("--salt", type=str, default=None, help="Salt for hashing file paths")