image_server/main.py

218 lines
6.7 KiB
Python

import argparse
import hashlib
import mimetypes
import os
import random
import secrets
import zipfile
from glob import glob
from io import BytesIO
from pathlib import Path
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse, StreamingResponse
app = FastAPI()
file_mapping = {}
indexers = []
class FileIndexer:
def __init__(self, path: str, salt: str | None = None):
self.path = Path(path)
self._salt = salt
self._file_mapping = self._index()
@property
def salt(self) -> str:
"""Generate a random salt for hashing"""
if self._salt is None:
self._salt = secrets.token_hex(16)
return self._salt
def _hash_path(self, filepath: str) -> str:
"""Generate a salted hash of the file path"""
return hashlib.sha256((filepath + self.salt).encode()).hexdigest()
def _index(self) -> dict[str, str]:
"""Index all files in the directory"""
mapping = {}
for root, _, files in os.walk(self.path):
for file in files:
filepath = os.path.join(root, file)
file_hash = self._hash_path(filepath)
mapping[file_hash] = filepath
return mapping
def get_file_by_hash(self, file_hash: str):
"""Get file content by hash"""
if file_hash not in self._file_mapping:
return None
file_path = self._file_mapping[file_hash]
with open(file_path, "rb") as f:
yield from f
def get_filename_by_hash(self, file_hash: str) -> str | None:
"""Get filename by hash"""
if file_hash not in self._file_mapping:
return None
return self._file_mapping[file_hash]
class ZipFileIndexer(FileIndexer):
def _index(self) -> dict[str, str]:
"""Index all files in the zip file"""
mapping = {}
with zipfile.ZipFile(self.path, "r") as zip_file:
for file_info in zip_file.infolist():
if not file_info.is_dir():
file_hash = self._hash_path(file_info.filename)
mapping[file_hash] = file_info.filename
return mapping
def get_file_by_hash(self, file_hash: str):
"""Get file content by hash"""
if file_hash not in self._file_mapping:
return None
filename = self._file_mapping[file_hash]
with zipfile.ZipFile(self.path, "r") as zip_file:
yield from BytesIO(zip_file.read(filename))
def get_filename_by_hash(self, file_hash: str) -> str | None:
"""Get filename by hash"""
if file_hash not in self._file_mapping:
return None
return self._file_mapping[file_hash]
INDEXER_MAP = {".zip": ZipFileIndexer}
def initialize_server(args: argparse.Namespace):
"""Initialize the server with directory or glob indexing"""
global file_mapping, indexers
src_path = Path(args.source)
shared_salt = args.salt
if shared_salt is None:
shared_salt = secrets.token_hex(16)
if src_path.is_dir():
indexer = FileIndexer(str(src_path), shared_salt)
indexers.append(indexer)
file_mapping.update(indexer._file_mapping)
else:
pattern = args.source
matching_files = glob(pattern)
if not matching_files:
raise SystemExit(f"No files match pattern {pattern}")
for file_path in matching_files:
file_ext = Path(file_path).suffix
if file_ext in INDEXER_MAP:
indexer = INDEXER_MAP[file_ext](file_path, shared_salt)
indexers.append(indexer)
file_mapping.update(indexer._file_mapping)
print(f"Indexed {len(file_mapping)} files from {len(indexers)} source(s)")
@app.get("/")
async def root():
"""Serve the Frontend app"""
return FileResponse("frontend.html")
@app.get("/random")
async def get_random_file():
"""Get random file hashes from the mapping"""
if not file_mapping:
raise HTTPException(status_code=404, detail="No files indexed")
keys = list(file_mapping.keys())
random_idx = random.randint(0, len(keys) - 1)
current = keys[random_idx]
next_hash = keys[(random_idx + 1) % len(keys)]
prev_hash = keys[random_idx - 1] if random_idx > 0 else keys[-1]
return {"img": current, "next": next_hash, "previous": prev_hash}
def _find_indexer_for_hash(file_hash: str):
"""Find the indexer that contains the file with the given hash"""
for idx in indexers:
if file_hash in idx._file_mapping:
return idx
return None
@app.get("/{file_hash}/data")
async def get_file_data(file_hash: str):
"""Serve a specific file by its hash"""
if file_hash not in file_mapping:
raise HTTPException(status_code=404, detail="File not found")
indexer = _find_indexer_for_hash(file_hash)
if not indexer:
raise HTTPException(status_code=404, detail="File not found")
filename = indexer.get_filename_by_hash(file_hash)
content_type, _ = mimetypes.guess_type(filename or "")
if not content_type:
content_type = "application/octet-stream"
return StreamingResponse(
indexer.get_file_by_hash(file_hash),
media_type=content_type,
headers={
"Content-Disposition": f"inline; filename={os.path.basename(filename or '')}",
},
)
@app.get("/{file_hash}")
async def get_file_info(file_hash: str):
"""Get file info by hash"""
if file_hash not in file_mapping:
raise HTTPException(status_code=404, detail="File not found")
keys = list(file_mapping.keys())
idx = keys.index(file_hash)
next_hash = keys[(idx + 1) % len(keys)]
prev_hash = keys[idx - 1] if idx > 0 else keys[-1]
indexer = _find_indexer_for_hash(file_hash)
if not indexer:
raise HTTPException(status_code=404, detail="File not found")
filename = indexer.get_filename_by_hash(file_hash)
return {"img": file_hash, "next": next_hash, "previous": prev_hash, "filename": filename}
# Optional: Add a health check endpoint
@app.get("/health")
async def health_check():
return {"status": "healthy", "file_count": len(file_mapping)}
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the file server")
parser.add_argument(
"source",
type=str,
help="Path to directory, ZIP archive, or glob pattern (e.g., *.zip, path/to/zips/*.zip)",
)
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
parser.add_argument("--salt", type=str, default=None, help="Salt for hashing file paths")
args = parser.parse_args()
initialize_server(args)
import uvicorn
uvicorn.run(app, host=args.host, port=args.port)