greendeck/greendeck/lib/homeassistant/__init__.py
2023-02-11 10:16:06 -06:00

326 lines
10 KiB
Python

"""Async interface to Home Assistant."""
import asyncio
import datetime
import logging
from pprint import pformat
from typing import Annotated
from typing import Any
from typing import Awaitable
from typing import Callable
from typing import Literal
from typing import Union
import aiohttp
from hyperlink import URL
from pydantic import BaseModel
from pydantic import Field
from pydantic import ValidationError
from pydantic import parse_obj_as
from pydantic import parse_raw_as
from pydantic import validator
from typing_extensions import Self
import websockets.client
import websockets.exceptions
from greendeck.lib.util import task_done_callback
logger = logging.getLogger(__name__)
cache = {}
entities = []
# LAST_ID: int = 0
# def generate_id() -> int:
# """Generate a unique ID for websocket reqeuests."""
# global LAST_ID # NOQA: pylint(global-statement)
# LAST_ID += 1
# return LAST_ID
class Request(BaseModel):
_last_id: int = 0
id: int | None
@validator("id", always=True)
def generate_id(cls: type[Self], v: int) -> int:
if v is None:
cls._last_id += 1
return cls._last_id
return v
class AuthRequest(BaseModel):
"""Websocket authentication request."""
type: str = "auth"
access_token: str
class AuthRequiredResponse(BaseModel):
"""Authentication required response."""
type: Literal["auth_required"]
ha_version: str
class AuthOKResponse(BaseModel):
"""Authentication OK response."""
type: Literal["auth_ok"]
ha_version: str
class SubscribeEventsRequest(Request):
"""Subscribe to events request."""
type: str = "subscribe_events"
event_type: str
class GetStatesRequest(Request):
type: str = "get_states"
class EventContext(BaseModel):
id: str
parent_id: str | None
user_id: str | None
class EventState(BaseModel):
entity_id: str
state: str
attributes: dict[str, Any]
context: EventContext | None
last_changed: datetime.datetime | None
last_updated: datetime.datetime | None
class EventData(BaseModel):
entity_id: str
new_state: EventState | None
old_state: EventState | None
class Event(BaseModel):
context: EventContext
data: EventData
event_type: str
origin: str
time_fired: datetime.datetime
class EventResponse(BaseModel):
type: Literal["event"]
id: int
event: Event
class ResultResponse(BaseModel):
type: Literal["result"]
id: int
result: Any
success: bool
Response = Annotated[
Union[AuthRequiredResponse, AuthOKResponse, ResultResponse, EventResponse],
Field(discriminator="type"),
]
class TargetData(BaseModel):
entity_id: str
class CallServiceRequest(Request):
type: str = "call_service"
domain: str
service: str
service_data: Any
target: TargetData
class HomeAssistant:
websocket: websockets.client.WebSocketClientProtocol
host: URL
port: int
secure: bool
token: str
last_id: int
event_callbacks: dict[str, set[Callable[[EventResponse], Awaitable[None]]]]
response_callbacks: dict[int, Callable[[ResultResponse], Awaitable[None]]]
websocket_task: asyncio.Task
def __init__(self: Self, host: str, port: int | None, secure: bool, token: str):
self.host = host
self.port = port if port is not None else 443 if secure else 80
self.secure = secure
self.token = token
self.last_id = 0
self.websocket = None
self.event_callbacks = {}
self.response_callbacks = {}
self.websocket_task = None
def generate_id(self: Self) -> int:
self.last_id += 1
return self.last_id
def websocket_url(self: Self) -> URL:
return URL(
scheme="wss" if self.secure else "ws",
host=self.host,
port=self.port,
path=("api", "websocket"),
)
def rest_api_url(self: Self) -> URL:
return URL(
scheme="https" if self.secure else "http",
host=self.host,
port=self.port,
path=("api",),
)
def add_event_callback(
self: Self,
entity_id: str,
callback: Callable[[EventResponse], Awaitable[None]],
):
callbacks = self.event_callbacks.get(entity_id, set())
callbacks.add(callback)
self.event_callbacks[entity_id] = callbacks
async def call_service(
self, domain: str, service: str, data: dict
) -> list[EventState]:
async with aiohttp.ClientSession() as session:
async with session.post(
str(self.rest_api_url().child("services", domain, service)),
headers={
"Authorization": f"Bearer {self.token}",
"Content-Type": "application/json",
},
json=data,
) as response:
if response.status in [200, 201]:
data = await response.json()
return parse_obj_as(list[EventState], data)
else:
logger.error("Home Assistant REST API error:")
for line in (await response.text()).splitlines():
logger.error(line)
async def get_state(self: Self, entity_id: str) -> EventState | None:
async with aiohttp.ClientSession() as session:
async with session.get(
str(self.rest_api_url().child("states", entity_id)),
headers={
"Authorization": f"Bearer {self.token}",
"Content-Type": "application/json",
},
) as response:
if response.status in [200, 201]:
data = await response.read()
return parse_raw_as(EventState, data)
else:
print(response.status)
print(await response.text())
return
async def start(self: Self):
self.websocket_task = asyncio.create_task(self.websocket_runner())
self.websocket_task.add_done_callback(task_done_callback)
async def authenticated_callback(self: Self):
request = SubscribeEventsRequest(event_type="state_changed")
await self.websocket.send(request.json())
async def get_events_callback(self: Self, response: Response) -> None:
print(response.result)
async def websocket_runner(self: Self) -> None:
while True:
try:
async with websockets.client.connect(
str(self.websocket_url())
) as self.websocket:
async for message in self.websocket:
try:
response = parse_raw_as(Response, message)
# pprint(response)
match response.type:
case "auth_required":
await self.websocket.send(
AuthRequest(access_token=self.token).json()
)
case "auth_ok":
t = asyncio.create_task(
self.authenticated_callback(),
name="authenticaion ok callback",
)
t.add_done_callback(task_done_callback)
case "event":
if (
response.event.data.entity_id
in self.event_callbacks
):
for callback in self.event_callbacks[
response.event.data.entity_id
]:
t = asyncio.create_task(
callback(response),
name=(
f"homeassistant event callback for "
f"{response.event.data.entity_id}"
),
)
t.add_done_callback(task_done_callback)
case "result":
if response.id in self.response_callbacks:
callback = self.response_callbacks[response.id]
t = asyncio.create_task(
callback(response),
name=f"result callback for {response.id}",
)
t.add_done_callback(task_done_callback)
del self.response_callbacks[response.id]
case _:
logger.debug(
f"unknown message type {response.type}"
)
for line in pformat(message).splitlines():
logger.debug(line)
except ValidationError as exception:
for line in pformat(exception).splitlines():
logger.error(line)
for line in pformat(message).splitlines():
logger.error(line)
except asyncio.CancelledError:
logger.error("homeassistant websocket task cancelled")
self.websocket_task = None
return
except websockets.exceptions.ConnectionClosedError as exception:
logger.error("websocket connection closed error")
for line in pformat(exception).splitlines():
logger.error(line)
# self.websocket_task = None
await asyncio.sleep(30.0)
except websockets.exceptions.InvalidStatusCode as exception:
logger.error("websocket invalid status code")
for line in pformat(exception).splitlines():
logger.error(line)
await asyncio.sleep(30.0)