# JTFTP - Python/AsyncIO TFTP Server # Copyright (C) 2022 Jeffrey C. Ollie # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . import itertools import logging import struct from abc import ABC from abc import abstractmethod from collections import OrderedDict from enum import Enum from enum import IntEnum from typing import Any from jtftp.errors import InvalidErrorcodeError from jtftp.errors import InvalidOpcodeError from jtftp.errors import OptionsDecodeError from jtftp.errors import PayloadDecodeError from jtftp.errors import WireProtocolError logger = logging.getLogger(__name__) class TFTPOption(bytes, Enum): BLOCKSIZE = b"blksize" TIMEOUT = b"timeout" TRANSFER_SIZE = b"tsize" @classmethod def _missing_(cls, value: bytes | str): if isinstance(value, str): value = value.encode("ascii", "replace") value = value.lower() for member in cls: if member.value == value: return member class TFTPMode(bytes, Enum): MAIL = b"mail" NETASCII = b"netascii" OCTET = b"octet" @classmethod def _missing_(cls, value: bytes | str): if isinstance(value, str): value = value.encode("ascii", "replace") value = value.lower() for member in cls: if member.value == value: return member class TFTPOpcode(IntEnum): RRQ = 1 WRQ = 2 DATA = 3 ACK = 4 ERROR = 5 OACK = 6 class TFTPError(IntEnum): NOT_DEFINED = 0 FILE_NOT_FOUND = 1 ACCESS_VIOLATION = 2 DISK_FULL = 3 ILLEGAL_OPERATION = 4 TID_UNKNOWN = 5 FILE_EXISTS = 6 NO_SUCH_USER = 7 TERM_OPTION = 8 def message(self) -> bytes: match self.value: case self.NOT_DEFINED: return b"" case self.FILE_NOT_FOUND: return b"File not found" case self.ACCESS_VIOLATION: return b"Access violation" case self.DISK_FULL: return b"Disk full or allocation exceeded" case self.ILLEGAL_OPERATION: return b"Illegal TFTP operation" case self.TID_UNKNOWN: return b"Unknown transfer ID" case self.FILE_EXISTS: return b"File already exists" case self.NO_SUCH_USER: return b"No such user" case self.TERM_OPTION: return b"Terminate transfer due to option negotiation" def split_opcode(datagram: bytes) -> tuple[TFTPOpcode, bytes]: """Split the raw datagram into opcode and payload. @param datagram: raw datagram @type datagram: C{bytes} @return: a 2-tuple, the first item is the opcode and the second item is the payload @rtype: (C{OP}, C{bytes}) @raise WireProtocolError: if the opcode cannot be extracted """ try: opcode = struct.unpack(b"!H", datagram[:2])[0] try: opcode = TFTPOpcode(opcode) except ValueError: raise InvalidOpcodeError(opcode) return opcode, datagram[2:] except struct.error: raise WireProtocolError("failed to extract the opcode") def assert_options(options: dict) -> None: if __debug__: for name, value in options.items(): assert isinstance( name, TFTPOption ), f"{name} ({type(name)}) is not a TFTPOption" def decode_options(parts: list[bytes]) -> dict[TFTPOption, Any]: if parts and not parts[-1]: parts.pop(-1) # To maintain consistency during testing. # The actual order of options is not important as per RFC2347 options = OrderedDict() if len(parts) % 2: raise OptionsDecodeError(f"no value for option {parts[-1]}") iparts = iter(parts) for name, value in [(name, next(iparts, None)) for name in iparts]: try: name = TFTPOption(name) except ValueError: raise OptionsDecodeError( f"{name.decode('ascii', 'replace')!r} is not a valid option" ) try: match name: case TFTPOption.BLOCKSIZE: value = int(value) if value < 8 or value > 65464: raise OptionsDecodeError( f"{value} is not a valid value for option {name.decode('ascii', 'replace')!r}" ) case TFTPOption.TIMEOUT: value = int(value) if value < 1: raise OptionsDecodeError( f"{value} is not a valid value for option {name.decode('ascii', 'replace')!r}" ) case TFTPOption.TRANSFER_SIZE: value = int(value) if value < 0: raise OptionsDecodeError( f"{value} is not a valid value for option {name.decode('ascii', 'replace')!r}" ) except ValueError: raise OptionsDecodeError( f"{value.decode('ascii','replace')!r} is not a valid value for option {name.decode('ascii', 'replace')!r}" ) if name in options: raise OptionsDecodeError( f"duplicate option specified: {name.decode('ascii', 'replace')!r}" ) options[name] = value return options def encode_options(options: dict[TFTPOption, Any]) -> list[bytes]: parts = [] for name, value in options.items(): match name: case TFTPOption.BLOCKSIZE | TFTPOption.TIMEOUT | TFTPOption.TRANSFER_SIZE: parts.append(bytes(name)) parts.append(f"{value:d}".encode("ascii")) case _: raise WireProtocolError( f"unknown option {name.decode('ascii','replace')}" ) return parts class Datagram(ABC): """Base class for datagrams @cvar opcode: The opcode, corresponding to this datagram @type opcode: C{Opcode} """ opcode: TFTPOpcode @classmethod @abstractmethod def from_wire(cls, payload: bytes): """Parse the payload and return a datagram object @param payload: Binary representation of the payload (without the opcode) @type payload: C{bytes} """ raise NotImplementedError("Subclasses must override this") @abstractmethod def to_wire(self) -> bytes: """Return the wire representation of the datagram. @rtype: C{bytes} """ raise NotImplementedError("Subclasses must override this") class RQDatagram(Datagram): """Base class for "RQ" (request) datagrams. @ivar filename: File name, that corresponds to this request. @type filename: C{bytes} @ivar mode: Transfer mode. Valid values are C{netascii} and C{octet}. Case-insensitive. @type mode: C{bytes} @ivar options: Any options, that were requested by the client (as per U{RFC2374} @type options: C{dict} """ filename: bytes mode: TFTPMode options: dict[TFTPOption, Any] @classmethod def from_wire(cls, payload: bytes): """Parse the payload and return a RRQ/WRQ datagram object. @return: datagram object @rtype: L{RRQDatagram} or L{WRQDatagram} @raise OptionsDecodeError: if we failed to decode the options, requested by the client @raise PayloadDecodeError: if there were not enough fields in the payload. Fields are terminated by NUL. """ parts = payload.split(b"\x00") try: filename, mode = parts.pop(0), parts.pop(0) try: mode = TFTPMode(mode) except ValueError: raise PayloadDecodeError( f"{mode.decode('ascii', 'replace')!r} is not a valid mode" ) except IndexError: raise PayloadDecodeError("Not enough fields in the payload") options = decode_options(parts) return cls(filename, mode, options) def __init__( self, filename: bytes, mode: TFTPMode, options: dict[TFTPOption, bytes] ): assert isinstance(filename, bytes) assert isinstance(mode, TFTPMode) assert_options(options) self.filename = filename self.mode = mode self.options = options def __repr__(self): if self.options: return f"<{self.__class__.__name__}(filename={self.filename}, mode={self.mode}, options={self.options})>" return "<{self.__class__.__name__}(filename={self.filename}, mode={self.mode})>" def to_wire(self): opcode = struct.pack(b"!H", self.opcode) if self.options: options = b"\x00".join(encode_options(self.options)) return b"".join( (opcode, self.filename, b"\x00", self.mode, b"\x00", options, b"\x00") ) else: return b"".join((opcode, self.filename, b"\x00", self.mode, b"\x00")) class RRQDatagram(RQDatagram): opcode = TFTPOpcode.RRQ class WRQDatagram(RQDatagram): opcode = TFTPOpcode.WRQ class OACKDatagram(Datagram): """An OACK datagram @ivar options: Any options, that were requested by the client (as per U{RFC2374} @type options: C{dict} """ opcode = TFTPOpcode.OACK options = dict[TFTPOption, Any] @classmethod def from_wire(cls, payload: bytes): """Parse the payload and return an OACK datagram object. @return: datagram object @rtype: L{OACKDatagram} @raise OptionsDecodeError: if we failed to decode the options """ parts = payload.split(b"\x00") options = decode_options(parts) return cls(options) def __init__(self, options: dict[TFTPOption, Any]): assert_options(options) self.options = options def __repr__(self) -> str: return f"<{self.__class__.__name__}(options={self.options})>" def to_wire(self) -> bytes: opcode = struct.pack(b"!H", self.opcode) if self.options: options = b"\x00".join(encode_options(self.options)) return b"".join((opcode, options, b"\x00")) else: return opcode class DATADatagram(Datagram): """A DATA datagram @ivar blocknum: A block number, that this chunk of data is associated with @type blocknum: C{int} @ivar data: binary data @type data: C{bytes} """ opcode = TFTPOpcode.DATA payload: bytes @classmethod def from_wire(cls, payload: bytes): """Parse the payload and return a L{DATADatagram} object. @param payload: Binary representation of the payload (without the opcode) @type payload: C{bytes} @return: A L{DATADatagram} object @rtype: L{DATADatagram} @raise PayloadDecodeError: if the format of payload is incorrect """ try: block_number, payload = struct.unpack(b"!H", payload[:2])[0], payload[2:] except struct.error: raise PayloadDecodeError() return cls(block_number, payload) def __init__(self, block_number: int, payload: bytes): assert isinstance(payload, bytes) self.block_number = block_number self.payload = payload def __repr__(self) -> str: return f"<{self.__class__.__name__}(blocknum={self.block_number}, {len(self.payload)} bytes of data)>" def to_wire(self) -> bytes: return b"".join( (struct.pack(b"!HH", self.opcode, self.block_number), self.payload) ) class ACKDatagram(Datagram): """An ACK datagram. @ivar blocknum: Block number of the data chunk, which this datagram is supposed to acknowledge @type blocknum: C{int} """ opcode = TFTPOpcode.ACK block_number: int @classmethod def from_wire(cls, payload: bytes): """Parse the payload and return a L{ACKDatagram} object. @param payload: Binary representation of the payload (without the opcode) @type payload: C{bytes} @return: An L{ACKDatagram} object @rtype: L{ACKDatagram} @raise PayloadDecodeError: if the format of payload is incorrect """ try: block_number = struct.unpack(b"!H", payload)[0] except struct.error: raise PayloadDecodeError("Unable to extract the block number") return cls(block_number) def __init__(self, blocknum: int): self.block_number = blocknum def __repr__(self) -> str: return f"<{self.__class__.__name__}(block_number={self.block_number})>" def to_wire(self) -> bytes: return struct.pack(b"!HH", self.opcode, self.block_number) class ERRORDatagram(Datagram): """An ERROR datagram. @ivar errorcode: A valid TFTP error code @type errorcode: C{int} @ivar errmsg: An error message, describing the error condition in which this datagram was produced @type errmsg: C{bytes} """ opcode = TFTPOpcode.ERROR error_code: TFTPError error_message: bytes @classmethod def from_wire(cls, payload: bytes): """Parse the payload and return a L{ERRORDatagram} object. This method violates the standard a bit - if the error string was not extracted, a default error string is generated, based on the error code. @param payload: Binary representation of the payload (without the opcode) @type payload: C{bytes} @return: An L{ERRORDatagram} object @rtype: L{ERRORDatagram} @raise PayloadDecodeError: if the format of payload is incorrect @raise InvalidErrorcodeError: a more specific exception, that is raised if the error code was successfully, extracted, but it does not correspond to any known/standartized error code values. """ try: error_code = struct.unpack(b"!H", payload[:2])[0] try: error_code = TFTPError(error_code) except ValueError as e: raise InvalidErrorcodeError(error_code) except struct.error: raise PayloadDecodeError("Unable to extract the error code") error_message = payload[2:].split(b"\x00")[0] if not error_message: error_message = error_code.message() return cls(error_code, error_message) @classmethod def from_code(cls, error_code: TFTPError, error_message: bytes | str | None = None): """Create an L{ERRORDatagram}, given an error code and, optionally, an error message to go with it. If not provided, default error message for the given error code is used. @param error_code: An error code @type error_code: L{TFTPError} @param error_message: An error message (optional) @type error_message: C{bytes} or C{str} or C{NoneType} @raise InvalidErrorcodeError: if the error code is not known @return: an L{ERRORDatagram} @rtype: L{ERRORDatagram} """ assert isinstance(error_code, TFTPError) if isinstance(error_message, str): error_message = error_message.encode("ascii", "replace") elif error_message is None: error_message = error_code.message() assert isinstance(error_message, bytes) return cls(error_code, error_message) def __init__(self, error_code: TFTPError, error_message: bytes): assert isinstance(error_message, bytes) self.error_code = error_code self.error_message = error_message def to_wire(self) -> bytes: return b"".join( ( struct.pack(b"!HH", self.opcode, self.error_code), self.error_message, b"\x00", ) ) class _DatagramFactory: classes: dict[TFTPOpcode, Datagram] = { TFTPOpcode.ACK: ACKDatagram, TFTPOpcode.DATA: DATADatagram, TFTPOpcode.ERROR: ERRORDatagram, TFTPOpcode.OACK: OACKDatagram, TFTPOpcode.RRQ: RRQDatagram, TFTPOpcode.WRQ: WRQDatagram, } def __call__(self, datagram: bytes) -> Datagram: opcode, payload = split_opcode(datagram) try: cls = self.classes[opcode] except KeyError: raise InvalidOpcodeError(opcode.value) return cls.from_wire(payload) datagram_factory = _DatagramFactory()