172 lines
5.1 KiB
Python
172 lines
5.1 KiB
Python
import argparse
|
|
import hashlib
|
|
import mimetypes
|
|
import os
|
|
import random
|
|
import secrets
|
|
import zipfile
|
|
from io import BytesIO
|
|
from pathlib import Path
|
|
|
|
from fastapi import FastAPI, HTTPException
|
|
from fastapi.responses import RedirectResponse, StreamingResponse
|
|
|
|
app = FastAPI()
|
|
indexer = None
|
|
|
|
|
|
class FileIndexer:
|
|
def __init__(self, path: str):
|
|
self.path = Path(path)
|
|
self.file_mapping = {}
|
|
self._salt = None
|
|
self._index()
|
|
|
|
@property
|
|
def salt(self) -> str:
|
|
"""Generate a random salt for hashing"""
|
|
if not self._salt:
|
|
self._salt = secrets.token_hex(16)
|
|
return self._salt
|
|
|
|
def _hash_path(self, filepath: str) -> str:
|
|
"""Generate a salted hash of the file path"""
|
|
return hashlib.sha256((filepath + self.salt).encode()).hexdigest()
|
|
|
|
def _index(self):
|
|
"""Index all files in the zip file"""
|
|
for root, _, files in os.walk(self.path):
|
|
for file in files:
|
|
filepath = os.path.join(root, file)
|
|
# Generate hash for the file path
|
|
file_hash = self._hash_path(filepath)
|
|
# Store mapping
|
|
self.file_mapping[file_hash] = filepath
|
|
|
|
def get_file_by_hash(self, file_hash: str):
|
|
"""Get file content by hash"""
|
|
if file_hash not in self.file_mapping:
|
|
return None
|
|
|
|
file_path = self.file_mapping[file_hash]
|
|
with open(file_path, "rb") as f:
|
|
yield from f
|
|
|
|
def get_filename_by_hash(self, file_hash: str) -> str:
|
|
"""Get filename by hash"""
|
|
if file_hash not in self.file_mapping:
|
|
return None
|
|
return self.file_mapping[file_hash]
|
|
|
|
|
|
class ZipFileIndexer(FileIndexer):
|
|
def _index(self):
|
|
"""Index all files in the zip file"""
|
|
with zipfile.ZipFile(self.path, "r") as zip_file:
|
|
self.file_mapping = {
|
|
self._hash_path(file_info.filename): file_info
|
|
for file_info in zip_file.infolist()
|
|
if not file_info.is_dir()
|
|
}
|
|
|
|
def get_file_by_hash(self, file_hash: str):
|
|
"""Get file content by hash"""
|
|
if file_hash not in self.file_mapping:
|
|
return None
|
|
|
|
file_info = self.file_mapping[file_hash]
|
|
|
|
with zipfile.ZipFile(self.path, "r") as zip_file:
|
|
yield from BytesIO(zip_file.read(file_info.filename))
|
|
|
|
def get_filename_by_hash(self, file_hash: str) -> str:
|
|
"""Get filename by hash"""
|
|
if file_hash not in self.file_mapping:
|
|
return None
|
|
return self.file_mapping[file_hash].filename
|
|
|
|
|
|
INDEXER_MAP = {".zip": ZipFileIndexer}
|
|
|
|
|
|
def initialize_server(args: argparse.Namespace):
|
|
"""Initialize the server with directory indexing"""
|
|
global indexer
|
|
|
|
src_path = Path(args.source)
|
|
if not src_path.exists():
|
|
raise SystemExit(f"Source path {src_path} does not exist")
|
|
|
|
indexer = FileIndexer(src_path) if src_path.is_dir() else INDEXER_MAP[src_path.suffix](src_path)
|
|
|
|
print(f"Indexed {len(indexer.file_mapping)} files")
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
"""Redirect to /random"""
|
|
return RedirectResponse(url="/random")
|
|
|
|
|
|
@app.get("/random")
|
|
async def get_random_file():
|
|
"""Serve a random file from the mapping"""
|
|
if not indexer.file_mapping:
|
|
raise HTTPException(status_code=404, detail="No files indexed")
|
|
|
|
random_hash = random.choice(list(indexer.file_mapping.keys()))
|
|
|
|
filename = indexer.get_filename_by_hash(random_hash)
|
|
content_type, _ = mimetypes.guess_type(filename)
|
|
if not content_type:
|
|
content_type = "application/octet-stream"
|
|
buffer = indexer.get_file_by_hash(random_hash)
|
|
response = StreamingResponse(
|
|
buffer,
|
|
media_type=content_type,
|
|
headers={
|
|
"Content-Disposition": f"inline; filename={os.path.basename(filename)}",
|
|
"Cache-Control": "no-cache, no-store, must-revalidate",
|
|
"Pragma": "no-cache",
|
|
"Expires": "0",
|
|
},
|
|
)
|
|
return response
|
|
|
|
|
|
@app.get("/{file_hash}")
|
|
async def get_file_by_hash(file_hash: str):
|
|
"""Serve a specific file by its hash"""
|
|
if file_hash not in indexer.file_mapping:
|
|
raise HTTPException(status_code=404, detail="File not found")
|
|
|
|
file_path, content_type, buffer = indexer.get_file_by_hash(file_hash)
|
|
|
|
return StreamingResponse(
|
|
buffer,
|
|
media_type=content_type,
|
|
headers={
|
|
"Content-Disposition": f"inline; filename={os.path.basename(file_path)}",
|
|
},
|
|
)
|
|
|
|
|
|
# Optional: Add a health check endpoint
|
|
@app.get("/health")
|
|
async def health_check():
|
|
return {"status": "healthy", "file_count": len(indexer.file_mapping)}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Run the file server")
|
|
parser.add_argument("source", type=str, help="Path to directory or ZIP archive")
|
|
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
|
|
parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
|
|
args = parser.parse_args()
|
|
|
|
initialize_server(args)
|
|
|
|
import uvicorn
|
|
|
|
uvicorn.run(app, host=args.host, port=args.port)
|