import os, json, torch, gradio as gr from datasets import load_dataset from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from peft import LoraConfig, PeftModel, get_peft_model from trl import SFTTrainer, SFTConfig # ==== 基本配置(可改小模型/步数)==== BASE_MODEL = os.getenv("BASE_MODEL", "Qwen/Qwen2.5-0.5B-Instruct") ADAPTER_DIR = os.getenv("ADAPTER_DIR", "lora_adapter") TRAIN_PATH = os.getenv("TRAIN_PATH", "data/sft_train.jsonl") VAL_PATH = os.getenv("VAL_PATH", "data/sft_val.jsonl") # ==== 懒加载:先占位,按钮点了再真正下载 ==== _tokenizer = None _base_model = None _gen_model = None # 推理用(可能带LoRA) def load_base(load_in_4bit=None): global _tokenizer, _base_model if _tokenizer is None: _tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True) if _base_model is None: use_4bit = torch.cuda.is_available() if load_in_4bit is None else load_in_4bit if use_4bit: bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4") _base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, quantization_config=bnb, device_map="auto") else: _base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, torch_dtype=torch.float32, device_map="cpu") return _tokenizer, _base_model def train_qlora(max_steps=500, lora_r=16, lora_alpha=32, lora_dropout=0.05, per_device_bs=1, grad_accum=8): # 准备数据 if not os.path.exists(TRAIN_PATH): return f"[Error] Train file not found: {TRAIN_PATH}" if not os.path.exists(VAL_PATH): return f"[Error] Val file not found: {VAL_PATH}" train_ds = load_dataset("json", data_files=TRAIN_PATH)["train"] val_ds = load_dataset("json", data_files=VAL_PATH)["train"] tok, base = load_base(load_in_4bit=True) # LoRA 配置 peft_cfg = LoraConfig( r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, target_modules=["q_proj","k_proj","v_proj","o_proj","up_proj","down_proj"], task_type="CAUSAL_LM" ) model = get_peft_model(base, peft_cfg) # 训练配置(TRL) sft_cfg = SFTConfig( output_dir=ADAPTER_DIR, max_steps=int(max_steps), per_device_train_batch_size=per_device_bs, gradient_accumulation_steps=grad_accum, learning_rate=2e-4, bf16=torch.cuda.is_available(), logging_steps=20, save_steps=200, packing=False ) trainer = SFTTrainer( model=model, tokenizer=tok, train_dataset=train_ds, eval_dataset=val_ds, args=sft_cfg ) trainer.train() trainer.save_model(ADAPTER_DIR) return f"✅ Trained LoRA saved to: {ADAPTER_DIR}" def load_for_infer(adapter_dir=ADAPTER_DIR): global _gen_model tok, base = load_base(load_in_4bit=True) if adapter_dir and os.path.isdir(adapter_dir): _gen_model = PeftModel.from_pretrained(base, adapter_dir) else: _gen_model = base return "✅ Model ready (with LoRA)" if adapter_dir and os.path.isdir(adapter_dir) else "✅ Model ready (base only)" def generate(prompt, max_new_tokens=200, adapter_dir=ADAPTER_DIR): if _gen_model is None: load_for_infer(adapter_dir) tok, _ = load_base(load_in_4bit=True) inputs = tok(prompt, return_tensors="pt").to(_gen_model.device) with torch.no_grad(): out = _gen_model.generate(**inputs, max_new_tokens=int(max_new_tokens), do_sample=True, temperature=0.8) return tok.decode(out[0], skip_special_tokens=True) # ==== Gradio UI ==== with gr.Blocks(title="WeChat Style QLoRA (Minimal)") as demo: gr.Markdown("## WeChat Style QLoRA — Minimal Demo \n" "使用 QLoRA 在私有 JSONL 上做最小监督微调(SFT)并进行推理。 \n" "**建议流程**:先用 CPU Basic 启动验证 → 切到 ZeroGPU/T4 训练 30–60 分钟 → 保存 LoRA → 返回 CPU 测试推理。") with gr.Tab("Train (QLoRA)"): gr.Markdown("**请先把 `data/sft_train.jsonl` 和 `data/sft_val.jsonl` 上传到本 Space 的 `data/` 目录。**") ms = gr.Number(value=500, label="max_steps") r = gr.Number(value=16, label="lora_r") a = gr.Number(value=32, label="lora_alpha") d = gr.Number(value=0.05, label="lora_dropout") bsz= gr.Number(value=1, label="per_device_train_batch_size") gas= gr.Number(value=8, label="gradient_accumulation_steps") train_btn = gr.Button("Start Training (GPU/ZeroGPU)") train_log = gr.Textbox(label="Training Log", interactive=False) train_btn.click(fn=train_qlora, inputs=[ms,r,a,d,bsz,gas], outputs=train_log) with gr.Tab("Inference"): gr.Markdown("默认会尝试加载 `lora_adapter/`。若还没训练,可直接用基础模型。") adapter = gr.Textbox(value=ADAPTER_DIR, label="LoRA adapter dir") load_btn = gr.Button("Load (with/without LoRA)") load_log = gr.Textbox(label="Status", interactive=False) load_btn.click(fn=load_for_infer, inputs=[adapter], outputs=load_log) prompt = gr.Textbox(lines=6, label="Prompt") gen_tokens = gr.Slider(32, 512, value=200, step=8, label="max_new_tokens") gen_btn = gr.Button("Generate") output = gr.Textbox(lines=12, label="Output") gen_btn.click(fn=generate, inputs=[prompt, gen_tokens, adapter], outputs=output) gr.Markdown("> 提示:训练时请先在 **Settings → Hardware** 切到 **ZeroGPU/T4**;完成后切回 **CPU Basic** 并停止 Space 以节省费用。") if __name__ == "__main__": demo.queue().launch()