123 lines
3.7 KiB
Python
123 lines
3.7 KiB
Python
|
# JTFTP - Python/AsyncIO TFTP Server
|
||
|
# Copyright (C) 2022 Jeffrey C. Ollie
|
||
|
|
||
|
# This program is free software: you can redistribute it and/or modify
|
||
|
# it under the terms of the GNU General Public License as published by
|
||
|
# the Free Software Foundation, either version 3 of the License, or
|
||
|
# (at your option) any later version.
|
||
|
#
|
||
|
# This program is distributed in the hope that it will be useful,
|
||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||
|
# GNU General Public License for more details.
|
||
|
#
|
||
|
# You should have received a copy of the GNU General Public License
|
||
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||
|
|
||
|
#
|
||
|
# Warning!!! This does not actually work because aiofiles does not support
|
||
|
# usage outside of a context manager. See: https://github.com/Tinche/aiofiles/issues/139
|
||
|
#
|
||
|
|
||
|
import functools
|
||
|
import logging
|
||
|
import os
|
||
|
import pathlib
|
||
|
from typing import Callable
|
||
|
from typing import TypeVar
|
||
|
|
||
|
import aiofiles
|
||
|
import aiofiles.os
|
||
|
from jtftp.errors import AccessViolation
|
||
|
from jtftp.filesystem import FileMode
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
RT = TypeVar("RT")
|
||
|
|
||
|
|
||
|
def ensure_open(method: Callable[..., RT]) -> Callable[..., RT]:
|
||
|
@functools.wraps(method)
|
||
|
async def _impl(self, *args, **kwargs) -> RT:
|
||
|
logger.debug("ensure open")
|
||
|
if self.data is None:
|
||
|
await self.open()
|
||
|
return await method(self, *args, **kwargs)
|
||
|
|
||
|
return _impl
|
||
|
|
||
|
|
||
|
class OnDiskFile:
|
||
|
path: bytes
|
||
|
mode: FileMode
|
||
|
data: aiofiles.threadpool.binary.AsyncFileIO | None
|
||
|
closed: bool
|
||
|
|
||
|
def __init__(self, path: pathlib.PosixPath, mode: FileMode):
|
||
|
logger.debug(f"myfile {path} {mode}")
|
||
|
self.path = path
|
||
|
self.mode = mode
|
||
|
self.data = None
|
||
|
self.closed = True
|
||
|
|
||
|
async def open(self) -> None:
|
||
|
logger.debug("diskfile open")
|
||
|
self.data = aiofiles.open(self.path, self.mode.value)
|
||
|
self.closed = False
|
||
|
|
||
|
@ensure_open
|
||
|
async def length(self) -> int:
|
||
|
logger.debug("diskfile length")
|
||
|
cur = await self.data.seek(0, os.SEEK_CUR)
|
||
|
logger.debug(f"diskfile cur {cur}")
|
||
|
length = await self.data.seek(0, os.SEEK_END)
|
||
|
logger.debug(f"diskfile len {length}")
|
||
|
await self.data.seek(cur, os.SEEK_SET)
|
||
|
return length
|
||
|
|
||
|
@ensure_open
|
||
|
async def seek(self, offset: int, whence: int) -> int:
|
||
|
logger.debug(f"diskfile seek {offset} {whence}")
|
||
|
return await self.data.seek(offset, whence)
|
||
|
|
||
|
@ensure_open
|
||
|
async def read(self, length: int) -> bytes:
|
||
|
logger.debug(f"diskfile read {length}")
|
||
|
return await self.data.read(length)
|
||
|
|
||
|
@ensure_open
|
||
|
async def write(self, data: bytes) -> int:
|
||
|
logger.debug(f"diskfile write {len(data)}")
|
||
|
return await self.data.write(data)
|
||
|
|
||
|
@ensure_open
|
||
|
async def close(self, complete: bool) -> None:
|
||
|
logger.debug(f"diskfile close {complete}")
|
||
|
self.closed = True
|
||
|
if self.data is not None:
|
||
|
await self.data.close()
|
||
|
self.data = None
|
||
|
|
||
|
|
||
|
class ReadOnlyOnDiskFilesystem:
|
||
|
def __init__(self, root: pathlib.PosixPath):
|
||
|
self.root = root.resolve()
|
||
|
|
||
|
async def open(self, filename: bytes, mode: FileMode) -> OnDiskFile:
|
||
|
path = self.root.joinpath(filename.decode("ascii")).resolve()
|
||
|
try:
|
||
|
path.relative_to(self.root)
|
||
|
except ValueError:
|
||
|
raise AccessViolation("illegal directory traversal")
|
||
|
|
||
|
logger.debug(f"ro ondisk open {filename} {mode}")
|
||
|
match mode:
|
||
|
case FileMode.BINARY_READ:
|
||
|
return OnDiskFile(path, mode)
|
||
|
|
||
|
case FileMode.BINARY_WRITE:
|
||
|
raise AccessViolation("read-only")
|
||
|
|
||
|
case _:
|
||
|
raise ValueError
|