jtftp/jtftp/protocol/remote_origin_read.py
2022-07-07 12:37:41 -05:00

192 lines
6.8 KiB
Python

# JTFTP - Python/AsyncIO TFTP Server
# Copyright (C) 2022 Jeffrey C. Ollie
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import asyncio
import functools
import itertools
import logging
from collections import OrderedDict
from jtftp.datagram import ACKDatagram
from jtftp.datagram import DATADatagram
from jtftp.datagram import Datagram
from jtftp.datagram import ERRORDatagram
from jtftp.datagram import OACKDatagram
from jtftp.datagram import TFTPError
from jtftp.datagram import TFTPMode
from jtftp.datagram import TFTPOption
from jtftp.datagram import datagram_factory
from jtftp.filesystem import FileProtocol
from jtftp.log import handle_task_result
from jtftp.util import timed_caller
logger = logging.getLogger(__name__)
class RemoteOriginReadProtocol(asyncio.DatagramProtocol):
block_size: int
timeout: tuple[int, int, int]
transfer_size: int | None
retransmit_task: asyncio.Task | None
last_block_number: int
last_block_sent: False
offered_options: dict[TFTPOption, bytes]
accepted_options: dict[TFTPOption, bytes]
local_tid = tuple[str, int]
remote_tid = tuple[str, int]
def __init__(
self,
*,
file: FileProtocol,
options: dict[TFTPOption, bytes],
loop: asyncio.AbstractEventLoop,
) -> None:
super().__init__()
self.file = file
self.loop = loop
self.block_size = 512
self.timeout = (1, 3, 7)
self.transfer_size = None
self.last_block_number = 0
self.last_block_sent = False
self.offered_options = options
self.accepted_options = OrderedDict()
self.remote_tid = tuple[str, int]
def connection_made(self, transport: asyncio.transports.DatagramTransport) -> None:
self.transport = transport
self.local_tid = transport.get_extra_info("sockname")[:2]
self.remote_tid = transport.get_extra_info("peername")[:2]
task = self.loop.create_task(self._connection_made())
task.add_done_callback(handle_task_result)
async def _connection_made(self):
logger.debug(f"new session between {self.local_tid}{self.remote_tid}")
for name, value in self.offered_options.items():
logger.debug(f"{name!r} {value!r}")
match name:
case TFTPOption.BLOCKSIZE:
if value >= 8 and value <= 65464:
self.block_size = value
self.accepted_options[name] = value
case TFTPOption.TIMEOUT:
self.timeout = (value,) * 3
self.accepted_options[name] = value
case TFTPOption.TRANSFER_SIZE:
self.transfer_size = await self.file.length()
if self.transfer_size is not None:
self.accepted_options[name] = self.transfer_size
case _:
logger.warning(f"unknown option {name}: {value}")
if len(self.offered_options) != 0:
data = OACKDatagram(self.accepted_options)
else:
data = await self._get_next_block()
self.retransmit_task = self.loop.create_task(
timed_caller(
itertools.chain((0,), self.timeout),
functools.partial(self.send, data),
self._timed_out,
)
)
self.retransmit_task.add_done_callback(handle_task_result)
def connection_lost(self, exc: Exception | None) -> None:
return super().connection_lost(exc)
async def send(self, datagram: Datagram):
self.transport.sendto(datagram.to_wire())
def datagram_received(
self, datagram: bytes, addr: tuple[str, int] | tuple[str, int, int, int]
) -> None:
tid = addr[:2]
if self.remote_tid != tid:
logger.error(f"packet from unknown sender {tid}")
return
task = self.loop.create_task(self._datagram_received(datagram))
task.add_done_callback(handle_task_result)
async def _datagram_received(self, datagram: bytes) -> None:
datagram = datagram_factory(datagram)
if isinstance(datagram, ACKDatagram):
if datagram.block_number == self.last_block_number:
if (
self.retransmit_task is not None
and not self.retransmit_task.cancelled()
):
self.retransmit_task.cancel()
self.retransmit_task = None
if self.last_block_sent:
await self.file.close(complete=True)
self.transport.close()
return
data = await self._get_next_block()
self.retransmit_task = asyncio.create_task(
timed_caller(
itertools.chain((0,), self.timeout),
functools.partial(self.send, data),
self._timed_out,
)
)
self.retransmit_task.add_done_callback(handle_task_result)
else:
logger.warning(
f"received ack for block number {datagram.block_number} - was expecting {self.last_block_number}"
)
else:
await self.send(ERRORDatagram.from_code(TFTPError.ILLEGAL_OPERATION))
async def _get_next_block(self) -> Datagram:
payload = await self.file.read(self.block_size)
if len(payload) < self.block_size:
self.last_block_sent = True
self.last_block_number = (self.last_block_number + 1) % 65536
return DATADatagram(self.last_block_number, payload)
async def _timed_out(self):
logger.debug("timed out")
if self.retransmit_task is not None and not self.retransmit_task.cancelled():
self.retransmit_task.cancel()
self.retransmit_task = None
await self.file.close(complete=False)
self.transport.close()
def __del__(self):
logger.debug("RemoteOriginReadProtocol __del__")
def remote_origin_read_protocol(
*,
file: FileProtocol,
options: dict[bytes, bytes],
loop: asyncio.AbstractEventLoop,
) -> RemoteOriginReadProtocol:
return RemoteOriginReadProtocol(file=file, options=options, loop=loop)