from __future__ import annotations import inspect from types import TracebackType from typing import TYPE_CHECKING, Any, Generic, Callable, Iterable, Awaitable, AsyncIterator, cast from typing_extensions import Self, Iterator, assert_never from jiter import from_json from ._types import ParsedChoiceSnapshot, ParsedChatCompletionSnapshot, ParsedChatCompletionMessageSnapshot from ._events import ( ChunkEvent, ContentDoneEvent, RefusalDoneEvent, ContentDeltaEvent, RefusalDeltaEvent, LogprobsContentDoneEvent, LogprobsRefusalDoneEvent, ChatCompletionStreamEvent, LogprobsContentDeltaEvent, LogprobsRefusalDeltaEvent, FunctionToolCallArgumentsDoneEvent, FunctionToolCallArgumentsDeltaEvent, ) from .._deltas import accumulate_delta from ...._types import NOT_GIVEN, IncEx, NotGiven from ...._utils import is_given, consume_sync_iterator, consume_async_iterator from ...._compat import model_dump from ...._models import build, construct_type from ..._parsing import ( ResponseFormatT, has_parseable_input, maybe_parse_content, parse_chat_completion, get_input_tool_by_name, solve_response_format_t, parse_function_tool_arguments, ) from ...._streaming import Stream, AsyncStream from ....types.chat import ChatCompletionChunk, ParsedChatCompletion, ChatCompletionToolParam from ...._exceptions import LengthFinishReasonError, ContentFilterFinishReasonError from ....types.chat.chat_completion import ChoiceLogprobs from ....types.chat.chat_completion_chunk import Choice as ChoiceChunk from ....types.chat.completion_create_params import ResponseFormat as ResponseFormatParam class ChatCompletionStream(Generic[ResponseFormatT]): """Wrapper over the Chat Completions streaming API that adds helpful events such as `content.done`, supports automatically parsing responses & tool calls and accumulates a `ChatCompletion` object from each individual chunk. https://platform.openai.com/docs/api-reference/streaming """ def __init__( self, *, raw_stream: Stream[ChatCompletionChunk], response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven, input_tools: Iterable[ChatCompletionToolParam] | NotGiven, ) -> None: self._raw_stream = raw_stream self._response = raw_stream.response self._iterator = self.__stream__() self._state = ChatCompletionStreamState(response_format=response_format, input_tools=input_tools) def __next__(self) -> ChatCompletionStreamEvent[ResponseFormatT]: return self._iterator.__next__() def __iter__(self) -> Iterator[ChatCompletionStreamEvent[ResponseFormatT]]: for item in self._iterator: yield item def __enter__(self) -> Self: return self def __exit__( self, exc_type: type[BaseException] | None, exc: BaseException | None, exc_tb: TracebackType | None, ) -> None: self.close() def close(self) -> None: """ Close the response and release the connection. Automatically called if the response body is read to completion. """ self._response.close() def get_final_completion(self) -> ParsedChatCompletion[ResponseFormatT]: """Waits until the stream has been read to completion and returns the accumulated `ParsedChatCompletion` object. If you passed a class type to `.stream()`, the `completion.choices[0].message.parsed` property will be the content deserialised into that class, if there was any content returned by the API. """ self.until_done() return self._state.get_final_completion() def until_done(self) -> Self: """Blocks until the stream has been consumed.""" consume_sync_iterator(self) return self @property def current_completion_snapshot(self) -> ParsedChatCompletionSnapshot: return self._state.current_completion_snapshot def __stream__(self) -> Iterator[ChatCompletionStreamEvent[ResponseFormatT]]: for sse_event in self._raw_stream: events_to_fire = self._state.handle_chunk(sse_event) for event in events_to_fire: yield event class ChatCompletionStreamManager(Generic[ResponseFormatT]): """Context manager over a `ChatCompletionStream` that is returned by `.stream()`. This context manager ensures the response cannot be leaked if you don't read the stream to completion. Usage: ```py with client.beta.chat.completions.stream(...) as stream: for event in stream: ... ``` """ def __init__( self, api_request: Callable[[], Stream[ChatCompletionChunk]], *, response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven, input_tools: Iterable[ChatCompletionToolParam] | NotGiven, ) -> None: self.__stream: ChatCompletionStream[ResponseFormatT] | None = None self.__api_request = api_request self.__response_format = response_format self.__input_tools = input_tools def __enter__(self) -> ChatCompletionStream[ResponseFormatT]: raw_stream = self.__api_request() self.__stream = ChatCompletionStream( raw_stream=raw_stream, response_format=self.__response_format, input_tools=self.__input_tools, ) return self.__stream def __exit__( self, exc_type: type[BaseException] | None, exc: BaseException | None, exc_tb: TracebackType | None, ) -> None: if self.__stream is not None: self.__stream.close() class AsyncChatCompletionStream(Generic[ResponseFormatT]): """Wrapper over the Chat Completions streaming API that adds helpful events such as `content.done`, supports automatically parsing responses & tool calls and accumulates a `ChatCompletion` object from each individual chunk. https://platform.openai.com/docs/api-reference/streaming """ def __init__( self, *, raw_stream: AsyncStream[ChatCompletionChunk], response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven, input_tools: Iterable[ChatCompletionToolParam] | NotGiven, ) -> None: self._raw_stream = raw_stream self._response = raw_stream.response self._iterator = self.__stream__() self._state = ChatCompletionStreamState(response_format=response_format, input_tools=input_tools) async def __anext__(self) -> ChatCompletionStreamEvent[ResponseFormatT]: return await self._iterator.__anext__() async def __aiter__(self) -> AsyncIterator[ChatCompletionStreamEvent[ResponseFormatT]]: async for item in self._iterator: yield item async def __aenter__(self) -> Self: return self async def __aexit__( self, exc_type: type[BaseException] | None, exc: BaseException | None, exc_tb: TracebackType | None, ) -> None: await self.close() async def close(self) -> None: """ Close the response and release the connection. Automatically called if the response body is read to completion. """ await self._response.aclose() async def get_final_completion(self) -> ParsedChatCompletion[ResponseFormatT]: """Waits until the stream has been read to completion and returns the accumulated `ParsedChatCompletion` object. If you passed a class type to `.stream()`, the `completion.choices[0].message.parsed` property will be the content deserialised into that class, if there was any content returned by the API. """ await self.until_done() return self._state.get_final_completion() async def until_done(self) -> Self: """Blocks until the stream has been consumed.""" await consume_async_iterator(self) return self @property def current_completion_snapshot(self) -> ParsedChatCompletionSnapshot: return self._state.current_completion_snapshot async def __stream__(self) -> AsyncIterator[ChatCompletionStreamEvent[ResponseFormatT]]: async for sse_event in self._raw_stream: events_to_fire = self._state.handle_chunk(sse_event) for event in events_to_fire: yield event class AsyncChatCompletionStreamManager(Generic[ResponseFormatT]): """Context manager over a `AsyncChatCompletionStream` that is returned by `.stream()`. This context manager ensures the response cannot be leaked if you don't read the stream to completion. Usage: ```py async with client.beta.chat.completions.stream(...) as stream: for event in stream: ... ``` """ def __init__( self, api_request: Awaitable[AsyncStream[ChatCompletionChunk]], *, response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven, input_tools: Iterable[ChatCompletionToolParam] | NotGiven, ) -> None: self.__stream: AsyncChatCompletionStream[ResponseFormatT] | None = None self.__api_request = api_request self.__response_format = response_format self.__input_tools = input_tools async def __aenter__(self) -> AsyncChatCompletionStream[ResponseFormatT]: raw_stream = await self.__api_request self.__stream = AsyncChatCompletionStream( raw_stream=raw_stream, response_format=self.__response_format, input_tools=self.__input_tools, ) return self.__stream async def __aexit__( self, exc_type: type[BaseException] | None, exc: BaseException | None, exc_tb: TracebackType | None, ) -> None: if self.__stream is not None: await self.__stream.close() class ChatCompletionStreamState(Generic[ResponseFormatT]): """Helper class for manually accumulating `ChatCompletionChunk`s into a final `ChatCompletion` object. This is useful in cases where you can't always use the `.stream()` method, e.g. ```py from openai.lib.streaming.chat import ChatCompletionStreamState state = ChatCompletionStreamState() stream = client.chat.completions.create(..., stream=True) for chunk in response: state.handle_chunk(chunk) # can also access the accumulated `ChatCompletion` mid-stream state.current_completion_snapshot print(state.get_final_completion()) ``` """ def __init__( self, *, input_tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven = NOT_GIVEN, ) -> None: self.__current_completion_snapshot: ParsedChatCompletionSnapshot | None = None self.__choice_event_states: list[ChoiceEventState] = [] self._input_tools = [tool for tool in input_tools] if is_given(input_tools) else [] self._response_format = response_format self._rich_response_format: type | NotGiven = response_format if inspect.isclass(response_format) else NOT_GIVEN def get_final_completion(self) -> ParsedChatCompletion[ResponseFormatT]: """Parse the final completion object. Note this does not provide any guarantees that the stream has actually finished, you must only call this method when the stream is finished. """ return parse_chat_completion( chat_completion=self.current_completion_snapshot, response_format=self._rich_response_format, input_tools=self._input_tools, ) @property def current_completion_snapshot(self) -> ParsedChatCompletionSnapshot: assert self.__current_completion_snapshot is not None return self.__current_completion_snapshot def handle_chunk(self, chunk: ChatCompletionChunk) -> Iterable[ChatCompletionStreamEvent[ResponseFormatT]]: """Accumulate a new chunk into the snapshot and returns an iterable of events to yield.""" self.__current_completion_snapshot = self._accumulate_chunk(chunk) return self._build_events( chunk=chunk, completion_snapshot=self.__current_completion_snapshot, ) def _get_choice_state(self, choice: ChoiceChunk) -> ChoiceEventState: try: return self.__choice_event_states[choice.index] except IndexError: choice_state = ChoiceEventState(input_tools=self._input_tools) self.__choice_event_states.append(choice_state) return choice_state def _accumulate_chunk(self, chunk: ChatCompletionChunk) -> ParsedChatCompletionSnapshot: completion_snapshot = self.__current_completion_snapshot if completion_snapshot is None: return _convert_initial_chunk_into_snapshot(chunk) for choice in chunk.choices: try: choice_snapshot = completion_snapshot.choices[choice.index] previous_tool_calls = choice_snapshot.message.tool_calls or [] choice_snapshot.message = cast( ParsedChatCompletionMessageSnapshot, construct_type( type_=ParsedChatCompletionMessageSnapshot, value=accumulate_delta( cast( "dict[object, object]", model_dump( choice_snapshot.message, # we don't want to serialise / deserialise our custom properties # as they won't appear in the delta and we don't want to have to # continuosly reparse the content exclude=cast( # cast required as mypy isn't smart enough to infer `True` here to `Literal[True]` IncEx, { "parsed": True, "tool_calls": { idx: {"function": {"parsed_arguments": True}} for idx, _ in enumerate(choice_snapshot.message.tool_calls or []) }, }, ), ), ), cast("dict[object, object]", choice.delta.to_dict()), ), ), ) # ensure tools that have already been parsed are added back into the newly # constructed message snapshot for tool_index, prev_tool in enumerate(previous_tool_calls): new_tool = (choice_snapshot.message.tool_calls or [])[tool_index] if prev_tool.type == "function": assert new_tool.type == "function" new_tool.function.parsed_arguments = prev_tool.function.parsed_arguments elif TYPE_CHECKING: # type: ignore[unreachable] assert_never(prev_tool) except IndexError: choice_snapshot = cast( ParsedChoiceSnapshot, construct_type( type_=ParsedChoiceSnapshot, value={ **choice.model_dump(exclude_unset=True, exclude={"delta"}), "message": choice.delta.to_dict(), }, ), ) completion_snapshot.choices.append(choice_snapshot) if choice.finish_reason: choice_snapshot.finish_reason = choice.finish_reason if has_parseable_input(response_format=self._response_format, input_tools=self._input_tools): if choice.finish_reason == "length": # at the time of writing, `.usage` will always be `None` but # we include it here in case that is changed in the future raise LengthFinishReasonError(completion=completion_snapshot) if choice.finish_reason == "content_filter": raise ContentFilterFinishReasonError() if ( choice_snapshot.message.content and not choice_snapshot.message.refusal and is_given(self._rich_response_format) ): choice_snapshot.message.parsed = from_json( bytes(choice_snapshot.message.content, "utf-8"), partial_mode=True, ) for tool_call_chunk in choice.delta.tool_calls or []: tool_call_snapshot = (choice_snapshot.message.tool_calls or [])[tool_call_chunk.index] if tool_call_snapshot.type == "function": input_tool = get_input_tool_by_name( input_tools=self._input_tools, name=tool_call_snapshot.function.name ) if ( input_tool and input_tool.get("function", {}).get("strict") and tool_call_snapshot.function.arguments ): tool_call_snapshot.function.parsed_arguments = from_json( bytes(tool_call_snapshot.function.arguments, "utf-8"), partial_mode=True, ) elif TYPE_CHECKING: # type: ignore[unreachable] assert_never(tool_call_snapshot) if choice.logprobs is not None: if choice_snapshot.logprobs is None: choice_snapshot.logprobs = build( ChoiceLogprobs, content=choice.logprobs.content, refusal=choice.logprobs.refusal, ) else: if choice.logprobs.content: if choice_snapshot.logprobs.content is None: choice_snapshot.logprobs.content = [] choice_snapshot.logprobs.content.extend(choice.logprobs.content) if choice.logprobs.refusal: if choice_snapshot.logprobs.refusal is None: choice_snapshot.logprobs.refusal = [] choice_snapshot.logprobs.refusal.extend(choice.logprobs.refusal) completion_snapshot.usage = chunk.usage completion_snapshot.system_fingerprint = chunk.system_fingerprint return completion_snapshot def _build_events( self, *, chunk: ChatCompletionChunk, completion_snapshot: ParsedChatCompletionSnapshot, ) -> list[ChatCompletionStreamEvent[ResponseFormatT]]: events_to_fire: list[ChatCompletionStreamEvent[ResponseFormatT]] = [] events_to_fire.append( build(ChunkEvent, type="chunk", chunk=chunk, snapshot=completion_snapshot), ) for choice in chunk.choices: choice_state = self._get_choice_state(choice) choice_snapshot = completion_snapshot.choices[choice.index] if choice.delta.content is not None and choice_snapshot.message.content is not None: events_to_fire.append( build( ContentDeltaEvent, type="content.delta", delta=choice.delta.content, snapshot=choice_snapshot.message.content, parsed=choice_snapshot.message.parsed, ) ) if choice.delta.refusal is not None and choice_snapshot.message.refusal is not None: events_to_fire.append( build( RefusalDeltaEvent, type="refusal.delta", delta=choice.delta.refusal, snapshot=choice_snapshot.message.refusal, ) ) if choice.delta.tool_calls: tool_calls = choice_snapshot.message.tool_calls assert tool_calls is not None for tool_call_delta in choice.delta.tool_calls: tool_call = tool_calls[tool_call_delta.index] if tool_call.type == "function": assert tool_call_delta.function is not None events_to_fire.append( build( FunctionToolCallArgumentsDeltaEvent, type="tool_calls.function.arguments.delta", name=tool_call.function.name, index=tool_call_delta.index, arguments=tool_call.function.arguments, parsed_arguments=tool_call.function.parsed_arguments, arguments_delta=tool_call_delta.function.arguments or "", ) ) elif TYPE_CHECKING: # type: ignore[unreachable] assert_never(tool_call) if choice.logprobs is not None and choice_snapshot.logprobs is not None: if choice.logprobs.content and choice_snapshot.logprobs.content: events_to_fire.append( build( LogprobsContentDeltaEvent, type="logprobs.content.delta", content=choice.logprobs.content, snapshot=choice_snapshot.logprobs.content, ), ) if choice.logprobs.refusal and choice_snapshot.logprobs.refusal: events_to_fire.append( build( LogprobsRefusalDeltaEvent, type="logprobs.refusal.delta", refusal=choice.logprobs.refusal, snapshot=choice_snapshot.logprobs.refusal, ), ) events_to_fire.extend( choice_state.get_done_events( choice_chunk=choice, choice_snapshot=choice_snapshot, response_format=self._response_format, ) ) return events_to_fire class ChoiceEventState: def __init__(self, *, input_tools: list[ChatCompletionToolParam]) -> None: self._input_tools = input_tools self._content_done = False self._refusal_done = False self._logprobs_content_done = False self._logprobs_refusal_done = False self._done_tool_calls: set[int] = set() self.__current_tool_call_index: int | None = None def get_done_events( self, *, choice_chunk: ChoiceChunk, choice_snapshot: ParsedChoiceSnapshot, response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven, ) -> list[ChatCompletionStreamEvent[ResponseFormatT]]: events_to_fire: list[ChatCompletionStreamEvent[ResponseFormatT]] = [] if choice_snapshot.finish_reason: events_to_fire.extend( self._content_done_events(choice_snapshot=choice_snapshot, response_format=response_format) ) if ( self.__current_tool_call_index is not None and self.__current_tool_call_index not in self._done_tool_calls ): self._add_tool_done_event( events_to_fire=events_to_fire, choice_snapshot=choice_snapshot, tool_index=self.__current_tool_call_index, ) for tool_call in choice_chunk.delta.tool_calls or []: if self.__current_tool_call_index != tool_call.index: events_to_fire.extend( self._content_done_events(choice_snapshot=choice_snapshot, response_format=response_format) ) if self.__current_tool_call_index is not None: self._add_tool_done_event( events_to_fire=events_to_fire, choice_snapshot=choice_snapshot, tool_index=self.__current_tool_call_index, ) self.__current_tool_call_index = tool_call.index return events_to_fire def _content_done_events( self, *, choice_snapshot: ParsedChoiceSnapshot, response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven, ) -> list[ChatCompletionStreamEvent[ResponseFormatT]]: events_to_fire: list[ChatCompletionStreamEvent[ResponseFormatT]] = [] if choice_snapshot.message.content and not self._content_done: self._content_done = True parsed = maybe_parse_content( response_format=response_format, message=choice_snapshot.message, ) # update the parsed content to now use the richer `response_format` # as opposed to the raw JSON-parsed object as the content is now # complete and can be fully validated. choice_snapshot.message.parsed = parsed events_to_fire.append( build( # we do this dance so that when the `ContentDoneEvent` instance # is printed at runtime the class name will include the solved # type variable, e.g. `ContentDoneEvent[MyModelType]` cast( # pyright: ignore[reportUnnecessaryCast] "type[ContentDoneEvent[ResponseFormatT]]", cast(Any, ContentDoneEvent)[solve_response_format_t(response_format)], ), type="content.done", content=choice_snapshot.message.content, parsed=parsed, ), ) if choice_snapshot.message.refusal is not None and not self._refusal_done: self._refusal_done = True events_to_fire.append( build(RefusalDoneEvent, type="refusal.done", refusal=choice_snapshot.message.refusal), ) if ( choice_snapshot.logprobs is not None and choice_snapshot.logprobs.content is not None and not self._logprobs_content_done ): self._logprobs_content_done = True events_to_fire.append( build(LogprobsContentDoneEvent, type="logprobs.content.done", content=choice_snapshot.logprobs.content), ) if ( choice_snapshot.logprobs is not None and choice_snapshot.logprobs.refusal is not None and not self._logprobs_refusal_done ): self._logprobs_refusal_done = True events_to_fire.append( build(LogprobsRefusalDoneEvent, type="logprobs.refusal.done", refusal=choice_snapshot.logprobs.refusal), ) return events_to_fire def _add_tool_done_event( self, *, events_to_fire: list[ChatCompletionStreamEvent[ResponseFormatT]], choice_snapshot: ParsedChoiceSnapshot, tool_index: int, ) -> None: if tool_index in self._done_tool_calls: return self._done_tool_calls.add(tool_index) assert choice_snapshot.message.tool_calls is not None tool_call_snapshot = choice_snapshot.message.tool_calls[tool_index] if tool_call_snapshot.type == "function": parsed_arguments = parse_function_tool_arguments( input_tools=self._input_tools, function=tool_call_snapshot.function ) # update the parsed content to potentially use a richer type # as opposed to the raw JSON-parsed object as the content is now # complete and can be fully validated. tool_call_snapshot.function.parsed_arguments = parsed_arguments events_to_fire.append( build( FunctionToolCallArgumentsDoneEvent, type="tool_calls.function.arguments.done", index=tool_index, name=tool_call_snapshot.function.name, arguments=tool_call_snapshot.function.arguments, parsed_arguments=parsed_arguments, ) ) elif TYPE_CHECKING: # type: ignore[unreachable] assert_never(tool_call_snapshot) def _convert_initial_chunk_into_snapshot(chunk: ChatCompletionChunk) -> ParsedChatCompletionSnapshot: data = chunk.to_dict() choices = cast("list[object]", data["choices"]) for choice in chunk.choices: choices[choice.index] = { **choice.model_dump(exclude_unset=True, exclude={"delta"}), "message": choice.delta.to_dict(), } return cast( ParsedChatCompletionSnapshot, construct_type( type_=ParsedChatCompletionSnapshot, value={ "system_fingerprint": None, **data, "object": "chat.completion", }, ), )