Allow a glob source

This commit is contained in:
Timothy Farrell 2026-03-09 09:46:43 +00:00
parent 82c32c2a0b
commit 5774ef7779

124
main.py
View File

@ -5,6 +5,7 @@ import os
import random import random
import secrets import secrets
import zipfile import zipfile
from glob import glob
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
@ -12,15 +13,15 @@ from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse, StreamingResponse from fastapi.responses import FileResponse, StreamingResponse
app = FastAPI() app = FastAPI()
indexer = None file_mapping = {}
indexers = []
class FileIndexer: class FileIndexer:
def __init__(self, path: str, salt: str | None = None): def __init__(self, path: str, salt: str | None = None):
self.path = Path(path) self.path = Path(path)
self.file_mapping = {}
self._salt = salt self._salt = salt
self._index() self._file_mapping = self._index()
@property @property
def salt(self) -> str: def salt(self) -> str:
@ -33,77 +34,91 @@ class FileIndexer:
"""Generate a salted hash of the file path""" """Generate a salted hash of the file path"""
return hashlib.sha256((filepath + self.salt).encode()).hexdigest() return hashlib.sha256((filepath + self.salt).encode()).hexdigest()
def _index(self): def _index(self) -> dict[str, str]:
"""Index all files in the zip file""" """Index all files in the directory"""
mapping = {}
for root, _, files in os.walk(self.path): for root, _, files in os.walk(self.path):
for file in files: for file in files:
filepath = os.path.join(root, file) filepath = os.path.join(root, file)
# Generate hash for the file path
file_hash = self._hash_path(filepath) file_hash = self._hash_path(filepath)
# Store mapping mapping[file_hash] = filepath
self.file_mapping[file_hash] = filepath return mapping
def get_file_by_hash(self, file_hash: str): def get_file_by_hash(self, file_hash: str):
"""Get file content by hash""" """Get file content by hash"""
if file_hash not in self.file_mapping: if file_hash not in self._file_mapping:
return None return None
file_path = self.file_mapping[file_hash] file_path = self._file_mapping[file_hash]
with open(file_path, "rb") as f: with open(file_path, "rb") as f:
yield from f yield from f
def get_filename_by_hash(self, file_hash: str) -> str: def get_filename_by_hash(self, file_hash: str) -> str | None:
"""Get filename by hash""" """Get filename by hash"""
if file_hash not in self.file_mapping: if file_hash not in self._file_mapping:
return None return None
return self.file_mapping[file_hash] return self._file_mapping[file_hash]
class ZipFileIndexer(FileIndexer): class ZipFileIndexer(FileIndexer):
def _index(self): def _index(self) -> dict[str, str]:
"""Index all files in the zip file""" """Index all files in the zip file"""
mapping = {}
with zipfile.ZipFile(self.path, "r") as zip_file: with zipfile.ZipFile(self.path, "r") as zip_file:
self.file_mapping = { for file_info in zip_file.infolist():
self._hash_path(file_info.filename): file_info if not file_info.is_dir():
for file_info in zip_file.infolist() file_hash = self._hash_path(file_info.filename)
if not file_info.is_dir() mapping[file_hash] = file_info.filename
} return mapping
def get_file_by_hash(self, file_hash: str): def get_file_by_hash(self, file_hash: str):
"""Get file content by hash""" """Get file content by hash"""
if file_hash not in self.file_mapping: if file_hash not in self._file_mapping:
return None return None
file_info = self.file_mapping[file_hash] filename = self._file_mapping[file_hash]
with zipfile.ZipFile(self.path, "r") as zip_file: with zipfile.ZipFile(self.path, "r") as zip_file:
yield from BytesIO(zip_file.read(file_info.filename)) yield from BytesIO(zip_file.read(filename))
def get_filename_by_hash(self, file_hash: str) -> str: def get_filename_by_hash(self, file_hash: str) -> str | None:
"""Get filename by hash""" """Get filename by hash"""
if file_hash not in self.file_mapping: if file_hash not in self._file_mapping:
return None return None
return self.file_mapping[file_hash].filename return self._file_mapping[file_hash]
INDEXER_MAP = {".zip": ZipFileIndexer} INDEXER_MAP = {".zip": ZipFileIndexer}
def initialize_server(args: argparse.Namespace): def initialize_server(args: argparse.Namespace):
"""Initialize the server with directory indexing""" """Initialize the server with directory or glob indexing"""
global indexer global file_mapping, indexers
src_path = Path(args.source) src_path = Path(args.source)
if not src_path.exists():
raise SystemExit(f"Source path {src_path} does not exist")
indexer = ( shared_salt = args.salt
FileIndexer(src_path, args.salt) if shared_salt is None:
if src_path.is_dir() shared_salt = secrets.token_hex(16)
else INDEXER_MAP[src_path.suffix](src_path, args.salt)
)
print(f"Indexed {len(indexer.file_mapping)} files") if src_path.is_dir():
indexer = FileIndexer(str(src_path), shared_salt)
indexers.append(indexer)
file_mapping.update(indexer._file_mapping)
else:
pattern = args.source
matching_files = glob(pattern)
if not matching_files:
raise SystemExit(f"No files match pattern {pattern}")
for file_path in matching_files:
file_ext = Path(file_path).suffix
if file_ext in INDEXER_MAP:
indexer = INDEXER_MAP[file_ext](file_path, shared_salt)
indexers.append(indexer)
file_mapping.update(indexer._file_mapping)
print(f"Indexed {len(file_mapping)} files from {len(indexers)} source(s)")
@app.get("/") @app.get("/")
@ -115,10 +130,10 @@ async def root():
@app.get("/random") @app.get("/random")
async def get_random_file(): async def get_random_file():
"""Get random file hashes from the mapping""" """Get random file hashes from the mapping"""
if not indexer.file_mapping: if not file_mapping:
raise HTTPException(status_code=404, detail="No files indexed") raise HTTPException(status_code=404, detail="No files indexed")
keys = list(indexer.file_mapping.keys()) keys = list(file_mapping.keys())
random_idx = random.randint(0, len(keys) - 1) random_idx = random.randint(0, len(keys) - 1)
current = keys[random_idx] current = keys[random_idx]
next_hash = keys[(random_idx + 1) % len(keys)] next_hash = keys[(random_idx + 1) % len(keys)]
@ -126,14 +141,26 @@ async def get_random_file():
return {"img": current, "next": next_hash, "previous": prev_hash} return {"img": current, "next": next_hash, "previous": prev_hash}
def _find_indexer_for_hash(file_hash: str):
"""Find the indexer that contains the file with the given hash"""
for idx in indexers:
if file_hash in idx._file_mapping:
return idx
return None
@app.get("/{file_hash}/data") @app.get("/{file_hash}/data")
async def get_file_data(file_hash: str): async def get_file_data(file_hash: str):
"""Serve a specific file by its hash""" """Serve a specific file by its hash"""
if file_hash not in indexer.file_mapping: if file_hash not in file_mapping:
raise HTTPException(status_code=404, detail="File not found")
indexer = _find_indexer_for_hash(file_hash)
if not indexer:
raise HTTPException(status_code=404, detail="File not found") raise HTTPException(status_code=404, detail="File not found")
filename = indexer.get_filename_by_hash(file_hash) filename = indexer.get_filename_by_hash(file_hash)
content_type, _ = mimetypes.guess_type(filename) content_type, _ = mimetypes.guess_type(filename or "")
if not content_type: if not content_type:
content_type = "application/octet-stream" content_type = "application/octet-stream"
@ -141,7 +168,7 @@ async def get_file_data(file_hash: str):
indexer.get_file_by_hash(file_hash), indexer.get_file_by_hash(file_hash),
media_type=content_type, media_type=content_type,
headers={ headers={
"Content-Disposition": f"inline; filename={os.path.basename(filename)}", "Content-Disposition": f"inline; filename={os.path.basename(filename or '')}",
}, },
) )
@ -149,13 +176,18 @@ async def get_file_data(file_hash: str):
@app.get("/{file_hash}") @app.get("/{file_hash}")
async def get_file_info(file_hash: str): async def get_file_info(file_hash: str):
"""Get file info by hash""" """Get file info by hash"""
if file_hash not in indexer.file_mapping: if file_hash not in file_mapping:
raise HTTPException(status_code=404, detail="File not found") raise HTTPException(status_code=404, detail="File not found")
keys = list(indexer.file_mapping.keys()) keys = list(file_mapping.keys())
idx = keys.index(file_hash) idx = keys.index(file_hash)
next_hash = keys[(idx + 1) % len(keys)] next_hash = keys[(idx + 1) % len(keys)]
prev_hash = keys[idx - 1] if idx > 0 else keys[-1] prev_hash = keys[idx - 1] if idx > 0 else keys[-1]
indexer = _find_indexer_for_hash(file_hash)
if not indexer:
raise HTTPException(status_code=404, detail="File not found")
filename = indexer.get_filename_by_hash(file_hash) filename = indexer.get_filename_by_hash(file_hash)
return {"img": file_hash, "next": next_hash, "previous": prev_hash, "filename": filename} return {"img": file_hash, "next": next_hash, "previous": prev_hash, "filename": filename}
@ -163,12 +195,16 @@ async def get_file_info(file_hash: str):
# Optional: Add a health check endpoint # Optional: Add a health check endpoint
@app.get("/health") @app.get("/health")
async def health_check(): async def health_check():
return {"status": "healthy", "file_count": len(indexer.file_mapping)} return {"status": "healthy", "file_count": len(file_mapping)}
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the file server") parser = argparse.ArgumentParser(description="Run the file server")
parser.add_argument("source", type=str, help="Path to directory or ZIP archive") parser.add_argument(
"source",
type=str,
help="Path to directory, ZIP archive, or glob pattern (e.g., *.zip, path/to/zips/*.zip)",
)
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to") parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
parser.add_argument("--port", type=int, default=8000, help="Port to bind to") parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
parser.add_argument("--salt", type=str, default=None, help="Salt for hashing file paths") parser.add_argument("--salt", type=str, default=None, help="Salt for hashing file paths")