import networkx as nx import matplotlib.pyplot as plt from typing import List, Dict import math from openai import OpenAI from dotenv import load_dotenv load_dotenv(override=True) class TokenPredictor: def __init__(self, model_name: str): self.client = OpenAI() self.messages = [] self.predictions = [] self.model_name = model_name def predict_tokens(self, prompt: str, max_tokens: int = 100) -> List[Dict]: """ Generate text token by token and track prediction probabilities. Returns list of predictions with top token and alternatives. """ response = self.client.chat.completions.create( model=self.model_name, messages=[{"role": "user", "content": prompt}], max_tokens=max_tokens, temperature=0, # Use temperature 0 for deterministic output logprobs=True, seed=42, top_logprobs=3, # Get top 3 token predictions stream=True, # Stream the response ) predictions = [] for chunk in response: if chunk.choices[0].delta.content: token = chunk.choices[0].delta.content logprobs = chunk.choices[0].logprobs.content[0].top_logprobs logprob_dict = {item.token: item.logprob for item in logprobs} # Get top predicted token and probability top_token = token top_prob = logprob_dict[token] # Get alternative predictions alternatives = [] for alt_token, alt_prob in logprob_dict.items(): if alt_token != token: alternatives.append((alt_token, math.exp(alt_prob))) alternatives.sort(key=lambda x: x[1], reverse=True) prediction = { "token": top_token, "probability": math.exp(top_prob), "alternatives": alternatives[:2], # Keep top 2 alternatives } predictions.append(prediction) return predictions def create_token_graph(model_name: str, predictions: List[Dict]) -> nx.DiGraph: """ Create a directed graph showing token predictions and alternatives. """ G = nx.DiGraph() G.add_node("START", token=model_name, prob="START", color="lightgreen", size=4000) # First, create all main token nodes in sequence for i, pred in enumerate(predictions): token_id = f"t{i}" G.add_node( token_id, token=pred["token"], prob=f"{pred['probability'] * 100:.1f}%", color="lightblue", size=6000, ) if i == 0: G.add_edge("START", token_id) else: G.add_edge(f"t{i - 1}", token_id) # Then add alternative nodes with a different y-position last_id = None for i, pred in enumerate(predictions): parent_token = "START" if i == 0 else f"t{i - 1}" # Add alternative token nodes slightly below main sequence for j, (alt_token, alt_prob) in enumerate(pred["alternatives"]): alt_id = f"t{i}_alt{j}" G.add_node( alt_id, token=alt_token, prob=f"{alt_prob * 100:.1f}%", color="lightgray", size=6000 ) # Add edge from main token to its alternatives only G.add_edge(parent_token, alt_id) last_id = parent_token G.add_node("END", token="END", prob="100%", color="red", size=6000) G.add_edge(last_id, "END") return G def visualize_predictions(G: nx.DiGraph, figsize=(14, 80)): """ Visualize the token prediction graph with vertical layout and alternating alternatives. """ plt.figure(figsize=figsize) # Create custom positioning for nodes pos = {} spacing_y = 5 # Vertical spacing between main tokens spacing_x = 5 # Horizontal spacing for alternatives # Position main token nodes in a vertical line main_nodes = [n for n in G.nodes() if "_alt" not in n] for i, node in enumerate(main_nodes): pos[node] = (0, -i * spacing_y) # Center main tokens vertically # Position alternative nodes to left and right of main tokens for node in G.nodes(): if "_alt" in node: main_token = node.split("_")[0] alt_num = int(node.split("_alt")[1]) if main_token in pos: # Place first alternative to left, second to right x_offset = -spacing_x if alt_num == 0 else spacing_x pos[node] = (x_offset, pos[main_token][1] + 0.05) # Draw nodes node_colors = [G.nodes[node]["color"] for node in G.nodes()] node_sizes = [G.nodes[node]["size"] for node in G.nodes()] nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=node_sizes) # Draw all edges as straight lines nx.draw_networkx_edges(G, pos, edge_color="gray", arrows=True, arrowsize=20, alpha=0.7) # Add labels with token and probability labels = {node: f"{G.nodes[node]['token']}\n{G.nodes[node]['prob']}" for node in G.nodes()} nx.draw_networkx_labels(G, pos, labels, font_size=14) plt.title("Token prediction.") plt.axis("off") # Adjust plot limits to ensure all nodes are visible margin = 8 x_values = [x for x, y in pos.values()] y_values = [y for x, y in pos.values()] plt.xlim(min(x_values) - margin, max(x_values) + margin) plt.ylim(min(y_values) - margin, max(y_values) + margin) # plt.tight_layout() return plt