# JTFTP - Python/AsyncIO TFTP Server # Copyright (C) 2022 Jeffrey C. Ollie # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . import asyncio import functools import itertools import logging from collections import OrderedDict from jtftp.datagram import ACKDatagram from jtftp.datagram import DATADatagram from jtftp.datagram import Datagram from jtftp.datagram import ERRORDatagram from jtftp.datagram import OACKDatagram from jtftp.datagram import TFTPError from jtftp.datagram import TFTPMode from jtftp.datagram import TFTPOption from jtftp.datagram import datagram_factory from jtftp.filesystem import FileProtocol from jtftp.log import handle_task_result from jtftp.util import timed_caller logger = logging.getLogger(__name__) class RemoteOriginReadProtocol(asyncio.DatagramProtocol): block_size: int timeout: tuple[int, int, int] transfer_size: int | None retransmit_task: asyncio.Task | None last_block_number: int last_block_sent: False offered_options: dict[TFTPOption, bytes] accepted_options: dict[TFTPOption, bytes] local_tid = tuple[str, int] remote_tid = tuple[str, int] def __init__( self, *, file: FileProtocol, options: dict[TFTPOption, bytes], loop: asyncio.AbstractEventLoop, ) -> None: super().__init__() self.file = file self.loop = loop self.block_size = 512 self.timeout = (1, 3, 7) self.transfer_size = None self.last_block_number = 0 self.last_block_sent = False self.offered_options = options self.accepted_options = OrderedDict() self.remote_tid = tuple[str, int] def connection_made(self, transport: asyncio.transports.DatagramTransport) -> None: self.transport = transport self.local_tid = transport.get_extra_info("sockname")[:2] self.remote_tid = transport.get_extra_info("peername")[:2] task = self.loop.create_task(self._connection_made()) task.add_done_callback(handle_task_result) async def _connection_made(self): logger.debug(f"new session between {self.local_tid} → {self.remote_tid}") for name, value in self.offered_options.items(): logger.debug(f"{name!r} {value!r}") match name: case TFTPOption.BLOCKSIZE: if value >= 8 and value <= 65464: self.block_size = value self.accepted_options[name] = value case TFTPOption.TIMEOUT: self.timeout = (value,) * 3 self.accepted_options[name] = value case TFTPOption.TRANSFER_SIZE: self.transfer_size = await self.file.length() if self.transfer_size is not None: self.accepted_options[name] = self.transfer_size case _: logger.warning(f"unknown option {name}: {value}") if len(self.offered_options) != 0: data = OACKDatagram(self.accepted_options) else: data = await self._get_next_block() self.retransmit_task = self.loop.create_task( timed_caller( itertools.chain((0,), self.timeout), functools.partial(self.send, data), self._timed_out, ) ) self.retransmit_task.add_done_callback(handle_task_result) def connection_lost(self, exc: Exception | None) -> None: return super().connection_lost(exc) async def send(self, datagram: Datagram): self.transport.sendto(datagram.to_wire()) def datagram_received( self, datagram: bytes, addr: tuple[str, int] | tuple[str, int, int, int] ) -> None: tid = addr[:2] if self.remote_tid != tid: logger.error(f"packet from unknown sender {tid}") return task = self.loop.create_task(self._datagram_received(datagram)) task.add_done_callback(handle_task_result) async def _datagram_received(self, datagram: bytes) -> None: datagram = datagram_factory(datagram) if isinstance(datagram, ACKDatagram): if datagram.block_number == self.last_block_number: if ( self.retransmit_task is not None and not self.retransmit_task.cancelled() ): self.retransmit_task.cancel() self.retransmit_task = None if self.last_block_sent: await self.file.close(complete=True) self.transport.close() return data = await self._get_next_block() self.retransmit_task = asyncio.create_task( timed_caller( itertools.chain((0,), self.timeout), functools.partial(self.send, data), self._timed_out, ) ) self.retransmit_task.add_done_callback(handle_task_result) else: logger.warning( f"received ack for block number {datagram.block_number} - was expecting {self.last_block_number}" ) else: await self.send(ERRORDatagram.from_code(TFTPError.ILLEGAL_OPERATION)) async def _get_next_block(self) -> Datagram: payload = await self.file.read(self.block_size) if len(payload) < self.block_size: self.last_block_sent = True self.last_block_number = (self.last_block_number + 1) % 65536 return DATADatagram(self.last_block_number, payload) async def _timed_out(self): logger.debug("timed out") if self.retransmit_task is not None and not self.retransmit_task.cancelled(): self.retransmit_task.cancel() self.retransmit_task = None await self.file.close(complete=False) self.transport.close() def __del__(self): logger.debug("RemoteOriginReadProtocol __del__") def remote_origin_read_protocol( *, file: FileProtocol, options: dict[bytes, bytes], loop: asyncio.AbstractEventLoop, ) -> RemoteOriginReadProtocol: return RemoteOriginReadProtocol(file=file, options=options, loop=loop)