"""Async interface to Home Assistant.""" import asyncio import datetime import logging from pprint import pformat from typing import Annotated from typing import Any from typing import Awaitable from typing import Callable from typing import Literal from typing import Union import aiohttp from hyperlink import URL from pydantic import BaseModel from pydantic import Field from pydantic import ValidationError from pydantic import parse_obj_as from pydantic import parse_raw_as from pydantic import validator from typing_extensions import Self import websockets.client import websockets.exceptions from greendeck.lib.util import task_done_callback logger = logging.getLogger(__name__) cache = {} entities = [] # LAST_ID: int = 0 # def generate_id() -> int: # """Generate a unique ID for websocket reqeuests.""" # global LAST_ID # NOQA: pylint(global-statement) # LAST_ID += 1 # return LAST_ID class Request(BaseModel): _last_id: int = 0 id: int | None @validator("id", always=True) def generate_id(cls: type[Self], v: int) -> int: if v is None: cls._last_id += 1 return cls._last_id return v class AuthRequest(BaseModel): """Websocket authentication request.""" type: str = "auth" access_token: str class AuthRequiredResponse(BaseModel): """Authentication required response.""" type: Literal["auth_required"] ha_version: str class AuthOKResponse(BaseModel): """Authentication OK response.""" type: Literal["auth_ok"] ha_version: str class SubscribeEventsRequest(Request): """Subscribe to events request.""" type: str = "subscribe_events" event_type: str class GetStatesRequest(Request): type: str = "get_states" class EventContext(BaseModel): id: str parent_id: str | None user_id: str | None class EventState(BaseModel): entity_id: str state: str attributes: dict[str, Any] context: EventContext | None last_changed: datetime.datetime | None last_updated: datetime.datetime | None class EventData(BaseModel): entity_id: str new_state: EventState old_state: EventState class Event(BaseModel): context: EventContext data: EventData event_type: str origin: str time_fired: datetime.datetime class EventResponse(BaseModel): type: Literal["event"] id: int event: Event class ResultResponse(BaseModel): type: Literal["result"] id: int result: Any success: bool Response = Annotated[ Union[AuthRequiredResponse, AuthOKResponse, ResultResponse, EventResponse], Field(discriminator="type"), ] class TargetData(BaseModel): entity_id: str class CallServiceRequest(Request): type: str = "call_service" domain: str service: str service_data: Any target: TargetData class HomeAssistant: websocket: websockets.client.WebSocketClientProtocol host: URL port: int secure: bool token: str last_id: int event_callbacks: dict[str, set[Callable[[EventResponse], Awaitable[None]]]] response_callbacks: dict[int, Callable[[ResultResponse], Awaitable[None]]] websocket_task: asyncio.Task def __init__(self: Self, host: str, port: int | None, secure: bool, token: str): self.host = host self.port = port if port is not None else 443 if secure else 80 self.secure = secure self.token = token self.last_id = 0 self.websocket = None self.event_callbacks = {} self.response_callbacks = {} self.websocket_task = None def generate_id() -> int: global LAST_ID LAST_ID += 1 return LAST_ID def websocket_url(self: Self) -> URL: return URL( scheme="wss" if self.secure else "ws", host=self.host, port=self.port, path=("api", "websocket"), ) def rest_api_url(self: Self) -> URL: return URL( scheme="https" if self.secure else "http", host=self.host, port=self.port, path=("api",), ) def add_event_callback( self: Self, entity_id: str, callback: Callable[[EventResponse], Awaitable[None]], ): callbacks = self.event_callbacks.get(entity_id, set()) callbacks.add(callback) self.event_callbacks[entity_id] = callbacks async def call_service( self, domain: str, service: str, data: dict ) -> list[EventState]: async with aiohttp.ClientSession() as session: async with session.post( str(self.rest_api_url().child("services", domain, service)), headers={ "Authorization": f"Bearer {self.token}", "Content-Type": "application/json", }, json=data, ) as response: if response.status in [200, 201]: data = await response.json() return parse_obj_as(list[EventState], data) else: logger.error("Home Assistant REST API error:") for line in (await response.text()).splitlines(): logger.error(line) async def get_state(self: Self, entity_id: str) -> EventState | None: async with aiohttp.ClientSession() as session: async with session.get( str(self.rest_api_url().child("states", entity_id)), headers={ "Authorization": f"Bearer {self.token}", "Content-Type": "application/json", }, ) as response: if response.status in [200, 201]: data = await response.read() return parse_raw_as(EventState, data) else: print(response.status) print(await response.text()) return async def start(self: Self): self.websocket_task = asyncio.create_task(self.websocket_runner()) self.websocket_task.add_done_callback(task_done_callback) async def authenticated_callback(self: Self): request = SubscribeEventsRequest(event_type="state_changed") await self.websocket.send(request.json()) async def get_events_callback(self: Self, response: Response) -> None: print(response.result) async def websocket_runner(self: Self) -> None: try: async with websockets.client.connect( str(self.websocket_url()) ) as self.websocket: async for message in self.websocket: try: response = parse_raw_as(Response, message) # pprint(response) match response.type: case "auth_required": await self.websocket.send( AuthRequest(access_token=self.token).json() ) case "auth_ok": t = asyncio.create_task( self.authenticated_callback(), name="authenticaion ok callback", ) t.add_done_callback(task_done_callback) case "event": if ( response.event.data.entity_id in self.event_callbacks ): for callback in self.event_callbacks[ response.event.data.entity_id ]: t = asyncio.create_task( callback(response), name=( f"homeassistant event callback for " f"{response.event.data.entity_id}" ), ) t.add_done_callback(task_done_callback) case "result": if response.id in self.response_callbacks: callback = self.response_callbacks[response.id] t = asyncio.create_task( callback(response), name=f"result callback for {response.id}", ) t.add_done_callback(task_done_callback) del self.response_callbacks[response.id] case _: logger.debug(f"unknown message type {response.type}") for line in pformat(message).splitlines(): logger.debug(line) except ValidationError as exception: for line in pformat(exception).splitlines(): logger.error(line) for line in pformat(message).splitlines(): logger.error(line) except asyncio.CancelledError: print("homeassistant websocket task cancelled") self.websocket_task = None except websockets.exceptions.ConnectionClosedError: print("websocket connection closed error") self.websocket_task = None