image_server/main.py

159 lines
4.7 KiB
Python

import argparse
import hashlib
import mimetypes
import os
import random
import secrets
import zipfile
from io import BytesIO
from pathlib import Path
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse, StreamingResponse
app = FastAPI()
indexer = None
class FileIndexer:
def __init__(self, path: str):
self.path = Path(path)
self.file_mapping = {}
self._salt = None
self._index()
@property
def salt(self) -> str:
"""Generate a random salt for hashing"""
if not self._salt:
self._salt = secrets.token_hex(16)
return self._salt
def _hash_path(self, filepath: str) -> str:
"""Generate a salted hash of the file path"""
return hashlib.sha256((filepath + self.salt).encode()).hexdigest()
def _index(self):
"""Index all files in the zip file"""
for root, _, files in os.walk(self.path):
for file in files:
filepath = os.path.join(root, file)
# Generate hash for the file path
file_hash = self._hash_path(filepath)
# Store mapping
self.file_mapping[file_hash] = filepath
def get_file_by_hash(self, file_hash: str):
"""Get file content by hash"""
if file_hash not in self.file_mapping:
return None
file_path = self.file_mapping[file_hash]
with open(file_path, "rb") as f:
yield from f
def get_filename_by_hash(self, file_hash: str) -> str:
"""Get filename by hash"""
if file_hash not in self.file_mapping:
return None
return self.file_mapping[file_hash]
class ZipFileIndexer(FileIndexer):
def _index(self):
"""Index all files in the zip file"""
with zipfile.ZipFile(self.path, "r") as zip_file:
self.file_mapping = {
self._hash_path(file_info.filename): file_info
for file_info in zip_file.infolist()
if not file_info.is_dir()
}
def get_file_by_hash(self, file_hash: str):
"""Get file content by hash"""
if file_hash not in self.file_mapping:
return None
file_info = self.file_mapping[file_hash]
with zipfile.ZipFile(self.path, "r") as zip_file:
yield from BytesIO(zip_file.read(file_info.filename))
def get_filename_by_hash(self, file_hash: str) -> str:
"""Get filename by hash"""
if file_hash not in self.file_mapping:
return None
return self.file_mapping[file_hash].filename
INDEXER_MAP = {".zip": ZipFileIndexer}
def initialize_server(args: argparse.Namespace):
"""Initialize the server with directory indexing"""
global indexer
src_path = Path(args.source)
if not src_path.exists():
raise SystemExit(f"Source path {src_path} does not exist")
indexer = FileIndexer(src_path) if src_path.is_dir() else INDEXER_MAP[src_path.suffix](src_path)
print(f"Indexed {len(indexer.file_mapping)} files")
@app.get("/")
async def root():
"""Serve the Frontend app"""
return FileResponse("frontend.html")
@app.get("/random")
async def get_random_file():
"""Get a random file hash from the mapping"""
if not indexer.file_mapping:
raise HTTPException(status_code=404, detail="No files indexed")
random_hash = random.choice(list(indexer.file_mapping.keys()))
return {"img": random_hash}
@app.get("/{file_hash}")
async def get_file_by_hash(file_hash: str):
"""Serve a specific file by its hash"""
if file_hash not in indexer.file_mapping:
raise HTTPException(status_code=404, detail="File not found")
filename = indexer.get_filename_by_hash(file_hash)
content_type, _ = mimetypes.guess_type(filename)
if not content_type:
content_type = "application/octet-stream"
return StreamingResponse(
indexer.get_file_by_hash(file_hash),
media_type=content_type,
headers={
"Content-Disposition": f"inline; filename={os.path.basename(filename)}",
},
)
# Optional: Add a health check endpoint
@app.get("/health")
async def health_check():
return {"status": "healthy", "file_count": len(indexer.file_mapping)}
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the file server")
parser.add_argument("source", type=str, help="Path to directory or ZIP archive")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
args = parser.parse_args()
initialize_server(args)
import uvicorn
uvicorn.run(app, host=args.host, port=args.port)