jtftp/jtftp/filesystem/ondisk.py
2022-07-07 12:37:41 -05:00

123 lines
3.7 KiB
Python

# JTFTP - Python/AsyncIO TFTP Server
# Copyright (C) 2022 Jeffrey C. Ollie
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
# Warning!!! This does not actually work because aiofiles does not support
# usage outside of a context manager. See: https://github.com/Tinche/aiofiles/issues/139
#
import functools
import logging
import os
import pathlib
from typing import Callable
from typing import TypeVar
import aiofiles
import aiofiles.os
from jtftp.errors import AccessViolation
from jtftp.filesystem import FileMode
logger = logging.getLogger(__name__)
RT = TypeVar("RT")
def ensure_open(method: Callable[..., RT]) -> Callable[..., RT]:
@functools.wraps(method)
async def _impl(self, *args, **kwargs) -> RT:
logger.debug("ensure open")
if self.data is None:
await self.open()
return await method(self, *args, **kwargs)
return _impl
class OnDiskFile:
path: bytes
mode: FileMode
data: aiofiles.threadpool.binary.AsyncFileIO | None
closed: bool
def __init__(self, path: pathlib.PosixPath, mode: FileMode):
logger.debug(f"myfile {path} {mode}")
self.path = path
self.mode = mode
self.data = None
self.closed = True
async def open(self) -> None:
logger.debug("diskfile open")
self.data = aiofiles.open(self.path, self.mode.value)
self.closed = False
@ensure_open
async def length(self) -> int:
logger.debug("diskfile length")
cur = await self.data.seek(0, os.SEEK_CUR)
logger.debug(f"diskfile cur {cur}")
length = await self.data.seek(0, os.SEEK_END)
logger.debug(f"diskfile len {length}")
await self.data.seek(cur, os.SEEK_SET)
return length
@ensure_open
async def seek(self, offset: int, whence: int) -> int:
logger.debug(f"diskfile seek {offset} {whence}")
return await self.data.seek(offset, whence)
@ensure_open
async def read(self, length: int) -> bytes:
logger.debug(f"diskfile read {length}")
return await self.data.read(length)
@ensure_open
async def write(self, data: bytes) -> int:
logger.debug(f"diskfile write {len(data)}")
return await self.data.write(data)
@ensure_open
async def close(self, complete: bool) -> None:
logger.debug(f"diskfile close {complete}")
self.closed = True
if self.data is not None:
await self.data.close()
self.data = None
class ReadOnlyOnDiskFilesystem:
def __init__(self, root: pathlib.PosixPath):
self.root = root.resolve()
async def open(self, filename: bytes, mode: FileMode) -> OnDiskFile:
path = self.root.joinpath(filename.decode("ascii")).resolve()
try:
path.relative_to(self.root)
except ValueError:
raise AccessViolation("illegal directory traversal")
logger.debug(f"ro ondisk open {filename} {mode}")
match mode:
case FileMode.BINARY_READ:
return OnDiskFile(path, mode)
case FileMode.BINARY_WRITE:
raise AccessViolation("read-only")
case _:
raise ValueError