Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions invokeai/app/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.app.services.style_preset_images.style_preset_images_disk import StylePresetImageFileStorageDisk
from invokeai.app.services.style_preset_records.style_preset_records_sqlite import SqliteStylePresetRecordsStorage
from invokeai.app.services.system_prompt_records.system_prompt_records_sqlite import SqliteSystemPromptRecordsStorage
from invokeai.app.services.urls.urls_default import LocalUrlService
from invokeai.app.services.users.users_default import UserService
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
Expand Down Expand Up @@ -185,6 +186,7 @@ def initialize(
workflow_records = SqliteWorkflowRecordsStorage(db=db)
style_preset_records = SqliteStylePresetRecordsStorage(db=db)
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
system_prompt_records = SqliteSystemPromptRecordsStorage(db=db)
workflow_thumbnails = WorkflowThumbnailFileStorageDisk(workflow_thumbnails_folder)
client_state_persistence = ClientStatePersistenceSqlite(db=db)
users = UserService(db=db)
Expand Down Expand Up @@ -218,6 +220,7 @@ def initialize(
conditioning=conditioning,
style_preset_records=style_preset_records,
style_preset_image_files=style_preset_image_files,
system_prompt_records=system_prompt_records,
workflow_thumbnails=workflow_thumbnails,
client_state_persistence=client_state_persistence,
users=users,
Expand Down
117 changes: 117 additions & 0 deletions invokeai/app/api/routers/system_prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from typing import Optional

from fastapi import APIRouter, Body, HTTPException, Path

from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.system_prompt_records.system_prompt_records_common import (
SystemPromptChanges,
SystemPromptNotFoundError,
SystemPromptRecordDTO,
SystemPromptWithoutId,
)

system_prompts_router = APIRouter(prefix="/v1/system_prompts", tags=["system_prompts"])


@system_prompts_router.get(
"/",
operation_id="list_system_prompts",
responses={200: {"model": list[SystemPromptRecordDTO]}},
)
async def list_system_prompts(current_user: CurrentUserOrDefault) -> list[SystemPromptRecordDTO]:
"""Lists system prompts visible to the current user (own + public)."""
config = ApiDependencies.invoker.services.configuration
# Admins (and single-user installs) see everything; multiuser non-admins are scoped to own + public.
user_id_filter: Optional[str] = None
if config.multiuser and not current_user.is_admin:
user_id_filter = current_user.user_id
return ApiDependencies.invoker.services.system_prompt_records.get_many(user_id=user_id_filter)


@system_prompts_router.get(
"/i/{system_prompt_id}",
operation_id="get_system_prompt",
responses={200: {"model": SystemPromptRecordDTO}},
)
async def get_system_prompt(
current_user: CurrentUserOrDefault,
system_prompt_id: str = Path(description="The id of the system prompt to get"),
) -> SystemPromptRecordDTO:
"""Gets a system prompt by id."""
try:
prompt = ApiDependencies.invoker.services.system_prompt_records.get(system_prompt_id)
except SystemPromptNotFoundError:
raise HTTPException(status_code=404, detail="System prompt not found")

config = ApiDependencies.invoker.services.configuration
if config.multiuser:
is_owner = prompt.user_id == current_user.user_id
if not (is_owner or prompt.is_public or current_user.is_admin):
raise HTTPException(status_code=403, detail="Not authorized to access this system prompt")
return prompt


@system_prompts_router.post(
"/",
operation_id="create_system_prompt",
responses={200: {"model": SystemPromptRecordDTO}},
)
async def create_system_prompt(
current_user: CurrentUserOrDefault,
system_prompt: SystemPromptWithoutId = Body(description="The system prompt to create"),
) -> SystemPromptRecordDTO:
"""Creates a new system prompt owned by the current user."""
# Single-user: shared so legacy/single-user behaviour is unchanged. Multiuser: private by default.
config = ApiDependencies.invoker.services.configuration
is_public = not config.multiuser
return ApiDependencies.invoker.services.system_prompt_records.create(
system_prompt, user_id=current_user.user_id, is_public=is_public
)


@system_prompts_router.patch(
"/i/{system_prompt_id}",
operation_id="update_system_prompt",
responses={200: {"model": SystemPromptRecordDTO}},
)
async def update_system_prompt(
current_user: CurrentUserOrDefault,
system_prompt_id: str = Path(description="The id of the system prompt to update"),
changes: SystemPromptChanges = Body(description="The changes to apply"),
) -> SystemPromptRecordDTO:
"""Updates a system prompt. Only the owner or an admin may update."""
config = ApiDependencies.invoker.services.configuration
if config.multiuser:
try:
existing = ApiDependencies.invoker.services.system_prompt_records.get(system_prompt_id)
except SystemPromptNotFoundError:
raise HTTPException(status_code=404, detail="System prompt not found")
if not current_user.is_admin and existing.user_id != current_user.user_id:
raise HTTPException(status_code=403, detail="Not authorized to update this system prompt")
user_id = None if current_user.is_admin else current_user.user_id
try:
return ApiDependencies.invoker.services.system_prompt_records.update(system_prompt_id, changes, user_id=user_id)
except SystemPromptNotFoundError:
raise HTTPException(status_code=404, detail="System prompt not found")


@system_prompts_router.delete(
"/i/{system_prompt_id}",
operation_id="delete_system_prompt",
)
async def delete_system_prompt(
current_user: CurrentUserOrDefault,
system_prompt_id: str = Path(description="The id of the system prompt to delete"),
) -> None:
"""Deletes a system prompt. Only the owner or an admin may delete."""
config = ApiDependencies.invoker.services.configuration
if config.multiuser:
try:
existing = ApiDependencies.invoker.services.system_prompt_records.get(system_prompt_id)
except SystemPromptNotFoundError:
raise HTTPException(status_code=404, detail="System prompt not found")
if not current_user.is_admin and existing.user_id != current_user.user_id:
raise HTTPException(status_code=403, detail="Not authorized to delete this system prompt")
user_id = None if current_user.is_admin else current_user.user_id
ApiDependencies.invoker.services.system_prompt_records.delete(system_prompt_id, user_id=user_id)
2 changes: 2 additions & 0 deletions invokeai/app/api_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
recall_parameters,
session_queue,
style_presets,
system_prompts,
utilities,
virtual_boards,
workflows,
Expand Down Expand Up @@ -185,6 +186,7 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
app.include_router(session_queue.session_queue_router, prefix="/api")
app.include_router(workflows.workflows_router, prefix="/api")
app.include_router(style_presets.style_presets_router, prefix="/api")
app.include_router(system_prompts.system_prompts_router, prefix="/api")
app.include_router(client_state.client_state_router, prefix="/api")
app.include_router(recall_parameters.recall_parameters_router, prefix="/api")
app.include_router(custom_nodes.custom_nodes_router, prefix="/api")
Expand Down
6 changes: 6 additions & 0 deletions invokeai/app/invocations/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,12 @@ class StylePresetField(BaseModel):
style_preset_id: str = Field(description="The id of the style preset")


class SystemPromptField(BaseModel):
"""A system prompt primitive field"""

system_prompt_id: str = Field(description="The id of the system prompt")


class DenoiseMaskField(BaseModel):
"""An inpaint mask field"""

Expand Down
93 changes: 79 additions & 14 deletions invokeai/app/invocations/text_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from transformers import AutoTokenizer

from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import FieldDescriptions, InputField, UIComponent
from invokeai.app.invocations.fields import FieldDescriptions, InputField, SystemPromptField, UIComponent
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import StringOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
Expand All @@ -11,6 +11,31 @@
from invokeai.backend.util.devices import TorchDevice


def _run_text_llm(
context: InvocationContext,
text_llm_model: ModelIdentifierField,
prompt: str,
system_prompt: str,
max_tokens: int,
) -> str:
"""Shared LLM invocation body used by every text-LLM node in this module."""
model_config = context.models.get_config(text_llm_model)

with context.models.load(text_llm_model).model_on_device() as (_, model):
model_abs_path = context.models.get_absolute_path(model_config)
tokenizer = AutoTokenizer.from_pretrained(model_abs_path, local_files_only=True)

pipeline = TextLLMPipeline(model, tokenizer)
model_device = next(model.parameters()).device
return pipeline.run(
prompt=prompt,
system_prompt=system_prompt,
max_new_tokens=max_tokens,
device=model_device,
dtype=TorchDevice.choose_torch_dtype(),
)


@invocation(
"text_llm",
title="Text LLM",
Expand Down Expand Up @@ -46,20 +71,60 @@ class TextLLMInvocation(BaseInvocation):

@torch.no_grad()
def invoke(self, context: InvocationContext) -> StringOutput:
model_config = context.models.get_config(self.text_llm_model)
output = _run_text_llm(
context=context,
text_llm_model=self.text_llm_model,
prompt=self.prompt,
system_prompt=self.system_prompt,
max_tokens=self.max_tokens,
)
return StringOutput(value=output)


with context.models.load(self.text_llm_model).model_on_device() as (_, model):
model_abs_path = context.models.get_absolute_path(model_config)
tokenizer = AutoTokenizer.from_pretrained(model_abs_path, local_files_only=True)
@invocation(
"text_llm_with_preset",
title="Text LLM (with System Prompt Preset)",
tags=["llm", "text", "prompt", "preset", "template"],
category="llm",
version="1.0.0",
classification=Classification.Beta,
)
class TextLLMWithPresetInvocation(BaseInvocation):
"""Run a text language model using a saved system prompt from the System Prompts library.

pipeline = TextLLMPipeline(model, tokenizer)
model_device = next(model.parameters()).device
output = pipeline.run(
prompt=self.prompt,
system_prompt=self.system_prompt,
max_new_tokens=self.max_tokens,
device=model_device,
dtype=TorchDevice.choose_torch_dtype(),
)
Behaves identically to the Text LLM node, but the system prompt is selected from a
DB-backed preset instead of being typed inline. Useful when you maintain a curated
library of expansion strategies and want to reuse them across workflows.
"""

prompt: str = InputField(
default="",
description="Input text prompt.",
ui_component=UIComponent.Textarea,
)
system_prompt: SystemPromptField = InputField(
description="The saved system prompt to use as the LLM's instruction.",
)
text_llm_model: ModelIdentifierField = InputField(
title="Text LLM Model",
description=FieldDescriptions.text_llm_model,
ui_model_type=ModelType.TextLLM,
)
max_tokens: int = InputField(
default=300,
ge=1,
le=2048,
description="Maximum number of tokens to generate.",
)

@torch.no_grad()
def invoke(self, context: InvocationContext) -> StringOutput:
record = context._services.system_prompt_records.get(self.system_prompt.system_prompt_id)
output = _run_text_llm(
context=context,
text_llm_model=self.text_llm_model,
prompt=self.prompt,
system_prompt=record.content,
max_tokens=self.max_tokens,
)
return StringOutput(value=output)
3 changes: 3 additions & 0 deletions invokeai/app/services/invocation_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
from invokeai.app.services.style_preset_images.style_preset_images_base import StylePresetImageFileStorageBase
from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase
from invokeai.app.services.system_prompt_records.system_prompt_records_base import SystemPromptRecordsStorageBase

if TYPE_CHECKING:
from logging import Logger
Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(
conditioning: "ObjectSerializerBase[ConditioningFieldData]",
style_preset_records: "StylePresetRecordsStorageBase",
style_preset_image_files: "StylePresetImageFileStorageBase",
system_prompt_records: "SystemPromptRecordsStorageBase",
workflow_thumbnails: "WorkflowThumbnailServiceBase",
client_state_persistence: "ClientStatePersistenceABC",
users: "UserServiceBase",
Expand Down Expand Up @@ -108,6 +110,7 @@ def __init__(
self.conditioning = conditioning
self.style_preset_records = style_preset_records
self.style_preset_image_files = style_preset_image_files
self.system_prompt_records = system_prompt_records
self.workflow_thumbnails = workflow_thumbnails
self.client_state_persistence = client_state_persistence
self.users = users
2 changes: 2 additions & 0 deletions invokeai/app/services/shared/sqlite/sqlite_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_29 import build_migration_29
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_30 import build_migration_30
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_31 import build_migration_31
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_32 import build_migration_32
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator


Expand Down Expand Up @@ -85,6 +86,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_29())
migrator.register_migration(build_migration_30())
migrator.register_migration(build_migration_31())
migrator.register_migration(build_migration_32())
migrator.run_migrations()

return db
Loading
Loading