HandsomeSB commited on
Commit
9aa1873
Β·
1 Parent(s): 6a50f6f

token streaming

Browse files
Files changed (1) hide show
  1. main.py +102 -42
main.py CHANGED
@@ -107,23 +107,112 @@ def generate_visual(prompt, max_tokens=50, gamma=15, confidence_threshold=0.5):
107
 
108
  steps = []
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  while result.shape[-1] - inputs["input_ids"].shape[-1] < max_tokens:
111
- print(steps)
112
  drafted, drafted_probs, draft_kv = draft(result, gamma, confidence_threshold, eos_token, draft_kv)
113
- accepted_tokens, num_accepted, verify_kv = verify(drafted, drafted_probs, eos_token, verify_kv)
114
-
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  total_drafted += len(drafted_probs)
 
 
 
 
 
 
116
  total_accepted += num_accepted
117
 
118
- # Extract token IDs for visualization
119
- drafted_token_ids = drafted[0, -len(drafted_probs):].tolist()
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- step = {
122
- "drafted": [tokenizer.decode([t]) for t in drafted_token_ids],
 
123
  "accepted": num_accepted,
124
  "resampled": tokenizer.decode([accepted_tokens[-1]]) if num_accepted < len(accepted_tokens) else None
125
  }
126
- steps.append(step)
 
 
127
 
128
  valid_len = result.shape[-1] + num_accepted
129
  result = torch.cat([result, torch.tensor([accepted_tokens], device=device)], dim=-1)
@@ -135,40 +224,9 @@ def generate_visual(prompt, max_tokens=50, gamma=15, confidence_threshold=0.5):
135
 
136
  if eos_token in accepted_tokens or im_end_token in accepted_tokens:
137
  break
138
-
139
- # Extract final output
140
- final_output = tokenizer.decode(result[0])
141
-
142
- # Build HTML visualization
143
- html = "<div style='font-family: monospace;'>"
144
- html += f"<div style='margin-bottom: 20px; padding: 10px; background: transparent; border: 2px solid white; border-radius: 5px;'>"
145
- html += f"<b>Final Output:</b><br/>{final_output}"
146
- html += "</div>"
147
- html += f"<div style='margin-bottom: 20px; padding: 10px; background: transparent; border: 2pd solid white; border-radius: 5px;'>"
148
- html += f"<b>Acceptance Rate:</b> {total_accepted}/{total_drafted} = {total_accepted/total_drafted*100:.1f}%"
149
- html += "</div>"
150
- html += "<div style='margin-bottom: 10px;'><b>Decoding Steps:</b></div>"
151
-
152
- for i, step in enumerate(steps):
153
- html += f"<div style='margin: 10px 0; padding: 10px; border: 1px solid #ccc; border-radius: 5px;'>"
154
- html += f"<b>Step {i+1}:</b> "
155
-
156
- for j, token in enumerate(step["drafted"]):
157
- # Escape HTML special characters
158
- token_display = token.replace("<", "&lt;").replace(">", "&gt;")
159
- if j < step["accepted"]:
160
- html += f"<span style='background: #66CC66; padding: 2px 4px; margin: 2px; border-radius: 3px;'>{token_display}</span>"
161
- else:
162
- html += f"<span style='background: #FF8B9A; padding: 2px 4px; margin: 2px; text-decoration: line-through; border-radius: 3px;'>{token_display}</span>"
163
 
164
- if step["resampled"]:
165
- resampled_display = step["resampled"].replace("<", "&lt;").replace(">", "&gt;")
166
- html += f" β†’ <span style='background: #5AADCC; padding: 2px 4px; border-radius: 3px;'>{resampled_display}</span>"
167
-
168
- html += "</div>"
169
- html += "</div>"
170
-
171
- return html
172
 
173
  demo = gr.Interface(
174
  fn=generate_visual,
@@ -190,6 +248,8 @@ demo = gr.Interface(
190
  - 🟒 Green = Accepted tokens from draft model
191
  - πŸ”΄ Red = Rejected tokens (with strikethrough)
192
  - πŸ”΅ Blue = Resampled tokens from verify model
 
 
193
  """,
194
  examples=[
195
  ["What is a deal flow in a VC fund?", 80, 15, 0.5],
@@ -199,4 +259,4 @@ demo = gr.Interface(
199
  )
200
 
201
  if __name__ == "__main__":
202
- demo.launch()
 
107
 
108
  steps = []
109
 
110
+ # Track the actual output tokens (for streaming display)
111
+ output_tokens = []
112
+ # Track metadata for each token: 'accepted', 'rejected', or 'resampled'
113
+ token_metadata = []
114
+
115
+ def build_html():
116
+ html = "<div style='font-family: monospace;'>"
117
+
118
+ # Final output box - shows the streaming tokens with color coding
119
+ html += f"<div style='margin-bottom: 20px; padding: 10px; background: transparent; border: 2px solid white; border-radius: 5px;'>"
120
+ html += f"<b>Final Output:</b><br/>"
121
+ if output_tokens:
122
+ for i, token_id in enumerate(output_tokens):
123
+ token_text = tokenizer.decode([token_id])
124
+ token_display = token_text.replace("<", "&lt;").replace(">", "&gt;")
125
+
126
+ # Apply color based on metadata
127
+ if i < len(token_metadata):
128
+ if token_metadata[i] == 'accepted':
129
+ html += f"<span style='background: #66CC66; padding: 2px 4px; margin: 1px; border-radius: 3px;'>{token_display}</span>"
130
+ elif token_metadata[i] == 'resampled':
131
+ html += f"<span style='background: #5AADCC; padding: 2px 4px; margin: 1px; border-radius: 3px;'>{token_display}</span>"
132
+ elif token_metadata[i] == 'rejected':
133
+ html += f"<span style='background: #FF8B9A; padding: 2px 4px; margin: 1px; text-decoration: line-through; border-radius: 3px;'>{token_display}</span>"
134
+ else:
135
+ html += token_display
136
+ html += "</div>"
137
+
138
+ # Acceptance rate
139
+ if total_drafted > 0:
140
+ html += f"<div style='margin-bottom: 20px; padding: 10px; background: #e0e0e0; border-radius: 5px;'>"
141
+ html += f"<b>Acceptance Rate:</b> {total_accepted}/{total_drafted} = {total_accepted/total_drafted*100:.1f}%"
142
+ html += "</div>"
143
+
144
+ # Decoding steps
145
+ html += "<div style='margin-bottom: 10px;'><b>Decoding Steps:</b></div>"
146
+ for i, step in enumerate(steps):
147
+ html += f"<div style='margin: 10px 0; padding: 10px; border: 1px solid #ccc; border-radius: 5px;'>"
148
+ html += f"<b>Step {i+1}:</b> "
149
+
150
+ for j, token in enumerate(step["drafted"]):
151
+ token_display = token.replace("<", "&lt;").replace(">", "&gt;")
152
+ if j < step["accepted"]:
153
+ html += f"<span style='background: #66CC66; padding: 2px 4px; margin: 2px; border-radius: 3px;'>{token_display}</span>"
154
+ else:
155
+ html += f"<span style='background: #FF8B9A; padding: 2px 4px; margin: 2px; text-decoration: line-through; border-radius: 3px;'>{token_display}</span>"
156
+
157
+ if step["resampled"]:
158
+ resampled_display = step["resampled"].replace("<", "&lt;").replace(">", "&gt;")
159
+ html += f" β†’ <span style='background: #5AADCC; padding: 2px 4px; border-radius: 3px;'>{resampled_display}</span>"
160
+
161
+ html += "</div>"
162
+ html += "</div>"
163
+ return html
164
+
165
  while result.shape[-1] - inputs["input_ids"].shape[-1] < max_tokens:
166
+ # Draft phase
167
  drafted, drafted_probs, draft_kv = draft(result, gamma, confidence_threshold, eos_token, draft_kv)
168
+ drafted_token_ids = drafted[0, -len(drafted_probs):].tolist()
169
+ drafted_tokens = [tokenizer.decode([t]) for t in drafted_token_ids]
170
+
171
+ # Immediately show drafted tokens in output (optimistically)
172
+ output_tokens.extend(drafted_token_ids)
173
+ # Mark all as accepted initially (will be corrected after verification)
174
+ token_metadata.extend(['accepted'] * len(drafted_token_ids))
175
+
176
+ # Create a temporary step showing all drafted tokens as accepted
177
+ temp_step = {
178
+ "drafted": drafted_tokens,
179
+ "accepted": len(drafted_tokens),
180
+ "resampled": None
181
+ }
182
+ steps.append(temp_step)
183
  total_drafted += len(drafted_probs)
184
+
185
+ # Yield the state with drafted tokens showing
186
+ yield build_html()
187
+
188
+ # Verify phase
189
+ accepted_tokens, num_accepted, verify_kv = verify(drafted, drafted_probs, eos_token, verify_kv)
190
  total_accepted += num_accepted
191
 
192
+ # Now update the step with actual acceptance information
193
+ # Remove the optimistically added tokens and metadata
194
+ output_tokens = output_tokens[:-len(drafted_token_ids)]
195
+ token_metadata = token_metadata[:-len(drafted_token_ids)]
196
+
197
+ # Add back the actually accepted tokens with correct metadata
198
+ for i, token_id in enumerate(accepted_tokens):
199
+ output_tokens.append(token_id)
200
+ if i < num_accepted:
201
+ # This token was accepted from the draft
202
+ token_metadata.append('accepted')
203
+ else:
204
+ # This is the resampled token
205
+ token_metadata.append('resampled')
206
 
207
+ # Update the step with real acceptance info
208
+ steps[-1] = {
209
+ "drafted": drafted_tokens,
210
  "accepted": num_accepted,
211
  "resampled": tokenizer.decode([accepted_tokens[-1]]) if num_accepted < len(accepted_tokens) else None
212
  }
213
+
214
+ # Yield the corrected state
215
+ yield build_html()
216
 
217
  valid_len = result.shape[-1] + num_accepted
218
  result = torch.cat([result, torch.tensor([accepted_tokens], device=device)], dim=-1)
 
224
 
225
  if eos_token in accepted_tokens or im_end_token in accepted_tokens:
226
  break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
+ # Final yield with complete output
229
+ yield build_html()
 
 
 
 
 
 
230
 
231
  demo = gr.Interface(
232
  fn=generate_visual,
 
248
  - 🟒 Green = Accepted tokens from draft model
249
  - πŸ”΄ Red = Rejected tokens (with strikethrough)
250
  - πŸ”΅ Blue = Resampled tokens from verify model
251
+
252
+ **Watch the tokens stream in real-time!** Draft tokens appear immediately, then get accepted or rejected by the verify model.
253
  """,
254
  examples=[
255
  ["What is a deal flow in a VC fund?", 80, 15, 0.5],
 
259
  )
260
 
261
  if __name__ == "__main__":
262
+ demo.launch()