Spaces:
Configuration error
Configuration error
freeze demo using model revisions
#4
by Xenova HF Staff - opened
- src/hooks/useLLM.ts +45 -8
src/hooks/useLLM.ts
CHANGED
|
@@ -5,6 +5,21 @@ import {
|
|
| 5 |
TextStreamer,
|
| 6 |
} from "@huggingface/transformers";
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
interface LLMState {
|
| 9 |
isLoading: boolean;
|
| 10 |
isReady: boolean;
|
|
@@ -24,7 +39,7 @@ let moduleCache: {
|
|
| 24 |
};
|
| 25 |
} = {};
|
| 26 |
|
| 27 |
-
export const useLLM = (modelId?: string) => {
|
| 28 |
const [state, setState] = useState<LLMState>({
|
| 29 |
isLoading: false,
|
| 30 |
isReady: false,
|
|
@@ -105,15 +120,37 @@ export const useLLM = (modelId?: string) => {
|
|
| 105 |
}
|
| 106 |
};
|
| 107 |
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
});
|
| 111 |
-
|
| 112 |
-
const model = await AutoModelForCausalLM.from_pretrained(MODEL_ID, {
|
| 113 |
dtype: "q4f16",
|
| 114 |
device: "webgpu",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
progress_callback: progressCallback,
|
| 116 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
const instance = { model, tokenizer };
|
| 119 |
instanceRef.current = instance;
|
|
@@ -231,4 +268,4 @@ export const useLLM = (modelId?: string) => {
|
|
| 231 |
clearPastKeyValues,
|
| 232 |
cleanup,
|
| 233 |
};
|
| 234 |
-
};
|
|
|
|
| 5 |
TextStreamer,
|
| 6 |
} from "@huggingface/transformers";
|
| 7 |
|
| 8 |
+
// Define the supported model IDs
|
| 9 |
+
export type SupportedModelId = "350M" | "700M" | "1.2B";
|
| 10 |
+
|
| 11 |
+
export interface ModelConfig {
|
| 12 |
+
dtype: string;
|
| 13 |
+
device: string;
|
| 14 |
+
revision?: string;
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
export const MODEL_CONFIGS: Record<SupportedModelId, ModelConfig> = {
|
| 18 |
+
"350M": { dtype: "q4f16", device: "webgpu", revision: "5bc4b3e8cfd21660c0b1b9faa447ffbd9926b829" },
|
| 19 |
+
"700M": { dtype: "q4f16", device: "webgpu", revision: "bf72eeabfe73a798674db899830a0dca99f8eabc" },
|
| 20 |
+
"1.2B": { dtype: "q4f16", device: "webgpu", revision: "7f871660813dc1f34f0d304c77506c5fbdb440a0" },
|
| 21 |
+
};
|
| 22 |
+
|
| 23 |
interface LLMState {
|
| 24 |
isLoading: boolean;
|
| 25 |
isReady: boolean;
|
|
|
|
| 39 |
};
|
| 40 |
} = {};
|
| 41 |
|
| 42 |
+
export const useLLM = (modelId?: SupportedModelId | string) => {
|
| 43 |
const [state, setState] = useState<LLMState>({
|
| 44 |
isLoading: false,
|
| 45 |
isReady: false,
|
|
|
|
| 120 |
}
|
| 121 |
};
|
| 122 |
|
| 123 |
+
// Fallback to defaults if an unknown modelId string is passed
|
| 124 |
+
const config = MODEL_CONFIGS[modelId as SupportedModelId] || {
|
|
|
|
|
|
|
|
|
|
| 125 |
dtype: "q4f16",
|
| 126 |
device: "webgpu",
|
| 127 |
+
};
|
| 128 |
+
|
| 129 |
+
const tokenizerOptions: Record<string, any> = {
|
| 130 |
+
progress_callback: progressCallback,
|
| 131 |
+
};
|
| 132 |
+
if (config.revision) {
|
| 133 |
+
tokenizerOptions.revision = config.revision;
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
const tokenizer = await AutoTokenizer.from_pretrained(
|
| 137 |
+
MODEL_ID,
|
| 138 |
+
tokenizerOptions
|
| 139 |
+
);
|
| 140 |
+
|
| 141 |
+
const modelOptions: Record<string, any> = {
|
| 142 |
+
dtype: config.dtype,
|
| 143 |
+
device: config.device,
|
| 144 |
progress_callback: progressCallback,
|
| 145 |
+
};
|
| 146 |
+
if (config.revision) {
|
| 147 |
+
modelOptions.revision = config.revision;
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
const model = await AutoModelForCausalLM.from_pretrained(
|
| 151 |
+
MODEL_ID,
|
| 152 |
+
modelOptions
|
| 153 |
+
);
|
| 154 |
|
| 155 |
const instance = { model, tokenizer };
|
| 156 |
instanceRef.current = instance;
|
|
|
|
| 268 |
clearPastKeyValues,
|
| 269 |
cleanup,
|
| 270 |
};
|
| 271 |
+
};
|