Files
LLM_Engineering_OLD/week3/visualizer.py

158 lines
5.5 KiB
Python

import networkx as nx
import matplotlib.pyplot as plt
from typing import List, Dict
import math
from openai import OpenAI
from dotenv import load_dotenv
load_dotenv(override=True)
class TokenPredictor:
def __init__(self, model_name: str):
self.client = OpenAI()
self.messages = []
self.predictions = []
self.model_name = model_name
def predict_tokens(self, prompt: str, max_tokens: int = 100) -> List[Dict]:
"""
Generate text token by token and track prediction probabilities.
Returns list of predictions with top token and alternatives.
"""
response = self.client.chat.completions.create(
model=self.model_name,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens,
temperature=0, # Use temperature 0 for deterministic output
logprobs=True,
seed=42,
top_logprobs=3, # Get top 3 token predictions
stream=True, # Stream the response
)
predictions = []
for chunk in response:
if chunk.choices[0].delta.content:
token = chunk.choices[0].delta.content
logprobs = chunk.choices[0].logprobs.content[0].top_logprobs
logprob_dict = {item.token: item.logprob for item in logprobs}
# Get top predicted token and probability
top_token = token
top_prob = logprob_dict[token]
# Get alternative predictions
alternatives = []
for alt_token, alt_prob in logprob_dict.items():
if alt_token != token:
alternatives.append((alt_token, math.exp(alt_prob)))
alternatives.sort(key=lambda x: x[1], reverse=True)
prediction = {
"token": top_token,
"probability": math.exp(top_prob),
"alternatives": alternatives[:2], # Keep top 2 alternatives
}
predictions.append(prediction)
return predictions
def create_token_graph(model_name: str, predictions: List[Dict]) -> nx.DiGraph:
"""
Create a directed graph showing token predictions and alternatives.
"""
G = nx.DiGraph()
G.add_node("START", token=model_name, prob="START", color="lightgreen", size=4000)
# First, create all main token nodes in sequence
for i, pred in enumerate(predictions):
token_id = f"t{i}"
G.add_node(
token_id,
token=pred["token"],
prob=f"{pred['probability'] * 100:.1f}%",
color="lightblue",
size=6000,
)
if i == 0:
G.add_edge("START", token_id)
else:
G.add_edge(f"t{i - 1}", token_id)
# Then add alternative nodes with a different y-position
last_id = None
for i, pred in enumerate(predictions):
parent_token = "START" if i == 0 else f"t{i - 1}"
# Add alternative token nodes slightly below main sequence
for j, (alt_token, alt_prob) in enumerate(pred["alternatives"]):
alt_id = f"t{i}_alt{j}"
G.add_node(
alt_id, token=alt_token, prob=f"{alt_prob * 100:.1f}%", color="lightgray", size=6000
)
# Add edge from main token to its alternatives only
G.add_edge(parent_token, alt_id)
last_id = parent_token
G.add_node("END", token="END", prob="100%", color="red", size=6000)
G.add_edge(last_id, "END")
return G
def visualize_predictions(G: nx.DiGraph, figsize=(14, 80)):
"""
Visualize the token prediction graph with vertical layout and alternating alternatives.
"""
plt.figure(figsize=figsize)
# Create custom positioning for nodes
pos = {}
spacing_y = 5 # Vertical spacing between main tokens
spacing_x = 5 # Horizontal spacing for alternatives
# Position main token nodes in a vertical line
main_nodes = [n for n in G.nodes() if "_alt" not in n]
for i, node in enumerate(main_nodes):
pos[node] = (0, -i * spacing_y) # Center main tokens vertically
# Position alternative nodes to left and right of main tokens
for node in G.nodes():
if "_alt" in node:
main_token = node.split("_")[0]
alt_num = int(node.split("_alt")[1])
if main_token in pos:
# Place first alternative to left, second to right
x_offset = -spacing_x if alt_num == 0 else spacing_x
pos[node] = (x_offset, pos[main_token][1] + 0.05)
# Draw nodes
node_colors = [G.nodes[node]["color"] for node in G.nodes()]
node_sizes = [G.nodes[node]["size"] for node in G.nodes()]
nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=node_sizes)
# Draw all edges as straight lines
nx.draw_networkx_edges(G, pos, edge_color="gray", arrows=True, arrowsize=20, alpha=0.7)
# Add labels with token and probability
labels = {node: f"{G.nodes[node]['token']}\n{G.nodes[node]['prob']}" for node in G.nodes()}
nx.draw_networkx_labels(G, pos, labels, font_size=14)
plt.title("Token prediction.")
plt.axis("off")
# Adjust plot limits to ensure all nodes are visible
margin = 8
x_values = [x for x, y in pos.values()]
y_values = [y for x, y in pos.values()]
plt.xlim(min(x_values) - margin, max(x_values) + margin)
plt.ylim(min(y_values) - margin, max(y_values) + margin)
# plt.tight_layout()
return plt