jtftp/jtftp/datagram.py
2022-07-07 12:37:41 -05:00

518 lines
17 KiB
Python

# JTFTP - Python/AsyncIO TFTP Server
# Copyright (C) 2022 Jeffrey C. Ollie
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import itertools
import logging
import struct
from abc import ABC
from abc import abstractmethod
from collections import OrderedDict
from enum import Enum
from enum import IntEnum
from typing import Any
from jtftp.errors import InvalidErrorcodeError
from jtftp.errors import InvalidOpcodeError
from jtftp.errors import OptionsDecodeError
from jtftp.errors import PayloadDecodeError
from jtftp.errors import WireProtocolError
logger = logging.getLogger(__name__)
class TFTPOption(bytes, Enum):
BLOCKSIZE = b"blksize"
TIMEOUT = b"timeout"
TRANSFER_SIZE = b"tsize"
@classmethod
def _missing_(cls, value: bytes | str):
if isinstance(value, str):
value = value.encode("ascii", "replace")
value = value.lower()
for member in cls:
if member.value == value:
return member
class TFTPMode(bytes, Enum):
MAIL = b"mail"
NETASCII = b"netascii"
OCTET = b"octet"
@classmethod
def _missing_(cls, value: bytes | str):
if isinstance(value, str):
value = value.encode("ascii", "replace")
value = value.lower()
for member in cls:
if member.value == value:
return member
class TFTPOpcode(IntEnum):
RRQ = 1
WRQ = 2
DATA = 3
ACK = 4
ERROR = 5
OACK = 6
class TFTPError(IntEnum):
NOT_DEFINED = 0
FILE_NOT_FOUND = 1
ACCESS_VIOLATION = 2
DISK_FULL = 3
ILLEGAL_OPERATION = 4
TID_UNKNOWN = 5
FILE_EXISTS = 6
NO_SUCH_USER = 7
TERM_OPTION = 8
def message(self) -> bytes:
match self.value:
case self.NOT_DEFINED:
return b""
case self.FILE_NOT_FOUND:
return b"File not found"
case self.ACCESS_VIOLATION:
return b"Access violation"
case self.DISK_FULL:
return b"Disk full or allocation exceeded"
case self.ILLEGAL_OPERATION:
return b"Illegal TFTP operation"
case self.TID_UNKNOWN:
return b"Unknown transfer ID"
case self.FILE_EXISTS:
return b"File already exists"
case self.NO_SUCH_USER:
return b"No such user"
case self.TERM_OPTION:
return b"Terminate transfer due to option negotiation"
def split_opcode(datagram: bytes) -> tuple[TFTPOpcode, bytes]:
"""Split the raw datagram into opcode and payload.
@param datagram: raw datagram
@type datagram: C{bytes}
@return: a 2-tuple, the first item is the opcode and the second item is the payload
@rtype: (C{OP}, C{bytes})
@raise WireProtocolError: if the opcode cannot be extracted
"""
try:
opcode = struct.unpack(b"!H", datagram[:2])[0]
try:
opcode = TFTPOpcode(opcode)
except ValueError:
raise InvalidOpcodeError(opcode)
return opcode, datagram[2:]
except struct.error:
raise WireProtocolError("failed to extract the opcode")
def assert_options(options: dict) -> None:
if __debug__:
for name, value in options.items():
assert isinstance(
name, TFTPOption
), f"{name} ({type(name)}) is not a TFTPOption"
def decode_options(parts: list[bytes]) -> dict[TFTPOption, Any]:
if parts and not parts[-1]:
parts.pop(-1)
# To maintain consistency during testing.
# The actual order of options is not important as per RFC2347
options = OrderedDict()
if len(parts) % 2:
raise OptionsDecodeError(f"no value for option {parts[-1]}")
iparts = iter(parts)
for name, value in [(name, next(iparts, None)) for name in iparts]:
try:
name = TFTPOption(name)
except ValueError:
raise OptionsDecodeError(
f"{name.decode('ascii', 'replace')!r} is not a valid option"
)
try:
match name:
case TFTPOption.BLOCKSIZE:
value = int(value)
if value < 8 or value > 65464:
raise OptionsDecodeError(
f"{value} is not a valid value for option {name.decode('ascii', 'replace')!r}"
)
case TFTPOption.TIMEOUT:
value = int(value)
if value < 1:
raise OptionsDecodeError(
f"{value} is not a valid value for option {name.decode('ascii', 'replace')!r}"
)
case TFTPOption.TRANSFER_SIZE:
value = int(value)
if value < 0:
raise OptionsDecodeError(
f"{value} is not a valid value for option {name.decode('ascii', 'replace')!r}"
)
except ValueError:
raise OptionsDecodeError(
f"{value.decode('ascii','replace')!r} is not a valid value for option {name.decode('ascii', 'replace')!r}"
)
if name in options:
raise OptionsDecodeError(
f"duplicate option specified: {name.decode('ascii', 'replace')!r}"
)
options[name] = value
return options
def encode_options(options: dict[TFTPOption, Any]) -> list[bytes]:
parts = []
for name, value in options.items():
match name:
case TFTPOption.BLOCKSIZE | TFTPOption.TIMEOUT | TFTPOption.TRANSFER_SIZE:
parts.append(bytes(name))
parts.append(f"{value:d}".encode("ascii"))
case _:
raise WireProtocolError(
f"unknown option {name.decode('ascii','replace')}"
)
return parts
class Datagram(ABC):
"""Base class for datagrams
@cvar opcode: The opcode, corresponding to this datagram
@type opcode: C{Opcode}
"""
opcode: TFTPOpcode
@classmethod
@abstractmethod
def from_wire(cls, payload: bytes):
"""Parse the payload and return a datagram object
@param payload: Binary representation of the payload (without the opcode)
@type payload: C{bytes}
"""
raise NotImplementedError("Subclasses must override this")
@abstractmethod
def to_wire(self) -> bytes:
"""Return the wire representation of the datagram.
@rtype: C{bytes}
"""
raise NotImplementedError("Subclasses must override this")
class RQDatagram(Datagram):
"""Base class for "RQ" (request) datagrams.
@ivar filename: File name, that corresponds to this request.
@type filename: C{bytes}
@ivar mode: Transfer mode. Valid values are C{netascii} and C{octet}.
Case-insensitive.
@type mode: C{bytes}
@ivar options: Any options, that were requested by the client (as per
U{RFC2374<http://tools.ietf.org/html/rfc2347>}
@type options: C{dict}
"""
filename: bytes
mode: TFTPMode
options: dict[TFTPOption, Any]
@classmethod
def from_wire(cls, payload: bytes):
"""Parse the payload and return a RRQ/WRQ datagram object.
@return: datagram object
@rtype: L{RRQDatagram} or L{WRQDatagram}
@raise OptionsDecodeError: if we failed to decode the options, requested
by the client
@raise PayloadDecodeError: if there were not enough fields in the payload.
Fields are terminated by NUL.
"""
parts = payload.split(b"\x00")
try:
filename, mode = parts.pop(0), parts.pop(0)
try:
mode = TFTPMode(mode)
except ValueError:
raise PayloadDecodeError(
f"{mode.decode('ascii', 'replace')!r} is not a valid mode"
)
except IndexError:
raise PayloadDecodeError("Not enough fields in the payload")
options = decode_options(parts)
return cls(filename, mode, options)
def __init__(
self, filename: bytes, mode: TFTPMode, options: dict[TFTPOption, bytes]
):
assert isinstance(filename, bytes)
assert isinstance(mode, TFTPMode)
assert_options(options)
self.filename = filename
self.mode = mode
self.options = options
def __repr__(self):
if self.options:
return f"<{self.__class__.__name__}(filename={self.filename}, mode={self.mode}, options={self.options})>"
return "<{self.__class__.__name__}(filename={self.filename}, mode={self.mode})>"
def to_wire(self):
opcode = struct.pack(b"!H", self.opcode)
if self.options:
options = b"\x00".join(encode_options(self.options))
return b"".join(
(opcode, self.filename, b"\x00", self.mode, b"\x00", options, b"\x00")
)
else:
return b"".join((opcode, self.filename, b"\x00", self.mode, b"\x00"))
class RRQDatagram(RQDatagram):
opcode = TFTPOpcode.RRQ
class WRQDatagram(RQDatagram):
opcode = TFTPOpcode.WRQ
class OACKDatagram(Datagram):
"""An OACK datagram
@ivar options: Any options, that were requested by the client (as per
U{RFC2374<http://tools.ietf.org/html/rfc2347>}
@type options: C{dict}
"""
opcode = TFTPOpcode.OACK
options = dict[TFTPOption, Any]
@classmethod
def from_wire(cls, payload: bytes):
"""Parse the payload and return an OACK datagram object.
@return: datagram object
@rtype: L{OACKDatagram}
@raise OptionsDecodeError: if we failed to decode the options
"""
parts = payload.split(b"\x00")
options = decode_options(parts)
return cls(options)
def __init__(self, options: dict[TFTPOption, Any]):
assert_options(options)
self.options = options
def __repr__(self) -> str:
return f"<{self.__class__.__name__}(options={self.options})>"
def to_wire(self) -> bytes:
opcode = struct.pack(b"!H", self.opcode)
if self.options:
options = b"\x00".join(encode_options(self.options))
return b"".join((opcode, options, b"\x00"))
else:
return opcode
class DATADatagram(Datagram):
"""A DATA datagram
@ivar blocknum: A block number, that this chunk of data is associated with
@type blocknum: C{int}
@ivar data: binary data
@type data: C{bytes}
"""
opcode = TFTPOpcode.DATA
payload: bytes
@classmethod
def from_wire(cls, payload: bytes):
"""Parse the payload and return a L{DATADatagram} object.
@param payload: Binary representation of the payload (without the opcode)
@type payload: C{bytes}
@return: A L{DATADatagram} object
@rtype: L{DATADatagram}
@raise PayloadDecodeError: if the format of payload is incorrect
"""
try:
block_number, payload = struct.unpack(b"!H", payload[:2])[0], payload[2:]
except struct.error:
raise PayloadDecodeError()
return cls(block_number, payload)
def __init__(self, block_number: int, payload: bytes):
assert isinstance(payload, bytes)
self.block_number = block_number
self.payload = payload
def __repr__(self) -> str:
return f"<{self.__class__.__name__}(blocknum={self.block_number}, {len(self.payload)} bytes of data)>"
def to_wire(self) -> bytes:
return b"".join(
(struct.pack(b"!HH", self.opcode, self.block_number), self.payload)
)
class ACKDatagram(Datagram):
"""An ACK datagram.
@ivar blocknum: Block number of the data chunk, which this datagram is supposed to acknowledge
@type blocknum: C{int}
"""
opcode = TFTPOpcode.ACK
block_number: int
@classmethod
def from_wire(cls, payload: bytes):
"""Parse the payload and return a L{ACKDatagram} object.
@param payload: Binary representation of the payload (without the opcode)
@type payload: C{bytes}
@return: An L{ACKDatagram} object
@rtype: L{ACKDatagram}
@raise PayloadDecodeError: if the format of payload is incorrect
"""
try:
block_number = struct.unpack(b"!H", payload)[0]
except struct.error:
raise PayloadDecodeError("Unable to extract the block number")
return cls(block_number)
def __init__(self, blocknum: int):
self.block_number = blocknum
def __repr__(self) -> str:
return f"<{self.__class__.__name__}(block_number={self.block_number})>"
def to_wire(self) -> bytes:
return struct.pack(b"!HH", self.opcode, self.block_number)
class ERRORDatagram(Datagram):
"""An ERROR datagram.
@ivar errorcode: A valid TFTP error code
@type errorcode: C{int}
@ivar errmsg: An error message, describing the error condition in which this
datagram was produced
@type errmsg: C{bytes}
"""
opcode = TFTPOpcode.ERROR
error_code: TFTPError
error_message: bytes
@classmethod
def from_wire(cls, payload: bytes):
"""Parse the payload and return a L{ERRORDatagram} object.
This method violates the standard a bit - if the error string was not
extracted, a default error string is generated, based on the error code.
@param payload: Binary representation of the payload (without the opcode)
@type payload: C{bytes}
@return: An L{ERRORDatagram} object
@rtype: L{ERRORDatagram}
@raise PayloadDecodeError: if the format of payload is incorrect
@raise InvalidErrorcodeError: a more specific exception, that is raised
if the error code was successfully, extracted, but it does not correspond
to any known/standartized error code values.
"""
try:
error_code = struct.unpack(b"!H", payload[:2])[0]
try:
error_code = TFTPError(error_code)
except ValueError as e:
raise InvalidErrorcodeError(error_code)
except struct.error:
raise PayloadDecodeError("Unable to extract the error code")
error_message = payload[2:].split(b"\x00")[0]
if not error_message:
error_message = error_code.message()
return cls(error_code, error_message)
@classmethod
def from_code(cls, error_code: TFTPError, error_message: bytes | str | None = None):
"""Create an L{ERRORDatagram}, given an error code and, optionally, an
error message to go with it. If not provided, default error message for
the given error code is used.
@param error_code: An error code
@type error_code: L{TFTPError}
@param error_message: An error message (optional)
@type error_message: C{bytes} or C{str} or C{NoneType}
@raise InvalidErrorcodeError: if the error code is not known
@return: an L{ERRORDatagram}
@rtype: L{ERRORDatagram}
"""
assert isinstance(error_code, TFTPError)
if isinstance(error_message, str):
error_message = error_message.encode("ascii", "replace")
elif error_message is None:
error_message = error_code.message()
assert isinstance(error_message, bytes)
return cls(error_code, error_message)
def __init__(self, error_code: TFTPError, error_message: bytes):
assert isinstance(error_message, bytes)
self.error_code = error_code
self.error_message = error_message
def to_wire(self) -> bytes:
return b"".join(
(
struct.pack(b"!HH", self.opcode, self.error_code),
self.error_message,
b"\x00",
)
)
class _DatagramFactory:
classes: dict[TFTPOpcode, Datagram] = {
TFTPOpcode.ACK: ACKDatagram,
TFTPOpcode.DATA: DATADatagram,
TFTPOpcode.ERROR: ERRORDatagram,
TFTPOpcode.OACK: OACKDatagram,
TFTPOpcode.RRQ: RRQDatagram,
TFTPOpcode.WRQ: WRQDatagram,
}
def __call__(self, datagram: bytes) -> Datagram:
opcode, payload = split_opcode(datagram)
try:
cls = self.classes[opcode]
except KeyError:
raise InvalidOpcodeError(opcode.value)
return cls.from_wire(payload)
datagram_factory = _DatagramFactory()