-
Notifications
You must be signed in to change notification settings - Fork 31
WIP draft of generate() outside of chat.answer() #432
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
684dc35
290f4c7
abb88c0
954c4af
82637dd
a11919e
ae59aee
986b34d
084b287
2e8ae2c
e767c1b
caa8eff
c9e1a7e
e4451a3
a1c912f
8e120c6
6288bb1
c290d07
5b3bfd5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| from typing import AsyncIterator, cast | ||
| from typing import Any, AsyncIterator, Union, cast | ||
|
|
||
| from ragna.core import Message, RagnaException, Source | ||
| from ragna.core import Message, MessageRole, RagnaException, Source | ||
|
|
||
| from ._http_api import HttpApiAssistant, HttpStreamingProtocol | ||
|
|
||
|
|
@@ -14,23 +14,59 @@ class CohereAssistant(HttpApiAssistant): | |
| def display_name(cls) -> str: | ||
| return f"Cohere/{cls._MODEL}" | ||
|
|
||
| def _make_preamble(self) -> str: | ||
| def _make_rag_preamble(self) -> str: | ||
| return ( | ||
| "You are a helpful assistant that answers user questions given the included context. " | ||
| "If you don't know the answer, just say so. Don't try to make up an answer. " | ||
| "Only use the included documents below to generate the answer." | ||
| ) | ||
|
|
||
| def _make_source_documents(self, sources: list[Source]) -> list[dict[str, str]]: | ||
| def _make_rag_source_documents(self, sources: list[Source]) -> list[dict[str, str]]: | ||
| return [{"title": source.id, "snippet": source.content} for source in sources] | ||
|
|
||
| async def answer( | ||
| self, messages: list[Message], *, max_new_tokens: int = 256 | ||
| ) -> AsyncIterator[str]: | ||
| def _render_prompt(self, prompt: Union[str, list[Message]]) -> str: | ||
| """ | ||
| Ingests ragna messages-list or a single string prompt and converts to assistant-appropriate format. | ||
|
|
||
| Returns: | ||
| prompt string | ||
| """ | ||
| if isinstance(prompt, str): | ||
| messages = [Message(content=prompt, role=MessageRole.USER)] | ||
| else: | ||
| messages = prompt | ||
|
|
||
| for message in reversed(messages): | ||
| if message.role is MessageRole.USER: | ||
| return message.content | ||
| else: | ||
| raise RagnaException | ||
|
|
||
| async def generate( | ||
| self, | ||
| prompt: Union[str, list[Message]], | ||
| source_documents: list[dict[str, str]], | ||
| *, | ||
| system_prompt: str = "You are a helpful assistant.", | ||
| max_new_tokens: int = 256, | ||
| ) -> AsyncIterator[dict[str, Any]]: | ||
| """ | ||
| Primary method for calling assistant inference, either as a one-off request from anywhere in ragna, or as part of self.answer() | ||
| This method should be called for tasks like pre-processing, agentic tasks, or any other user-defined calls. | ||
|
|
||
| Args: | ||
| prompt: Either a single prompt string or a list of ragna messages | ||
| system_prompt: System prompt string | ||
| source_documents: List of source content dicts with 'title' and 'snippet' keys | ||
| max_new_tokens: Max number of completion tokens (default 256) | ||
|
|
||
| Returns: | ||
| async streamed inference response string chunks | ||
| """ | ||
| # See https://docs.cohere.com/docs/cochat-beta | ||
| # See https://docs.cohere.com/reference/chat | ||
| # See https://docs.cohere.com/docs/retrieval-augmented-generation-rag | ||
| prompt, sources = (message := messages[-1]).content, message.sources | ||
|
|
||
| async with self._call_api( | ||
| "POST", | ||
| "https://api.cohere.ai/v1/chat", | ||
|
|
@@ -40,23 +76,35 @@ async def answer( | |
| "authorization": f"Bearer {self._api_key}", | ||
| }, | ||
| json={ | ||
| "preamble_override": self._make_preamble(), | ||
| "message": prompt, | ||
| "preamble_override": system_prompt, | ||
| "message": self._render_prompt(prompt), | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While the message indeed can only be a single string here, the endpoint has a I would let chat_history, message = self._render_prompt(prompt)
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Conflicted on this one - seems like this would fit specifically with the pre-process pass, or this puts this one specific assistant ahead of the others in terms of capabilities - it certainly doesn't hurt anything, so happy to do so here, but also see how we might want to see what a pre-process stage looks like for all assistants and implement in one go. |
||
| "model": self._MODEL, | ||
| "stream": True, | ||
| "temperature": 0.0, | ||
| "max_tokens": max_new_tokens, | ||
| "documents": self._make_source_documents(sources), | ||
| "documents": source_documents, | ||
| }, | ||
| ) as stream: | ||
| async for event in stream: | ||
| if event["event_type"] == "stream-end": | ||
| if event["event_type"] == "COMPLETE": | ||
| break | ||
|
|
||
| raise RagnaException(event["error_message"]) | ||
| if "text" in event: | ||
| yield cast(str, event["text"]) | ||
| async for data in stream: | ||
| yield data | ||
|
|
||
| async def answer( | ||
| self, messages: list[Message], *, max_new_tokens: int = 256 | ||
| ) -> AsyncIterator[str]: | ||
| message = messages[-1] | ||
| async for data in self.generate( | ||
| prompt=message.content, | ||
| system_prompt=self._make_rag_preamble(), | ||
| source_documents=self._make_rag_source_documents(message.sources), | ||
| max_new_tokens=max_new_tokens, | ||
| ): | ||
| if data["event_type"] == "stream-end": | ||
| if data["event_type"] == "COMPLETE": | ||
| break | ||
|
|
||
| raise RagnaException(data["error_message"]) | ||
| if "text" in data: | ||
| yield cast(str, data["text"]) | ||
|
|
||
|
|
||
| class Command(CohereAssistant): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.