318 lines
10 KiB
Python
318 lines
10 KiB
Python
"""Async interface to Home Assistant."""
|
|
import asyncio
|
|
import datetime
|
|
import logging
|
|
from pprint import pformat
|
|
from typing import Annotated
|
|
from typing import Any
|
|
from typing import Awaitable
|
|
from typing import Callable
|
|
from typing import Literal
|
|
from typing import Union
|
|
|
|
import aiohttp
|
|
from hyperlink import URL
|
|
from pydantic import BaseModel
|
|
from pydantic import Field
|
|
from pydantic import ValidationError
|
|
from pydantic import parse_obj_as
|
|
from pydantic import parse_raw_as
|
|
from pydantic import validator
|
|
from typing_extensions import Self
|
|
import websockets.client
|
|
import websockets.exceptions
|
|
|
|
from greendeck.lib.util import task_done_callback
|
|
|
|
logger = logging.getLogger(__name__)
|
|
cache = {}
|
|
entities = []
|
|
|
|
# LAST_ID: int = 0
|
|
|
|
|
|
# def generate_id() -> int:
|
|
# """Generate a unique ID for websocket reqeuests."""
|
|
# global LAST_ID # NOQA: pylint(global-statement)
|
|
# LAST_ID += 1
|
|
# return LAST_ID
|
|
|
|
|
|
class Request(BaseModel):
|
|
_last_id: int = 0
|
|
|
|
id: int | None
|
|
|
|
@validator("id", always=True)
|
|
def generate_id(cls: type[Self], v: int) -> int:
|
|
if v is None:
|
|
cls._last_id += 1
|
|
return cls._last_id
|
|
return v
|
|
|
|
|
|
class AuthRequest(BaseModel):
|
|
"""Websocket authentication request."""
|
|
|
|
type: str = "auth"
|
|
access_token: str
|
|
|
|
|
|
class AuthRequiredResponse(BaseModel):
|
|
"""Authentication required response."""
|
|
|
|
type: Literal["auth_required"]
|
|
ha_version: str
|
|
|
|
|
|
class AuthOKResponse(BaseModel):
|
|
"""Authentication OK response."""
|
|
|
|
type: Literal["auth_ok"]
|
|
ha_version: str
|
|
|
|
|
|
class SubscribeEventsRequest(Request):
|
|
"""Subscribe to events request."""
|
|
|
|
type: str = "subscribe_events"
|
|
event_type: str
|
|
|
|
|
|
class GetStatesRequest(Request):
|
|
type: str = "get_states"
|
|
|
|
|
|
class EventContext(BaseModel):
|
|
id: str
|
|
parent_id: str | None
|
|
user_id: str | None
|
|
|
|
|
|
class EventState(BaseModel):
|
|
entity_id: str
|
|
state: str
|
|
attributes: dict[str, Any]
|
|
context: EventContext | None
|
|
last_changed: datetime.datetime | None
|
|
last_updated: datetime.datetime | None
|
|
|
|
|
|
class EventData(BaseModel):
|
|
entity_id: str
|
|
new_state: EventState
|
|
old_state: EventState | None
|
|
|
|
|
|
class Event(BaseModel):
|
|
context: EventContext
|
|
data: EventData
|
|
event_type: str
|
|
origin: str
|
|
time_fired: datetime.datetime
|
|
|
|
|
|
class EventResponse(BaseModel):
|
|
type: Literal["event"]
|
|
id: int
|
|
event: Event
|
|
|
|
|
|
class ResultResponse(BaseModel):
|
|
type: Literal["result"]
|
|
id: int
|
|
result: Any
|
|
success: bool
|
|
|
|
|
|
Response = Annotated[
|
|
Union[AuthRequiredResponse, AuthOKResponse, ResultResponse, EventResponse],
|
|
Field(discriminator="type"),
|
|
]
|
|
|
|
|
|
class TargetData(BaseModel):
|
|
entity_id: str
|
|
|
|
|
|
class CallServiceRequest(Request):
|
|
type: str = "call_service"
|
|
domain: str
|
|
service: str
|
|
service_data: Any
|
|
target: TargetData
|
|
|
|
|
|
class HomeAssistant:
|
|
websocket: websockets.client.WebSocketClientProtocol
|
|
host: URL
|
|
port: int
|
|
secure: bool
|
|
token: str
|
|
last_id: int
|
|
event_callbacks: dict[str, set[Callable[[EventResponse], Awaitable[None]]]]
|
|
response_callbacks: dict[int, Callable[[ResultResponse], Awaitable[None]]]
|
|
websocket_task: asyncio.Task
|
|
|
|
def __init__(self: Self, host: str, port: int | None, secure: bool, token: str):
|
|
self.host = host
|
|
self.port = port if port is not None else 443 if secure else 80
|
|
self.secure = secure
|
|
self.token = token
|
|
self.last_id = 0
|
|
self.websocket = None
|
|
self.event_callbacks = {}
|
|
self.response_callbacks = {}
|
|
self.websocket_task = None
|
|
|
|
def generate_id(self: Self) -> int:
|
|
self.last_id += 1
|
|
return self.last_id
|
|
|
|
def websocket_url(self: Self) -> URL:
|
|
return URL(
|
|
scheme="wss" if self.secure else "ws",
|
|
host=self.host,
|
|
port=self.port,
|
|
path=("api", "websocket"),
|
|
)
|
|
|
|
def rest_api_url(self: Self) -> URL:
|
|
return URL(
|
|
scheme="https" if self.secure else "http",
|
|
host=self.host,
|
|
port=self.port,
|
|
path=("api",),
|
|
)
|
|
|
|
def add_event_callback(
|
|
self: Self,
|
|
entity_id: str,
|
|
callback: Callable[[EventResponse], Awaitable[None]],
|
|
):
|
|
callbacks = self.event_callbacks.get(entity_id, set())
|
|
callbacks.add(callback)
|
|
self.event_callbacks[entity_id] = callbacks
|
|
|
|
async def call_service(
|
|
self, domain: str, service: str, data: dict
|
|
) -> list[EventState]:
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post(
|
|
str(self.rest_api_url().child("services", domain, service)),
|
|
headers={
|
|
"Authorization": f"Bearer {self.token}",
|
|
"Content-Type": "application/json",
|
|
},
|
|
json=data,
|
|
) as response:
|
|
if response.status in [200, 201]:
|
|
data = await response.json()
|
|
return parse_obj_as(list[EventState], data)
|
|
else:
|
|
logger.error("Home Assistant REST API error:")
|
|
for line in (await response.text()).splitlines():
|
|
logger.error(line)
|
|
|
|
async def get_state(self: Self, entity_id: str) -> EventState | None:
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(
|
|
str(self.rest_api_url().child("states", entity_id)),
|
|
headers={
|
|
"Authorization": f"Bearer {self.token}",
|
|
"Content-Type": "application/json",
|
|
},
|
|
) as response:
|
|
if response.status in [200, 201]:
|
|
data = await response.read()
|
|
return parse_raw_as(EventState, data)
|
|
else:
|
|
print(response.status)
|
|
print(await response.text())
|
|
return
|
|
|
|
async def start(self: Self):
|
|
self.websocket_task = asyncio.create_task(self.websocket_runner())
|
|
self.websocket_task.add_done_callback(task_done_callback)
|
|
|
|
async def authenticated_callback(self: Self):
|
|
request = SubscribeEventsRequest(event_type="state_changed")
|
|
await self.websocket.send(request.json())
|
|
|
|
async def get_events_callback(self: Self, response: Response) -> None:
|
|
print(response.result)
|
|
|
|
async def websocket_runner(self: Self) -> None:
|
|
while True:
|
|
try:
|
|
async with websockets.client.connect(
|
|
str(self.websocket_url())
|
|
) as self.websocket:
|
|
|
|
async for message in self.websocket:
|
|
try:
|
|
response = parse_raw_as(Response, message)
|
|
# pprint(response)
|
|
match response.type:
|
|
case "auth_required":
|
|
await self.websocket.send(
|
|
AuthRequest(access_token=self.token).json()
|
|
)
|
|
|
|
case "auth_ok":
|
|
t = asyncio.create_task(
|
|
self.authenticated_callback(),
|
|
name="authenticaion ok callback",
|
|
)
|
|
t.add_done_callback(task_done_callback)
|
|
|
|
case "event":
|
|
if (
|
|
response.event.data.entity_id
|
|
in self.event_callbacks
|
|
):
|
|
for callback in self.event_callbacks[
|
|
response.event.data.entity_id
|
|
]:
|
|
t = asyncio.create_task(
|
|
callback(response),
|
|
name=(
|
|
f"homeassistant event callback for "
|
|
f"{response.event.data.entity_id}"
|
|
),
|
|
)
|
|
t.add_done_callback(task_done_callback)
|
|
|
|
case "result":
|
|
if response.id in self.response_callbacks:
|
|
callback = self.response_callbacks[response.id]
|
|
t = asyncio.create_task(
|
|
callback(response),
|
|
name=f"result callback for {response.id}",
|
|
)
|
|
t.add_done_callback(task_done_callback)
|
|
del self.response_callbacks[response.id]
|
|
|
|
case _:
|
|
logger.debug(
|
|
f"unknown message type {response.type}"
|
|
)
|
|
for line in pformat(message).splitlines():
|
|
logger.debug(line)
|
|
|
|
except ValidationError as exception:
|
|
for line in pformat(exception).splitlines():
|
|
logger.error(line)
|
|
for line in pformat(message).splitlines():
|
|
logger.error(line)
|
|
|
|
except asyncio.CancelledError:
|
|
logger.error("homeassistant websocket task cancelled")
|
|
self.websocket_task = None
|
|
return
|
|
|
|
except websockets.exceptions.ConnectionClosedError:
|
|
logger.warning("websocket connection closed error")
|
|
# self.websocket_task = None
|
|
await asyncio.sleep(30.0)
|