Allow a glob source
This commit is contained in:
parent
82c32c2a0b
commit
5774ef7779
124
main.py
124
main.py
@ -5,6 +5,7 @@ import os
|
|||||||
import random
|
import random
|
||||||
import secrets
|
import secrets
|
||||||
import zipfile
|
import zipfile
|
||||||
|
from glob import glob
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -12,15 +13,15 @@ from fastapi import FastAPI, HTTPException
|
|||||||
from fastapi.responses import FileResponse, StreamingResponse
|
from fastapi.responses import FileResponse, StreamingResponse
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
indexer = None
|
file_mapping = {}
|
||||||
|
indexers = []
|
||||||
|
|
||||||
|
|
||||||
class FileIndexer:
|
class FileIndexer:
|
||||||
def __init__(self, path: str, salt: str | None = None):
|
def __init__(self, path: str, salt: str | None = None):
|
||||||
self.path = Path(path)
|
self.path = Path(path)
|
||||||
self.file_mapping = {}
|
|
||||||
self._salt = salt
|
self._salt = salt
|
||||||
self._index()
|
self._file_mapping = self._index()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def salt(self) -> str:
|
def salt(self) -> str:
|
||||||
@ -33,77 +34,91 @@ class FileIndexer:
|
|||||||
"""Generate a salted hash of the file path"""
|
"""Generate a salted hash of the file path"""
|
||||||
return hashlib.sha256((filepath + self.salt).encode()).hexdigest()
|
return hashlib.sha256((filepath + self.salt).encode()).hexdigest()
|
||||||
|
|
||||||
def _index(self):
|
def _index(self) -> dict[str, str]:
|
||||||
"""Index all files in the zip file"""
|
"""Index all files in the directory"""
|
||||||
|
mapping = {}
|
||||||
for root, _, files in os.walk(self.path):
|
for root, _, files in os.walk(self.path):
|
||||||
for file in files:
|
for file in files:
|
||||||
filepath = os.path.join(root, file)
|
filepath = os.path.join(root, file)
|
||||||
# Generate hash for the file path
|
|
||||||
file_hash = self._hash_path(filepath)
|
file_hash = self._hash_path(filepath)
|
||||||
# Store mapping
|
mapping[file_hash] = filepath
|
||||||
self.file_mapping[file_hash] = filepath
|
return mapping
|
||||||
|
|
||||||
def get_file_by_hash(self, file_hash: str):
|
def get_file_by_hash(self, file_hash: str):
|
||||||
"""Get file content by hash"""
|
"""Get file content by hash"""
|
||||||
if file_hash not in self.file_mapping:
|
if file_hash not in self._file_mapping:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
file_path = self.file_mapping[file_hash]
|
file_path = self._file_mapping[file_hash]
|
||||||
with open(file_path, "rb") as f:
|
with open(file_path, "rb") as f:
|
||||||
yield from f
|
yield from f
|
||||||
|
|
||||||
def get_filename_by_hash(self, file_hash: str) -> str:
|
def get_filename_by_hash(self, file_hash: str) -> str | None:
|
||||||
"""Get filename by hash"""
|
"""Get filename by hash"""
|
||||||
if file_hash not in self.file_mapping:
|
if file_hash not in self._file_mapping:
|
||||||
return None
|
return None
|
||||||
return self.file_mapping[file_hash]
|
return self._file_mapping[file_hash]
|
||||||
|
|
||||||
|
|
||||||
class ZipFileIndexer(FileIndexer):
|
class ZipFileIndexer(FileIndexer):
|
||||||
def _index(self):
|
def _index(self) -> dict[str, str]:
|
||||||
"""Index all files in the zip file"""
|
"""Index all files in the zip file"""
|
||||||
|
mapping = {}
|
||||||
with zipfile.ZipFile(self.path, "r") as zip_file:
|
with zipfile.ZipFile(self.path, "r") as zip_file:
|
||||||
self.file_mapping = {
|
for file_info in zip_file.infolist():
|
||||||
self._hash_path(file_info.filename): file_info
|
if not file_info.is_dir():
|
||||||
for file_info in zip_file.infolist()
|
file_hash = self._hash_path(file_info.filename)
|
||||||
if not file_info.is_dir()
|
mapping[file_hash] = file_info.filename
|
||||||
}
|
return mapping
|
||||||
|
|
||||||
def get_file_by_hash(self, file_hash: str):
|
def get_file_by_hash(self, file_hash: str):
|
||||||
"""Get file content by hash"""
|
"""Get file content by hash"""
|
||||||
if file_hash not in self.file_mapping:
|
if file_hash not in self._file_mapping:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
file_info = self.file_mapping[file_hash]
|
filename = self._file_mapping[file_hash]
|
||||||
|
|
||||||
with zipfile.ZipFile(self.path, "r") as zip_file:
|
with zipfile.ZipFile(self.path, "r") as zip_file:
|
||||||
yield from BytesIO(zip_file.read(file_info.filename))
|
yield from BytesIO(zip_file.read(filename))
|
||||||
|
|
||||||
def get_filename_by_hash(self, file_hash: str) -> str:
|
def get_filename_by_hash(self, file_hash: str) -> str | None:
|
||||||
"""Get filename by hash"""
|
"""Get filename by hash"""
|
||||||
if file_hash not in self.file_mapping:
|
if file_hash not in self._file_mapping:
|
||||||
return None
|
return None
|
||||||
return self.file_mapping[file_hash].filename
|
return self._file_mapping[file_hash]
|
||||||
|
|
||||||
|
|
||||||
INDEXER_MAP = {".zip": ZipFileIndexer}
|
INDEXER_MAP = {".zip": ZipFileIndexer}
|
||||||
|
|
||||||
|
|
||||||
def initialize_server(args: argparse.Namespace):
|
def initialize_server(args: argparse.Namespace):
|
||||||
"""Initialize the server with directory indexing"""
|
"""Initialize the server with directory or glob indexing"""
|
||||||
global indexer
|
global file_mapping, indexers
|
||||||
|
|
||||||
src_path = Path(args.source)
|
src_path = Path(args.source)
|
||||||
if not src_path.exists():
|
|
||||||
raise SystemExit(f"Source path {src_path} does not exist")
|
|
||||||
|
|
||||||
indexer = (
|
shared_salt = args.salt
|
||||||
FileIndexer(src_path, args.salt)
|
if shared_salt is None:
|
||||||
if src_path.is_dir()
|
shared_salt = secrets.token_hex(16)
|
||||||
else INDEXER_MAP[src_path.suffix](src_path, args.salt)
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Indexed {len(indexer.file_mapping)} files")
|
if src_path.is_dir():
|
||||||
|
indexer = FileIndexer(str(src_path), shared_salt)
|
||||||
|
indexers.append(indexer)
|
||||||
|
file_mapping.update(indexer._file_mapping)
|
||||||
|
else:
|
||||||
|
pattern = args.source
|
||||||
|
matching_files = glob(pattern)
|
||||||
|
if not matching_files:
|
||||||
|
raise SystemExit(f"No files match pattern {pattern}")
|
||||||
|
|
||||||
|
for file_path in matching_files:
|
||||||
|
file_ext = Path(file_path).suffix
|
||||||
|
if file_ext in INDEXER_MAP:
|
||||||
|
indexer = INDEXER_MAP[file_ext](file_path, shared_salt)
|
||||||
|
indexers.append(indexer)
|
||||||
|
file_mapping.update(indexer._file_mapping)
|
||||||
|
|
||||||
|
print(f"Indexed {len(file_mapping)} files from {len(indexers)} source(s)")
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
@ -115,10 +130,10 @@ async def root():
|
|||||||
@app.get("/random")
|
@app.get("/random")
|
||||||
async def get_random_file():
|
async def get_random_file():
|
||||||
"""Get random file hashes from the mapping"""
|
"""Get random file hashes from the mapping"""
|
||||||
if not indexer.file_mapping:
|
if not file_mapping:
|
||||||
raise HTTPException(status_code=404, detail="No files indexed")
|
raise HTTPException(status_code=404, detail="No files indexed")
|
||||||
|
|
||||||
keys = list(indexer.file_mapping.keys())
|
keys = list(file_mapping.keys())
|
||||||
random_idx = random.randint(0, len(keys) - 1)
|
random_idx = random.randint(0, len(keys) - 1)
|
||||||
current = keys[random_idx]
|
current = keys[random_idx]
|
||||||
next_hash = keys[(random_idx + 1) % len(keys)]
|
next_hash = keys[(random_idx + 1) % len(keys)]
|
||||||
@ -126,14 +141,26 @@ async def get_random_file():
|
|||||||
return {"img": current, "next": next_hash, "previous": prev_hash}
|
return {"img": current, "next": next_hash, "previous": prev_hash}
|
||||||
|
|
||||||
|
|
||||||
|
def _find_indexer_for_hash(file_hash: str):
|
||||||
|
"""Find the indexer that contains the file with the given hash"""
|
||||||
|
for idx in indexers:
|
||||||
|
if file_hash in idx._file_mapping:
|
||||||
|
return idx
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@app.get("/{file_hash}/data")
|
@app.get("/{file_hash}/data")
|
||||||
async def get_file_data(file_hash: str):
|
async def get_file_data(file_hash: str):
|
||||||
"""Serve a specific file by its hash"""
|
"""Serve a specific file by its hash"""
|
||||||
if file_hash not in indexer.file_mapping:
|
if file_hash not in file_mapping:
|
||||||
|
raise HTTPException(status_code=404, detail="File not found")
|
||||||
|
|
||||||
|
indexer = _find_indexer_for_hash(file_hash)
|
||||||
|
if not indexer:
|
||||||
raise HTTPException(status_code=404, detail="File not found")
|
raise HTTPException(status_code=404, detail="File not found")
|
||||||
|
|
||||||
filename = indexer.get_filename_by_hash(file_hash)
|
filename = indexer.get_filename_by_hash(file_hash)
|
||||||
content_type, _ = mimetypes.guess_type(filename)
|
content_type, _ = mimetypes.guess_type(filename or "")
|
||||||
if not content_type:
|
if not content_type:
|
||||||
content_type = "application/octet-stream"
|
content_type = "application/octet-stream"
|
||||||
|
|
||||||
@ -141,7 +168,7 @@ async def get_file_data(file_hash: str):
|
|||||||
indexer.get_file_by_hash(file_hash),
|
indexer.get_file_by_hash(file_hash),
|
||||||
media_type=content_type,
|
media_type=content_type,
|
||||||
headers={
|
headers={
|
||||||
"Content-Disposition": f"inline; filename={os.path.basename(filename)}",
|
"Content-Disposition": f"inline; filename={os.path.basename(filename or '')}",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -149,13 +176,18 @@ async def get_file_data(file_hash: str):
|
|||||||
@app.get("/{file_hash}")
|
@app.get("/{file_hash}")
|
||||||
async def get_file_info(file_hash: str):
|
async def get_file_info(file_hash: str):
|
||||||
"""Get file info by hash"""
|
"""Get file info by hash"""
|
||||||
if file_hash not in indexer.file_mapping:
|
if file_hash not in file_mapping:
|
||||||
raise HTTPException(status_code=404, detail="File not found")
|
raise HTTPException(status_code=404, detail="File not found")
|
||||||
|
|
||||||
keys = list(indexer.file_mapping.keys())
|
keys = list(file_mapping.keys())
|
||||||
idx = keys.index(file_hash)
|
idx = keys.index(file_hash)
|
||||||
next_hash = keys[(idx + 1) % len(keys)]
|
next_hash = keys[(idx + 1) % len(keys)]
|
||||||
prev_hash = keys[idx - 1] if idx > 0 else keys[-1]
|
prev_hash = keys[idx - 1] if idx > 0 else keys[-1]
|
||||||
|
|
||||||
|
indexer = _find_indexer_for_hash(file_hash)
|
||||||
|
if not indexer:
|
||||||
|
raise HTTPException(status_code=404, detail="File not found")
|
||||||
|
|
||||||
filename = indexer.get_filename_by_hash(file_hash)
|
filename = indexer.get_filename_by_hash(file_hash)
|
||||||
return {"img": file_hash, "next": next_hash, "previous": prev_hash, "filename": filename}
|
return {"img": file_hash, "next": next_hash, "previous": prev_hash, "filename": filename}
|
||||||
|
|
||||||
@ -163,12 +195,16 @@ async def get_file_info(file_hash: str):
|
|||||||
# Optional: Add a health check endpoint
|
# Optional: Add a health check endpoint
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health_check():
|
async def health_check():
|
||||||
return {"status": "healthy", "file_count": len(indexer.file_mapping)}
|
return {"status": "healthy", "file_count": len(file_mapping)}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run the file server")
|
parser = argparse.ArgumentParser(description="Run the file server")
|
||||||
parser.add_argument("source", type=str, help="Path to directory or ZIP archive")
|
parser.add_argument(
|
||||||
|
"source",
|
||||||
|
type=str,
|
||||||
|
help="Path to directory, ZIP archive, or glob pattern (e.g., *.zip, path/to/zips/*.zip)",
|
||||||
|
)
|
||||||
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
|
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
|
||||||
parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
|
parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
|
||||||
parser.add_argument("--salt", type=str, default=None, help="Salt for hashing file paths")
|
parser.add_argument("--salt", type=str, default=None, help="Salt for hashing file paths")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user