158 lines
5.5 KiB
Python
158 lines
5.5 KiB
Python
import networkx as nx
|
|
import matplotlib.pyplot as plt
|
|
from typing import List, Dict
|
|
import math
|
|
from openai import OpenAI
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv(override=True)
|
|
|
|
|
|
class TokenPredictor:
|
|
def __init__(self, model_name: str):
|
|
self.client = OpenAI()
|
|
self.messages = []
|
|
self.predictions = []
|
|
self.model_name = model_name
|
|
|
|
def predict_tokens(self, prompt: str, max_tokens: int = 100) -> List[Dict]:
|
|
"""
|
|
Generate text token by token and track prediction probabilities.
|
|
Returns list of predictions with top token and alternatives.
|
|
"""
|
|
response = self.client.chat.completions.create(
|
|
model=self.model_name,
|
|
messages=[{"role": "user", "content": prompt}],
|
|
max_tokens=max_tokens,
|
|
temperature=0, # Use temperature 0 for deterministic output
|
|
logprobs=True,
|
|
seed=42,
|
|
top_logprobs=3, # Get top 3 token predictions
|
|
stream=True, # Stream the response
|
|
)
|
|
|
|
predictions = []
|
|
for chunk in response:
|
|
if chunk.choices[0].delta.content:
|
|
token = chunk.choices[0].delta.content
|
|
logprobs = chunk.choices[0].logprobs.content[0].top_logprobs
|
|
logprob_dict = {item.token: item.logprob for item in logprobs}
|
|
|
|
# Get top predicted token and probability
|
|
top_token = token
|
|
top_prob = logprob_dict[token]
|
|
|
|
# Get alternative predictions
|
|
alternatives = []
|
|
for alt_token, alt_prob in logprob_dict.items():
|
|
if alt_token != token:
|
|
alternatives.append((alt_token, math.exp(alt_prob)))
|
|
alternatives.sort(key=lambda x: x[1], reverse=True)
|
|
|
|
prediction = {
|
|
"token": top_token,
|
|
"probability": math.exp(top_prob),
|
|
"alternatives": alternatives[:2], # Keep top 2 alternatives
|
|
}
|
|
predictions.append(prediction)
|
|
|
|
return predictions
|
|
|
|
|
|
def create_token_graph(model_name: str, predictions: List[Dict]) -> nx.DiGraph:
|
|
"""
|
|
Create a directed graph showing token predictions and alternatives.
|
|
"""
|
|
G = nx.DiGraph()
|
|
|
|
G.add_node("START", token=model_name, prob="START", color="lightgreen", size=4000)
|
|
|
|
# First, create all main token nodes in sequence
|
|
for i, pred in enumerate(predictions):
|
|
token_id = f"t{i}"
|
|
G.add_node(
|
|
token_id,
|
|
token=pred["token"],
|
|
prob=f"{pred['probability'] * 100:.1f}%",
|
|
color="lightblue",
|
|
size=6000,
|
|
)
|
|
|
|
if i == 0:
|
|
G.add_edge("START", token_id)
|
|
else:
|
|
G.add_edge(f"t{i - 1}", token_id)
|
|
|
|
# Then add alternative nodes with a different y-position
|
|
last_id = None
|
|
for i, pred in enumerate(predictions):
|
|
parent_token = "START" if i == 0 else f"t{i - 1}"
|
|
|
|
# Add alternative token nodes slightly below main sequence
|
|
for j, (alt_token, alt_prob) in enumerate(pred["alternatives"]):
|
|
alt_id = f"t{i}_alt{j}"
|
|
G.add_node(
|
|
alt_id, token=alt_token, prob=f"{alt_prob * 100:.1f}%", color="lightgray", size=6000
|
|
)
|
|
|
|
# Add edge from main token to its alternatives only
|
|
G.add_edge(parent_token, alt_id)
|
|
last_id = parent_token
|
|
|
|
G.add_node("END", token="END", prob="100%", color="red", size=6000)
|
|
G.add_edge(last_id, "END")
|
|
|
|
return G
|
|
|
|
|
|
def visualize_predictions(G: nx.DiGraph, figsize=(14, 80)):
|
|
"""
|
|
Visualize the token prediction graph with vertical layout and alternating alternatives.
|
|
"""
|
|
plt.figure(figsize=figsize)
|
|
|
|
# Create custom positioning for nodes
|
|
pos = {}
|
|
spacing_y = 5 # Vertical spacing between main tokens
|
|
spacing_x = 5 # Horizontal spacing for alternatives
|
|
|
|
# Position main token nodes in a vertical line
|
|
main_nodes = [n for n in G.nodes() if "_alt" not in n]
|
|
for i, node in enumerate(main_nodes):
|
|
pos[node] = (0, -i * spacing_y) # Center main tokens vertically
|
|
|
|
# Position alternative nodes to left and right of main tokens
|
|
for node in G.nodes():
|
|
if "_alt" in node:
|
|
main_token = node.split("_")[0]
|
|
alt_num = int(node.split("_alt")[1])
|
|
if main_token in pos:
|
|
# Place first alternative to left, second to right
|
|
x_offset = -spacing_x if alt_num == 0 else spacing_x
|
|
pos[node] = (x_offset, pos[main_token][1] + 0.05)
|
|
|
|
# Draw nodes
|
|
node_colors = [G.nodes[node]["color"] for node in G.nodes()]
|
|
node_sizes = [G.nodes[node]["size"] for node in G.nodes()]
|
|
nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=node_sizes)
|
|
|
|
# Draw all edges as straight lines
|
|
nx.draw_networkx_edges(G, pos, edge_color="gray", arrows=True, arrowsize=20, alpha=0.7)
|
|
|
|
# Add labels with token and probability
|
|
labels = {node: f"{G.nodes[node]['token']}\n{G.nodes[node]['prob']}" for node in G.nodes()}
|
|
nx.draw_networkx_labels(G, pos, labels, font_size=14)
|
|
|
|
plt.title("Token prediction.")
|
|
plt.axis("off")
|
|
|
|
# Adjust plot limits to ensure all nodes are visible
|
|
margin = 8
|
|
x_values = [x for x, y in pos.values()]
|
|
y_values = [y for x, y in pos.values()]
|
|
plt.xlim(min(x_values) - margin, max(x_values) + margin)
|
|
plt.ylim(min(y_values) - margin, max(y_values) + margin)
|
|
|
|
# plt.tight_layout()
|
|
return plt
|