jtftp/jtftp/protocol/__init__.py
2022-07-07 12:37:41 -05:00

170 lines
6 KiB
Python

# JTFTP - Python/AsyncIO TFTP Server
# Copyright (C) 2022 Jeffrey C. Ollie
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import asyncio
import functools
import logging
from jtftp.datagram import Datagram
from jtftp.datagram import ERRORDatagram
from jtftp.datagram import RRQDatagram
from jtftp.datagram import TFTPError
from jtftp.datagram import TFTPMode
from jtftp.datagram import WRQDatagram
from jtftp.datagram import datagram_factory
from jtftp.errors import OptionsDecodeError
from jtftp.errors import PayloadDecodeError
from jtftp.filesystem import FileMode
from jtftp.filesystem import FilesystemProtocol
from jtftp.log import handle_task_result
from jtftp.netascii import NetAscii
from jtftp.netascii import NetAsciiReceiverProxy
from jtftp.netascii import NetAsciiSenderProxy
from jtftp.protocol.remote_origin_read import remote_origin_read_protocol
from jtftp.protocol.remote_origin_write import remote_origin_write_protocol
logger = logging.getLogger(__name__)
class TftpServerProtocol(asyncio.DatagramProtocol):
filesystem: FilesystemProtocol
loop: asyncio.AbstractEventLoop
def __init__(
self, filesystem: FilesystemProtocol, loop: asyncio.AbstractEventLoop = None
):
super().__init__()
self.filesystem = filesystem
if loop is None:
loop = asyncio.get_event_loop()
self.loop = loop
def connection_made(self, transport: asyncio.transports.DatagramTransport) -> None:
logger.debug(f"listening on made {transport.get_extra_info('sockname')[:2]}")
self.transport = transport
async def send(self, datagram: Datagram, tid: tuple[str, int]) -> None:
self.transport.sendto(datagram.to_wire(), tid)
def datagram_received(
self, data: bytes, addr: tuple[str, int] | tuple[str, int, int, int]
) -> None:
logger.debug(f"datagram received: {data!r} {addr!r}")
tid = addr[:2]
task = self.loop.create_task(self._datagram_received(data, tid))
task.add_done_callback(handle_task_result)
async def _datagram_received(self, data: bytes, tid: tuple[str, int]) -> None:
try:
datagram = datagram_factory(data)
except OptionsDecodeError as e:
await self.send(
ERRORDatagram.from_code(TFTPError.ILLEGAL_OPERATION, str(e)),
tid,
)
return
except PayloadDecodeError as e:
await self.send(
ERRORDatagram.from_code(TFTPError.ILLEGAL_OPERATION, str(e)),
tid,
)
return
if not isinstance(datagram, (RRQDatagram, WRQDatagram)):
logger.warning(
f"Datagram with unexpected opcode {datagram.opcode} was received without establishing the session. Ignoring."
)
return
if datagram.mode == TFTPMode.MAIL:
errmsg = f"Usupported transfer mode '{datagram.mode.decode('ascii')}'"
await self.send(
ERRORDatagram.from_code(TFTPError.ILLEGAL_OPERATION, errmsg),
tid,
)
return
match datagram:
case WRQDatagram():
try:
file = await self.filesystem.open(
datagram.filename, FileMode.BINARY_WRITE
)
if datagram.mode != TFTPMode.OCTET:
file = NetAsciiReceiverProxy(file)
protocol = functools.partial(
remote_origin_write_protocol,
file=file,
options=datagram.options,
loop=self.loop,
)
except PermissionError:
await self.send(
ERRORDatagram.from_code(TFTPError.ACCESS_VIOLATION),
tid,
)
return
except FileExistsError:
await self.send(
ERRORDatagram.from_code(TFTPError.FILE_NOT_FOUND),
tid,
)
return
case RRQDatagram():
try:
file = await self.filesystem.open(
datagram.filename, FileMode.BINARY_READ
)
if datagram.mode != TFTPMode.OCTET:
file = NetAsciiSenderProxy(file)
protocol = functools.partial(
remote_origin_read_protocol,
file=file,
options=datagram.options,
loop=self.loop,
)
except PermissionError:
await self.send(
ERRORDatagram.from_code(TFTPError.ACCESS_VIOLATION),
tid,
)
return
except FileNotFoundError:
await self.send(
ERRORDatagram.from_code(TFTPError.FILE_NOT_FOUND),
tid,
)
return
await self.loop.create_datagram_endpoint(
protocol,
local_addr=None,
remote_addr=tid,
)
def tftp_server_protocol_factory(
*, filesystem: FilesystemProtocol, loop: asyncio.AbstractEventLoop = None
) -> TftpServerProtocol:
return TftpServerProtocol(filesystem=filesystem, loop=loop)