"""This module defines generic classes for models in the Fabricatio library, providing a foundation for various model functionalities."""
from abc import ABC, abstractmethod
from functools import cached_property
from pathlib import Path
from typing import Any, Callable, Iterable, List, Optional, Self, Set, Union, Unpack, final, overload
import orjson
from pydantic import (
BaseModel,
ConfigDict,
Field,
NonNegativeFloat,
PositiveFloat,
PositiveInt,
)
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
from fabricatio_core.journal import logger
from fabricatio_core.models.kwargs_types import EmbeddingKwargs, LLMKwargs, RerankerKwargs, ValidateKwargs
from fabricatio_core.rust import CONFIG, TEMPLATE_MANAGER, blake3_hash, detect_language, is_likely_text
from fabricatio_core.utils import first_available, ok
[docs]
class Base(BaseModel, ABC):
"""Base class for all models with Pydantic configuration.
This class sets up the basic Pydantic configuration for all models in the Fabricatio library.
The ``model_config`` uses ``use_attribute_docstrings=True`` to ensure field descriptions are
pulled from the attribute's docstring instead of the default Pydantic behavior.
"""
model_config = ConfigDict(use_attribute_docstrings=True)
[docs]
class Display(Base, ABC):
"""Class that provides formatted JSON representation utilities.
Provides methods to generate both pretty-printed and compact JSON representations of the model.
Used for debugging and logging purposes.
"""
[docs]
def display(self) -> str:
"""Generate pretty-printed JSON representation.
Returns:
str: JSON string with 1-level indentation for readability
"""
return self.model_dump_json(indent=1, by_alias=True)
[docs]
def compact(self) -> str:
"""Generate compact JSON representation.
Returns:
str: Minified JSON string without whitespace
"""
return self.model_dump_json(by_alias=True)
[docs]
@staticmethod
def seq_display(seq: Iterable["Display"], compact: bool = False) -> str:
"""Generate formatted display for sequence of Display objects.
Args:
seq (Iterable[Display]): Sequence of objects to display
compact (bool): Use compact format instead of pretty print
Returns:
str: Combined display output with boundary markers
"""
return (
"--- Start of Extra Info Sequence ---"
+ "\n".join(d.compact() if compact else d.display() for d in seq)
+ "--- End of Extra Info Sequence ---"
)
[docs]
class Named(Base, ABC):
"""Class that includes a name attribute.
This class adds a name attribute to models, which is intended to be a unique identifier.
"""
name: str
"""The name of this object,briefly and conclusively."""
[docs]
class Described(Base, ABC):
"""Class that includes a description attribute.
This class adds a description attribute to models, providing additional context or information.
"""
description: str
"""A comprehensive description of this object, including its purpose, scope, and context.
This should clearly explain what this object is about, why it exists, and in what situations
it applies. The description should be detailed enough to provide full understanding of
this object's intent and application."""
[docs]
class Titled(Base, ABC):
"""Class that includes a title attribute."""
title: str
"""The title of this object, make it professional and concise.No prefixed heading number should be included."""
[docs]
class WithBriefing(Named, Described, ABC):
"""Class that combines naming and description attributes with briefing generation.
This class inherits from both Named and Described classes to provide a combined interface
that includes both name and description attributes. It also provides automatic briefing
generation by combining these two attributes.
"""
@property
def briefing(self) -> str:
"""Get the briefing of the object.
Returns:
str: The briefing of the object.
"""
return f"{self.name}: {self.description}" if self.description else self.name
[docs]
def __eq__(self, other: object) -> bool:
"""Check if two roles are equal."""
return self.name == other.name if isinstance(other, self.__class__) else False
[docs]
def __hash__(self) -> int:
"""Get the hash value of the role."""
return hash(self.name)
[docs]
class WithDependency(Base, ABC):
"""Class that manages file dependencies.
This class includes methods to manage file dependencies required for reading or writing.
"""
dependencies: List[str] = Field(default_factory=list)
"""The file dependencies which is needed to read or write to meet a specific requirement, a list of file paths."""
[docs]
def add_dependency[P: str | Path](self, dependency: P | List[P]) -> Self:
"""Add a file dependency to the task.
Args:
dependency (str | Path | List[str | Path]): The file dependency to add to the task.
Returns:
Self: The current instance of the task.
"""
if not isinstance(dependency, list):
dependency = [dependency]
self.dependencies.extend(Path(d).as_posix() for d in dependency)
return self
[docs]
def remove_dependency[P: str | Path](self, dependency: P | List[P]) -> Self:
"""Remove a file dependency from the task.
Args:
dependency (str | Path | List[str | Path]): The file dependency to remove from the task.
Returns:
Self: The current instance of the task.
"""
if not isinstance(dependency, list):
dependency = [dependency]
for d in dependency:
self.dependencies.remove(Path(d).as_posix())
return self
[docs]
def clear_dependencies(self) -> Self:
"""Clear all file dependencies from the task.
Returns:
Self: The current instance of the task.
"""
self.dependencies.clear()
return self
[docs]
def override_dependencies[P: str | Path](self, dependencies: List[P] | P) -> Self:
"""Override the file dependencies of the task.
Args:
dependencies (List[str | Path] | str | Path): The file dependencies to override the task's dependencies.
Returns:
Self: The current instance of the task.
"""
return self.clear_dependencies().add_dependency(dependencies)
[docs]
def read_dependency[T](
self, idx: int = -1, reader: Callable[[str], T] = lambda p: Path(p).read_text(encoding="utf-8", errors="ignore")
) -> T:
"""Read the content of a file dependency.
Args:
idx (int): Index of the dependency to read. Defaults to -1 (last dependency).
reader (Callable[[str], T]): Function to use for reading the file.
Returns:
T: The content of the file read using the provided reader function.
"""
return reader(self.dependencies[idx])
@property
def dependencies_prompt(self) -> str:
"""Generate a prompt for the task based on the file dependencies.
Returns:
str: The generated prompt for the task.
"""
return TEMPLATE_MANAGER.render_template(
CONFIG.templates.dependencies_template,
{
(pth := Path(p).absolute().relative_to(Path.cwd())).name: {
"path": pth.as_posix(),
"exists": (exi := pth.exists()),
"is_text": (is_f := is_likely_text(pth)),
"size": f"{pth.stat().st_size / 1024 if exi and pth.is_file() else 0:.3f} KiB",
"content": (text := pth.read_text(encoding="utf-8", errors="ignore") if is_f else ""),
"lines": len(text.splitlines()) if is_f else 0,
"checksum": blake3_hash(pth.read_bytes()) if exi and pth.is_file() else "unknown",
}
for p in self.dependencies
},
)
[docs]
class Vectorizable(ABC):
"""Class that prepares the vectorization of the model.
This class includes methods to prepare the model for vectorization, ensuring it fits within a specified token length.
"""
@abstractmethod
def _prepare_vectorization_inner(self) -> str:
"""Prepare the model for vectorization."""
[docs]
@final
def prepare_vectorization(self) -> str:
"""Prepare the vectorization of the model.
Returns:
str: The prepared vectorization of the model.
Raises:
ValueError: If the chunk exceeds the maximum sequence length.
"""
return self._prepare_vectorization_inner()
[docs]
class ScopedConfig(Base, ABC):
"""Configuration holder with hierarchical fallback mechanism."""
[docs]
@final
def fallback_to(self, other: Union["ScopedConfig", Any], exclude: Optional[Set[str]] = None) -> Self:
"""Merge configuration values with fallback priority.
Copies non-null values from 'other' to self where current values are None.
Args:
other (ScopedConfig): Configuration to fallback to
exclude (Optional[Set[str]]): Field names to exclude from fallback
Returns:
Self: Current instance with merged values
"""
if not isinstance(other, ScopedConfig):
return self
exclude = exclude or set()
# Iterate over the attribute names and copy values from 'other' to 'self' where applicable
# noinspection PydanticTypeChecker,PyTypeChecker
for attr_name in self.__class__.model_fields:
if attr_name in exclude:
logger.trace(f"Excluding `{attr_name}` from fallback")
continue
# Check if both self and other have the attribute before accessing
if (
hasattr(other, attr_name)
and getattr(self, attr_name) is None
and (attr := getattr(other, attr_name)) is not None
):
logger.trace(f"Falling back `{attr_name}` to `{attr}`")
# Copy the attribute value from 'other' to 'self' only if 'self' has None and 'other' has a non-None value
setattr(self, attr_name, attr)
# Return the current instance to allow for method chaining
return self
[docs]
@final
def hold_to(
self,
others: Union["ScopedConfig", Any] | Iterable[Union["ScopedConfig", Any]],
exclude: Optional[Set[str]] = None,
) -> Self:
"""Propagate non-null values to other configurations.
Copies current non-null values to target configurations where they are None.
Args:
others (ScopedConfig|Iterable): Target configurations to update
exclude (Optional[Set[str]]): Field names to exclude from propagation
Returns:
Self: Current instance unchanged
"""
if not isinstance(others, Iterable):
others = [others]
for other in (o for o in others if isinstance(o, ScopedConfig)):
other.fallback_to(self, exclude=exclude)
return self
[docs]
class EmbeddingScopedConfig(ScopedConfig):
"""Configuration for embedding-related settings."""
embedding_send_to: Optional[str] = None
"""The LLM model name."""
embedding_no_cache: bool = False
"""Whether to disable caching for embeddings."""
embedding_ndim: Optional[int] = None
"""The dimensionality of the output embeddings. Must match between search and store."""
def _resolve_embedding_params(
self, send_to: str | None = None, ndim: int | None = None, no_cache: bool | None = None, **_
) -> EmbeddingKwargs:
return EmbeddingKwargs(
send_to=ok(
send_to or self.embedding_send_to or CONFIG.embedding.send_to,
"send_to is not specified at any where",
),
ndim=first_available((ndim, self.embedding_ndim, CONFIG.embedding.ndim)),
no_cache=first_available(
(no_cache, self.embedding_no_cache, CONFIG.embedding.no_cache), raise_exception=False
)
or False,
)
[docs]
class RerankerScopedConfig(ScopedConfig):
"""Configuration for reranker-related settings."""
reranker_send_to: Optional[str] = None
"""The group name of which the requests will be sent."""
reranker_no_cache: Optional[bool] = None
"""Whether to disable caching for the reranker."""
def _resolve_reranker_params(
self, send_to: Optional[str] = None, no_cache: Optional[bool] = None, **_
) -> RerankerKwargs:
return RerankerKwargs(
send_to=ok(
send_to or self.reranker_send_to or CONFIG.reranker.send_to,
"send_to is not specified at any where",
),
no_cache=first_available(
(no_cache, self.reranker_no_cache, CONFIG.reranker.no_cache), raise_exception=False
)
or False,
)
[docs]
class LLMScopedConfig(ScopedConfig):
"""Configuration for LLM-related settings."""
llm_send_to: Optional[str] = None
"""The group name of which the requests will be sent."""
llm_top_p: Optional[NonNegativeFloat] = None
"""The top p of the LLM model."""
llm_temperature: Optional[NonNegativeFloat] = None
"""The temperature of the LLM model."""
llm_stream: Optional[bool] = None
"""Whether to stream the LLM model's response."""
llm_max_completion_tokens: Optional[PositiveInt] = None
"""The maximum number of tokens to generate."""
llm_presence_penalty: Optional[PositiveFloat] = None
"""The presence penalty of the LLM model."""
llm_frequency_penalty: Optional[PositiveFloat] = None
"""The frequency penalty of the LLM model."""
llm_no_cache: Optional[bool] = None
"""Whether to disable caching for the LLM model."""
def _resolve_completion_params(
self,
send_to: Optional[str] = None,
stream: Optional[bool] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
max_completion_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
no_cache: Optional[bool] = None,
**_,
) -> LLMKwargs:
"""Resolve LLM completion parameters from kwargs, instance defaults, and CONFIG."""
return LLMKwargs(
send_to=ok(send_to or self.llm_send_to or CONFIG.llm.send_to, "`send_to` is not specified at any where!"),
stream=first_available((stream, self.llm_stream, CONFIG.llm.stream), raise_exception=False) or False,
top_p=first_available((top_p, self.llm_top_p, CONFIG.llm.top_p), raise_exception=False),
temperature=first_available(
(temperature, self.llm_temperature, CONFIG.llm.temperature), raise_exception=False
),
max_completion_tokens=first_available(
(max_completion_tokens, self.llm_max_completion_tokens, CONFIG.llm.max_completion_tokens),
raise_exception=False,
),
presence_penalty=first_available(
(presence_penalty, self.llm_presence_penalty, CONFIG.llm.presence_penalty), raise_exception=False
),
frequency_penalty=first_available(
(frequency_penalty, self.llm_frequency_penalty, CONFIG.llm.frequency_penalty), raise_exception=False
),
no_cache=first_available((no_cache, self.llm_no_cache, CONFIG.llm.no_cache), raise_exception=False)
or False,
)
def _resolve_validation_params[T](
self, default: None | T = None, max_validations: PositiveInt = 3, **kwargs: Unpack[LLMKwargs]
) -> ValidateKwargs[T]:
res = self._resolve_completion_params(**kwargs)
return ValidateKwargs(default=default, max_validations=max_validations, **res)
[docs]
class UnsortGenerate(GenerateJsonSchema):
"""Class that provides a reverse JSON schema of the model.
This class overrides the sorting behavior of the JSON schema generation to maintain the original order.
"""
[docs]
def sort(self, value: JsonSchemaValue, parent_key: str | None = None) -> JsonSchemaValue:
"""Not sort.
Args:
value (JsonSchemaValue): The JSON schema value to sort.
parent_key (str | None): The parent key of the JSON schema value.
Returns:
JsonSchemaValue: The JSON schema value without sorting.
"""
return value
[docs]
class CreateJsonObjPrompt(WithFormatedJsonSchema, ABC):
"""Class that provides a prompt for creating a JSON object.
This class includes a method to create a prompt for creating a JSON object based on the model's schema and a requirement.
"""
@classmethod
@overload
def create_json_prompt(cls, requirement: List[str]) -> List[str]: ...
@classmethod
@overload
def create_json_prompt(cls, requirement: str) -> str: ...
[docs]
@classmethod
def create_json_prompt(cls, requirement: str | List[str]) -> str | List[str]:
"""Create the prompt for creating a JSON object with given requirement.
Args:
requirement (str | List[str]): The requirement for the JSON object.
Returns:
str | List[str]: The prompt for creating a JSON object with given requirement.
"""
if isinstance(requirement, str):
return TEMPLATE_MANAGER.render_template(
CONFIG.templates.create_json_obj_template,
{"requirement": requirement, "json_schema": cls.formated_json_schema()},
)
return [
TEMPLATE_MANAGER.render_template(
CONFIG.templates.create_json_obj_template,
{"requirement": r, "json_schema": cls.formated_json_schema()},
)
for r in requirement
]
[docs]
class InstantiateFromString(Base, ABC):
"""Class that provides a method to instantiate the class from a string.
This class includes a method to instantiate the class from a JSON string representation.
"""
[docs]
@classmethod
def instantiate_from_string(cls, string: str) -> Self | None:
"""Instantiate the class from a string.
Args:
string (str): The string to instantiate the class from.
Returns:
Self | None: The instance of the class or None if the string is not valid.
"""
from fabricatio_core.rust import json_parser
converted = json_parser.convert(string)
if converted is None:
logger.debug(f"Instantiate `{cls.__name__}` from string, Failed (conversion returned None).")
return None
obj = cls.model_validate(converted)
logger.debug(f"Instantiate `{cls.__name__}` from string, {'Failed' if obj is None else 'Success'}.")
return obj
[docs]
class ProposedAble(CreateJsonObjPrompt, InstantiateFromString, ABC):
"""Class that provides a method to propose a JSON object based on the requirement.
This class combines the functionality to create a prompt for a JSON object and instantiate it from a string.
"""
[docs]
class Language:
"""Class that provides a language attribute."""
[docs]
@cached_property
def language(self) -> str:
"""Get the language of the object."""
if isinstance(self, Described) and self.description:
return detect_language(self.description)
if isinstance(self, Titled) and self.title:
return detect_language(self.title)
if isinstance(self, Named) and self.name:
return detect_language(self.name)
raise RuntimeError(f"Cannot determine language! class that not support language: {self.__class__.__name__}")
[docs]
class SketchedAble(ProposedAble, Display, ABC):
"""Class that provides a method to scratch the object.
This class combines the functionality to propose a JSON object, instantiate it from a string, and display it.
"""