freeze demo using model revisions

#4
by Xenova HF Staff - opened
Files changed (1) hide show
  1. src/hooks/useLLM.ts +45 -8
src/hooks/useLLM.ts CHANGED
@@ -5,6 +5,21 @@ import {
5
  TextStreamer,
6
  } from "@huggingface/transformers";
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  interface LLMState {
9
  isLoading: boolean;
10
  isReady: boolean;
@@ -24,7 +39,7 @@ let moduleCache: {
24
  };
25
  } = {};
26
 
27
- export const useLLM = (modelId?: string) => {
28
  const [state, setState] = useState<LLMState>({
29
  isLoading: false,
30
  isReady: false,
@@ -105,15 +120,37 @@ export const useLLM = (modelId?: string) => {
105
  }
106
  };
107
 
108
- const tokenizer = await AutoTokenizer.from_pretrained(MODEL_ID, {
109
- progress_callback: progressCallback,
110
- });
111
-
112
- const model = await AutoModelForCausalLM.from_pretrained(MODEL_ID, {
113
  dtype: "q4f16",
114
  device: "webgpu",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  progress_callback: progressCallback,
116
- });
 
 
 
 
 
 
 
 
117
 
118
  const instance = { model, tokenizer };
119
  instanceRef.current = instance;
@@ -231,4 +268,4 @@ export const useLLM = (modelId?: string) => {
231
  clearPastKeyValues,
232
  cleanup,
233
  };
234
- };
 
5
  TextStreamer,
6
  } from "@huggingface/transformers";
7
 
8
+ // Define the supported model IDs
9
+ export type SupportedModelId = "350M" | "700M" | "1.2B";
10
+
11
+ export interface ModelConfig {
12
+ dtype: string;
13
+ device: string;
14
+ revision?: string;
15
+ }
16
+
17
+ export const MODEL_CONFIGS: Record<SupportedModelId, ModelConfig> = {
18
+ "350M": { dtype: "q4f16", device: "webgpu", revision: "5bc4b3e8cfd21660c0b1b9faa447ffbd9926b829" },
19
+ "700M": { dtype: "q4f16", device: "webgpu", revision: "bf72eeabfe73a798674db899830a0dca99f8eabc" },
20
+ "1.2B": { dtype: "q4f16", device: "webgpu", revision: "7f871660813dc1f34f0d304c77506c5fbdb440a0" },
21
+ };
22
+
23
  interface LLMState {
24
  isLoading: boolean;
25
  isReady: boolean;
 
39
  };
40
  } = {};
41
 
42
+ export const useLLM = (modelId?: SupportedModelId | string) => {
43
  const [state, setState] = useState<LLMState>({
44
  isLoading: false,
45
  isReady: false,
 
120
  }
121
  };
122
 
123
+ // Fallback to defaults if an unknown modelId string is passed
124
+ const config = MODEL_CONFIGS[modelId as SupportedModelId] || {
 
 
 
125
  dtype: "q4f16",
126
  device: "webgpu",
127
+ };
128
+
129
+ const tokenizerOptions: Record<string, any> = {
130
+ progress_callback: progressCallback,
131
+ };
132
+ if (config.revision) {
133
+ tokenizerOptions.revision = config.revision;
134
+ }
135
+
136
+ const tokenizer = await AutoTokenizer.from_pretrained(
137
+ MODEL_ID,
138
+ tokenizerOptions
139
+ );
140
+
141
+ const modelOptions: Record<string, any> = {
142
+ dtype: config.dtype,
143
+ device: config.device,
144
  progress_callback: progressCallback,
145
+ };
146
+ if (config.revision) {
147
+ modelOptions.revision = config.revision;
148
+ }
149
+
150
+ const model = await AutoModelForCausalLM.from_pretrained(
151
+ MODEL_ID,
152
+ modelOptions
153
+ );
154
 
155
  const instance = { model, tokenizer };
156
  instanceRef.current = instance;
 
268
  clearPastKeyValues,
269
  cleanup,
270
  };
271
+ };