# JTFTP - Python/AsyncIO TFTP Server # Copyright (C) 2022 Jeffrey C. Ollie # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . import asyncio import functools import logging from jtftp.datagram import Datagram from jtftp.datagram import ERRORDatagram from jtftp.datagram import RRQDatagram from jtftp.datagram import TFTPError from jtftp.datagram import TFTPMode from jtftp.datagram import WRQDatagram from jtftp.datagram import datagram_factory from jtftp.errors import OptionsDecodeError from jtftp.errors import PayloadDecodeError from jtftp.filesystem import FileMode from jtftp.filesystem import FilesystemProtocol from jtftp.log import handle_task_result from jtftp.netascii import NetAscii from jtftp.netascii import NetAsciiReceiverProxy from jtftp.netascii import NetAsciiSenderProxy from jtftp.protocol.remote_origin_read import remote_origin_read_protocol from jtftp.protocol.remote_origin_write import remote_origin_write_protocol logger = logging.getLogger(__name__) class TftpServerProtocol(asyncio.DatagramProtocol): filesystem: FilesystemProtocol loop: asyncio.AbstractEventLoop def __init__( self, filesystem: FilesystemProtocol, loop: asyncio.AbstractEventLoop = None ): super().__init__() self.filesystem = filesystem if loop is None: loop = asyncio.get_event_loop() self.loop = loop def connection_made(self, transport: asyncio.transports.DatagramTransport) -> None: logger.debug(f"listening on made {transport.get_extra_info('sockname')[:2]}") self.transport = transport async def send(self, datagram: Datagram, tid: tuple[str, int]) -> None: self.transport.sendto(datagram.to_wire(), tid) def datagram_received( self, data: bytes, addr: tuple[str, int] | tuple[str, int, int, int] ) -> None: logger.debug(f"datagram received: {data!r} {addr!r}") tid = addr[:2] task = self.loop.create_task(self._datagram_received(data, tid)) task.add_done_callback(handle_task_result) async def _datagram_received(self, data: bytes, tid: tuple[str, int]) -> None: try: datagram = datagram_factory(data) except OptionsDecodeError as e: await self.send( ERRORDatagram.from_code(TFTPError.ILLEGAL_OPERATION, str(e)), tid, ) return except PayloadDecodeError as e: await self.send( ERRORDatagram.from_code(TFTPError.ILLEGAL_OPERATION, str(e)), tid, ) return if not isinstance(datagram, (RRQDatagram, WRQDatagram)): logger.warning( f"Datagram with unexpected opcode {datagram.opcode} was received without establishing the session. Ignoring." ) return if datagram.mode == TFTPMode.MAIL: errmsg = f"Usupported transfer mode '{datagram.mode.decode('ascii')}'" await self.send( ERRORDatagram.from_code(TFTPError.ILLEGAL_OPERATION, errmsg), tid, ) return match datagram: case WRQDatagram(): try: file = await self.filesystem.open( datagram.filename, FileMode.BINARY_WRITE ) if datagram.mode != TFTPMode.OCTET: file = NetAsciiReceiverProxy(file) protocol = functools.partial( remote_origin_write_protocol, file=file, options=datagram.options, loop=self.loop, ) except PermissionError: await self.send( ERRORDatagram.from_code(TFTPError.ACCESS_VIOLATION), tid, ) return except FileExistsError: await self.send( ERRORDatagram.from_code(TFTPError.FILE_NOT_FOUND), tid, ) return case RRQDatagram(): try: file = await self.filesystem.open( datagram.filename, FileMode.BINARY_READ ) if datagram.mode != TFTPMode.OCTET: file = NetAsciiSenderProxy(file) protocol = functools.partial( remote_origin_read_protocol, file=file, options=datagram.options, loop=self.loop, ) except PermissionError: await self.send( ERRORDatagram.from_code(TFTPError.ACCESS_VIOLATION), tid, ) return except FileNotFoundError: await self.send( ERRORDatagram.from_code(TFTPError.FILE_NOT_FOUND), tid, ) return await self.loop.create_datagram_endpoint( protocol, local_addr=None, remote_addr=tid, ) def tftp_server_protocol_factory( *, filesystem: FilesystemProtocol, loop: asyncio.AbstractEventLoop = None ) -> TftpServerProtocol: return TftpServerProtocol(filesystem=filesystem, loop=loop)