Spaces:
Running on CPU Upgrade
Add dataset uploads to Hub (#255)
Browse files* Add dataset uploads to Hub
Co-authored-by: OpenAI Codex <codex@openai.com>
* Address dataset upload review items
Co-authored-by: OpenAI Codex <codex@openai.com>
* Show dataset upload progress
Co-authored-by: OpenAI Codex <codex@openai.com>
* Move dataset upload button to composer corner
Co-authored-by: OpenAI Codex <codex@openai.com>
* Fix dataset upload file handoff
Co-authored-by: OpenAI Codex <codex@openai.com>
* Show upload alerts below composer
Co-authored-by: OpenAI Codex <codex@openai.com>
* Improve dataset upload errors and chips
Co-authored-by: OpenAI Codex <codex@openai.com>
* Link dataset chips to repo
Co-authored-by: OpenAI Codex <codex@openai.com>
* Expose dataset uploads as configs
Co-authored-by: OpenAI Codex <codex@openai.com>
---------
Co-authored-by: OpenAI Codex <codex@openai.com>
- backend/dataset_uploads.py +305 -0
- backend/models.py +18 -1
- backend/routes/agent.py +147 -1
- frontend/src/components/Chat/ChatInput.tsx +224 -11
- frontend/src/components/SessionChat.tsx +11 -1
- frontend/src/hooks/useAgentChat.ts +43 -0
- frontend/src/utils/api.ts +70 -18
- pyproject.toml +1 -0
- tests/unit/test_dataset_uploads.py +465 -0
- uv.lock +2 -0
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Helpers for session-scoped dataset uploads to the Hugging Face Hub."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
import uuid
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from urllib.parse import quote
|
| 9 |
+
|
| 10 |
+
from fastapi import HTTPException, UploadFile
|
| 11 |
+
from huggingface_hub import HfApi
|
| 12 |
+
|
| 13 |
+
MAX_DATASET_UPLOAD_BYTES = 100 * 1024 * 1024
|
| 14 |
+
ALLOWED_DATASET_EXTENSIONS = {"csv", "json", "jsonl"}
|
| 15 |
+
_SAFE_FILENAME_RE = re.compile(r"[^A-Za-z0-9._-]+")
|
| 16 |
+
_SAFE_NAMESPACE_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,95}$")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass(frozen=True)
|
| 20 |
+
class DatasetUpload:
|
| 21 |
+
session_id: str
|
| 22 |
+
repo_id: str
|
| 23 |
+
repo_type: str
|
| 24 |
+
private: bool
|
| 25 |
+
upload_id: str
|
| 26 |
+
config_name: str
|
| 27 |
+
filename: str
|
| 28 |
+
original_filename: str
|
| 29 |
+
path_in_repo: str
|
| 30 |
+
size_bytes: int
|
| 31 |
+
format: str
|
| 32 |
+
hub_url: str
|
| 33 |
+
load_dataset_snippet: str
|
| 34 |
+
|
| 35 |
+
def response_payload(self) -> dict[str, str | int | bool]:
|
| 36 |
+
return {
|
| 37 |
+
"session_id": self.session_id,
|
| 38 |
+
"repo_id": self.repo_id,
|
| 39 |
+
"repo_type": self.repo_type,
|
| 40 |
+
"private": self.private,
|
| 41 |
+
"upload_id": self.upload_id,
|
| 42 |
+
"config_name": self.config_name,
|
| 43 |
+
"filename": self.filename,
|
| 44 |
+
"path_in_repo": self.path_in_repo,
|
| 45 |
+
"size_bytes": self.size_bytes,
|
| 46 |
+
"format": self.format,
|
| 47 |
+
"hub_url": self.hub_url,
|
| 48 |
+
"load_dataset_snippet": self.load_dataset_snippet,
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def sanitize_dataset_filename(filename: str | None) -> str:
|
| 53 |
+
"""Return a Hub-safe basename while preserving the extension."""
|
| 54 |
+
raw = os.path.basename(filename or "").strip()
|
| 55 |
+
if not raw:
|
| 56 |
+
raw = "dataset.csv"
|
| 57 |
+
|
| 58 |
+
safe = _SAFE_FILENAME_RE.sub("-", raw).strip(".-_")
|
| 59 |
+
if not safe:
|
| 60 |
+
safe = "dataset.csv"
|
| 61 |
+
|
| 62 |
+
stem, ext = os.path.splitext(safe)
|
| 63 |
+
if not stem:
|
| 64 |
+
stem = "dataset"
|
| 65 |
+
if not ext:
|
| 66 |
+
ext = ".csv"
|
| 67 |
+
|
| 68 |
+
max_stem_len = 96 - len(ext)
|
| 69 |
+
stem = stem[:max_stem_len].strip(".-_") or "dataset"
|
| 70 |
+
return f"{stem}{ext.lower()}"
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def display_filename(filename: str | None, fallback: str) -> str:
|
| 74 |
+
raw = os.path.basename(filename or "").strip()
|
| 75 |
+
if not raw:
|
| 76 |
+
return fallback
|
| 77 |
+
cleaned = "".join(char for char in raw if ord(char) >= 32)
|
| 78 |
+
return cleaned[:160] or fallback
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def dataset_format_from_filename(filename: str) -> str:
|
| 82 |
+
ext = os.path.splitext(filename)[1].lower().lstrip(".")
|
| 83 |
+
if ext not in ALLOWED_DATASET_EXTENSIONS:
|
| 84 |
+
raise HTTPException(
|
| 85 |
+
status_code=400,
|
| 86 |
+
detail="Only .csv, .json, and .jsonl dataset files are supported.",
|
| 87 |
+
)
|
| 88 |
+
return ext
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def session_dataset_repo_id(hf_username: str | None, session_id: str) -> str:
|
| 92 |
+
namespace = (hf_username or "").strip()
|
| 93 |
+
if not namespace or not _SAFE_NAMESPACE_RE.fullmatch(namespace):
|
| 94 |
+
raise HTTPException(
|
| 95 |
+
status_code=400,
|
| 96 |
+
detail="Could not determine a valid Hugging Face namespace.",
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
safe_session_id = re.sub(r"[^A-Za-z0-9]+", "-", session_id).strip("-")
|
| 100 |
+
if not safe_session_id:
|
| 101 |
+
safe_session_id = uuid.uuid4().hex[:8]
|
| 102 |
+
return f"{namespace}/ml-intern-{safe_session_id[:8]}-datasets"
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
async def upload_size_bytes(upload: UploadFile) -> int:
|
| 106 |
+
await asyncio.to_thread(upload.file.seek, 0, os.SEEK_END)
|
| 107 |
+
size = await asyncio.to_thread(upload.file.tell)
|
| 108 |
+
await asyncio.to_thread(upload.file.seek, 0)
|
| 109 |
+
return int(size)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
async def validate_dataset_upload(upload: UploadFile) -> tuple[str, str, int]:
|
| 113 |
+
dataset_format = dataset_format_from_filename(upload.filename or "")
|
| 114 |
+
safe_filename = sanitize_dataset_filename(upload.filename)
|
| 115 |
+
size = await upload_size_bytes(upload)
|
| 116 |
+
if size <= 0:
|
| 117 |
+
raise HTTPException(status_code=400, detail="Uploaded dataset file is empty.")
|
| 118 |
+
if size > MAX_DATASET_UPLOAD_BYTES:
|
| 119 |
+
raise HTTPException(
|
| 120 |
+
status_code=413,
|
| 121 |
+
detail="Dataset upload exceeds the 100 MB limit.",
|
| 122 |
+
)
|
| 123 |
+
return safe_filename, dataset_format, size
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def dataset_hub_url(repo_id: str, path_in_repo: str) -> str:
|
| 127 |
+
quoted_path = quote(path_in_repo, safe="/")
|
| 128 |
+
return f"https://huggingface.co/datasets/{repo_id}/blob/main/{quoted_path}"
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def dataset_config_name(upload_id: str) -> str:
|
| 132 |
+
safe_upload_id = re.sub(r"[^A-Za-z0-9]+", "_", upload_id).strip("_").lower()
|
| 133 |
+
if not safe_upload_id:
|
| 134 |
+
safe_upload_id = "dataset"
|
| 135 |
+
return f"upload_{safe_upload_id[:32]}"
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def dataset_config_name_from_path(path_in_repo: str) -> str:
|
| 139 |
+
parts = path_in_repo.split("/")
|
| 140 |
+
if len(parts) >= 3 and parts[0] == "uploads":
|
| 141 |
+
return dataset_config_name(parts[1])
|
| 142 |
+
stem = os.path.splitext(os.path.basename(path_in_repo))[0]
|
| 143 |
+
return dataset_config_name(stem)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def is_dataset_upload_path(path_in_repo: str) -> bool:
|
| 147 |
+
parts = path_in_repo.split("/")
|
| 148 |
+
if len(parts) != 3 or parts[0] != "uploads" or not parts[1] or not parts[2]:
|
| 149 |
+
return False
|
| 150 |
+
extension = os.path.splitext(path_in_repo)[1].lower().lstrip(".")
|
| 151 |
+
return extension in ALLOWED_DATASET_EXTENSIONS
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def unique_dataset_upload_paths(paths: list[str]) -> list[str]:
|
| 155 |
+
seen = set()
|
| 156 |
+
upload_paths = []
|
| 157 |
+
for path in paths:
|
| 158 |
+
if not is_dataset_upload_path(path) or path in seen:
|
| 159 |
+
continue
|
| 160 |
+
seen.add(path)
|
| 161 |
+
upload_paths.append(path)
|
| 162 |
+
return upload_paths
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def load_dataset_snippet(repo_id: str, config_name: str) -> str:
|
| 166 |
+
return (
|
| 167 |
+
"from datasets import load_dataset\n\n"
|
| 168 |
+
f'dataset = load_dataset("{repo_id}", "{config_name}", '
|
| 169 |
+
'split="train", token=True)'
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def dataset_repo_card(repo_id: str, upload_paths: list[str]) -> bytes:
|
| 174 |
+
config_lines = []
|
| 175 |
+
unique_upload_paths = unique_dataset_upload_paths(upload_paths)
|
| 176 |
+
if unique_upload_paths:
|
| 177 |
+
config_lines.append("configs:")
|
| 178 |
+
for path in unique_upload_paths:
|
| 179 |
+
config_lines.extend(
|
| 180 |
+
[
|
| 181 |
+
f"- config_name: {dataset_config_name_from_path(path)}",
|
| 182 |
+
" data_files:",
|
| 183 |
+
" - split: train",
|
| 184 |
+
f' path: "{path}"',
|
| 185 |
+
]
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
configs = "\n".join(config_lines)
|
| 189 |
+
if configs:
|
| 190 |
+
configs = f"{configs}\n"
|
| 191 |
+
|
| 192 |
+
content = f"""---
|
| 193 |
+
tags:
|
| 194 |
+
- ml-intern
|
| 195 |
+
- uploaded-dataset
|
| 196 |
+
{configs}---
|
| 197 |
+
|
| 198 |
+
# {repo_id}
|
| 199 |
+
|
| 200 |
+
Private dataset files uploaded through ML Intern.
|
| 201 |
+
|
| 202 |
+
Files are stored under `uploads/<upload_id>/` and are attached to the
|
| 203 |
+
corresponding ML Intern session context by Hub reference, not by copying file
|
| 204 |
+
contents into the chat.
|
| 205 |
+
|
| 206 |
+
Each uploaded file is exposed as its own dataset config so files with different
|
| 207 |
+
schemas can coexist in the same session repo.
|
| 208 |
+
"""
|
| 209 |
+
return content.encode("utf-8")
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def dataset_context_note(upload: DatasetUpload) -> str:
|
| 213 |
+
return f"""[SYSTEM: The user uploaded a dataset file for this session.
|
| 214 |
+
|
| 215 |
+
Use this Hugging Face Hub dataset reference when the task needs the uploaded data.
|
| 216 |
+
Do not look for the uploaded file on local disk and do not ask the user to
|
| 217 |
+
upload it again unless this Hub reference fails.
|
| 218 |
+
|
| 219 |
+
- Repo ID: {upload.repo_id}
|
| 220 |
+
- Repo type: dataset
|
| 221 |
+
- Dataset config: {upload.config_name}
|
| 222 |
+
- File in repo: {upload.path_in_repo}
|
| 223 |
+
- Original filename: {upload.original_filename}
|
| 224 |
+
- Stored filename: {upload.filename}
|
| 225 |
+
- Format: {upload.format}
|
| 226 |
+
- Size: {upload.size_bytes} bytes
|
| 227 |
+
- Hub URL: {upload.hub_url}
|
| 228 |
+
|
| 229 |
+
Load it with:
|
| 230 |
+
```python
|
| 231 |
+
{upload.load_dataset_snippet}
|
| 232 |
+
```
|
| 233 |
+
]"""
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
async def push_dataset_upload_to_hub(
|
| 237 |
+
*,
|
| 238 |
+
upload: UploadFile,
|
| 239 |
+
session_id: str,
|
| 240 |
+
hf_username: str,
|
| 241 |
+
hf_token: str,
|
| 242 |
+
) -> DatasetUpload:
|
| 243 |
+
safe_filename, dataset_format, size = await validate_dataset_upload(upload)
|
| 244 |
+
original_filename = display_filename(upload.filename, safe_filename)
|
| 245 |
+
upload_id = uuid.uuid4().hex[:12]
|
| 246 |
+
config_name = dataset_config_name(upload_id)
|
| 247 |
+
repo_id = session_dataset_repo_id(hf_username, session_id)
|
| 248 |
+
path_in_repo = f"uploads/{upload_id}/{safe_filename}"
|
| 249 |
+
hub_url = dataset_hub_url(repo_id, path_in_repo)
|
| 250 |
+
snippet = load_dataset_snippet(repo_id, config_name)
|
| 251 |
+
api = HfApi(token=hf_token)
|
| 252 |
+
|
| 253 |
+
await asyncio.to_thread(
|
| 254 |
+
api.create_repo,
|
| 255 |
+
repo_id=repo_id,
|
| 256 |
+
repo_type="dataset",
|
| 257 |
+
private=True,
|
| 258 |
+
exist_ok=True,
|
| 259 |
+
)
|
| 260 |
+
await asyncio.to_thread(
|
| 261 |
+
api.update_repo_settings,
|
| 262 |
+
repo_id=repo_id,
|
| 263 |
+
repo_type="dataset",
|
| 264 |
+
private=True,
|
| 265 |
+
)
|
| 266 |
+
repo_files = await asyncio.to_thread(
|
| 267 |
+
api.list_repo_files,
|
| 268 |
+
repo_id=repo_id,
|
| 269 |
+
repo_type="dataset",
|
| 270 |
+
)
|
| 271 |
+
upload_paths = unique_dataset_upload_paths([*repo_files, path_in_repo])
|
| 272 |
+
await asyncio.to_thread(upload.file.seek, 0)
|
| 273 |
+
file_bytes = await asyncio.to_thread(upload.file.read)
|
| 274 |
+
await asyncio.to_thread(
|
| 275 |
+
api.upload_file,
|
| 276 |
+
path_or_fileobj=file_bytes,
|
| 277 |
+
path_in_repo=path_in_repo,
|
| 278 |
+
repo_id=repo_id,
|
| 279 |
+
repo_type="dataset",
|
| 280 |
+
commit_message=f"Upload dataset file {safe_filename}",
|
| 281 |
+
)
|
| 282 |
+
await asyncio.to_thread(
|
| 283 |
+
api.upload_file,
|
| 284 |
+
path_or_fileobj=dataset_repo_card(repo_id, upload_paths),
|
| 285 |
+
path_in_repo="README.md",
|
| 286 |
+
repo_id=repo_id,
|
| 287 |
+
repo_type="dataset",
|
| 288 |
+
commit_message="Update ML Intern dataset upload configs",
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
return DatasetUpload(
|
| 292 |
+
session_id=session_id,
|
| 293 |
+
repo_id=repo_id,
|
| 294 |
+
repo_type="dataset",
|
| 295 |
+
private=True,
|
| 296 |
+
upload_id=upload_id,
|
| 297 |
+
config_name=config_name,
|
| 298 |
+
filename=safe_filename,
|
| 299 |
+
original_filename=original_filename,
|
| 300 |
+
path_in_repo=path_in_repo,
|
| 301 |
+
size_bytes=size,
|
| 302 |
+
format=dataset_format,
|
| 303 |
+
hub_url=hub_url,
|
| 304 |
+
load_dataset_snippet=snippet,
|
| 305 |
+
)
|
|
@@ -1,7 +1,7 @@
|
|
| 1 |
"""Pydantic models for API requests and responses."""
|
| 2 |
|
| 3 |
from enum import Enum
|
| 4 |
-
from typing import Any
|
| 5 |
|
| 6 |
from pydantic import BaseModel, Field
|
| 7 |
|
|
@@ -120,6 +120,23 @@ class SessionYoloRequest(BaseModel):
|
|
| 120 |
cost_cap_usd: float | None = Field(default=None, ge=0)
|
| 121 |
|
| 122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
class HealthResponse(BaseModel):
|
| 124 |
"""Health check response."""
|
| 125 |
|
|
|
|
| 1 |
"""Pydantic models for API requests and responses."""
|
| 2 |
|
| 3 |
from enum import Enum
|
| 4 |
+
from typing import Any, Literal
|
| 5 |
|
| 6 |
from pydantic import BaseModel, Field
|
| 7 |
|
|
|
|
| 120 |
cost_cap_usd: float | None = Field(default=None, ge=0)
|
| 121 |
|
| 122 |
|
| 123 |
+
class DatasetUploadResponse(BaseModel):
|
| 124 |
+
"""Response for a dataset file uploaded to the Hub."""
|
| 125 |
+
|
| 126 |
+
session_id: str
|
| 127 |
+
repo_id: str
|
| 128 |
+
repo_type: Literal["dataset"] = "dataset"
|
| 129 |
+
private: bool = True
|
| 130 |
+
upload_id: str
|
| 131 |
+
config_name: str
|
| 132 |
+
filename: str
|
| 133 |
+
path_in_repo: str
|
| 134 |
+
size_bytes: int
|
| 135 |
+
format: Literal["csv", "json", "jsonl"]
|
| 136 |
+
hub_url: str
|
| 137 |
+
load_dataset_snippet: str
|
| 138 |
+
|
| 139 |
+
|
| 140 |
class HealthResponse(BaseModel):
|
| 141 |
"""Health check response."""
|
| 142 |
|
|
@@ -21,10 +21,18 @@ from fastapi import (
|
|
| 21 |
)
|
| 22 |
from fastapi.exceptions import RequestValidationError
|
| 23 |
from fastapi.responses import StreamingResponse
|
| 24 |
-
from
|
|
|
|
| 25 |
from pydantic import ValidationError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
from models import (
|
| 27 |
ApprovalRequest,
|
|
|
|
| 28 |
HealthResponse,
|
| 29 |
LLMHealthResponse,
|
| 30 |
SessionInfo,
|
|
@@ -58,6 +66,7 @@ PREMIUM_MODEL_IDS = {
|
|
| 58 |
DEFAULT_CLAUDE_MODEL_ID,
|
| 59 |
"openai/gpt-5.5",
|
| 60 |
}
|
|
|
|
| 61 |
|
| 62 |
|
| 63 |
def _claude_picker_model_id() -> str:
|
|
@@ -203,6 +212,63 @@ def _user_hf_token(user: dict[str, Any] | None) -> str | None:
|
|
| 203 |
return user.get(INTERNAL_HF_TOKEN_KEY)
|
| 204 |
|
| 205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
async def _check_session_access(
|
| 207 |
session_id: str,
|
| 208 |
user: dict[str, Any],
|
|
@@ -542,6 +608,86 @@ async def set_session_notifications(
|
|
| 542 |
}
|
| 543 |
|
| 544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
@router.patch("/session/{session_id}/yolo")
|
| 546 |
async def set_session_yolo(
|
| 547 |
session_id: str,
|
|
|
|
| 21 |
)
|
| 22 |
from fastapi.exceptions import RequestValidationError
|
| 23 |
from fastapi.responses import StreamingResponse
|
| 24 |
+
from huggingface_hub.errors import HfHubHTTPError
|
| 25 |
+
from litellm import Message, acompletion
|
| 26 |
from pydantic import ValidationError
|
| 27 |
+
from starlette.datastructures import FormData, UploadFile
|
| 28 |
+
from dataset_uploads import (
|
| 29 |
+
MAX_DATASET_UPLOAD_BYTES,
|
| 30 |
+
dataset_context_note,
|
| 31 |
+
push_dataset_upload_to_hub,
|
| 32 |
+
)
|
| 33 |
from models import (
|
| 34 |
ApprovalRequest,
|
| 35 |
+
DatasetUploadResponse,
|
| 36 |
HealthResponse,
|
| 37 |
LLMHealthResponse,
|
| 38 |
SessionInfo,
|
|
|
|
| 66 |
DEFAULT_CLAUDE_MODEL_ID,
|
| 67 |
"openai/gpt-5.5",
|
| 68 |
}
|
| 69 |
+
DATASET_UPLOAD_MULTIPART_SLACK_BYTES = 1024 * 1024
|
| 70 |
|
| 71 |
|
| 72 |
def _claude_picker_model_id() -> str:
|
|
|
|
| 212 |
return user.get(INTERNAL_HF_TOKEN_KEY)
|
| 213 |
|
| 214 |
|
| 215 |
+
def _reject_oversize_dataset_upload(request: Request) -> None:
|
| 216 |
+
raw_content_length = request.headers.get("content-length")
|
| 217 |
+
if raw_content_length is None:
|
| 218 |
+
return
|
| 219 |
+
try:
|
| 220 |
+
content_length = int(raw_content_length)
|
| 221 |
+
except (TypeError, ValueError):
|
| 222 |
+
return
|
| 223 |
+
if content_length > MAX_DATASET_UPLOAD_BYTES + DATASET_UPLOAD_MULTIPART_SLACK_BYTES:
|
| 224 |
+
raise HTTPException(
|
| 225 |
+
status_code=413,
|
| 226 |
+
detail="Dataset upload exceeds the 100 MB limit.",
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def _dataset_upload_file_from_form(form: FormData) -> UploadFile:
|
| 231 |
+
uploaded_files = [
|
| 232 |
+
(key, value)
|
| 233 |
+
for key, value in form.multi_items()
|
| 234 |
+
if isinstance(value, UploadFile)
|
| 235 |
+
]
|
| 236 |
+
if len(uploaded_files) != 1:
|
| 237 |
+
raise HTTPException(
|
| 238 |
+
status_code=400,
|
| 239 |
+
detail="Upload exactly one dataset file.",
|
| 240 |
+
)
|
| 241 |
+
field_name, upload = uploaded_files[0]
|
| 242 |
+
if field_name != "file":
|
| 243 |
+
raise HTTPException(
|
| 244 |
+
status_code=400,
|
| 245 |
+
detail="Missing 'file' upload field.",
|
| 246 |
+
)
|
| 247 |
+
return upload
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def _dataset_upload_hub_http_exception(error: HfHubHTTPError) -> HTTPException:
|
| 251 |
+
status_code = getattr(error.response, "status_code", None)
|
| 252 |
+
if status_code == 401:
|
| 253 |
+
detail = "Hugging Face rejected the token used for the dataset upload."
|
| 254 |
+
return HTTPException(status_code=401, detail=detail)
|
| 255 |
+
if status_code == 403:
|
| 256 |
+
detail = (
|
| 257 |
+
"Hugging Face denied permission to create or write to the dataset repo."
|
| 258 |
+
)
|
| 259 |
+
return HTTPException(status_code=403, detail=detail)
|
| 260 |
+
if status_code == 404:
|
| 261 |
+
detail = "Could not find the Hugging Face namespace or dataset repo."
|
| 262 |
+
return HTTPException(status_code=404, detail=detail)
|
| 263 |
+
if status_code == 429:
|
| 264 |
+
detail = "Hugging Face Hub rate limit reached while uploading the dataset."
|
| 265 |
+
return HTTPException(status_code=429, detail=detail)
|
| 266 |
+
return HTTPException(
|
| 267 |
+
status_code=502,
|
| 268 |
+
detail="Hugging Face Hub upload failed. Please try again.",
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
async def _check_session_access(
|
| 273 |
session_id: str,
|
| 274 |
user: dict[str, Any],
|
|
|
|
| 608 |
}
|
| 609 |
|
| 610 |
|
| 611 |
+
@router.post("/session/{session_id}/datasets", response_model=DatasetUploadResponse)
|
| 612 |
+
async def upload_session_dataset(
|
| 613 |
+
session_id: str,
|
| 614 |
+
request: Request,
|
| 615 |
+
user: dict = Depends(get_current_user),
|
| 616 |
+
) -> DatasetUploadResponse:
|
| 617 |
+
"""Upload a CSV/JSON dataset file to a private Hub dataset for this session."""
|
| 618 |
+
file: UploadFile | None = None
|
| 619 |
+
try:
|
| 620 |
+
_reject_oversize_dataset_upload(request)
|
| 621 |
+
agent_session = await _check_session_access(session_id, user, request)
|
| 622 |
+
if not agent_session or not agent_session.is_active:
|
| 623 |
+
raise HTTPException(status_code=404, detail="Session not found")
|
| 624 |
+
if agent_session.is_processing:
|
| 625 |
+
raise HTTPException(
|
| 626 |
+
status_code=409,
|
| 627 |
+
detail="Cannot upload a dataset while the agent is processing.",
|
| 628 |
+
)
|
| 629 |
+
if agent_session.session.pending_approval:
|
| 630 |
+
raise HTTPException(
|
| 631 |
+
status_code=409,
|
| 632 |
+
detail="Approve or reject pending tools before uploading a dataset.",
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
hf_token = (
|
| 636 |
+
resolve_hf_request_token(request, include_env_fallback=False)
|
| 637 |
+
or _user_hf_token(user)
|
| 638 |
+
or resolve_hf_request_token(request)
|
| 639 |
+
)
|
| 640 |
+
if not hf_token:
|
| 641 |
+
raise HTTPException(
|
| 642 |
+
status_code=401,
|
| 643 |
+
detail="A Hugging Face token is required to upload datasets.",
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
form = await request.form(
|
| 647 |
+
max_files=1,
|
| 648 |
+
max_fields=1,
|
| 649 |
+
max_part_size=MAX_DATASET_UPLOAD_BYTES,
|
| 650 |
+
)
|
| 651 |
+
file = _dataset_upload_file_from_form(form)
|
| 652 |
+
hf_username = user.get("username") or agent_session.hf_username
|
| 653 |
+
uploaded = await push_dataset_upload_to_hub(
|
| 654 |
+
upload=file,
|
| 655 |
+
session_id=session_id,
|
| 656 |
+
hf_username=hf_username,
|
| 657 |
+
hf_token=hf_token,
|
| 658 |
+
)
|
| 659 |
+
agent_session.session.context_manager.add_message(
|
| 660 |
+
Message(role="user", content=dataset_context_note(uploaded))
|
| 661 |
+
)
|
| 662 |
+
await session_manager.persist_session_snapshot(agent_session)
|
| 663 |
+
logger.info(
|
| 664 |
+
"Uploaded dataset file %s to %s for session %s",
|
| 665 |
+
uploaded.filename,
|
| 666 |
+
uploaded.repo_id,
|
| 667 |
+
session_id,
|
| 668 |
+
)
|
| 669 |
+
return DatasetUploadResponse(**uploaded.response_payload())
|
| 670 |
+
except HTTPException:
|
| 671 |
+
raise
|
| 672 |
+
except HfHubHTTPError as e:
|
| 673 |
+
logger.warning(
|
| 674 |
+
"Hub rejected dataset upload for session %s: status=%s request_id=%s",
|
| 675 |
+
session_id,
|
| 676 |
+
getattr(e.response, "status_code", None),
|
| 677 |
+
getattr(e, "request_id", None),
|
| 678 |
+
)
|
| 679 |
+
raise _dataset_upload_hub_http_exception(e)
|
| 680 |
+
except Exception:
|
| 681 |
+
logger.exception("Dataset upload failed for session %s", session_id)
|
| 682 |
+
raise HTTPException(
|
| 683 |
+
status_code=502,
|
| 684 |
+
detail="Dataset upload failed. Please try again.",
|
| 685 |
+
)
|
| 686 |
+
finally:
|
| 687 |
+
if file is not None:
|
| 688 |
+
await file.close()
|
| 689 |
+
|
| 690 |
+
|
| 691 |
@router.patch("/session/{session_id}/yolo")
|
| 692 |
async def set_session_yolo(
|
| 693 |
session_id: str,
|
|
@@ -11,12 +11,15 @@ import {
|
|
| 11 |
ListItemIcon,
|
| 12 |
ListItemText,
|
| 13 |
Chip,
|
|
|
|
| 14 |
Snackbar,
|
|
|
|
| 15 |
} from '@mui/material';
|
| 16 |
import ArrowUpwardIcon from '@mui/icons-material/ArrowUpward';
|
| 17 |
import ArrowDropDownIcon from '@mui/icons-material/ArrowDropDown';
|
| 18 |
import StopIcon from '@mui/icons-material/Stop';
|
| 19 |
-
import
|
|
|
|
| 20 |
import { useUserQuota } from '@/hooks/useUserQuota';
|
| 21 |
import ClaudeCapDialog from '@/components/ClaudeCapDialog';
|
| 22 |
import JobsUpgradeDialog from '@/components/JobsUpgradeDialog';
|
|
@@ -118,18 +121,49 @@ interface ChatInputProps {
|
|
| 118 |
initialModelPath?: string | null;
|
| 119 |
onSend: (text: string) => void;
|
| 120 |
onStop?: () => void;
|
|
|
|
| 121 |
isProcessing?: boolean;
|
| 122 |
disabled?: boolean;
|
| 123 |
placeholder?: string;
|
| 124 |
}
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
const isClaudeModel = (m: ModelOption) => isClaudePath(m.modelPath);
|
| 127 |
const isPremiumModel = (m: ModelOption) => isPremiumPath(m.modelPath);
|
| 128 |
const firstFreeModel = (options: ModelOption[]) => options.find(m => !isPremiumModel(m)) ?? options[0];
|
| 129 |
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
const [input, setInput] = useState('');
|
| 132 |
const inputRef = useRef<HTMLTextAreaElement>(null);
|
|
|
|
| 133 |
const [modelOptions, setModelOptions] = useState<ModelOption[]>(DEFAULT_MODEL_OPTIONS);
|
| 134 |
const modelOptionsRef = useRef<ModelOption[]>(DEFAULT_MODEL_OPTIONS);
|
| 135 |
const sessionIdRef = useRef<string | undefined>(sessionId);
|
|
@@ -150,6 +184,11 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop,
|
|
| 150 |
const updateSessionModel = useSessionStore((s) => s.updateSessionModel);
|
| 151 |
const [awaitingTopUp, setAwaitingTopUp] = useState(false);
|
| 152 |
const [modelSwitchError, setModelSwitchError] = useState<string | null>(null);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
const lastSentRef = useRef<string>('');
|
| 154 |
|
| 155 |
useEffect(() => {
|
|
@@ -216,12 +255,75 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop,
|
|
| 216 |
}, [disabled, isProcessing]);
|
| 217 |
|
| 218 |
const handleSend = useCallback(() => {
|
| 219 |
-
if (input.trim() && !disabled) {
|
| 220 |
lastSentRef.current = input;
|
| 221 |
onSend(input);
|
| 222 |
setInput('');
|
| 223 |
}
|
| 224 |
-
}, [input, disabled, onSend]);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
// When the chat transport reports a premium-model quota 429, restore the typed
|
| 227 |
// text so the user doesn't lose their message.
|
|
@@ -231,6 +333,18 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop,
|
|
| 231 |
}
|
| 232 |
}, [claudeQuotaExhausted]);
|
| 233 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
// Refresh the quota display whenever the session changes (user might
|
| 235 |
// have started another tab that spent quota).
|
| 236 |
useEffect(() => {
|
|
@@ -382,9 +496,12 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop,
|
|
| 382 |
<Box
|
| 383 |
className="composer"
|
| 384 |
sx={{
|
| 385 |
-
display: '
|
| 386 |
-
|
| 387 |
-
|
|
|
|
|
|
|
|
|
|
| 388 |
bgcolor: 'var(--composer-bg)',
|
| 389 |
borderRadius: 'var(--radius-md)',
|
| 390 |
p: '12px',
|
|
@@ -420,7 +537,7 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop,
|
|
| 420 |
}
|
| 421 |
}}
|
| 422 |
sx={{
|
| 423 |
-
|
| 424 |
'& .MuiInputBase-root': {
|
| 425 |
p: 0,
|
| 426 |
backgroundColor: 'transparent',
|
|
@@ -431,11 +548,46 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop,
|
|
| 431 |
}
|
| 432 |
}}
|
| 433 |
/>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
{isProcessing ? (
|
| 435 |
<IconButton
|
| 436 |
onClick={onStop}
|
| 437 |
sx={{
|
| 438 |
-
|
|
|
|
|
|
|
| 439 |
p: 1.5,
|
| 440 |
borderRadius: '10px',
|
| 441 |
color: 'var(--muted-text)',
|
|
@@ -455,9 +607,11 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop,
|
|
| 455 |
) : (
|
| 456 |
<IconButton
|
| 457 |
onClick={handleSend}
|
| 458 |
-
disabled={disabled || !input.trim()}
|
| 459 |
sx={{
|
| 460 |
-
|
|
|
|
|
|
|
| 461 |
p: 1,
|
| 462 |
borderRadius: '10px',
|
| 463 |
color: 'var(--muted-text)',
|
|
@@ -475,6 +629,65 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop,
|
|
| 475 |
</IconButton>
|
| 476 |
)}
|
| 477 |
</Box>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 478 |
|
| 479 |
{/* Powered By Badge */}
|
| 480 |
<Box
|
|
|
|
| 11 |
ListItemIcon,
|
| 12 |
ListItemText,
|
| 13 |
Chip,
|
| 14 |
+
LinearProgress,
|
| 15 |
Snackbar,
|
| 16 |
+
Tooltip,
|
| 17 |
} from '@mui/material';
|
| 18 |
import ArrowUpwardIcon from '@mui/icons-material/ArrowUpward';
|
| 19 |
import ArrowDropDownIcon from '@mui/icons-material/ArrowDropDown';
|
| 20 |
import StopIcon from '@mui/icons-material/Stop';
|
| 21 |
+
import AddIcon from '@mui/icons-material/Add';
|
| 22 |
+
import { apiFetch, apiUpload } from '@/utils/api';
|
| 23 |
import { useUserQuota } from '@/hooks/useUserQuota';
|
| 24 |
import ClaudeCapDialog from '@/components/ClaudeCapDialog';
|
| 25 |
import JobsUpgradeDialog from '@/components/JobsUpgradeDialog';
|
|
|
|
| 121 |
initialModelPath?: string | null;
|
| 122 |
onSend: (text: string) => void;
|
| 123 |
onStop?: () => void;
|
| 124 |
+
onDatasetUploaded?: () => Promise<boolean> | boolean;
|
| 125 |
isProcessing?: boolean;
|
| 126 |
disabled?: boolean;
|
| 127 |
placeholder?: string;
|
| 128 |
}
|
| 129 |
|
| 130 |
+
interface DatasetUploadResponse {
|
| 131 |
+
session_id: string;
|
| 132 |
+
repo_id: string;
|
| 133 |
+
repo_type: 'dataset';
|
| 134 |
+
private: true;
|
| 135 |
+
upload_id: string;
|
| 136 |
+
config_name: string;
|
| 137 |
+
filename: string;
|
| 138 |
+
path_in_repo: string;
|
| 139 |
+
size_bytes: number;
|
| 140 |
+
format: 'csv' | 'json' | 'jsonl';
|
| 141 |
+
hub_url: string;
|
| 142 |
+
load_dataset_snippet: string;
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
const MAX_DATASET_UPLOAD_BYTES = 100 * 1024 * 1024;
|
| 146 |
+
const DATASET_UPLOAD_ACCEPT = '.csv,.json,.jsonl';
|
| 147 |
+
const DATASET_UPLOAD_EXTENSIONS = new Set(['csv', 'json', 'jsonl']);
|
| 148 |
+
|
| 149 |
const isClaudeModel = (m: ModelOption) => isClaudePath(m.modelPath);
|
| 150 |
const isPremiumModel = (m: ModelOption) => isPremiumPath(m.modelPath);
|
| 151 |
const firstFreeModel = (options: ModelOption[]) => options.find(m => !isPremiumModel(m)) ?? options[0];
|
| 152 |
|
| 153 |
+
const formatBytes = (bytes: number) => {
|
| 154 |
+
if (bytes < 1024) return `${bytes} B`;
|
| 155 |
+
if (bytes < 1024 * 1024) return `${(bytes / 1024).toFixed(1)} KB`;
|
| 156 |
+
return `${(bytes / (1024 * 1024)).toFixed(1)} MB`;
|
| 157 |
+
};
|
| 158 |
+
|
| 159 |
+
const datasetRepoUrl = (repoId: string) => (
|
| 160 |
+
`https://huggingface.co/datasets/${repoId.split('/').map(encodeURIComponent).join('/')}`
|
| 161 |
+
);
|
| 162 |
+
|
| 163 |
+
export default function ChatInput({ sessionId, initialModelPath, onSend, onStop, onDatasetUploaded, isProcessing = false, disabled = false, placeholder = 'Ask anything...' }: ChatInputProps) {
|
| 164 |
const [input, setInput] = useState('');
|
| 165 |
const inputRef = useRef<HTMLTextAreaElement>(null);
|
| 166 |
+
const fileInputRef = useRef<HTMLInputElement>(null);
|
| 167 |
const [modelOptions, setModelOptions] = useState<ModelOption[]>(DEFAULT_MODEL_OPTIONS);
|
| 168 |
const modelOptionsRef = useRef<ModelOption[]>(DEFAULT_MODEL_OPTIONS);
|
| 169 |
const sessionIdRef = useRef<string | undefined>(sessionId);
|
|
|
|
| 184 |
const updateSessionModel = useSessionStore((s) => s.updateSessionModel);
|
| 185 |
const [awaitingTopUp, setAwaitingTopUp] = useState(false);
|
| 186 |
const [modelSwitchError, setModelSwitchError] = useState<string | null>(null);
|
| 187 |
+
const [datasetUploadError, setDatasetUploadError] = useState<string | null>(null);
|
| 188 |
+
const [datasetUploadSuccess, setDatasetUploadSuccess] = useState<string | null>(null);
|
| 189 |
+
const [uploadedDatasets, setUploadedDatasets] = useState<DatasetUploadResponse[]>([]);
|
| 190 |
+
const [isUploadingDataset, setIsUploadingDataset] = useState(false);
|
| 191 |
+
const [datasetUploadProgress, setDatasetUploadProgress] = useState<number | null>(null);
|
| 192 |
const lastSentRef = useRef<string>('');
|
| 193 |
|
| 194 |
useEffect(() => {
|
|
|
|
| 255 |
}, [disabled, isProcessing]);
|
| 256 |
|
| 257 |
const handleSend = useCallback(() => {
|
| 258 |
+
if (input.trim() && !disabled && !isUploadingDataset) {
|
| 259 |
lastSentRef.current = input;
|
| 260 |
onSend(input);
|
| 261 |
setInput('');
|
| 262 |
}
|
| 263 |
+
}, [input, disabled, isUploadingDataset, onSend]);
|
| 264 |
+
|
| 265 |
+
const handleDatasetUploadClick = useCallback(() => {
|
| 266 |
+
fileInputRef.current?.click();
|
| 267 |
+
}, []);
|
| 268 |
+
|
| 269 |
+
const handleDatasetFileChange = useCallback(
|
| 270 |
+
async (event: React.ChangeEvent<HTMLInputElement>) => {
|
| 271 |
+
const file = event.target.files?.[0];
|
| 272 |
+
event.target.value = '';
|
| 273 |
+
if (!file) return;
|
| 274 |
+
|
| 275 |
+
if (!sessionId) {
|
| 276 |
+
setDatasetUploadError('Start a session before uploading a dataset.');
|
| 277 |
+
return;
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
const extension = file.name.split('.').pop()?.toLowerCase() || '';
|
| 281 |
+
if (!DATASET_UPLOAD_EXTENSIONS.has(extension)) {
|
| 282 |
+
setDatasetUploadError('Only CSV, JSON, and JSONL dataset files are supported.');
|
| 283 |
+
return;
|
| 284 |
+
}
|
| 285 |
+
if (file.size > MAX_DATASET_UPLOAD_BYTES) {
|
| 286 |
+
setDatasetUploadError(
|
| 287 |
+
`Dataset files must be 100 MB or smaller. ${file.name} is ${formatBytes(file.size)}.`
|
| 288 |
+
);
|
| 289 |
+
return;
|
| 290 |
+
}
|
| 291 |
+
if (file.size === 0) {
|
| 292 |
+
setDatasetUploadError('Uploaded dataset file is empty.');
|
| 293 |
+
return;
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
const formData = new FormData();
|
| 297 |
+
formData.append('file', file);
|
| 298 |
+
setIsUploadingDataset(true);
|
| 299 |
+
setDatasetUploadProgress(0);
|
| 300 |
+
setDatasetUploadError(null);
|
| 301 |
+
setDatasetUploadSuccess(null);
|
| 302 |
+
try {
|
| 303 |
+
const res = await apiUpload(`/api/session/${sessionId}/datasets`, formData, {
|
| 304 |
+
onProgress: ({ percent }) => {
|
| 305 |
+
setDatasetUploadProgress(percent !== null && percent < 100 ? percent : null);
|
| 306 |
+
},
|
| 307 |
+
});
|
| 308 |
+
if (!res.ok) {
|
| 309 |
+
setDatasetUploadError(await readApiErrorMessage(res, 'Dataset upload failed.'));
|
| 310 |
+
return;
|
| 311 |
+
}
|
| 312 |
+
const payload = await res.json() as DatasetUploadResponse;
|
| 313 |
+
setUploadedDatasets((previous) => [payload, ...previous]);
|
| 314 |
+
setDatasetUploadSuccess(`Uploaded ${payload.filename} to ${payload.repo_id}`);
|
| 315 |
+
await onDatasetUploaded?.();
|
| 316 |
+
} catch (error) {
|
| 317 |
+
setDatasetUploadError(
|
| 318 |
+
error instanceof Error ? error.message : 'Dataset upload failed.'
|
| 319 |
+
);
|
| 320 |
+
} finally {
|
| 321 |
+
setIsUploadingDataset(false);
|
| 322 |
+
setDatasetUploadProgress(null);
|
| 323 |
+
}
|
| 324 |
+
},
|
| 325 |
+
[sessionId, onDatasetUploaded],
|
| 326 |
+
);
|
| 327 |
|
| 328 |
// When the chat transport reports a premium-model quota 429, restore the typed
|
| 329 |
// text so the user doesn't lose their message.
|
|
|
|
| 333 |
}
|
| 334 |
}, [claudeQuotaExhausted]);
|
| 335 |
|
| 336 |
+
useEffect(() => {
|
| 337 |
+
if (!datasetUploadError) return;
|
| 338 |
+
const timeout = window.setTimeout(() => setDatasetUploadError(null), 7000);
|
| 339 |
+
return () => window.clearTimeout(timeout);
|
| 340 |
+
}, [datasetUploadError]);
|
| 341 |
+
|
| 342 |
+
useEffect(() => {
|
| 343 |
+
if (!datasetUploadSuccess) return;
|
| 344 |
+
const timeout = window.setTimeout(() => setDatasetUploadSuccess(null), 5000);
|
| 345 |
+
return () => window.clearTimeout(timeout);
|
| 346 |
+
}, [datasetUploadSuccess]);
|
| 347 |
+
|
| 348 |
// Refresh the quota display whenever the session changes (user might
|
| 349 |
// have started another tab that spent quota).
|
| 350 |
useEffect(() => {
|
|
|
|
| 496 |
<Box
|
| 497 |
className="composer"
|
| 498 |
sx={{
|
| 499 |
+
display: 'grid',
|
| 500 |
+
gridTemplateColumns: 'auto 1fr auto',
|
| 501 |
+
gridTemplateRows: 'auto auto',
|
| 502 |
+
columnGap: '10px',
|
| 503 |
+
rowGap: '4px',
|
| 504 |
+
alignItems: 'end',
|
| 505 |
bgcolor: 'var(--composer-bg)',
|
| 506 |
borderRadius: 'var(--radius-md)',
|
| 507 |
p: '12px',
|
|
|
|
| 537 |
}
|
| 538 |
}}
|
| 539 |
sx={{
|
| 540 |
+
gridColumn: '1 / -1',
|
| 541 |
'& .MuiInputBase-root': {
|
| 542 |
p: 0,
|
| 543 |
backgroundColor: 'transparent',
|
|
|
|
| 548 |
}
|
| 549 |
}}
|
| 550 |
/>
|
| 551 |
+
<input
|
| 552 |
+
ref={fileInputRef}
|
| 553 |
+
type="file"
|
| 554 |
+
accept={DATASET_UPLOAD_ACCEPT}
|
| 555 |
+
onChange={handleDatasetFileChange}
|
| 556 |
+
style={{ display: 'none' }}
|
| 557 |
+
/>
|
| 558 |
+
<Box sx={{ gridColumn: '1', gridRow: '2', display: 'flex' }}>
|
| 559 |
+
<Tooltip title="Upload dataset">
|
| 560 |
+
<span>
|
| 561 |
+
<IconButton
|
| 562 |
+
onClick={handleDatasetUploadClick}
|
| 563 |
+
disabled={disabled || isProcessing || isUploadingDataset || !sessionId}
|
| 564 |
+
sx={{
|
| 565 |
+
p: 1,
|
| 566 |
+
borderRadius: '50%',
|
| 567 |
+
color: uploadedDatasets.length ? 'var(--accent-yellow)' : 'var(--muted-text)',
|
| 568 |
+
transition: 'all 0.2s',
|
| 569 |
+
'&:hover': {
|
| 570 |
+
color: 'var(--accent-yellow)',
|
| 571 |
+
bgcolor: 'var(--hover-bg)',
|
| 572 |
+
},
|
| 573 |
+
'&.Mui-disabled': {
|
| 574 |
+
opacity: 0.3,
|
| 575 |
+
},
|
| 576 |
+
}}
|
| 577 |
+
aria-label="Upload dataset"
|
| 578 |
+
>
|
| 579 |
+
<AddIcon fontSize="small" />
|
| 580 |
+
</IconButton>
|
| 581 |
+
</span>
|
| 582 |
+
</Tooltip>
|
| 583 |
+
</Box>
|
| 584 |
{isProcessing ? (
|
| 585 |
<IconButton
|
| 586 |
onClick={onStop}
|
| 587 |
sx={{
|
| 588 |
+
gridColumn: '3',
|
| 589 |
+
gridRow: '2',
|
| 590 |
+
justifySelf: 'end',
|
| 591 |
p: 1.5,
|
| 592 |
borderRadius: '10px',
|
| 593 |
color: 'var(--muted-text)',
|
|
|
|
| 607 |
) : (
|
| 608 |
<IconButton
|
| 609 |
onClick={handleSend}
|
| 610 |
+
disabled={disabled || isUploadingDataset || !input.trim()}
|
| 611 |
sx={{
|
| 612 |
+
gridColumn: '3',
|
| 613 |
+
gridRow: '2',
|
| 614 |
+
justifySelf: 'end',
|
| 615 |
p: 1,
|
| 616 |
borderRadius: '10px',
|
| 617 |
color: 'var(--muted-text)',
|
|
|
|
| 629 |
</IconButton>
|
| 630 |
)}
|
| 631 |
</Box>
|
| 632 |
+
{isUploadingDataset && (
|
| 633 |
+
<Box sx={{ mt: 1, px: 0.5 }}>
|
| 634 |
+
<LinearProgress
|
| 635 |
+
variant={datasetUploadProgress === null ? 'indeterminate' : 'determinate'}
|
| 636 |
+
value={datasetUploadProgress ?? 0}
|
| 637 |
+
aria-label="Dataset upload progress"
|
| 638 |
+
sx={{
|
| 639 |
+
height: 4,
|
| 640 |
+
borderRadius: 999,
|
| 641 |
+
bgcolor: 'rgba(255,255,255,0.08)',
|
| 642 |
+
'& .MuiLinearProgress-bar': {
|
| 643 |
+
borderRadius: 999,
|
| 644 |
+
bgcolor: 'var(--accent-yellow)',
|
| 645 |
+
},
|
| 646 |
+
}}
|
| 647 |
+
/>
|
| 648 |
+
</Box>
|
| 649 |
+
)}
|
| 650 |
+
{(datasetUploadError || datasetUploadSuccess) && (
|
| 651 |
+
<Box sx={{ display: 'flex', justifyContent: 'center', mt: 1 }}>
|
| 652 |
+
<Alert
|
| 653 |
+
severity={datasetUploadError ? 'error' : 'success'}
|
| 654 |
+
variant="filled"
|
| 655 |
+
onClose={() => {
|
| 656 |
+
setDatasetUploadError(null);
|
| 657 |
+
setDatasetUploadSuccess(null);
|
| 658 |
+
}}
|
| 659 |
+
sx={{ fontSize: '0.8rem', maxWidth: 520, width: '100%' }}
|
| 660 |
+
>
|
| 661 |
+
{datasetUploadError ?? datasetUploadSuccess}
|
| 662 |
+
</Alert>
|
| 663 |
+
</Box>
|
| 664 |
+
)}
|
| 665 |
+
{uploadedDatasets.length > 0 && (
|
| 666 |
+
<Box sx={{ display: 'flex', flexWrap: 'wrap', gap: 0.75, justifyContent: 'center', mt: 1 }}>
|
| 667 |
+
{uploadedDatasets.map((dataset) => (
|
| 668 |
+
<Chip
|
| 669 |
+
key={dataset.upload_id}
|
| 670 |
+
size="small"
|
| 671 |
+
label={`Dataset: ${dataset.filename}`}
|
| 672 |
+
component="a"
|
| 673 |
+
href={datasetRepoUrl(dataset.repo_id)}
|
| 674 |
+
target="_blank"
|
| 675 |
+
rel="noreferrer"
|
| 676 |
+
clickable
|
| 677 |
+
sx={{
|
| 678 |
+
maxWidth: '100%',
|
| 679 |
+
bgcolor: 'rgba(255,255,255,0.08)',
|
| 680 |
+
color: 'var(--text)',
|
| 681 |
+
border: '1px solid var(--divider)',
|
| 682 |
+
'& .MuiChip-label': {
|
| 683 |
+
overflow: 'hidden',
|
| 684 |
+
textOverflow: 'ellipsis',
|
| 685 |
+
},
|
| 686 |
+
}}
|
| 687 |
+
/>
|
| 688 |
+
))}
|
| 689 |
+
</Box>
|
| 690 |
+
)}
|
| 691 |
|
| 692 |
{/* Powered By Badge */}
|
| 693 |
<Box
|
|
@@ -27,7 +27,16 @@ export default function SessionChat({ sessionId, isActive, onSessionDead }: Sess
|
|
| 27 |
const sessionMeta = sessions.find((s) => s.id === sessionId);
|
| 28 |
const isExpired = sessionMeta?.expired === true;
|
| 29 |
|
| 30 |
-
const {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
sessionId,
|
| 32 |
isActive,
|
| 33 |
onReady: () => logger.log(`Session ${sessionId} ready`),
|
|
@@ -116,6 +125,7 @@ export default function SessionChat({ sessionId, isActive, onSessionDead }: Sess
|
|
| 116 |
initialModelPath={sessionMeta?.model}
|
| 117 |
onSend={handleSendMessage}
|
| 118 |
onStop={handleStop}
|
|
|
|
| 119 |
isProcessing={busy}
|
| 120 |
disabled={!isConnected || activityStatus.type === 'waiting-approval'}
|
| 121 |
placeholder={
|
|
|
|
| 27 |
const sessionMeta = sessions.find((s) => s.id === sessionId);
|
| 28 |
const isExpired = sessionMeta?.expired === true;
|
| 29 |
|
| 30 |
+
const {
|
| 31 |
+
messages,
|
| 32 |
+
sendMessage,
|
| 33 |
+
stop,
|
| 34 |
+
status,
|
| 35 |
+
undoLastTurn,
|
| 36 |
+
editAndRegenerate,
|
| 37 |
+
approveTools,
|
| 38 |
+
refreshMessages,
|
| 39 |
+
} = useAgentChat({
|
| 40 |
sessionId,
|
| 41 |
isActive,
|
| 42 |
onReady: () => logger.log(`Session ${sessionId} ready`),
|
|
|
|
| 125 |
initialModelPath={sessionMeta?.model}
|
| 126 |
onSend={handleSendMessage}
|
| 127 |
onStop={handleStop}
|
| 128 |
+
onDatasetUploaded={refreshMessages}
|
| 129 |
isProcessing={busy}
|
| 130 |
disabled={!isConnected || activityStatus.type === 'waiting-approval'}
|
| 131 |
placeholder={
|
|
@@ -804,6 +804,48 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
|
|
| 804 |
}
|
| 805 |
}, [sessionId, chat]);
|
| 806 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 807 |
return {
|
| 808 |
messages: chat.messages,
|
| 809 |
sendMessage: chat.sendMessage,
|
|
@@ -812,5 +854,6 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
|
|
| 812 |
undoLastTurn,
|
| 813 |
editAndRegenerate,
|
| 814 |
approveTools,
|
|
|
|
| 815 |
};
|
| 816 |
}
|
|
|
|
| 804 |
}
|
| 805 |
}, [sessionId, chat]);
|
| 806 |
|
| 807 |
+
const refreshMessages = useCallback(async () => {
|
| 808 |
+
try {
|
| 809 |
+
const [msgsRes, infoRes] = await Promise.all([
|
| 810 |
+
apiFetch(`/api/session/${sessionId}/messages`),
|
| 811 |
+
apiFetch(`/api/session/${sessionId}`),
|
| 812 |
+
]);
|
| 813 |
+
if (!msgsRes.ok) return false;
|
| 814 |
+
|
| 815 |
+
const data = await msgsRes.json();
|
| 816 |
+
if (!Array.isArray(data) || data.length === 0) return false;
|
| 817 |
+
saveBackendMessages(sessionId, data);
|
| 818 |
+
|
| 819 |
+
let pendingIds: Set<string> | undefined;
|
| 820 |
+
if (infoRes.ok) {
|
| 821 |
+
const info = await infoRes.json();
|
| 822 |
+
if (info.pending_approval && Array.isArray(info.pending_approval)) {
|
| 823 |
+
pendingIds = new Set(
|
| 824 |
+
info.pending_approval.map((t: { tool_call_id: string }) => t.tool_call_id)
|
| 825 |
+
);
|
| 826 |
+
if (pendingIds.size > 0) setNeedsAttention(sessionId, true);
|
| 827 |
+
}
|
| 828 |
+
if (info.auto_approval) {
|
| 829 |
+
updateSessionYolo(sessionId, info.auto_approval);
|
| 830 |
+
}
|
| 831 |
+
}
|
| 832 |
+
|
| 833 |
+
const uiMsgs = llmMessagesToUIMessages(
|
| 834 |
+
data,
|
| 835 |
+
pendingIds,
|
| 836 |
+
chatActionsRef.current.messages,
|
| 837 |
+
);
|
| 838 |
+
const setMsgs = chatActionsRef.current.setMessages;
|
| 839 |
+
if (setMsgs && uiMsgs.length > 0) {
|
| 840 |
+
setMsgs(uiMsgs);
|
| 841 |
+
saveMessages(sessionId, uiMsgs);
|
| 842 |
+
}
|
| 843 |
+
return true;
|
| 844 |
+
} catch {
|
| 845 |
+
return false;
|
| 846 |
+
}
|
| 847 |
+
}, [sessionId, setNeedsAttention, updateSessionYolo]);
|
| 848 |
+
|
| 849 |
return {
|
| 850 |
messages: chat.messages,
|
| 851 |
sendMessage: chat.sendMessage,
|
|
|
|
| 854 |
undoLastTurn,
|
| 855 |
editAndRegenerate,
|
| 856 |
approveTools,
|
| 857 |
+
refreshMessages,
|
| 858 |
};
|
| 859 |
}
|
|
@@ -7,15 +7,36 @@
|
|
| 7 |
|
| 8 |
import { triggerLogin } from '@/hooks/useAuth';
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
/** Wrapper around fetch with credentials and common headers. */
|
| 11 |
export async function apiFetch(
|
| 12 |
path: string,
|
| 13 |
options: RequestInit = {}
|
| 14 |
): Promise<Response> {
|
| 15 |
-
const headers
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
| 19 |
|
| 20 |
const response = await fetch(path, {
|
| 21 |
...options,
|
|
@@ -23,19 +44,50 @@ export async function apiFetch(
|
|
| 23 |
credentials: 'include', // Send cookies with every request
|
| 24 |
});
|
| 25 |
|
| 26 |
-
|
| 27 |
-
if (response.status === 401) {
|
| 28 |
-
try {
|
| 29 |
-
const authStatus = await fetch('/auth/status', { credentials: 'include' });
|
| 30 |
-
const data = await authStatus.json();
|
| 31 |
-
if (data.auth_enabled) {
|
| 32 |
-
triggerLogin();
|
| 33 |
-
throw new Error('Authentication required — redirecting to login.');
|
| 34 |
-
}
|
| 35 |
-
} catch (e) {
|
| 36 |
-
if (e instanceof Error && e.message.includes('redirecting')) throw e;
|
| 37 |
-
}
|
| 38 |
-
}
|
| 39 |
|
| 40 |
return response;
|
| 41 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
import { triggerLogin } from '@/hooks/useAuth';
|
| 9 |
|
| 10 |
+
export interface ApiUploadProgress {
|
| 11 |
+
loaded: number;
|
| 12 |
+
total: number | null;
|
| 13 |
+
percent: number | null;
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
async function handleUnauthorized(response: Response): Promise<void> {
|
| 17 |
+
if (response.status !== 401) return;
|
| 18 |
+
try {
|
| 19 |
+
const authStatus = await fetch('/auth/status', { credentials: 'include' });
|
| 20 |
+
const data = await authStatus.json();
|
| 21 |
+
if (data.auth_enabled) {
|
| 22 |
+
triggerLogin();
|
| 23 |
+
throw new Error('Authentication required — redirecting to login.');
|
| 24 |
+
}
|
| 25 |
+
} catch (e) {
|
| 26 |
+
if (e instanceof Error && e.message.includes('redirecting')) throw e;
|
| 27 |
+
}
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
/** Wrapper around fetch with credentials and common headers. */
|
| 31 |
export async function apiFetch(
|
| 32 |
path: string,
|
| 33 |
options: RequestInit = {}
|
| 34 |
): Promise<Response> {
|
| 35 |
+
const headers = new Headers(options.headers);
|
| 36 |
+
const isFormData = options.body instanceof FormData;
|
| 37 |
+
if (!isFormData && !headers.has('Content-Type')) {
|
| 38 |
+
headers.set('Content-Type', 'application/json');
|
| 39 |
+
}
|
| 40 |
|
| 41 |
const response = await fetch(path, {
|
| 42 |
...options,
|
|
|
|
| 44 |
credentials: 'include', // Send cookies with every request
|
| 45 |
});
|
| 46 |
|
| 47 |
+
await handleUnauthorized(response);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
return response;
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
function headersFromXhr(rawHeaders: string): Headers {
|
| 53 |
+
const headers = new Headers();
|
| 54 |
+
rawHeaders.trim().split(/[\r\n]+/).forEach((line) => {
|
| 55 |
+
const separator = line.indexOf(':');
|
| 56 |
+
if (separator <= 0) return;
|
| 57 |
+
headers.append(
|
| 58 |
+
line.slice(0, separator).trim(),
|
| 59 |
+
line.slice(separator + 1).trim(),
|
| 60 |
+
);
|
| 61 |
+
});
|
| 62 |
+
return headers;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
export async function apiUpload(
|
| 66 |
+
path: string,
|
| 67 |
+
formData: FormData,
|
| 68 |
+
options: { onProgress?: (progress: ApiUploadProgress) => void } = {},
|
| 69 |
+
): Promise<Response> {
|
| 70 |
+
return new Promise<Response>((resolve, reject) => {
|
| 71 |
+
const xhr = new XMLHttpRequest();
|
| 72 |
+
xhr.open('POST', path);
|
| 73 |
+
xhr.withCredentials = true;
|
| 74 |
+
xhr.upload.onprogress = (event) => {
|
| 75 |
+
const total = event.lengthComputable ? event.total : null;
|
| 76 |
+
const percent = total
|
| 77 |
+
? Math.min(100, Math.round((event.loaded / total) * 100))
|
| 78 |
+
: null;
|
| 79 |
+
options.onProgress?.({ loaded: event.loaded, total, percent });
|
| 80 |
+
};
|
| 81 |
+
xhr.onerror = () => reject(new Error('Network error while uploading.'));
|
| 82 |
+
xhr.onabort = () => reject(new Error('Dataset upload was canceled.'));
|
| 83 |
+
xhr.onload = () => {
|
| 84 |
+
const response = new Response(xhr.responseText, {
|
| 85 |
+
status: xhr.status,
|
| 86 |
+
statusText: xhr.statusText,
|
| 87 |
+
headers: headersFromXhr(xhr.getAllResponseHeaders()),
|
| 88 |
+
});
|
| 89 |
+
handleUnauthorized(response).then(() => resolve(response)).catch(reject);
|
| 90 |
+
};
|
| 91 |
+
xhr.send(formData);
|
| 92 |
+
});
|
| 93 |
+
}
|
|
@@ -28,6 +28,7 @@ dependencies = [
|
|
| 28 |
"websockets>=13.0",
|
| 29 |
"apscheduler>=3.10,<4",
|
| 30 |
"pymongo>=4.17.0",
|
|
|
|
| 31 |
]
|
| 32 |
|
| 33 |
[project.optional-dependencies]
|
|
|
|
| 28 |
"websockets>=13.0",
|
| 29 |
"apscheduler>=3.10,<4",
|
| 30 |
"pymongo>=4.17.0",
|
| 31 |
+
"python-multipart>=0.0.20",
|
| 32 |
]
|
| 33 |
|
| 34 |
[project.optional-dependencies]
|
|
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import sys
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from types import SimpleNamespace
|
| 5 |
+
|
| 6 |
+
import httpx
|
| 7 |
+
import pytest
|
| 8 |
+
from fastapi import HTTPException, UploadFile
|
| 9 |
+
from huggingface_hub.errors import HfHubHTTPError
|
| 10 |
+
from starlette.datastructures import FormData
|
| 11 |
+
|
| 12 |
+
_BACKEND_DIR = Path(__file__).resolve().parent.parent.parent / "backend"
|
| 13 |
+
if str(_BACKEND_DIR) not in sys.path:
|
| 14 |
+
sys.path.insert(0, str(_BACKEND_DIR))
|
| 15 |
+
|
| 16 |
+
import dataset_uploads # noqa: E402
|
| 17 |
+
from routes import agent # noqa: E402
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _upload(filename: str, content: bytes = b"a,b\n1,2\n") -> UploadFile:
|
| 21 |
+
return UploadFile(filename=filename, file=io.BytesIO(content))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _track_close(upload: UploadFile):
|
| 25 |
+
state = {"closed": False}
|
| 26 |
+
original_close = upload.close
|
| 27 |
+
|
| 28 |
+
async def close():
|
| 29 |
+
state["closed"] = True
|
| 30 |
+
await original_close()
|
| 31 |
+
|
| 32 |
+
upload.close = close
|
| 33 |
+
return state
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _request(
|
| 37 |
+
upload: UploadFile | None = None,
|
| 38 |
+
headers: dict[str, str] | None = None,
|
| 39 |
+
):
|
| 40 |
+
state = {"form_called": False}
|
| 41 |
+
|
| 42 |
+
class FakeRequest:
|
| 43 |
+
def __init__(self):
|
| 44 |
+
self.headers = headers or {}
|
| 45 |
+
self.cookies = {}
|
| 46 |
+
|
| 47 |
+
async def form(self, **_kwargs):
|
| 48 |
+
state["form_called"] = True
|
| 49 |
+
if upload is None:
|
| 50 |
+
raise AssertionError("request.form() should not be called")
|
| 51 |
+
return FormData([("file", upload)])
|
| 52 |
+
|
| 53 |
+
return FakeRequest(), state
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def test_sanitize_dataset_filename_strips_paths_and_unsafe_chars():
|
| 57 |
+
assert (
|
| 58 |
+
dataset_uploads.sanitize_dataset_filename("../../bad file (final).CSV")
|
| 59 |
+
== "bad-file-final.csv"
|
| 60 |
+
)
|
| 61 |
+
assert dataset_uploads.sanitize_dataset_filename("") == "dataset.csv"
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def test_dataset_format_rejects_unsupported_extension():
|
| 65 |
+
with pytest.raises(HTTPException) as exc_info:
|
| 66 |
+
dataset_uploads.dataset_format_from_filename("notes.txt")
|
| 67 |
+
|
| 68 |
+
assert exc_info.value.status_code == 400
|
| 69 |
+
|
| 70 |
+
with pytest.raises(HTTPException):
|
| 71 |
+
dataset_uploads.dataset_format_from_filename("notes")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def test_dataset_repo_card_exposes_each_upload_as_config():
|
| 75 |
+
card = dataset_uploads.dataset_repo_card(
|
| 76 |
+
"alice/ml-intern-s1-datasets",
|
| 77 |
+
[
|
| 78 |
+
"README.md",
|
| 79 |
+
"uploads/oldabc/rows.jsonl",
|
| 80 |
+
"uploads/oldabc/rows.jsonl",
|
| 81 |
+
"uploads/newdef/table.csv",
|
| 82 |
+
],
|
| 83 |
+
).decode("utf-8")
|
| 84 |
+
|
| 85 |
+
assert "configs:" in card
|
| 86 |
+
assert "- config_name: upload_oldabc" in card
|
| 87 |
+
assert ' path: "uploads/oldabc/rows.jsonl"' in card
|
| 88 |
+
assert "- config_name: upload_newdef" in card
|
| 89 |
+
assert ' path: "uploads/newdef/table.csv"' in card
|
| 90 |
+
assert card.count("- config_name: upload_oldabc") == 1
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@pytest.mark.asyncio
|
| 94 |
+
async def test_validate_dataset_upload_rejects_size_over_limit(monkeypatch):
|
| 95 |
+
monkeypatch.setattr(dataset_uploads, "MAX_DATASET_UPLOAD_BYTES", 3)
|
| 96 |
+
upload = _upload("rows.csv", b"abcd")
|
| 97 |
+
try:
|
| 98 |
+
with pytest.raises(HTTPException) as exc_info:
|
| 99 |
+
await dataset_uploads.validate_dataset_upload(upload)
|
| 100 |
+
finally:
|
| 101 |
+
await upload.close()
|
| 102 |
+
|
| 103 |
+
assert exc_info.value.status_code == 413
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@pytest.mark.asyncio
|
| 107 |
+
async def test_push_dataset_upload_creates_private_repo_and_uploads_file(monkeypatch):
|
| 108 |
+
instances = []
|
| 109 |
+
|
| 110 |
+
class FakeApi:
|
| 111 |
+
def __init__(self, token):
|
| 112 |
+
self.token = token
|
| 113 |
+
self.create_calls = []
|
| 114 |
+
self.settings_calls = []
|
| 115 |
+
self.list_calls = []
|
| 116 |
+
self.upload_calls = []
|
| 117 |
+
instances.append(self)
|
| 118 |
+
|
| 119 |
+
def create_repo(self, **kwargs):
|
| 120 |
+
self.create_calls.append(kwargs)
|
| 121 |
+
|
| 122 |
+
def update_repo_settings(self, **kwargs):
|
| 123 |
+
self.settings_calls.append(kwargs)
|
| 124 |
+
|
| 125 |
+
def list_repo_files(self, **kwargs):
|
| 126 |
+
self.list_calls.append(kwargs)
|
| 127 |
+
return [
|
| 128 |
+
"README.md",
|
| 129 |
+
"uploads/oldupload/old.jsonl",
|
| 130 |
+
"uploads/notes.txt",
|
| 131 |
+
]
|
| 132 |
+
|
| 133 |
+
def upload_file(self, **kwargs):
|
| 134 |
+
if kwargs["path_in_repo"] != "README.md":
|
| 135 |
+
assert kwargs["path_or_fileobj"] == b"a,b\n1,2\n"
|
| 136 |
+
self.upload_calls.append(kwargs)
|
| 137 |
+
|
| 138 |
+
monkeypatch.setattr(dataset_uploads, "HfApi", FakeApi)
|
| 139 |
+
monkeypatch.setattr(
|
| 140 |
+
dataset_uploads.uuid,
|
| 141 |
+
"uuid4",
|
| 142 |
+
lambda: SimpleNamespace(hex="feedfacecafebeef"),
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
upload = _upload("../Data Set.CSV")
|
| 146 |
+
try:
|
| 147 |
+
result = await dataset_uploads.push_dataset_upload_to_hub(
|
| 148 |
+
upload=upload,
|
| 149 |
+
session_id="12345678-90ab-cdef-1234-567890abcdef",
|
| 150 |
+
hf_username="alice",
|
| 151 |
+
hf_token="hf-token",
|
| 152 |
+
)
|
| 153 |
+
finally:
|
| 154 |
+
await upload.close()
|
| 155 |
+
|
| 156 |
+
api = instances[0]
|
| 157 |
+
assert api.token == "hf-token"
|
| 158 |
+
assert api.create_calls == [
|
| 159 |
+
{
|
| 160 |
+
"repo_id": "alice/ml-intern-12345678-datasets",
|
| 161 |
+
"repo_type": "dataset",
|
| 162 |
+
"private": True,
|
| 163 |
+
"exist_ok": True,
|
| 164 |
+
}
|
| 165 |
+
]
|
| 166 |
+
assert api.settings_calls == [
|
| 167 |
+
{
|
| 168 |
+
"repo_id": "alice/ml-intern-12345678-datasets",
|
| 169 |
+
"repo_type": "dataset",
|
| 170 |
+
"private": True,
|
| 171 |
+
}
|
| 172 |
+
]
|
| 173 |
+
assert api.list_calls == [
|
| 174 |
+
{
|
| 175 |
+
"repo_id": "alice/ml-intern-12345678-datasets",
|
| 176 |
+
"repo_type": "dataset",
|
| 177 |
+
}
|
| 178 |
+
]
|
| 179 |
+
assert [call["path_in_repo"] for call in api.upload_calls] == [
|
| 180 |
+
"uploads/feedfacecafe/Data-Set.csv",
|
| 181 |
+
"README.md",
|
| 182 |
+
]
|
| 183 |
+
readme = api.upload_calls[1]["path_or_fileobj"].decode("utf-8")
|
| 184 |
+
assert "- config_name: upload_oldupload" in readme
|
| 185 |
+
assert ' path: "uploads/oldupload/old.jsonl"' in readme
|
| 186 |
+
assert "- config_name: upload_feedfacecafe" in readme
|
| 187 |
+
assert ' path: "uploads/feedfacecafe/Data-Set.csv"' in readme
|
| 188 |
+
assert result.repo_id == "alice/ml-intern-12345678-datasets"
|
| 189 |
+
assert result.config_name == "upload_feedfacecafe"
|
| 190 |
+
assert result.format == "csv"
|
| 191 |
+
assert result.load_dataset_snippet == (
|
| 192 |
+
"from datasets import load_dataset\n\n"
|
| 193 |
+
'dataset = load_dataset("alice/ml-intern-12345678-datasets", '
|
| 194 |
+
'"upload_feedfacecafe", split="train", token=True)'
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
@pytest.mark.asyncio
|
| 199 |
+
async def test_upload_route_requires_hf_token_without_parsing_upload(monkeypatch):
|
| 200 |
+
monkeypatch.delenv("HF_TOKEN", raising=False)
|
| 201 |
+
upload = _upload("rows.csv")
|
| 202 |
+
close_state = _track_close(upload)
|
| 203 |
+
request, request_state = _request(upload)
|
| 204 |
+
|
| 205 |
+
async def fake_check_session_access(*_args, **_kwargs):
|
| 206 |
+
return SimpleNamespace(
|
| 207 |
+
is_active=True,
|
| 208 |
+
is_processing=False,
|
| 209 |
+
session=SimpleNamespace(pending_approval=None),
|
| 210 |
+
hf_username="alice",
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
monkeypatch.setattr(agent, "_check_session_access", fake_check_session_access)
|
| 214 |
+
|
| 215 |
+
try:
|
| 216 |
+
with pytest.raises(HTTPException) as exc_info:
|
| 217 |
+
await agent.upload_session_dataset(
|
| 218 |
+
"s1",
|
| 219 |
+
request,
|
| 220 |
+
{"user_id": "u1", "username": "alice"},
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
assert exc_info.value.status_code == 401
|
| 224 |
+
assert request_state["form_called"] is False
|
| 225 |
+
assert close_state["closed"] is False
|
| 226 |
+
finally:
|
| 227 |
+
await upload.close()
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
@pytest.mark.asyncio
|
| 231 |
+
async def test_upload_route_rejects_content_length_before_parsing(monkeypatch):
|
| 232 |
+
upload = _upload("rows.csv")
|
| 233 |
+
close_state = _track_close(upload)
|
| 234 |
+
request, request_state = _request(
|
| 235 |
+
upload,
|
| 236 |
+
headers={
|
| 237 |
+
"content-length": str(
|
| 238 |
+
dataset_uploads.MAX_DATASET_UPLOAD_BYTES
|
| 239 |
+
+ agent.DATASET_UPLOAD_MULTIPART_SLACK_BYTES
|
| 240 |
+
+ 1
|
| 241 |
+
)
|
| 242 |
+
},
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
async def fake_check_session_access(*_args, **_kwargs):
|
| 246 |
+
raise AssertionError("session access should not run for oversized uploads")
|
| 247 |
+
|
| 248 |
+
monkeypatch.setattr(agent, "_check_session_access", fake_check_session_access)
|
| 249 |
+
|
| 250 |
+
try:
|
| 251 |
+
with pytest.raises(HTTPException) as exc_info:
|
| 252 |
+
await agent.upload_session_dataset(
|
| 253 |
+
"s1",
|
| 254 |
+
request,
|
| 255 |
+
{
|
| 256 |
+
"user_id": "u1",
|
| 257 |
+
"username": "alice",
|
| 258 |
+
agent.INTERNAL_HF_TOKEN_KEY: "hf-token",
|
| 259 |
+
},
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
assert exc_info.value.status_code == 413
|
| 263 |
+
assert request_state["form_called"] is False
|
| 264 |
+
assert close_state["closed"] is False
|
| 265 |
+
finally:
|
| 266 |
+
await upload.close()
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
@pytest.mark.asyncio
|
| 270 |
+
async def test_upload_route_rejects_busy_session_without_parsing_upload(monkeypatch):
|
| 271 |
+
upload = _upload("rows.csv")
|
| 272 |
+
close_state = _track_close(upload)
|
| 273 |
+
request, request_state = _request(upload)
|
| 274 |
+
|
| 275 |
+
async def fake_check_session_access(*_args, **_kwargs):
|
| 276 |
+
return SimpleNamespace(
|
| 277 |
+
is_active=True,
|
| 278 |
+
is_processing=True,
|
| 279 |
+
session=SimpleNamespace(pending_approval=None),
|
| 280 |
+
hf_username="alice",
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
monkeypatch.setattr(agent, "_check_session_access", fake_check_session_access)
|
| 284 |
+
|
| 285 |
+
with pytest.raises(HTTPException) as exc_info:
|
| 286 |
+
await agent.upload_session_dataset(
|
| 287 |
+
"s1",
|
| 288 |
+
request,
|
| 289 |
+
{
|
| 290 |
+
"user_id": "u1",
|
| 291 |
+
"username": "alice",
|
| 292 |
+
agent.INTERNAL_HF_TOKEN_KEY: "hf-token",
|
| 293 |
+
},
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
assert exc_info.value.status_code == 409
|
| 297 |
+
assert request_state["form_called"] is False
|
| 298 |
+
assert close_state["closed"] is False
|
| 299 |
+
await upload.close()
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
@pytest.mark.asyncio
|
| 303 |
+
async def test_upload_route_appends_context_note_and_persists(monkeypatch):
|
| 304 |
+
upload = _upload("rows.jsonl", b'{"text":"hi"}\n')
|
| 305 |
+
close_state = _track_close(upload)
|
| 306 |
+
request, request_state = _request(upload)
|
| 307 |
+
messages = []
|
| 308 |
+
persisted = []
|
| 309 |
+
agent_session = SimpleNamespace(
|
| 310 |
+
is_active=True,
|
| 311 |
+
is_processing=False,
|
| 312 |
+
session=SimpleNamespace(
|
| 313 |
+
pending_approval=None,
|
| 314 |
+
context_manager=SimpleNamespace(add_message=messages.append),
|
| 315 |
+
),
|
| 316 |
+
hf_username="alice",
|
| 317 |
+
)
|
| 318 |
+
uploaded = dataset_uploads.DatasetUpload(
|
| 319 |
+
session_id="s1",
|
| 320 |
+
repo_id="alice/ml-intern-s1-datasets",
|
| 321 |
+
repo_type="dataset",
|
| 322 |
+
private=True,
|
| 323 |
+
upload_id="abc123",
|
| 324 |
+
config_name="upload_abc123",
|
| 325 |
+
filename="rows.jsonl",
|
| 326 |
+
original_filename="rows.jsonl",
|
| 327 |
+
path_in_repo="uploads/abc123/rows.jsonl",
|
| 328 |
+
size_bytes=14,
|
| 329 |
+
format="jsonl",
|
| 330 |
+
hub_url="https://huggingface.co/datasets/alice/ml-intern-s1-datasets/blob/main/uploads/abc123/rows.jsonl",
|
| 331 |
+
load_dataset_snippet='dataset = load_dataset("json")',
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
async def fake_check_session_access(*_args, **_kwargs):
|
| 335 |
+
return agent_session
|
| 336 |
+
|
| 337 |
+
async def fake_push_dataset_upload_to_hub(**kwargs):
|
| 338 |
+
assert kwargs["upload"] is upload
|
| 339 |
+
assert kwargs["hf_token"] == "hf-token"
|
| 340 |
+
return uploaded
|
| 341 |
+
|
| 342 |
+
async def fake_persist_session_snapshot(value):
|
| 343 |
+
persisted.append(value)
|
| 344 |
+
|
| 345 |
+
monkeypatch.setattr(agent, "_check_session_access", fake_check_session_access)
|
| 346 |
+
monkeypatch.setattr(
|
| 347 |
+
agent, "push_dataset_upload_to_hub", fake_push_dataset_upload_to_hub
|
| 348 |
+
)
|
| 349 |
+
monkeypatch.setattr(
|
| 350 |
+
agent.session_manager,
|
| 351 |
+
"persist_session_snapshot",
|
| 352 |
+
fake_persist_session_snapshot,
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
response = await agent.upload_session_dataset(
|
| 356 |
+
"s1",
|
| 357 |
+
request,
|
| 358 |
+
{
|
| 359 |
+
"user_id": "u1",
|
| 360 |
+
"username": "alice",
|
| 361 |
+
agent.INTERNAL_HF_TOKEN_KEY: "hf-token",
|
| 362 |
+
},
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
assert response.repo_id == uploaded.repo_id
|
| 366 |
+
assert response.config_name == uploaded.config_name
|
| 367 |
+
assert response.path_in_repo == uploaded.path_in_repo
|
| 368 |
+
assert len(messages) == 1
|
| 369 |
+
assert messages[0].role == "user"
|
| 370 |
+
assert messages[0].content.startswith("[SYSTEM:")
|
| 371 |
+
assert uploaded.config_name in messages[0].content
|
| 372 |
+
assert uploaded.path_in_repo in messages[0].content
|
| 373 |
+
assert persisted == [agent_session]
|
| 374 |
+
assert request_state["form_called"] is True
|
| 375 |
+
assert close_state["closed"] is True
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
@pytest.mark.asyncio
|
| 379 |
+
async def test_upload_route_closes_upload_when_hub_upload_fails(monkeypatch):
|
| 380 |
+
upload = _upload("rows.csv")
|
| 381 |
+
close_state = _track_close(upload)
|
| 382 |
+
request, request_state = _request(upload)
|
| 383 |
+
|
| 384 |
+
async def fake_check_session_access(*_args, **_kwargs):
|
| 385 |
+
return SimpleNamespace(
|
| 386 |
+
is_active=True,
|
| 387 |
+
is_processing=False,
|
| 388 |
+
session=SimpleNamespace(pending_approval=None),
|
| 389 |
+
hf_username="alice",
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
async def fake_push_dataset_upload_to_hub(**_kwargs):
|
| 393 |
+
raise RuntimeError("hub unavailable")
|
| 394 |
+
|
| 395 |
+
monkeypatch.setattr(agent, "_check_session_access", fake_check_session_access)
|
| 396 |
+
monkeypatch.setattr(
|
| 397 |
+
agent, "push_dataset_upload_to_hub", fake_push_dataset_upload_to_hub
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
with pytest.raises(HTTPException) as exc_info:
|
| 401 |
+
await agent.upload_session_dataset(
|
| 402 |
+
"s1",
|
| 403 |
+
request,
|
| 404 |
+
{
|
| 405 |
+
"user_id": "u1",
|
| 406 |
+
"username": "alice",
|
| 407 |
+
agent.INTERNAL_HF_TOKEN_KEY: "hf-token",
|
| 408 |
+
},
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
assert exc_info.value.status_code == 502
|
| 412 |
+
assert exc_info.value.detail == "Dataset upload failed. Please try again."
|
| 413 |
+
assert request_state["form_called"] is True
|
| 414 |
+
assert close_state["closed"] is True
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
@pytest.mark.asyncio
|
| 418 |
+
async def test_upload_route_maps_hub_permission_error_safely(monkeypatch):
|
| 419 |
+
upload = _upload("rows.csv")
|
| 420 |
+
close_state = _track_close(upload)
|
| 421 |
+
request, request_state = _request(upload)
|
| 422 |
+
|
| 423 |
+
async def fake_check_session_access(*_args, **_kwargs):
|
| 424 |
+
return SimpleNamespace(
|
| 425 |
+
is_active=True,
|
| 426 |
+
is_processing=False,
|
| 427 |
+
session=SimpleNamespace(pending_approval=None),
|
| 428 |
+
hf_username="alice",
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
async def fake_push_dataset_upload_to_hub(**_kwargs):
|
| 432 |
+
response = httpx.Response(
|
| 433 |
+
403,
|
| 434 |
+
request=httpx.Request("POST", "https://huggingface.co/api/datasets"),
|
| 435 |
+
headers={"x-request-id": "req-123"},
|
| 436 |
+
)
|
| 437 |
+
raise HfHubHTTPError(
|
| 438 |
+
"403 Forbidden: token hf_secret cannot write",
|
| 439 |
+
response=response,
|
| 440 |
+
server_message="token hf_secret cannot write",
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
monkeypatch.setattr(agent, "_check_session_access", fake_check_session_access)
|
| 444 |
+
monkeypatch.setattr(
|
| 445 |
+
agent, "push_dataset_upload_to_hub", fake_push_dataset_upload_to_hub
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
with pytest.raises(HTTPException) as exc_info:
|
| 449 |
+
await agent.upload_session_dataset(
|
| 450 |
+
"s1",
|
| 451 |
+
request,
|
| 452 |
+
{
|
| 453 |
+
"user_id": "u1",
|
| 454 |
+
"username": "alice",
|
| 455 |
+
agent.INTERNAL_HF_TOKEN_KEY: "hf-token",
|
| 456 |
+
},
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
assert exc_info.value.status_code == 403
|
| 460 |
+
assert exc_info.value.detail == (
|
| 461 |
+
"Hugging Face denied permission to create or write to the dataset repo."
|
| 462 |
+
)
|
| 463 |
+
assert "hf_secret" not in exc_info.value.detail
|
| 464 |
+
assert request_state["form_called"] is True
|
| 465 |
+
assert close_state["closed"] is True
|
|
@@ -1788,6 +1788,7 @@ dependencies = [
|
|
| 1788 |
{ name = "pydantic" },
|
| 1789 |
{ name = "pymongo" },
|
| 1790 |
{ name = "python-dotenv" },
|
|
|
|
| 1791 |
{ name = "requests" },
|
| 1792 |
{ name = "rich" },
|
| 1793 |
{ name = "thefuzz" },
|
|
@@ -1840,6 +1841,7 @@ requires-dist = [
|
|
| 1840 |
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=9.0.2" },
|
| 1841 |
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=1.2.0" },
|
| 1842 |
{ name = "python-dotenv", specifier = ">=1.2.1" },
|
|
|
|
| 1843 |
{ name = "requests", specifier = ">=2.33.0" },
|
| 1844 |
{ name = "rich", specifier = ">=13.0.0" },
|
| 1845 |
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.15.12" },
|
|
|
|
| 1788 |
{ name = "pydantic" },
|
| 1789 |
{ name = "pymongo" },
|
| 1790 |
{ name = "python-dotenv" },
|
| 1791 |
+
{ name = "python-multipart" },
|
| 1792 |
{ name = "requests" },
|
| 1793 |
{ name = "rich" },
|
| 1794 |
{ name = "thefuzz" },
|
|
|
|
| 1841 |
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=9.0.2" },
|
| 1842 |
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=1.2.0" },
|
| 1843 |
{ name = "python-dotenv", specifier = ">=1.2.1" },
|
| 1844 |
+
{ name = "python-multipart", specifier = ">=0.0.20" },
|
| 1845 |
{ name = "requests", specifier = ">=2.33.0" },
|
| 1846 |
{ name = "rich", specifier = ">=13.0.0" },
|
| 1847 |
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.15.12" },
|