jtftp/jtftp/protocol/remote_origin_write.py
2022-07-07 12:37:41 -05:00

203 lines
7.1 KiB
Python

# JTFTP - Python/AsyncIO TFTP Server
# Copyright (C) 2022 Jeffrey C. Ollie
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import asyncio
import functools
import itertools
import logging
from collections import OrderedDict
from jtftp.datagram import ACKDatagram
from jtftp.datagram import DATADatagram
from jtftp.datagram import Datagram
from jtftp.datagram import ERRORDatagram
from jtftp.datagram import OACKDatagram
from jtftp.datagram import TFTPError
from jtftp.datagram import TFTPMode
from jtftp.datagram import TFTPOption
from jtftp.datagram import datagram_factory
from jtftp.filesystem import FileProtocol
from jtftp.log import handle_task_result
from jtftp.util import timed_caller
logger = logging.getLogger(__name__)
class RemoteOriginWriteProtocol(asyncio.DatagramProtocol):
block_size: int
timeout: tuple[int, int, int]
transfer_size: int | None
retransmit_task: asyncio.Task | None
offered_options: dict[TFTPOption, bytes]
accepted_options: dict[TFTPOption, bytes]
local_tid: tuple[str, int]
remote_tid: tuple[str, int]
def __init__(
self,
*,
file: FileProtocol,
options: dict[TFTPOption, bytes],
loop: asyncio.AbstractEventLoop,
) -> None:
super().__init__()
self.file = file
self.offered_options = options
self.loop = loop
self.block_size = 512
self.timeout = (1, 3, 7)
self.transfer_size = None
self.retransmit_task = None
self.accepted_options = OrderedDict()
self.local_tid = None
self.remote_tid = None
def connection_made(self, transport: asyncio.transports.DatagramTransport) -> None:
self.transport = transport
self.local_tid = transport.get_extra_info("sockname")[:2]
self.remote_tid = self.transport.get_extra_info("peername")[:2]
task = self.loop.create_task(self._connection_made())
task.add_done_callback(handle_task_result)
async def _connection_made(self):
logger.debug(f"new session between {self.local_tid}{self.remote_tid}")
for name, value in self.offered_options.items():
logger.debug(f"{name!r} {value!r}")
match name:
case TFTPOption.BLOCKSIZE:
self.block_size = int(value)
self.accepted_options[name] = value
case TFTPOption.TIMEOUT:
self.timeout = (int(value),) * 3
self.accepted_options[name] = value
case TFTPOption.TRANSFER_SIZE:
self.transfer_size = int(value)
self.accepted_options[name] = value
case _:
logger.warning(f"unknown option {name}: {value}")
if len(self.offered_options) == 0:
data = ACKDatagram(0).to_wire()
else:
data = OACKDatagram(self.accepted_options).to_wire()
self.retransmit_task = asyncio.create_task(
timed_caller(
itertools.chain((0,), self.timeout),
functools.partial(self.transport.sendto, data),
self._timed_out,
)
)
def connection_lost(self, exc: Exception | None) -> None:
logger.debug(f"closed session between {self.local_tid}{self.remote_tid}")
return super().connection_lost(exc)
async def send(self, datagram: Datagram) -> None:
self.transport.sendto(datagram.to_wire())
def datagram_received(
self, datagram: bytes, addr: tuple[str, int] | tuple[str, int, int, int]
) -> None:
logger.debug(
f"RemoteOriginWriteProtocol data_received {len(datagram)} {addr!r}"
)
tid = addr[:2]
task = self.loop.create_task(self._datagram_received(datagram, tid))
task.add_done_callback(handle_task_result)
async def _datagram_received(self, datagram: bytes, tid: tuple[str, int]) -> None:
if self.remote_tid != tid:
logger.error(f"received packet from wrong address: {tid}")
if (
self.retransmit_task is not None
and not self.retransmit_task.cancelled()
):
self.retransmit_task.cancel()
self.retransmit_task = None
await self.file.close(complete=False)
self.transport.close()
return
datagram = datagram_factory(datagram)
if not isinstance(datagram, DATADatagram):
if (
self.retransmit_task is not None
and not self.retransmit_task.cancelled()
):
self.retransmit_task.cancel()
self.retransmit_task = None
await self.file.close(complete=False)
await self.send(ERRORDatagram.from_code(TFTPError.ILLEGAL_OPERATION))
self.transport.close()
return
logger.debug(f"RemoteOriginWriteProtocol data_received {datagram!r} {tid!r}")
if self.retransmit_task is not None and not self.retransmit_task.cancelled():
logger.debug("cancelling old timeout task")
self.retransmit_task.cancel()
self.retransmit_task = None
# need to check block numbers
if len(datagram.payload) == self.block_size:
self.file.write(datagram.payload)
self.retransmit_task = asyncio.create_task(
timed_caller(
itertools.chain((0,), self.timeout),
functools.partial(
self.send,
ACKDatagram(datagram.block_number),
),
self._timed_out,
)
)
self.retransmit_task.add_done_callback(handle_task_result)
else:
logger.debug(f"last data packet received")
await self.file.write(datagram.payload)
await self.file.close(complete=True)
await self.send(ACKDatagram(datagram.block_number))
self.transport.close()
async def _timed_out(self):
logger.debug("timed out")
if self.retransmit_task is not None and not self.retransmit_task.cancelled():
self.retransmit_task.cancel()
self.retransmit_task = None
await self.file.close(complete=False)
self.transport.close()
def __del__(self):
logger.debug("RemoteOriginWriteProtocol __del__")
def remote_origin_write_protocol(
*,
file: FileProtocol,
options: dict[TFTPOption, bytes],
loop: asyncio.AbstractEventLoop,
) -> RemoteOriginWriteProtocol:
return RemoteOriginWriteProtocol(file=file, options=options, loop=loop)