192 lines
6.8 KiB
Python
192 lines
6.8 KiB
Python
|
# JTFTP - Python/AsyncIO TFTP Server
|
||
|
# Copyright (C) 2022 Jeffrey C. Ollie
|
||
|
#
|
||
|
# This program is free software: you can redistribute it and/or modify
|
||
|
# it under the terms of the GNU General Public License as published by
|
||
|
# the Free Software Foundation, either version 3 of the License, or
|
||
|
# (at your option) any later version.
|
||
|
#
|
||
|
# This program is distributed in the hope that it will be useful,
|
||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||
|
# GNU General Public License for more details.
|
||
|
#
|
||
|
# You should have received a copy of the GNU General Public License
|
||
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||
|
|
||
|
import asyncio
|
||
|
import functools
|
||
|
import itertools
|
||
|
import logging
|
||
|
from collections import OrderedDict
|
||
|
|
||
|
from jtftp.datagram import ACKDatagram
|
||
|
from jtftp.datagram import DATADatagram
|
||
|
from jtftp.datagram import Datagram
|
||
|
from jtftp.datagram import ERRORDatagram
|
||
|
from jtftp.datagram import OACKDatagram
|
||
|
from jtftp.datagram import TFTPError
|
||
|
from jtftp.datagram import TFTPMode
|
||
|
from jtftp.datagram import TFTPOption
|
||
|
from jtftp.datagram import datagram_factory
|
||
|
from jtftp.filesystem import FileProtocol
|
||
|
from jtftp.log import handle_task_result
|
||
|
from jtftp.util import timed_caller
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
class RemoteOriginReadProtocol(asyncio.DatagramProtocol):
|
||
|
block_size: int
|
||
|
timeout: tuple[int, int, int]
|
||
|
transfer_size: int | None
|
||
|
retransmit_task: asyncio.Task | None
|
||
|
last_block_number: int
|
||
|
last_block_sent: False
|
||
|
offered_options: dict[TFTPOption, bytes]
|
||
|
accepted_options: dict[TFTPOption, bytes]
|
||
|
local_tid = tuple[str, int]
|
||
|
remote_tid = tuple[str, int]
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
*,
|
||
|
file: FileProtocol,
|
||
|
options: dict[TFTPOption, bytes],
|
||
|
loop: asyncio.AbstractEventLoop,
|
||
|
) -> None:
|
||
|
super().__init__()
|
||
|
|
||
|
self.file = file
|
||
|
self.loop = loop
|
||
|
|
||
|
self.block_size = 512
|
||
|
self.timeout = (1, 3, 7)
|
||
|
self.transfer_size = None
|
||
|
self.last_block_number = 0
|
||
|
self.last_block_sent = False
|
||
|
self.offered_options = options
|
||
|
self.accepted_options = OrderedDict()
|
||
|
self.remote_tid = tuple[str, int]
|
||
|
|
||
|
def connection_made(self, transport: asyncio.transports.DatagramTransport) -> None:
|
||
|
self.transport = transport
|
||
|
self.local_tid = transport.get_extra_info("sockname")[:2]
|
||
|
self.remote_tid = transport.get_extra_info("peername")[:2]
|
||
|
task = self.loop.create_task(self._connection_made())
|
||
|
task.add_done_callback(handle_task_result)
|
||
|
|
||
|
async def _connection_made(self):
|
||
|
logger.debug(f"new session between {self.local_tid} → {self.remote_tid}")
|
||
|
for name, value in self.offered_options.items():
|
||
|
logger.debug(f"{name!r} {value!r}")
|
||
|
match name:
|
||
|
case TFTPOption.BLOCKSIZE:
|
||
|
if value >= 8 and value <= 65464:
|
||
|
self.block_size = value
|
||
|
self.accepted_options[name] = value
|
||
|
|
||
|
case TFTPOption.TIMEOUT:
|
||
|
self.timeout = (value,) * 3
|
||
|
self.accepted_options[name] = value
|
||
|
|
||
|
case TFTPOption.TRANSFER_SIZE:
|
||
|
self.transfer_size = await self.file.length()
|
||
|
if self.transfer_size is not None:
|
||
|
self.accepted_options[name] = self.transfer_size
|
||
|
|
||
|
case _:
|
||
|
logger.warning(f"unknown option {name}: {value}")
|
||
|
|
||
|
if len(self.offered_options) != 0:
|
||
|
data = OACKDatagram(self.accepted_options)
|
||
|
else:
|
||
|
data = await self._get_next_block()
|
||
|
|
||
|
self.retransmit_task = self.loop.create_task(
|
||
|
timed_caller(
|
||
|
itertools.chain((0,), self.timeout),
|
||
|
functools.partial(self.send, data),
|
||
|
self._timed_out,
|
||
|
)
|
||
|
)
|
||
|
self.retransmit_task.add_done_callback(handle_task_result)
|
||
|
|
||
|
def connection_lost(self, exc: Exception | None) -> None:
|
||
|
return super().connection_lost(exc)
|
||
|
|
||
|
async def send(self, datagram: Datagram):
|
||
|
self.transport.sendto(datagram.to_wire())
|
||
|
|
||
|
def datagram_received(
|
||
|
self, datagram: bytes, addr: tuple[str, int] | tuple[str, int, int, int]
|
||
|
) -> None:
|
||
|
tid = addr[:2]
|
||
|
|
||
|
if self.remote_tid != tid:
|
||
|
logger.error(f"packet from unknown sender {tid}")
|
||
|
return
|
||
|
|
||
|
task = self.loop.create_task(self._datagram_received(datagram))
|
||
|
task.add_done_callback(handle_task_result)
|
||
|
|
||
|
async def _datagram_received(self, datagram: bytes) -> None:
|
||
|
datagram = datagram_factory(datagram)
|
||
|
if isinstance(datagram, ACKDatagram):
|
||
|
if datagram.block_number == self.last_block_number:
|
||
|
if (
|
||
|
self.retransmit_task is not None
|
||
|
and not self.retransmit_task.cancelled()
|
||
|
):
|
||
|
self.retransmit_task.cancel()
|
||
|
self.retransmit_task = None
|
||
|
|
||
|
if self.last_block_sent:
|
||
|
await self.file.close(complete=True)
|
||
|
self.transport.close()
|
||
|
return
|
||
|
|
||
|
data = await self._get_next_block()
|
||
|
self.retransmit_task = asyncio.create_task(
|
||
|
timed_caller(
|
||
|
itertools.chain((0,), self.timeout),
|
||
|
functools.partial(self.send, data),
|
||
|
self._timed_out,
|
||
|
)
|
||
|
)
|
||
|
self.retransmit_task.add_done_callback(handle_task_result)
|
||
|
else:
|
||
|
logger.warning(
|
||
|
f"received ack for block number {datagram.block_number} - was expecting {self.last_block_number}"
|
||
|
)
|
||
|
else:
|
||
|
await self.send(ERRORDatagram.from_code(TFTPError.ILLEGAL_OPERATION))
|
||
|
|
||
|
async def _get_next_block(self) -> Datagram:
|
||
|
payload = await self.file.read(self.block_size)
|
||
|
if len(payload) < self.block_size:
|
||
|
self.last_block_sent = True
|
||
|
|
||
|
self.last_block_number = (self.last_block_number + 1) % 65536
|
||
|
return DATADatagram(self.last_block_number, payload)
|
||
|
|
||
|
async def _timed_out(self):
|
||
|
logger.debug("timed out")
|
||
|
if self.retransmit_task is not None and not self.retransmit_task.cancelled():
|
||
|
self.retransmit_task.cancel()
|
||
|
self.retransmit_task = None
|
||
|
await self.file.close(complete=False)
|
||
|
self.transport.close()
|
||
|
|
||
|
def __del__(self):
|
||
|
logger.debug("RemoteOriginReadProtocol __del__")
|
||
|
|
||
|
|
||
|
def remote_origin_read_protocol(
|
||
|
*,
|
||
|
file: FileProtocol,
|
||
|
options: dict[bytes, bytes],
|
||
|
loop: asyncio.AbstractEventLoop,
|
||
|
) -> RemoteOriginReadProtocol:
|
||
|
return RemoteOriginReadProtocol(file=file, options=options, loop=loop)
|