Improved prompting to avoid space or symbols in domain names.
This commit is contained in:
@@ -32,7 +32,7 @@
|
||||
"import os\n",
|
||||
"import json\n",
|
||||
"import requests\n",
|
||||
"from typing import Dict, List, Tuple\n",
|
||||
"from typing import Dict, List, Tuple, Any, Optional\n",
|
||||
"import re\n",
|
||||
"\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
@@ -47,6 +47,73 @@
|
||||
"openai = OpenAI()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "361f7fe3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# --- robust logging that works inside VS Code notebooks + Gradio threads ---\n",
|
||||
"import sys, logging, threading\n",
|
||||
"from collections import deque\n",
|
||||
"from typing import Any\n",
|
||||
"\n",
|
||||
"DEBUG_LLM = True # toggle on/off noisy logs\n",
|
||||
"CLEAR_LOG_ON_RUN = True # clear panel before each submit\n",
|
||||
"\n",
|
||||
"_LOG_BUFFER = deque(maxlen=2000) # keep ~2000 lines in memory\n",
|
||||
"_LOG_LOCK = threading.Lock()\n",
|
||||
"\n",
|
||||
"class GradioBufferHandler(logging.Handler):\n",
|
||||
" def emit(self, record: logging.LogRecord) -> None:\n",
|
||||
" try:\n",
|
||||
" msg = self.format(record)\n",
|
||||
" except Exception:\n",
|
||||
" msg = record.getMessage()\n",
|
||||
" with _LOG_LOCK:\n",
|
||||
" for line in (msg.splitlines() or [\"\"]):\n",
|
||||
" _LOG_BUFFER.append(line)\n",
|
||||
"\n",
|
||||
"def get_log_text() -> str:\n",
|
||||
" with _LOG_LOCK:\n",
|
||||
" return \"\\n\".join(_LOG_BUFFER)\n",
|
||||
"\n",
|
||||
"def clear_log_buffer() -> None:\n",
|
||||
" with _LOG_LOCK:\n",
|
||||
" _LOG_BUFFER.clear()\n",
|
||||
"\n",
|
||||
"def _setup_logger() -> logging.Logger:\n",
|
||||
" logger = logging.getLogger(\"aidf\")\n",
|
||||
" logger.setLevel(logging.DEBUG if DEBUG_LLM else logging.INFO)\n",
|
||||
" logger.handlers.clear()\n",
|
||||
" fmt = logging.Formatter(\"%(asctime)s | %(levelname)s | %(message)s\", \"%H:%M:%S\")\n",
|
||||
"\n",
|
||||
" stream = logging.StreamHandler(stream=sys.stdout) # captured by VS Code notebook\n",
|
||||
" stream.setFormatter(fmt)\n",
|
||||
"\n",
|
||||
" buf = GradioBufferHandler() # shown inside the Gradio panel\n",
|
||||
" buf.setFormatter(fmt)\n",
|
||||
"\n",
|
||||
" logger.addHandler(stream)\n",
|
||||
" logger.addHandler(buf)\n",
|
||||
" logger.propagate = False\n",
|
||||
" return logger\n",
|
||||
"\n",
|
||||
"logger = _setup_logger()\n",
|
||||
"\n",
|
||||
"def dbg_json(obj: Any, title: str = \"\") -> None:\n",
|
||||
" \"\"\"Convenience: pretty-print JSON-ish objects to the logger.\"\"\"\n",
|
||||
" try:\n",
|
||||
" txt = json.dumps(obj, ensure_ascii=False, indent=2)\n",
|
||||
" except Exception:\n",
|
||||
" txt = str(obj)\n",
|
||||
" if title:\n",
|
||||
" logger.debug(\"%s\\n%s\", title, txt)\n",
|
||||
" else:\n",
|
||||
" logger.debug(\"%s\", txt)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -56,12 +123,35 @@
|
||||
"source": [
|
||||
"RDAP_URL = \"https://rdap.verisign.com/com/v1/domain/{}\"\n",
|
||||
"\n",
|
||||
"_ALPHA_RE = re.compile(r\"^[a-z]+$\", re.IGNORECASE)\n",
|
||||
"\n",
|
||||
"def _to_com(domain: str) -> str:\n",
|
||||
" d = domain.strip().lower()\n",
|
||||
" return d if d.endswith(\".com\") else f\"{d}.com\"\n",
|
||||
"\n",
|
||||
"def _sld_is_english_alpha(fqdn: str) -> bool:\n",
|
||||
" \"\"\"\n",
|
||||
" True only if the second-level label (just before .com) is made up\n",
|
||||
" exclusively of English letters (a-z).\n",
|
||||
" Examples:\n",
|
||||
" foo.com -> True\n",
|
||||
" foo-bar.com -> False\n",
|
||||
" foo1.com -> False\n",
|
||||
" café.com -> False\n",
|
||||
" xn--cafe.com -> False\n",
|
||||
" www.foo.com -> True (checks 'foo')\n",
|
||||
" \"\"\"\n",
|
||||
" if not fqdn.endswith(\".com\"):\n",
|
||||
" return False\n",
|
||||
" sld = fqdn[:-4].split(\".\")[-1] # take label immediately before .com\n",
|
||||
" return bool(sld) and bool(_ALPHA_RE.fullmatch(sld))\n",
|
||||
"\n",
|
||||
"def check_com_availability(domain: str) -> Dict:\n",
|
||||
" fqdn = _to_com(domain)\n",
|
||||
" # Skip API if not strictly English letters\n",
|
||||
" if not _sld_is_english_alpha(fqdn):\n",
|
||||
" return {\"domain\": fqdn, \"available\": False, \"status\": 0}\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" r = requests.get(RDAP_URL.format(fqdn), timeout=6)\n",
|
||||
" return {\"domain\": fqdn, \"available\": (r.status_code == 404), \"status\": r.status_code}\n",
|
||||
@@ -80,8 +170,15 @@
|
||||
" \"\"\"\n",
|
||||
" session = requests.Session()\n",
|
||||
" results: List[Dict] = []\n",
|
||||
"\n",
|
||||
" for d in domains:\n",
|
||||
" fqdn = _to_com(d)\n",
|
||||
"\n",
|
||||
" # Skip API if not strictly English letters\n",
|
||||
" if not _sld_is_english_alpha(fqdn):\n",
|
||||
" results.append({\"domain\": fqdn, \"available\": False, \"status\": 0})\n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" r = session.get(RDAP_URL.format(fqdn), timeout=6)\n",
|
||||
" ok = (r.status_code == 404)\n",
|
||||
@@ -133,19 +230,31 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def handle_tool_calls(message) -> List[Dict]:\n",
|
||||
" \"\"\"\n",
|
||||
" Translates model tool_calls into tool results for follow-up completion.\n",
|
||||
" \"\"\"\n",
|
||||
" results = []\n",
|
||||
" for call in (message.tool_calls or []):\n",
|
||||
" if call.function.name == \"check_com_availability\":\n",
|
||||
" args = json.loads(call.function.arguments or \"{}\")\n",
|
||||
" fn = getattr(call.function, \"name\", None)\n",
|
||||
" args_raw = getattr(call.function, \"arguments\", \"\") or \"{}\"\n",
|
||||
" try:\n",
|
||||
" args = json.loads(args_raw)\n",
|
||||
" except Exception:\n",
|
||||
" args = {}\n",
|
||||
"\n",
|
||||
" logger.debug(\"TOOL CALL -> %s | args=%s\", fn, json.dumps(args, ensure_ascii=False))\n",
|
||||
"\n",
|
||||
" if fn == \"check_com_availability_bulk\":\n",
|
||||
" payload = check_com_availability_bulk(args.get(\"domains\", []))\n",
|
||||
" elif fn == \"check_com_availability\":\n",
|
||||
" payload = check_com_availability(args.get(\"domain\", \"\"))\n",
|
||||
" results.append({\n",
|
||||
" \"role\": \"tool\",\n",
|
||||
" \"tool_call_id\": call.id,\n",
|
||||
" \"content\": json.dumps(payload)\n",
|
||||
" })\n",
|
||||
" else:\n",
|
||||
" payload = {\"error\": f\"unknown tool {fn}\"}\n",
|
||||
"\n",
|
||||
" logger.debug(\"TOOL RESULT <- %s | %s\", fn, json.dumps(payload, ensure_ascii=False))\n",
|
||||
"\n",
|
||||
" results.append({\n",
|
||||
" \"role\": \"tool\",\n",
|
||||
" \"tool_call_id\": call.id,\n",
|
||||
" \"content\": json.dumps(payload),\n",
|
||||
" })\n",
|
||||
" return results\n"
|
||||
]
|
||||
},
|
||||
@@ -159,26 +268,67 @@
|
||||
"SYSTEM_PROMPT = \"\"\"You are the Agent for project \"AI Domain Finder\".\n",
|
||||
"Goal: suggest .com domains and verify availability using the tool ONLY (no guessing).\n",
|
||||
"\n",
|
||||
"Instructions:\n",
|
||||
"- Always propose 5-12 brandable .com candidates based on:\n",
|
||||
"Do this each interaction:\n",
|
||||
"- Generate up to ~20 short, brandable .com candidates from:\n",
|
||||
" (1) Industry, (2) Target Customers, (3) Description.\n",
|
||||
"- For each candidate, CALL the tool check_com_availability.\n",
|
||||
"- Respond ONLY after checking all candidates.\n",
|
||||
"- Output Markdown with three sections and these exact headings:\n",
|
||||
" 1) Available .com domains:\n",
|
||||
" - itemized list (root + .com)\n",
|
||||
" 2) Preferred domain:\n",
|
||||
" - a single best pick\n",
|
||||
" 3) Audio explanation:\n",
|
||||
" - 1-2 concise sentences explaining the preference\n",
|
||||
"- Use the BULK tool `check_com_availability_bulk` with a list of candidates\n",
|
||||
" (roots or FQDNs). Prefer a single call or very few batched calls.\n",
|
||||
"- If >= 5 available .coms are found, STOP checking and finalize the answer.\n",
|
||||
"\n",
|
||||
"Output Markdown with EXACT section headings:\n",
|
||||
"1) Available .com domains:\n",
|
||||
" - itemized list of available .coms only (root + .com)\n",
|
||||
"2) Preferred domain:\n",
|
||||
" - a single best pick\n",
|
||||
"3) Audio explanation:\n",
|
||||
" - 1–2 concise sentences explaining the preference\n",
|
||||
"\n",
|
||||
"Constraints:\n",
|
||||
"- Use customer-familiar words where helpful.\n",
|
||||
"- Keep names short, simple, pronounceable; avoid hyphens/numbers unless meaningful.\n",
|
||||
"- Never include TLDs other than .com.\n",
|
||||
"- domain is made up of english alphabets in lower case only no symbols or spaces to use\n",
|
||||
"\"\"\"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "72e9d8c2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def _asdict_tool_call(tc: Any) -> dict:\n",
|
||||
" try:\n",
|
||||
" return {\n",
|
||||
" \"id\": getattr(tc, \"id\", None),\n",
|
||||
" \"type\": \"function\",\n",
|
||||
" \"function\": {\n",
|
||||
" \"name\": getattr(tc.function, \"name\", None),\n",
|
||||
" \"arguments\": getattr(tc.function, \"arguments\", None),\n",
|
||||
" },\n",
|
||||
" }\n",
|
||||
" except Exception:\n",
|
||||
" return {\"type\": \"function\", \"function\": {\"name\": None, \"arguments\": None}}\n",
|
||||
"\n",
|
||||
"def _asdict_message(msg: Any) -> dict:\n",
|
||||
" if isinstance(msg, dict):\n",
|
||||
" return msg\n",
|
||||
" role = getattr(msg, \"role\", None)\n",
|
||||
" content = getattr(msg, \"content\", None)\n",
|
||||
" tool_calls = getattr(msg, \"tool_calls\", None)\n",
|
||||
" out = {\"role\": role, \"content\": content}\n",
|
||||
" if tool_calls:\n",
|
||||
" out[\"tool_calls\"] = [_asdict_tool_call(tc) for tc in tool_calls]\n",
|
||||
" return out\n",
|
||||
"\n",
|
||||
"def _sanitized_messages_for_log(messages: list[dict | Any]) -> list[dict]:\n",
|
||||
" return [_asdict_message(m) for m in messages]\n",
|
||||
"\n",
|
||||
"def _limit_text(s: str, limit: int = 40000) -> str:\n",
|
||||
" return s if len(s) <= limit else (s[:limit] + \"\\n... [truncated]\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -186,22 +336,58 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def run_agent_with_tools(history: List[Dict]) -> str:\n",
|
||||
"def run_agent_with_tools(history: List[Dict]) -> Tuple[str, List[str], str]:\n",
|
||||
" \"\"\"\n",
|
||||
" history: list of {\"role\": \"...\", \"content\": \"...\"} messages\n",
|
||||
" returns assistant markdown string (includes sections required by SYSTEM_PROMPT)\n",
|
||||
" Returns:\n",
|
||||
" reply_md: final assistant markdown\n",
|
||||
" tool_available: .coms marked available by RDAP tools (order-preserving, deduped)\n",
|
||||
" dbg_text: concatenated log buffer (for the UI panel)\n",
|
||||
" \"\"\"\n",
|
||||
" messages = [{\"role\": \"system\", \"content\": SYSTEM_PROMPT}] + history\n",
|
||||
" messages: List[Dict] = [{\"role\": \"system\", \"content\": SYSTEM_PROMPT}] + history\n",
|
||||
" tool_available: List[str] = []\n",
|
||||
"\n",
|
||||
" dbg_json(_sanitized_messages_for_log(messages), \"=== LLM REQUEST (initial messages) ===\")\n",
|
||||
" resp = openai.chat.completions.create(model=OPENAI_MODEL, messages=messages, tools=TOOLS)\n",
|
||||
"\n",
|
||||
" while resp.choices[0].finish_reason == \"tool_calls\":\n",
|
||||
" tool_msg = resp.choices[0].message\n",
|
||||
" tool_results = handle_tool_calls(tool_msg)\n",
|
||||
" tool_msg_sdk = resp.choices[0].message\n",
|
||||
" tool_msg = _asdict_message(tool_msg_sdk)\n",
|
||||
" dbg_json(tool_msg, \"=== ASSISTANT (tool_calls) ===\")\n",
|
||||
"\n",
|
||||
" tool_results = handle_tool_calls(tool_msg_sdk)\n",
|
||||
"\n",
|
||||
" # Accumulate authoritative availability directly from tool outputs\n",
|
||||
" for tr in tool_results:\n",
|
||||
" try:\n",
|
||||
" data = json.loads(tr[\"content\"])\n",
|
||||
" if isinstance(data, dict) and isinstance(data.get(\"available\"), list):\n",
|
||||
" for d in data[\"available\"]:\n",
|
||||
" tool_available.append(_to_com(d))\n",
|
||||
" except Exception:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
" dbg_json([json.loads(tr[\"content\"]) for tr in tool_results], \"=== TOOL RESULTS ===\")\n",
|
||||
"\n",
|
||||
" messages.append(tool_msg)\n",
|
||||
" messages.extend(tool_results)\n",
|
||||
" dbg_json(_sanitized_messages_for_log(messages), \"=== LLM REQUEST (messages + tools) ===\")\n",
|
||||
"\n",
|
||||
" resp = openai.chat.completions.create(model=OPENAI_MODEL, messages=messages, tools=TOOLS)\n",
|
||||
"\n",
|
||||
" return resp.choices[0].message.content"
|
||||
" # Dedup preserve order\n",
|
||||
" seen, uniq = set(), []\n",
|
||||
" for d in tool_available:\n",
|
||||
" if d not in seen:\n",
|
||||
" seen.add(d)\n",
|
||||
" uniq.append(d)\n",
|
||||
"\n",
|
||||
" reply_md = resp.choices[0].message.content\n",
|
||||
" logger.debug(\"=== FINAL ASSISTANT ===\\n%s\", _limit_text(reply_md))\n",
|
||||
" dbg_json(uniq, \"=== AVAILABLE FROM TOOLS (authoritative) ===\")\n",
|
||||
"\n",
|
||||
" # Return current buffer text for the UI panel\n",
|
||||
" dbg_text = _limit_text(get_log_text(), 40000)\n",
|
||||
" return reply_md, uniq, dbg_text\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -233,33 +419,6 @@
|
||||
" return audio.content\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7bdf7c67",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def chat(message: str, history_ui: List[Dict]) -> Tuple[List[Dict], bytes]:\n",
|
||||
" \"\"\"\n",
|
||||
" Gradio ChatInterface callback.\n",
|
||||
" - message: latest user text (free-form)\n",
|
||||
" - history_ui: [{\"role\": \"user\"/\"assistant\", \"content\": \"...\"}]\n",
|
||||
" Returns: updated history, audio bytes for the 'Audio explanation'.\n",
|
||||
" \"\"\"\n",
|
||||
" # Convert Gradio UI history to OpenAI-format history\n",
|
||||
" history = [{\"role\": h[\"role\"], \"content\": h[\"content\"]} for h in history_ui]\n",
|
||||
" history.append({\"role\": \"user\", \"content\": message})\n",
|
||||
"\n",
|
||||
" reply_md = run_agent_with_tools(history)\n",
|
||||
" history.append({\"role\": \"assistant\", \"content\": reply_md})\n",
|
||||
"\n",
|
||||
" audio_text = extract_audio_text(reply_md)\n",
|
||||
" audio_bytes = synth_audio(audio_text)\n",
|
||||
"\n",
|
||||
" return history, audio_bytes\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -269,6 +428,8 @@
|
||||
"source": [
|
||||
"\n",
|
||||
"_DOMAIN_RE = re.compile(r\"\\b[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?\\.com\\b\", re.I)\n",
|
||||
"_HDR_AVAIL = re.compile(r\"^\\s*[\\d\\.\\)\\-]*\\s*available\\s+.*\\.com\\s+domains\", re.I)\n",
|
||||
"_HDR_PREF = re.compile(r\"^\\s*[\\d\\.\\)\\-]*\\s*preferred\\s+domain\", re.I)\n",
|
||||
"\n",
|
||||
"def _norm_domain(s: str) -> str:\n",
|
||||
" s = s.strip().lower()\n",
|
||||
@@ -279,16 +440,19 @@
|
||||
" out = []\n",
|
||||
" in_section = False\n",
|
||||
" for ln in lines:\n",
|
||||
" if ln.strip().lower().startswith(\"1) available .com domains\"):\n",
|
||||
" if _HDR_AVAIL.search(ln):\n",
|
||||
" in_section = True\n",
|
||||
" continue\n",
|
||||
" if in_section and ln.strip().lower().startswith(\"2) preferred\"):\n",
|
||||
" if in_section and _HDR_PREF.search(ln):\n",
|
||||
" break\n",
|
||||
" if in_section:\n",
|
||||
" if ln.strip().startswith((\"-\", \"*\")) or _DOMAIN_RE.search(ln):\n",
|
||||
" for m in _DOMAIN_RE.findall(ln):\n",
|
||||
" out.append(_norm_domain(m))\n",
|
||||
" # dedupe while preserving order\n",
|
||||
" for m in _DOMAIN_RE.findall(ln):\n",
|
||||
" out.append(_norm_domain(m))\n",
|
||||
" # Fallback: if the header wasn't found, collect all .coms then we'll still\n",
|
||||
" # rely on agent instruction to list only available, which should be safe.\n",
|
||||
" if not out:\n",
|
||||
" out = [_norm_domain(m) for m in _DOMAIN_RE.findall(md)]\n",
|
||||
" # dedupe preserve order\n",
|
||||
" seen, uniq = set(), []\n",
|
||||
" for d in out:\n",
|
||||
" if d not in seen:\n",
|
||||
@@ -297,14 +461,17 @@
|
||||
" return uniq\n",
|
||||
"\n",
|
||||
"def parse_preferred(md: str) -> str:\n",
|
||||
" # look in the preferred section; fallback to first domain anywhere\n",
|
||||
" lower = md.lower()\n",
|
||||
" idx = lower.find(\"2) preferred domain\")\n",
|
||||
" if idx != -1:\n",
|
||||
" seg = md[idx: idx + 500]\n",
|
||||
" m = _DOMAIN_RE.search(seg)\n",
|
||||
" if m:\n",
|
||||
" return _norm_domain(m.group(0))\n",
|
||||
" # search the preferred section first\n",
|
||||
" lines = md.splitlines()\n",
|
||||
" start = None\n",
|
||||
" for i, ln in enumerate(lines):\n",
|
||||
" if _HDR_PREF.search(ln):\n",
|
||||
" start = i\n",
|
||||
" break\n",
|
||||
" segment = \"\\n\".join(lines[start:start+8]) if start is not None else md[:500]\n",
|
||||
" m = _DOMAIN_RE.search(segment)\n",
|
||||
" if m:\n",
|
||||
" return _norm_domain(m.group(0))\n",
|
||||
" m = _DOMAIN_RE.search(md)\n",
|
||||
" return _norm_domain(m.group(0)) if m else \"\"\n",
|
||||
"\n",
|
||||
@@ -323,13 +490,22 @@
|
||||
" return \"### Preferred domain\\n\\n*– not chosen yet –*\"\n",
|
||||
" return f\"### Preferred domain\\n\\n`{d}`\"\n",
|
||||
"\n",
|
||||
"def build_initial_message(industry: str, customers: str, desc: str) -> str:\n",
|
||||
" return (\n",
|
||||
" \"Please propose .com domains based on:\\n\"\n",
|
||||
" f\"Industry: {industry}\\n\"\n",
|
||||
" f\"Target Customers: {customers}\\n\"\n",
|
||||
" f\"Description: {desc}\"\n",
|
||||
" )\n"
|
||||
"def build_context_msg(known_avail: Optional[List[str]], preferred_now: Optional[str]) -> str:\n",
|
||||
" \"\"\"\n",
|
||||
" Create a short 'state so far' block that we prepend to the next user turn\n",
|
||||
" so the model always sees the preferred and cumulative available list.\n",
|
||||
" \"\"\"\n",
|
||||
" lines = []\n",
|
||||
" if (preferred_now or \"\").strip():\n",
|
||||
" lines.append(f\"Preferred domain so far: {preferred_now.strip().lower()}\")\n",
|
||||
" if known_avail:\n",
|
||||
" lines.append(\"Available .com domains discovered so far:\")\n",
|
||||
" for d in known_avail:\n",
|
||||
" if d:\n",
|
||||
" lines.append(f\"- {d.strip().lower()}\")\n",
|
||||
" if not lines:\n",
|
||||
" return \"\"\n",
|
||||
" return \"STATE TO CARRY OVER FROM PREVIOUS TURNS:\\n\" + \"\\n\".join(lines)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -338,52 +514,102 @@
|
||||
"id": "07f079d6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def run_and_extract(history: List[Dict]) -> Tuple[str, List[str], str, str, str]:\n",
|
||||
" reply_md, avail_from_tools, dbg_text = run_agent_with_tools(history)\n",
|
||||
" parsed_avail = parse_available(reply_md)\n",
|
||||
" new_avail = merge_and_sort(avail_from_tools, parsed_avail)\n",
|
||||
" preferred = parse_preferred(reply_md)\n",
|
||||
" audio_text = extract_audio_text(reply_md)\n",
|
||||
" return reply_md, new_avail, preferred, audio_text, dbg_text\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4cd5d8ef",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def initial_submit(industry: str, customers: str, desc: str,\n",
|
||||
" history: list[dict], known_avail: list[str], preferred_now: str):\n",
|
||||
" msg = build_initial_message(industry, customers, desc)\n",
|
||||
" history = (history or []) + [{\"role\": \"user\", \"content\": msg}]\n",
|
||||
" history: List[Dict], known_avail: List[str], preferred_now: str):\n",
|
||||
" if CLEAR_LOG_ON_RUN:\n",
|
||||
" clear_log_buffer()\n",
|
||||
"\n",
|
||||
" reply_md, new_avail, preferred, audio_text = run_and_extract(history)\n",
|
||||
" logger.info(\"Initial submit | industry=%r | customers=%r | desc_len=%d\",\n",
|
||||
" industry, customers, len(desc or \"\"))\n",
|
||||
"\n",
|
||||
" # Build context (usually empty on the very first run, but future inits also work)\n",
|
||||
" ctx = build_context_msg(known_avail or [], preferred_now or \"\")\n",
|
||||
"\n",
|
||||
" user_msg = (\n",
|
||||
" \"Please propose .com domains based on:\\n\"\n",
|
||||
" f\"Industry: {industry}\\n\"\n",
|
||||
" f\"Target Customers: {customers}\\n\"\n",
|
||||
" f\"Description: {desc}\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Single user turn that includes state + prompt so the model always sees memory\n",
|
||||
" full_content = (ctx + \"\\n\\n\" if ctx else \"\") + user_msg\n",
|
||||
"\n",
|
||||
" history = (history or []) + [{\"role\": \"user\", \"content\": full_content}]\n",
|
||||
" reply_md, new_avail, preferred, audio_text, dbg_text = run_and_extract(history)\n",
|
||||
" history += [{\"role\": \"assistant\", \"content\": reply_md}]\n",
|
||||
"\n",
|
||||
" all_avail = merge_and_sort(known_avail or [], new_avail)\n",
|
||||
" all_avail = merge_and_sort(known_avail or [], new_avail or [])\n",
|
||||
" preferred_final = preferred or preferred_now or \"\"\n",
|
||||
" audio_bytes = synth_audio(audio_text)\n",
|
||||
"\n",
|
||||
" return (\n",
|
||||
" history, # s_history\n",
|
||||
" all_avail, # s_available\n",
|
||||
" all_avail, # s_available (cumulative)\n",
|
||||
" preferred_final, # s_preferred\n",
|
||||
" gr.update(value=fmt_preferred_md(preferred_final)), # preferred_md\n",
|
||||
" gr.update(value=fmt_available_md(all_avail)), # available_md\n",
|
||||
" gr.update(value=\"\", visible=True), # reply_in -> now visible\n",
|
||||
" gr.update(value=audio_bytes, visible=True), # audio_out\n",
|
||||
" gr.update(value=fmt_preferred_md(preferred_final)),\n",
|
||||
" gr.update(value=fmt_available_md(all_avail)),\n",
|
||||
" gr.update(value=\"\", visible=True), # reply_in: show after first run\n",
|
||||
" gr.update(value=audio_bytes, visible=True), # audio_out\n",
|
||||
" gr.update(value=dbg_text), # debug_box\n",
|
||||
" gr.update(value=\"Find Domains (done)\", interactive=False), # NEW: disable Find\n",
|
||||
" gr.update(visible=True), # NEW: show Send button\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"def refine_submit(reply: str,\n",
|
||||
" history: list[dict], known_avail: list[str], preferred_now: str):\n",
|
||||
" if not reply.strip():\n",
|
||||
" history: List[Dict], known_avail: List[str], preferred_now: str):\n",
|
||||
" # If empty, do nothing (keeps UI state untouched)\n",
|
||||
" if not (reply or \"\").strip():\n",
|
||||
" return (\"\", history, known_avail, preferred_now,\n",
|
||||
" gr.update(), gr.update(), gr.update())\n",
|
||||
" gr.update(), gr.update(), gr.update(), gr.update())\n",
|
||||
"\n",
|
||||
" history = (history or []) + [{\"role\": \"user\", \"content\": reply.strip()}]\n",
|
||||
" reply_md, new_avail, preferred, audio_text = run_and_extract(history)\n",
|
||||
" if CLEAR_LOG_ON_RUN:\n",
|
||||
" clear_log_buffer()\n",
|
||||
" logger.info(\"Refine submit | user_reply_len=%d\", len(reply))\n",
|
||||
"\n",
|
||||
" # Always prepend memory + the user's refinement so the model can iterate properly\n",
|
||||
" ctx = build_context_msg(known_avail or [], preferred_now or \"\")\n",
|
||||
" full_content = (ctx + \"\\n\\n\" if ctx else \"\") + reply.strip()\n",
|
||||
"\n",
|
||||
" history = (history or []) + [{\"role\": \"user\", \"content\": full_content}]\n",
|
||||
" reply_md, new_avail, preferred, audio_text, dbg_text = run_and_extract(history)\n",
|
||||
" history += [{\"role\": \"assistant\", \"content\": reply_md}]\n",
|
||||
"\n",
|
||||
" all_avail = merge_and_sort(known_avail or [], new_avail)\n",
|
||||
" all_avail = merge_and_sort(known_avail or [], new_avail or [])\n",
|
||||
" preferred_final = preferred or preferred_now or \"\"\n",
|
||||
" audio_bytes = synth_audio(audio_text)\n",
|
||||
"\n",
|
||||
" return (\n",
|
||||
" \"\", # clear Reply box\n",
|
||||
" history, # s_history\n",
|
||||
" all_avail, # s_available\n",
|
||||
" preferred_final, # s_preferred\n",
|
||||
" gr.update(value=fmt_preferred_md(preferred_final)), # preferred_md\n",
|
||||
" gr.update(value=fmt_available_md(all_avail)), # available_md\n",
|
||||
" gr.update(value=audio_bytes, visible=True), # audio_out\n",
|
||||
" )\n"
|
||||
" \"\", # clear Reply box\n",
|
||||
" history, # s_history\n",
|
||||
" all_avail, # s_available (cumulative)\n",
|
||||
" preferred_final, # s_preferred\n",
|
||||
" gr.update(value=fmt_preferred_md(preferred_final)),\n",
|
||||
" gr.update(value=fmt_available_md(all_avail)),\n",
|
||||
" gr.update(value=audio_bytes, visible=True),\n",
|
||||
" gr.update(value=dbg_text), # debug_box\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"def clear_debug():\n",
|
||||
" clear_log_buffer()\n",
|
||||
" return gr.update(value=\"\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -412,38 +638,61 @@
|
||||
"\n",
|
||||
" audio_out = gr.Audio(label=\"Audio explanation\", autoplay=True, visible=False)\n",
|
||||
"\n",
|
||||
" reply_in = gr.Textbox(\n",
|
||||
" label=\"Reply\",\n",
|
||||
" placeholder=\"Chat with agent to refine the outputs\",\n",
|
||||
" lines=2,\n",
|
||||
" visible=False, # 👈 hidden for the first input\n",
|
||||
" )\n",
|
||||
" with gr.Row():\n",
|
||||
" reply_in = gr.Textbox(\n",
|
||||
" label=\"Reply\",\n",
|
||||
" placeholder=\"Chat with the agent to refine the outputs\",\n",
|
||||
" lines=2,\n",
|
||||
" visible=False, # hidden for the first input\n",
|
||||
" )\n",
|
||||
" send_btn = gr.Button(\"Send\", variant=\"primary\", visible=False)\n",
|
||||
"\n",
|
||||
" with gr.Column(scale=3): # RIGHT 30%\n",
|
||||
" preferred_md = gr.Markdown(fmt_preferred_md(\"\"))\n",
|
||||
" available_md = gr.Markdown(fmt_available_md([]))\n",
|
||||
"\n",
|
||||
" with gr.Accordion(\"Debug log\", open=False):\n",
|
||||
" debug_box = gr.Textbox(label=\"Log\", value=\"\", lines=16, interactive=False)\n",
|
||||
" clear_btn = gr.Button(\"Clear log\", size=\"sm\")\n",
|
||||
"\n",
|
||||
" # Events\n",
|
||||
" # Initial run: also disables Find and shows Send\n",
|
||||
" find_btn.click(\n",
|
||||
" initial_submit,\n",
|
||||
" inputs=[industry_in, customers_in, desc_in, s_history, s_available, s_preferred],\n",
|
||||
" outputs=[\n",
|
||||
" s_history, s_available, s_preferred,\n",
|
||||
" preferred_md, available_md,\n",
|
||||
" reply_in, # 👈 becomes visible after first run\n",
|
||||
" audio_out # 👈 becomes visible after first run\n",
|
||||
" reply_in, # visible after first run\n",
|
||||
" audio_out, # visible after first run\n",
|
||||
" debug_box,\n",
|
||||
" find_btn, # NEW: disable + relabel\n",
|
||||
" send_btn, # NEW: show the Send button\n",
|
||||
" ],\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Multi-turn submit via Enter in the textbox\n",
|
||||
" reply_in.submit(\n",
|
||||
" refine_submit,\n",
|
||||
" inputs=[reply_in, s_history, s_available, s_preferred],\n",
|
||||
" outputs=[\n",
|
||||
" reply_in, s_history, s_available, s_preferred,\n",
|
||||
" preferred_md, available_md, audio_out\n",
|
||||
" preferred_md, available_md, audio_out, debug_box\n",
|
||||
" ],\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Multi-turn submit via explicit Send button\n",
|
||||
" send_btn.click(\n",
|
||||
" refine_submit,\n",
|
||||
" inputs=[reply_in, s_history, s_available, s_preferred],\n",
|
||||
" outputs=[\n",
|
||||
" reply_in, s_history, s_available, s_preferred,\n",
|
||||
" preferred_md, available_md, audio_out, debug_box\n",
|
||||
" ],\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" clear_btn.click(clear_debug, inputs=[], outputs=[debug_box])\n",
|
||||
"\n",
|
||||
"ui.launch(inbrowser=True, show_error=True)\n"
|
||||
]
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user