# JTFTP - Python/AsyncIO TFTP Server # Copyright (C) 2022 Jeffrey C. Ollie # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . import asyncio import functools import itertools import logging from collections import OrderedDict from jtftp.datagram import ACKDatagram from jtftp.datagram import DATADatagram from jtftp.datagram import Datagram from jtftp.datagram import ERRORDatagram from jtftp.datagram import OACKDatagram from jtftp.datagram import TFTPError from jtftp.datagram import TFTPMode from jtftp.datagram import TFTPOption from jtftp.datagram import datagram_factory from jtftp.filesystem import FileProtocol from jtftp.log import handle_task_result from jtftp.util import timed_caller logger = logging.getLogger(__name__) class RemoteOriginWriteProtocol(asyncio.DatagramProtocol): block_size: int timeout: tuple[int, int, int] transfer_size: int | None retransmit_task: asyncio.Task | None offered_options: dict[TFTPOption, bytes] accepted_options: dict[TFTPOption, bytes] local_tid: tuple[str, int] remote_tid: tuple[str, int] def __init__( self, *, file: FileProtocol, options: dict[TFTPOption, bytes], loop: asyncio.AbstractEventLoop, ) -> None: super().__init__() self.file = file self.offered_options = options self.loop = loop self.block_size = 512 self.timeout = (1, 3, 7) self.transfer_size = None self.retransmit_task = None self.accepted_options = OrderedDict() self.local_tid = None self.remote_tid = None def connection_made(self, transport: asyncio.transports.DatagramTransport) -> None: self.transport = transport self.local_tid = transport.get_extra_info("sockname")[:2] self.remote_tid = self.transport.get_extra_info("peername")[:2] task = self.loop.create_task(self._connection_made()) task.add_done_callback(handle_task_result) async def _connection_made(self): logger.debug(f"new session between {self.local_tid} → {self.remote_tid}") for name, value in self.offered_options.items(): logger.debug(f"{name!r} {value!r}") match name: case TFTPOption.BLOCKSIZE: self.block_size = int(value) self.accepted_options[name] = value case TFTPOption.TIMEOUT: self.timeout = (int(value),) * 3 self.accepted_options[name] = value case TFTPOption.TRANSFER_SIZE: self.transfer_size = int(value) self.accepted_options[name] = value case _: logger.warning(f"unknown option {name}: {value}") if len(self.offered_options) == 0: data = ACKDatagram(0).to_wire() else: data = OACKDatagram(self.accepted_options).to_wire() self.retransmit_task = asyncio.create_task( timed_caller( itertools.chain((0,), self.timeout), functools.partial(self.transport.sendto, data), self._timed_out, ) ) def connection_lost(self, exc: Exception | None) -> None: logger.debug(f"closed session between {self.local_tid} → {self.remote_tid}") return super().connection_lost(exc) async def send(self, datagram: Datagram) -> None: self.transport.sendto(datagram.to_wire()) def datagram_received( self, datagram: bytes, addr: tuple[str, int] | tuple[str, int, int, int] ) -> None: logger.debug( f"RemoteOriginWriteProtocol data_received {len(datagram)} {addr!r}" ) tid = addr[:2] task = self.loop.create_task(self._datagram_received(datagram, tid)) task.add_done_callback(handle_task_result) async def _datagram_received(self, datagram: bytes, tid: tuple[str, int]) -> None: if self.remote_tid != tid: logger.error(f"received packet from wrong address: {tid}") if ( self.retransmit_task is not None and not self.retransmit_task.cancelled() ): self.retransmit_task.cancel() self.retransmit_task = None await self.file.close(complete=False) self.transport.close() return datagram = datagram_factory(datagram) if not isinstance(datagram, DATADatagram): if ( self.retransmit_task is not None and not self.retransmit_task.cancelled() ): self.retransmit_task.cancel() self.retransmit_task = None await self.file.close(complete=False) await self.send(ERRORDatagram.from_code(TFTPError.ILLEGAL_OPERATION)) self.transport.close() return logger.debug(f"RemoteOriginWriteProtocol data_received {datagram!r} {tid!r}") if self.retransmit_task is not None and not self.retransmit_task.cancelled(): logger.debug("cancelling old timeout task") self.retransmit_task.cancel() self.retransmit_task = None # need to check block numbers if len(datagram.payload) == self.block_size: self.file.write(datagram.payload) self.retransmit_task = asyncio.create_task( timed_caller( itertools.chain((0,), self.timeout), functools.partial( self.send, ACKDatagram(datagram.block_number), ), self._timed_out, ) ) self.retransmit_task.add_done_callback(handle_task_result) else: logger.debug(f"last data packet received") await self.file.write(datagram.payload) await self.file.close(complete=True) await self.send(ACKDatagram(datagram.block_number)) self.transport.close() async def _timed_out(self): logger.debug("timed out") if self.retransmit_task is not None and not self.retransmit_task.cancelled(): self.retransmit_task.cancel() self.retransmit_task = None await self.file.close(complete=False) self.transport.close() def __del__(self): logger.debug("RemoteOriginWriteProtocol __del__") def remote_origin_write_protocol( *, file: FileProtocol, options: dict[TFTPOption, bytes], loop: asyncio.AbstractEventLoop, ) -> RemoteOriginWriteProtocol: return RemoteOriginWriteProtocol(file=file, options=options, loop=loop)