Source code for gpp_client.generated.async_base_client

import asyncio
import enum
import json
from collections.abc import AsyncIterator
from typing import IO, Any, Optional, TypeVar, cast
from uuid import uuid4

import httpx
from pydantic import BaseModel
from pydantic_core import to_jsonable_python

from .base_model import UNSET, Upload
from .exceptions import (
    GraphQLClientError,
    GraphQLClientGraphQLMultiError,
    GraphQLClientHttpError,
    GraphQLClientInvalidMessageFormat,
    GraphQLClientInvalidResponseError,
)

try:
    from websockets import (  # type: ignore[import-not-found,unused-ignore]
        ClientConnection,
    )
    from websockets import (  # type: ignore[import-not-found,unused-ignore]
        connect as ws_connect,
    )
    from websockets.typing import (  # type: ignore[import-not-found,unused-ignore]
        Data,
        Origin,
        Subprotocol,
    )
except ImportError:
    from contextlib import asynccontextmanager

    @asynccontextmanager  # type: ignore
    async def ws_connect(*args, **kwargs):
        raise NotImplementedError("Subscriptions require 'websockets' package.")
        yield

    ClientConnection = Any  # type: ignore[misc,assignment,unused-ignore]
    Data = Any  # type: ignore[misc,assignment,unused-ignore]
    Origin = Any  # type: ignore[misc,assignment,unused-ignore]

    def Subprotocol(*args, **kwargs):  # type: ignore # noqa: N802, N803
        raise NotImplementedError("Subscriptions require 'websockets' package.")


Self = TypeVar("Self", bound="AsyncBaseClient")

GRAPHQL_TRANSPORT_WS = "graphql-transport-ws"


[docs] class GraphQLTransportWSMessageType(str, enum.Enum): CONNECTION_INIT = "connection_init" CONNECTION_ACK = "connection_ack" PING = "ping" PONG = "pong" SUBSCRIBE = "subscribe" NEXT = "next" ERROR = "error" COMPLETE = "complete"
[docs] class AsyncBaseClient: def __init__( self, url: str = "", headers: Optional[dict[str, str]] = None, http_client: Optional[httpx.AsyncClient] = None, ws_url: str = "", ws_headers: Optional[dict[str, Any]] = None, ws_origin: Optional[str] = None, ws_connection_init_payload: Optional[dict[str, Any]] = None, ) -> None: self.url = url self.headers = headers self.http_client = ( http_client if http_client else httpx.AsyncClient(headers=headers) ) self.ws_url = ws_url self.ws_headers = ws_headers or {} self.ws_origin = Origin(ws_origin) if ws_origin else None self.ws_connection_init_payload = ws_connection_init_payload async def __aenter__(self: Self) -> Self: return self async def __aexit__( self, exc_type: object, exc_val: object, exc_tb: object, ) -> None: await self.http_client.aclose()
[docs] async def execute( self, query: str, operation_name: Optional[str] = None, variables: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> httpx.Response: processed_variables, files, files_map = self._process_variables(variables) if files and files_map: return await self._execute_multipart( query=query, operation_name=operation_name, variables=processed_variables, files=files, files_map=files_map, **kwargs, ) return await self._execute_json( query=query, operation_name=operation_name, variables=processed_variables, **kwargs, )
[docs] def get_data(self, response: httpx.Response) -> dict[str, Any]: if not response.is_success: raise GraphQLClientHttpError( status_code=response.status_code, response=response ) try: response_json = response.json() except ValueError as exc: raise GraphQLClientInvalidResponseError(response=response) from exc if (not isinstance(response_json, dict)) or ( "data" not in response_json and "errors" not in response_json ): raise GraphQLClientInvalidResponseError(response=response) data = response_json.get("data") errors = response_json.get("errors") if errors: raise GraphQLClientGraphQLMultiError.from_errors_dicts( errors_dicts=errors, data=data ) return cast(dict[str, Any], data)
[docs] async def execute_ws( self, query: str, operation_name: Optional[str] = None, variables: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> AsyncIterator[dict[str, Any]]: headers = self.ws_headers.copy() headers.update(kwargs.pop("additional_headers", {})) merged_kwargs: dict[str, Any] = {"origin": self.ws_origin} merged_kwargs.update(kwargs) merged_kwargs["additional_headers"] = headers operation_id = str(uuid4()) async with ws_connect( self.ws_url, subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) # Wait for connection_ack; some servers (e.g. Hasura) send ping before # connection_ack, so we loop and handle pings until we get ack. try: await asyncio.wait_for( self._wait_for_connection_ack(websocket), timeout=5.0, ) except asyncio.TimeoutError as exc: raise GraphQLClientError( "Connection ack not received within 5 seconds" ) from exc await self._send_subscribe( websocket, operation_id=operation_id, query=query, operation_name=operation_name, variables=variables, ) async for message in websocket: data = await self._handle_ws_message(message, websocket) if data and "connection_ack" not in data: yield data
def _process_variables( self, variables: Optional[dict[str, Any]] ) -> tuple[ dict[str, Any], dict[str, tuple[str, IO[bytes], str]], dict[str, list[str]] ]: if not variables: return {}, {}, {} serializable_variables = self._convert_dict_to_json_serializable(variables) return self._get_files_from_variables(serializable_variables) def _convert_dict_to_json_serializable( self, dict_: dict[str, Any] ) -> dict[str, Any]: return { key: self._convert_value(value) for key, value in dict_.items() if value is not UNSET } def _convert_value(self, value: Any) -> Any: if isinstance(value, BaseModel): return value.model_dump(by_alias=True, exclude_unset=True) if isinstance(value, list): return [self._convert_value(item) for item in value] return value def _get_files_from_variables( self, variables: dict[str, Any] ) -> tuple[ dict[str, Any], dict[str, tuple[str, IO[bytes], str]], dict[str, list[str]] ]: files_map: dict[str, list[str]] = {} files_list: list[Upload] = [] def separate_files(path: str, obj: Any) -> Any: if isinstance(obj, list): nulled_list = [] for index, value in enumerate(obj): value = separate_files(f"{path}.{index}", value) nulled_list.append(value) return nulled_list if isinstance(obj, dict): nulled_dict = {} for key, value in obj.items(): value = separate_files(f"{path}.{key}", value) nulled_dict[key] = value return nulled_dict if isinstance(obj, Upload): if obj in files_list: file_index = files_list.index(obj) files_map[str(file_index)].append(path) else: file_index = len(files_list) files_list.append(obj) files_map[str(file_index)] = [path] return None return obj nulled_variables = separate_files("variables", variables) files: dict[str, tuple[str, IO[bytes], str]] = { str(i): (file_.filename, cast(IO[bytes], file_.content), file_.content_type) for i, file_ in enumerate(files_list) } return nulled_variables, files, files_map async def _execute_multipart( self, query: str, operation_name: Optional[str], variables: dict[str, Any], files: dict[str, tuple[str, IO[bytes], str]], files_map: dict[str, list[str]], **kwargs: Any, ) -> httpx.Response: data = { "operations": json.dumps( { "query": query, "operationName": operation_name, "variables": variables, }, default=to_jsonable_python, ), "map": json.dumps(files_map, default=to_jsonable_python), } return await self.http_client.post( url=self.url, data=data, files=files, **kwargs ) async def _execute_json( self, query: str, operation_name: Optional[str], variables: dict[str, Any], **kwargs: Any, ) -> httpx.Response: headers: dict[str, str] = {"Content-type": "application/json"} headers.update(kwargs.get("headers", {})) merged_kwargs: dict[str, Any] = kwargs.copy() merged_kwargs["headers"] = headers return await self.http_client.post( url=self.url, content=json.dumps( { "query": query, "operationName": operation_name, "variables": variables, }, default=to_jsonable_python, ), **merged_kwargs, ) async def _send_connection_init(self, websocket: ClientConnection) -> None: payload: dict[str, Any] = { "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value } if self.ws_connection_init_payload: payload["payload"] = self.ws_connection_init_payload await websocket.send(json.dumps(payload)) async def _wait_for_connection_ack(self, websocket: ClientConnection) -> None: """Read messages until connection_ack; handle ping/pong in between.""" async for message in websocket: data = await self._handle_ws_message(message, websocket) if data is not None and "connection_ack" in data: return async def _send_subscribe( self, websocket: ClientConnection, operation_id: str, query: str, operation_name: Optional[str] = None, variables: Optional[dict[str, Any]] = None, ) -> None: payload_inner: dict[str, Any] = { "query": query, "operationName": operation_name, } if variables: payload_inner["variables"] = self._convert_dict_to_json_serializable( variables ) payload: dict[str, Any] = { "id": operation_id, "type": GraphQLTransportWSMessageType.SUBSCRIBE.value, "payload": payload_inner, } await websocket.send(json.dumps(payload)) async def _handle_ws_message( self, message: Data, websocket: ClientConnection, expected_type: Optional[GraphQLTransportWSMessageType] = None, ) -> Optional[dict[str, Any]]: try: message_dict = json.loads(message) except json.JSONDecodeError as exc: raise GraphQLClientInvalidMessageFormat(message=message) from exc type_ = message_dict.get("type") payload = message_dict.get("payload", {}) if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: raise GraphQLClientInvalidMessageFormat(message=message) if expected_type and expected_type != type_: raise GraphQLClientInvalidMessageFormat( f"Invalid message received. Expected: {expected_type.value}" ) if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) return cast(dict[str, Any], payload["data"]) if type_ == GraphQLTransportWSMessageType.COMPLETE: await websocket.close() elif type_ == GraphQLTransportWSMessageType.PING: await websocket.send( json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) ) elif type_ == GraphQLTransportWSMessageType.ERROR: raise GraphQLClientGraphQLMultiError.from_errors_dicts( errors_dicts=payload, data=message_dict ) elif type_ == GraphQLTransportWSMessageType.CONNECTION_ACK: return {"connection_ack": True} return None