shuqianpinggu/huaxi/Lib/site-packages/openai/lib/streaming/chat/_completions.py
WIN-T9V3LPE8NK2\Administrator 4911f3024f 主分支
2025-06-17 17:46:44 +08:00

756 lines
29 KiB
Python

from __future__ import annotations
import inspect
from types import TracebackType
from typing import TYPE_CHECKING, Any, Generic, Callable, Iterable, Awaitable, AsyncIterator, cast
from typing_extensions import Self, Iterator, assert_never
from jiter import from_json
from ._types import ParsedChoiceSnapshot, ParsedChatCompletionSnapshot, ParsedChatCompletionMessageSnapshot
from ._events import (
ChunkEvent,
ContentDoneEvent,
RefusalDoneEvent,
ContentDeltaEvent,
RefusalDeltaEvent,
LogprobsContentDoneEvent,
LogprobsRefusalDoneEvent,
ChatCompletionStreamEvent,
LogprobsContentDeltaEvent,
LogprobsRefusalDeltaEvent,
FunctionToolCallArgumentsDoneEvent,
FunctionToolCallArgumentsDeltaEvent,
)
from .._deltas import accumulate_delta
from ...._types import NOT_GIVEN, IncEx, NotGiven
from ...._utils import is_given, consume_sync_iterator, consume_async_iterator
from ...._compat import model_dump
from ...._models import build, construct_type
from ..._parsing import (
ResponseFormatT,
has_parseable_input,
maybe_parse_content,
parse_chat_completion,
get_input_tool_by_name,
solve_response_format_t,
parse_function_tool_arguments,
)
from ...._streaming import Stream, AsyncStream
from ....types.chat import ChatCompletionChunk, ParsedChatCompletion, ChatCompletionToolParam
from ...._exceptions import LengthFinishReasonError, ContentFilterFinishReasonError
from ....types.chat.chat_completion import ChoiceLogprobs
from ....types.chat.chat_completion_chunk import Choice as ChoiceChunk
from ....types.chat.completion_create_params import ResponseFormat as ResponseFormatParam
class ChatCompletionStream(Generic[ResponseFormatT]):
"""Wrapper over the Chat Completions streaming API that adds helpful
events such as `content.done`, supports automatically parsing
responses & tool calls and accumulates a `ChatCompletion` object
from each individual chunk.
https://platform.openai.com/docs/api-reference/streaming
"""
def __init__(
self,
*,
raw_stream: Stream[ChatCompletionChunk],
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
input_tools: Iterable[ChatCompletionToolParam] | NotGiven,
) -> None:
self._raw_stream = raw_stream
self._response = raw_stream.response
self._iterator = self.__stream__()
self._state = ChatCompletionStreamState(response_format=response_format, input_tools=input_tools)
def __next__(self) -> ChatCompletionStreamEvent[ResponseFormatT]:
return self._iterator.__next__()
def __iter__(self) -> Iterator[ChatCompletionStreamEvent[ResponseFormatT]]:
for item in self._iterator:
yield item
def __enter__(self) -> Self:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.close()
def close(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
self._response.close()
def get_final_completion(self) -> ParsedChatCompletion[ResponseFormatT]:
"""Waits until the stream has been read to completion and returns
the accumulated `ParsedChatCompletion` object.
If you passed a class type to `.stream()`, the `completion.choices[0].message.parsed`
property will be the content deserialised into that class, if there was any content returned
by the API.
"""
self.until_done()
return self._state.get_final_completion()
def until_done(self) -> Self:
"""Blocks until the stream has been consumed."""
consume_sync_iterator(self)
return self
@property
def current_completion_snapshot(self) -> ParsedChatCompletionSnapshot:
return self._state.current_completion_snapshot
def __stream__(self) -> Iterator[ChatCompletionStreamEvent[ResponseFormatT]]:
for sse_event in self._raw_stream:
events_to_fire = self._state.handle_chunk(sse_event)
for event in events_to_fire:
yield event
class ChatCompletionStreamManager(Generic[ResponseFormatT]):
"""Context manager over a `ChatCompletionStream` that is returned by `.stream()`.
This context manager ensures the response cannot be leaked if you don't read
the stream to completion.
Usage:
```py
with client.beta.chat.completions.stream(...) as stream:
for event in stream:
...
```
"""
def __init__(
self,
api_request: Callable[[], Stream[ChatCompletionChunk]],
*,
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
input_tools: Iterable[ChatCompletionToolParam] | NotGiven,
) -> None:
self.__stream: ChatCompletionStream[ResponseFormatT] | None = None
self.__api_request = api_request
self.__response_format = response_format
self.__input_tools = input_tools
def __enter__(self) -> ChatCompletionStream[ResponseFormatT]:
raw_stream = self.__api_request()
self.__stream = ChatCompletionStream(
raw_stream=raw_stream,
response_format=self.__response_format,
input_tools=self.__input_tools,
)
return self.__stream
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self.__stream is not None:
self.__stream.close()
class AsyncChatCompletionStream(Generic[ResponseFormatT]):
"""Wrapper over the Chat Completions streaming API that adds helpful
events such as `content.done`, supports automatically parsing
responses & tool calls and accumulates a `ChatCompletion` object
from each individual chunk.
https://platform.openai.com/docs/api-reference/streaming
"""
def __init__(
self,
*,
raw_stream: AsyncStream[ChatCompletionChunk],
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
input_tools: Iterable[ChatCompletionToolParam] | NotGiven,
) -> None:
self._raw_stream = raw_stream
self._response = raw_stream.response
self._iterator = self.__stream__()
self._state = ChatCompletionStreamState(response_format=response_format, input_tools=input_tools)
async def __anext__(self) -> ChatCompletionStreamEvent[ResponseFormatT]:
return await self._iterator.__anext__()
async def __aiter__(self) -> AsyncIterator[ChatCompletionStreamEvent[ResponseFormatT]]:
async for item in self._iterator:
yield item
async def __aenter__(self) -> Self:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.close()
async def close(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
await self._response.aclose()
async def get_final_completion(self) -> ParsedChatCompletion[ResponseFormatT]:
"""Waits until the stream has been read to completion and returns
the accumulated `ParsedChatCompletion` object.
If you passed a class type to `.stream()`, the `completion.choices[0].message.parsed`
property will be the content deserialised into that class, if there was any content returned
by the API.
"""
await self.until_done()
return self._state.get_final_completion()
async def until_done(self) -> Self:
"""Blocks until the stream has been consumed."""
await consume_async_iterator(self)
return self
@property
def current_completion_snapshot(self) -> ParsedChatCompletionSnapshot:
return self._state.current_completion_snapshot
async def __stream__(self) -> AsyncIterator[ChatCompletionStreamEvent[ResponseFormatT]]:
async for sse_event in self._raw_stream:
events_to_fire = self._state.handle_chunk(sse_event)
for event in events_to_fire:
yield event
class AsyncChatCompletionStreamManager(Generic[ResponseFormatT]):
"""Context manager over a `AsyncChatCompletionStream` that is returned by `.stream()`.
This context manager ensures the response cannot be leaked if you don't read
the stream to completion.
Usage:
```py
async with client.beta.chat.completions.stream(...) as stream:
for event in stream:
...
```
"""
def __init__(
self,
api_request: Awaitable[AsyncStream[ChatCompletionChunk]],
*,
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
input_tools: Iterable[ChatCompletionToolParam] | NotGiven,
) -> None:
self.__stream: AsyncChatCompletionStream[ResponseFormatT] | None = None
self.__api_request = api_request
self.__response_format = response_format
self.__input_tools = input_tools
async def __aenter__(self) -> AsyncChatCompletionStream[ResponseFormatT]:
raw_stream = await self.__api_request
self.__stream = AsyncChatCompletionStream(
raw_stream=raw_stream,
response_format=self.__response_format,
input_tools=self.__input_tools,
)
return self.__stream
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self.__stream is not None:
await self.__stream.close()
class ChatCompletionStreamState(Generic[ResponseFormatT]):
"""Helper class for manually accumulating `ChatCompletionChunk`s into a final `ChatCompletion` object.
This is useful in cases where you can't always use the `.stream()` method, e.g.
```py
from openai.lib.streaming.chat import ChatCompletionStreamState
state = ChatCompletionStreamState()
stream = client.chat.completions.create(..., stream=True)
for chunk in response:
state.handle_chunk(chunk)
# can also access the accumulated `ChatCompletion` mid-stream
state.current_completion_snapshot
print(state.get_final_completion())
```
"""
def __init__(
self,
*,
input_tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven = NOT_GIVEN,
) -> None:
self.__current_completion_snapshot: ParsedChatCompletionSnapshot | None = None
self.__choice_event_states: list[ChoiceEventState] = []
self._input_tools = [tool for tool in input_tools] if is_given(input_tools) else []
self._response_format = response_format
self._rich_response_format: type | NotGiven = response_format if inspect.isclass(response_format) else NOT_GIVEN
def get_final_completion(self) -> ParsedChatCompletion[ResponseFormatT]:
"""Parse the final completion object.
Note this does not provide any guarantees that the stream has actually finished, you must
only call this method when the stream is finished.
"""
return parse_chat_completion(
chat_completion=self.current_completion_snapshot,
response_format=self._rich_response_format,
input_tools=self._input_tools,
)
@property
def current_completion_snapshot(self) -> ParsedChatCompletionSnapshot:
assert self.__current_completion_snapshot is not None
return self.__current_completion_snapshot
def handle_chunk(self, chunk: ChatCompletionChunk) -> Iterable[ChatCompletionStreamEvent[ResponseFormatT]]:
"""Accumulate a new chunk into the snapshot and returns an iterable of events to yield."""
self.__current_completion_snapshot = self._accumulate_chunk(chunk)
return self._build_events(
chunk=chunk,
completion_snapshot=self.__current_completion_snapshot,
)
def _get_choice_state(self, choice: ChoiceChunk) -> ChoiceEventState:
try:
return self.__choice_event_states[choice.index]
except IndexError:
choice_state = ChoiceEventState(input_tools=self._input_tools)
self.__choice_event_states.append(choice_state)
return choice_state
def _accumulate_chunk(self, chunk: ChatCompletionChunk) -> ParsedChatCompletionSnapshot:
completion_snapshot = self.__current_completion_snapshot
if completion_snapshot is None:
return _convert_initial_chunk_into_snapshot(chunk)
for choice in chunk.choices:
try:
choice_snapshot = completion_snapshot.choices[choice.index]
previous_tool_calls = choice_snapshot.message.tool_calls or []
choice_snapshot.message = cast(
ParsedChatCompletionMessageSnapshot,
construct_type(
type_=ParsedChatCompletionMessageSnapshot,
value=accumulate_delta(
cast(
"dict[object, object]",
model_dump(
choice_snapshot.message,
# we don't want to serialise / deserialise our custom properties
# as they won't appear in the delta and we don't want to have to
# continuosly reparse the content
exclude=cast(
# cast required as mypy isn't smart enough to infer `True` here to `Literal[True]`
IncEx,
{
"parsed": True,
"tool_calls": {
idx: {"function": {"parsed_arguments": True}}
for idx, _ in enumerate(choice_snapshot.message.tool_calls or [])
},
},
),
),
),
cast("dict[object, object]", choice.delta.to_dict()),
),
),
)
# ensure tools that have already been parsed are added back into the newly
# constructed message snapshot
for tool_index, prev_tool in enumerate(previous_tool_calls):
new_tool = (choice_snapshot.message.tool_calls or [])[tool_index]
if prev_tool.type == "function":
assert new_tool.type == "function"
new_tool.function.parsed_arguments = prev_tool.function.parsed_arguments
elif TYPE_CHECKING: # type: ignore[unreachable]
assert_never(prev_tool)
except IndexError:
choice_snapshot = cast(
ParsedChoiceSnapshot,
construct_type(
type_=ParsedChoiceSnapshot,
value={
**choice.model_dump(exclude_unset=True, exclude={"delta"}),
"message": choice.delta.to_dict(),
},
),
)
completion_snapshot.choices.append(choice_snapshot)
if choice.finish_reason:
choice_snapshot.finish_reason = choice.finish_reason
if has_parseable_input(response_format=self._response_format, input_tools=self._input_tools):
if choice.finish_reason == "length":
# at the time of writing, `.usage` will always be `None` but
# we include it here in case that is changed in the future
raise LengthFinishReasonError(completion=completion_snapshot)
if choice.finish_reason == "content_filter":
raise ContentFilterFinishReasonError()
if (
choice_snapshot.message.content
and not choice_snapshot.message.refusal
and is_given(self._rich_response_format)
):
choice_snapshot.message.parsed = from_json(
bytes(choice_snapshot.message.content, "utf-8"),
partial_mode=True,
)
for tool_call_chunk in choice.delta.tool_calls or []:
tool_call_snapshot = (choice_snapshot.message.tool_calls or [])[tool_call_chunk.index]
if tool_call_snapshot.type == "function":
input_tool = get_input_tool_by_name(
input_tools=self._input_tools, name=tool_call_snapshot.function.name
)
if (
input_tool
and input_tool.get("function", {}).get("strict")
and tool_call_snapshot.function.arguments
):
tool_call_snapshot.function.parsed_arguments = from_json(
bytes(tool_call_snapshot.function.arguments, "utf-8"),
partial_mode=True,
)
elif TYPE_CHECKING: # type: ignore[unreachable]
assert_never(tool_call_snapshot)
if choice.logprobs is not None:
if choice_snapshot.logprobs is None:
choice_snapshot.logprobs = build(
ChoiceLogprobs,
content=choice.logprobs.content,
refusal=choice.logprobs.refusal,
)
else:
if choice.logprobs.content:
if choice_snapshot.logprobs.content is None:
choice_snapshot.logprobs.content = []
choice_snapshot.logprobs.content.extend(choice.logprobs.content)
if choice.logprobs.refusal:
if choice_snapshot.logprobs.refusal is None:
choice_snapshot.logprobs.refusal = []
choice_snapshot.logprobs.refusal.extend(choice.logprobs.refusal)
completion_snapshot.usage = chunk.usage
completion_snapshot.system_fingerprint = chunk.system_fingerprint
return completion_snapshot
def _build_events(
self,
*,
chunk: ChatCompletionChunk,
completion_snapshot: ParsedChatCompletionSnapshot,
) -> list[ChatCompletionStreamEvent[ResponseFormatT]]:
events_to_fire: list[ChatCompletionStreamEvent[ResponseFormatT]] = []
events_to_fire.append(
build(ChunkEvent, type="chunk", chunk=chunk, snapshot=completion_snapshot),
)
for choice in chunk.choices:
choice_state = self._get_choice_state(choice)
choice_snapshot = completion_snapshot.choices[choice.index]
if choice.delta.content is not None and choice_snapshot.message.content is not None:
events_to_fire.append(
build(
ContentDeltaEvent,
type="content.delta",
delta=choice.delta.content,
snapshot=choice_snapshot.message.content,
parsed=choice_snapshot.message.parsed,
)
)
if choice.delta.refusal is not None and choice_snapshot.message.refusal is not None:
events_to_fire.append(
build(
RefusalDeltaEvent,
type="refusal.delta",
delta=choice.delta.refusal,
snapshot=choice_snapshot.message.refusal,
)
)
if choice.delta.tool_calls:
tool_calls = choice_snapshot.message.tool_calls
assert tool_calls is not None
for tool_call_delta in choice.delta.tool_calls:
tool_call = tool_calls[tool_call_delta.index]
if tool_call.type == "function":
assert tool_call_delta.function is not None
events_to_fire.append(
build(
FunctionToolCallArgumentsDeltaEvent,
type="tool_calls.function.arguments.delta",
name=tool_call.function.name,
index=tool_call_delta.index,
arguments=tool_call.function.arguments,
parsed_arguments=tool_call.function.parsed_arguments,
arguments_delta=tool_call_delta.function.arguments or "",
)
)
elif TYPE_CHECKING: # type: ignore[unreachable]
assert_never(tool_call)
if choice.logprobs is not None and choice_snapshot.logprobs is not None:
if choice.logprobs.content and choice_snapshot.logprobs.content:
events_to_fire.append(
build(
LogprobsContentDeltaEvent,
type="logprobs.content.delta",
content=choice.logprobs.content,
snapshot=choice_snapshot.logprobs.content,
),
)
if choice.logprobs.refusal and choice_snapshot.logprobs.refusal:
events_to_fire.append(
build(
LogprobsRefusalDeltaEvent,
type="logprobs.refusal.delta",
refusal=choice.logprobs.refusal,
snapshot=choice_snapshot.logprobs.refusal,
),
)
events_to_fire.extend(
choice_state.get_done_events(
choice_chunk=choice,
choice_snapshot=choice_snapshot,
response_format=self._response_format,
)
)
return events_to_fire
class ChoiceEventState:
def __init__(self, *, input_tools: list[ChatCompletionToolParam]) -> None:
self._input_tools = input_tools
self._content_done = False
self._refusal_done = False
self._logprobs_content_done = False
self._logprobs_refusal_done = False
self._done_tool_calls: set[int] = set()
self.__current_tool_call_index: int | None = None
def get_done_events(
self,
*,
choice_chunk: ChoiceChunk,
choice_snapshot: ParsedChoiceSnapshot,
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
) -> list[ChatCompletionStreamEvent[ResponseFormatT]]:
events_to_fire: list[ChatCompletionStreamEvent[ResponseFormatT]] = []
if choice_snapshot.finish_reason:
events_to_fire.extend(
self._content_done_events(choice_snapshot=choice_snapshot, response_format=response_format)
)
if (
self.__current_tool_call_index is not None
and self.__current_tool_call_index not in self._done_tool_calls
):
self._add_tool_done_event(
events_to_fire=events_to_fire,
choice_snapshot=choice_snapshot,
tool_index=self.__current_tool_call_index,
)
for tool_call in choice_chunk.delta.tool_calls or []:
if self.__current_tool_call_index != tool_call.index:
events_to_fire.extend(
self._content_done_events(choice_snapshot=choice_snapshot, response_format=response_format)
)
if self.__current_tool_call_index is not None:
self._add_tool_done_event(
events_to_fire=events_to_fire,
choice_snapshot=choice_snapshot,
tool_index=self.__current_tool_call_index,
)
self.__current_tool_call_index = tool_call.index
return events_to_fire
def _content_done_events(
self,
*,
choice_snapshot: ParsedChoiceSnapshot,
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
) -> list[ChatCompletionStreamEvent[ResponseFormatT]]:
events_to_fire: list[ChatCompletionStreamEvent[ResponseFormatT]] = []
if choice_snapshot.message.content and not self._content_done:
self._content_done = True
parsed = maybe_parse_content(
response_format=response_format,
message=choice_snapshot.message,
)
# update the parsed content to now use the richer `response_format`
# as opposed to the raw JSON-parsed object as the content is now
# complete and can be fully validated.
choice_snapshot.message.parsed = parsed
events_to_fire.append(
build(
# we do this dance so that when the `ContentDoneEvent` instance
# is printed at runtime the class name will include the solved
# type variable, e.g. `ContentDoneEvent[MyModelType]`
cast( # pyright: ignore[reportUnnecessaryCast]
"type[ContentDoneEvent[ResponseFormatT]]",
cast(Any, ContentDoneEvent)[solve_response_format_t(response_format)],
),
type="content.done",
content=choice_snapshot.message.content,
parsed=parsed,
),
)
if choice_snapshot.message.refusal is not None and not self._refusal_done:
self._refusal_done = True
events_to_fire.append(
build(RefusalDoneEvent, type="refusal.done", refusal=choice_snapshot.message.refusal),
)
if (
choice_snapshot.logprobs is not None
and choice_snapshot.logprobs.content is not None
and not self._logprobs_content_done
):
self._logprobs_content_done = True
events_to_fire.append(
build(LogprobsContentDoneEvent, type="logprobs.content.done", content=choice_snapshot.logprobs.content),
)
if (
choice_snapshot.logprobs is not None
and choice_snapshot.logprobs.refusal is not None
and not self._logprobs_refusal_done
):
self._logprobs_refusal_done = True
events_to_fire.append(
build(LogprobsRefusalDoneEvent, type="logprobs.refusal.done", refusal=choice_snapshot.logprobs.refusal),
)
return events_to_fire
def _add_tool_done_event(
self,
*,
events_to_fire: list[ChatCompletionStreamEvent[ResponseFormatT]],
choice_snapshot: ParsedChoiceSnapshot,
tool_index: int,
) -> None:
if tool_index in self._done_tool_calls:
return
self._done_tool_calls.add(tool_index)
assert choice_snapshot.message.tool_calls is not None
tool_call_snapshot = choice_snapshot.message.tool_calls[tool_index]
if tool_call_snapshot.type == "function":
parsed_arguments = parse_function_tool_arguments(
input_tools=self._input_tools, function=tool_call_snapshot.function
)
# update the parsed content to potentially use a richer type
# as opposed to the raw JSON-parsed object as the content is now
# complete and can be fully validated.
tool_call_snapshot.function.parsed_arguments = parsed_arguments
events_to_fire.append(
build(
FunctionToolCallArgumentsDoneEvent,
type="tool_calls.function.arguments.done",
index=tool_index,
name=tool_call_snapshot.function.name,
arguments=tool_call_snapshot.function.arguments,
parsed_arguments=parsed_arguments,
)
)
elif TYPE_CHECKING: # type: ignore[unreachable]
assert_never(tool_call_snapshot)
def _convert_initial_chunk_into_snapshot(chunk: ChatCompletionChunk) -> ParsedChatCompletionSnapshot:
data = chunk.to_dict()
choices = cast("list[object]", data["choices"])
for choice in chunk.choices:
choices[choice.index] = {
**choice.model_dump(exclude_unset=True, exclude={"delta"}),
"message": choice.delta.to_dict(),
}
return cast(
ParsedChatCompletionSnapshot,
construct_type(
type_=ParsedChatCompletionSnapshot,
value={
"system_fingerprint": None,
**data,
"object": "chat.completion",
},
),
)