File size: 4,804 Bytes
7fcdb70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
# Source: https://github.com/QwenLM/Qwen-Agent/blob/main/qwen_agent/llm/schema.py
from typing import List, Literal, Optional, Tuple, Union
from pydantic import BaseModel, field_validator, model_validator
DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant."
ROLE = "role"
CONTENT = "content"
REASONING_CONTENT = "reasoning_content"
NAME = "name"
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
FUNCTION = "function"
FILE = "file"
IMAGE = "image"
AUDIO = "audio"
VIDEO = "video"
class BaseModelCompatibleDict(BaseModel):
def __getitem__(self, item):
return getattr(self, item)
def __setitem__(self, key, value):
setattr(self, key, value)
def model_dump(self, **kwargs):
if "exclude_none" not in kwargs:
kwargs["exclude_none"] = True
return super().model_dump(**kwargs)
def model_dump_json(self, **kwargs):
if "exclude_none" not in kwargs:
kwargs["exclude_none"] = True
return super().model_dump_json(**kwargs)
def get(self, key, default=None):
try:
value = getattr(self, key)
if value:
return value
else:
return default
except AttributeError:
return default
def __str__(self):
return f"{self.model_dump()}"
class FunctionCall(BaseModelCompatibleDict):
name: str
arguments: str
def __init__(self, name: str, arguments: str):
super().__init__(name=name, arguments=arguments)
def __repr__(self):
return f"FunctionCall({self.model_dump()})"
class ContentItem(BaseModelCompatibleDict):
text: Optional[str] = None
image: Optional[str] = None
file: Optional[str] = None
audio: Optional[Union[str, dict]] = None
video: Optional[Union[str, list]] = None
def __init__(
self,
text: Optional[str] = None,
image: Optional[str] = None,
file: Optional[str] = None,
audio: Optional[Union[str, dict]] = None,
video: Optional[Union[str, list]] = None,
):
super().__init__(text=text, image=image, file=file, audio=audio, video=video)
@model_validator(mode="after")
def check_exclusivity(self):
provided_fields = 0
if self.text is not None:
provided_fields += 1
if self.image:
provided_fields += 1
if self.file:
provided_fields += 1
if self.audio:
provided_fields += 1
if self.video:
provided_fields += 1
if provided_fields != 1:
raise ValueError(
"Exactly one of 'text', 'image', 'file', 'audio', or 'video' must be provided."
)
return self
def __repr__(self):
return f"ContentItem({self.model_dump()})"
def get_type_and_value(
self,
) -> Tuple[Literal["text", "image", "file", "audio", "video"], str]:
((t, v),) = self.model_dump().items()
assert t in ("text", "image", "file", "audio", "video")
return t, v
@property
def type(self) -> Literal["text", "image", "file", "audio", "video"]:
t, v = self.get_type_and_value()
return t
@property
def value(self) -> str:
t, v = self.get_type_and_value()
return v
class Message(BaseModelCompatibleDict):
role: str
content: Union[str, List[ContentItem]]
reasoning_content: Optional[Union[str, List[ContentItem]]] = None
name: Optional[str] = None
function_call: Optional[FunctionCall] = None
extra: Optional[dict] = None
def __init__(
self,
role: str,
content: Union[str, List[ContentItem]],
reasoning_content: Optional[Union[str, List[ContentItem]]] = None,
name: Optional[str] = None,
function_call: Optional[FunctionCall] = None,
extra: Optional[dict] = None,
**kwargs,
):
if content is None:
content = ""
if reasoning_content is None:
reasoning_content = ""
super().__init__(
role=role,
content=content,
reasoning_content=reasoning_content,
name=name,
function_call=function_call,
extra=extra,
)
def __repr__(self):
return f"Message({self.model_dump()})"
@field_validator("role")
def role_checker(cls, value: str) -> str:
if value not in [USER, ASSISTANT, SYSTEM, FUNCTION]:
raise ValueError(
f'{value} must be one of {",".join([USER, ASSISTANT, SYSTEM, FUNCTION])}'
)
return value
|