jtftp/jtftp/netascii.py
2022-07-07 12:37:41 -05:00

187 lines
5.3 KiB
Python

# JTFTP - Python/AsyncIO TFTP Server
# Copyright (C) 2022 Jeffrey C. Ollie
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import os
import re
from enum import Enum
from jtftp.filesystem import FileProtocol
class NetAsciiBase(bytes, Enum):
CR: bytes
LF: bytes
CRLF: bytes
NUL: bytes
CRNUL: bytes
NL: bytes
_re_from_netascii: re.Pattern[bytes]
_re_to_netascii: re.Pattern[bytes]
@classmethod
def _convert_from_netascii(cls, match_obj: re.Match) -> bytes | None:
match match_obj.group(1):
case cls.CRNUL:
return cls.CR
case cls.CRLF:
return cls.NL
@classmethod
def _convert_to_netascii(cls, match_obj: re.Match) -> bytes | None:
match match_obj.group(1):
case cls.NL:
return cls.CRLF
case cls.CR:
return cls.CRNUL
@classmethod
def from_netascii(cls, data: bytes) -> bytes:
"""Convert a netascii-encoded string into a string with platform-specific
newlines.
"""
if not hasattr(cls, "_re_from_netascii"):
cls._re_from_netascii = re.compile(rb"(\x0d\x0a|\x0d\x00)")
return cls._re_from_netascii.sub(cls._convert_from_netascii, data)
@classmethod
def to_netascii(cls, data: bytes) -> bytes:
"""Convert a string with platform-specific newlines into netascii."""
if not hasattr(cls, "_re_from_netascii"):
cls._re_to_netascii = re.compile(rb"(" + cls.NL + rb"|\x0d\x00)")
return cls._re_to_netascii.sub(cls._convert_to_netascii, data)
class NetAsciiCR(NetAsciiBase):
CR = b"\x0d"
LF = b"\x0a"
CRLF = b"\x0d\x0a" # CR + LF
NUL = b"\x00"
CRNUL = b"\x0d\x00" # CR + NUL
NL = b"\x0d"
class NetAsciiLF(NetAsciiBase):
CR = b"\x0d"
LF = b"\x0a"
CRLF = b"\x0d\x0a" # CR + LF
NUL = b"\x00"
CRNUL = b"\x0d\x00" # CR + NUL
NL = b"\x0a"
class NetAsciiCRLF(NetAsciiBase):
CR = b"\x0d"
LF = b"\x0a"
CRLF = b"\x0d\x0a" # CR + LF
NUL = b"\x00"
CRNUL = b"\x0d\x00" # CR + NUL
NL = b"\x0d\x0a"
NetAscii: NetAsciiBase
match os.linesep:
case "\x0d":
NetAscii = NetAsciiCR
case "\x0a":
NetAscii = NetAsciiLF
case "\x0d\x0a":
NetAscii = NetAsciiCRLF
case _:
raise RuntimeError(f"{os.linesep!r} is not a supported line separator")
class NetAsciiReceiverProxy:
writer: FileProtocol
netascii: NetAsciiBase
carry_cr: bool
def __init__(self, writer: FileProtocol, netascii: NetAsciiBase = NetAscii):
self.writer = writer
self.netascii = netascii
self.carry_cr = False
@property
def closed(self) -> bool:
return self.writer.closed
async def length(self) -> int | None:
return None
async def seek(self, offset: int, whence: int) -> int:
raise RuntimeError(f"{self.__class__.__name__} cannot seek")
async def read(self, length: int) -> bytes:
raise RuntimeError(f"{self.__class__.__name__} cannot read")
async def write(self, data: bytes) -> int:
if self.carry_cr:
data = self.netascii.CR + data
data = self.netascii.from_netascii(data)
if data.endswith(self.netascii.CR):
self.carry_cr = True
return await self.writer.write(data[:-1])
else:
self.carry_cr = False
return await self.writer.write(data)
async def close(self, complete: bool):
if self.carry_cr:
await self.writer.write(self.netascii.CR)
await self.writer.close(complete)
class NetAsciiSenderProxy:
reader: FileProtocol
netascii: NetAsciiBase
buffer: bytes
def __init__(self, reader: FileProtocol, netascii: NetAsciiBase = NetAscii):
self.reader = reader
self.netascii = netascii
self.buffer = b""
@property
def closed(self) -> bool:
return self.reader.closed
async def length(self) -> int | None:
return None
async def seek(self, offset: int, whence: int) -> int:
raise RuntimeError(f"{self.__class__.__name__} cannot seek")
async def read(self, length: int) -> bytes:
need_bytes = length - len(self.buffer)
if need_bytes <= 0:
data, self.buffer = self.buffer[:length], self.buffer[length:]
return data
data = await self.reader.read(need_bytes)
data = self.buffer + self.netascii.to_netascii(data)
data, self.buffer = data[:length], data[length:]
async def write(self, data: bytes) -> int:
raise RuntimeError(f"{self.__class__.__name__} cannot write")
async def close(self, complete: bool):
await self.reader.close(complete)