# JTFTP - Python/AsyncIO TFTP Server # Copyright (C) 2022 Jeffrey C. Ollie # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # # Warning!!! This does not actually work because aiofiles does not support # usage outside of a context manager. See: https://github.com/Tinche/aiofiles/issues/139 # import functools import logging import os import pathlib from typing import Callable from typing import TypeVar import aiofiles import aiofiles.os from jtftp.errors import AccessViolation from jtftp.filesystem import FileMode logger = logging.getLogger(__name__) RT = TypeVar("RT") def ensure_open(method: Callable[..., RT]) -> Callable[..., RT]: @functools.wraps(method) async def _impl(self, *args, **kwargs) -> RT: logger.debug("ensure open") if self.data is None: await self.open() return await method(self, *args, **kwargs) return _impl class OnDiskFile: path: bytes mode: FileMode data: aiofiles.threadpool.binary.AsyncFileIO | None closed: bool def __init__(self, path: pathlib.PosixPath, mode: FileMode): logger.debug(f"myfile {path} {mode}") self.path = path self.mode = mode self.data = None self.closed = True async def open(self) -> None: logger.debug("diskfile open") self.data = aiofiles.open(self.path, self.mode.value) self.closed = False @ensure_open async def length(self) -> int: logger.debug("diskfile length") cur = await self.data.seek(0, os.SEEK_CUR) logger.debug(f"diskfile cur {cur}") length = await self.data.seek(0, os.SEEK_END) logger.debug(f"diskfile len {length}") await self.data.seek(cur, os.SEEK_SET) return length @ensure_open async def seek(self, offset: int, whence: int) -> int: logger.debug(f"diskfile seek {offset} {whence}") return await self.data.seek(offset, whence) @ensure_open async def read(self, length: int) -> bytes: logger.debug(f"diskfile read {length}") return await self.data.read(length) @ensure_open async def write(self, data: bytes) -> int: logger.debug(f"diskfile write {len(data)}") return await self.data.write(data) @ensure_open async def close(self, complete: bool) -> None: logger.debug(f"diskfile close {complete}") self.closed = True if self.data is not None: await self.data.close() self.data = None class ReadOnlyOnDiskFilesystem: def __init__(self, root: pathlib.PosixPath): self.root = root.resolve() async def open(self, filename: bytes, mode: FileMode) -> OnDiskFile: path = self.root.joinpath(filename.decode("ascii")).resolve() try: path.relative_to(self.root) except ValueError: raise AccessViolation("illegal directory traversal") logger.debug(f"ro ondisk open {filename} {mode}") match mode: case FileMode.BINARY_READ: return OnDiskFile(path, mode) case FileMode.BINARY_WRITE: raise AccessViolation("read-only") case _: raise ValueError