Improved prompting to avoid space or symbols in domain names.

This commit is contained in:
Nik
2025-10-23 13:48:16 +05:30
parent ca9c77d91c
commit ec0296b261

View File

@@ -32,7 +32,7 @@
"import os\n", "import os\n",
"import json\n", "import json\n",
"import requests\n", "import requests\n",
"from typing import Dict, List, Tuple\n", "from typing import Dict, List, Tuple, Any, Optional\n",
"import re\n", "import re\n",
"\n", "\n",
"from dotenv import load_dotenv\n", "from dotenv import load_dotenv\n",
@@ -47,6 +47,73 @@
"openai = OpenAI()" "openai = OpenAI()"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"id": "361f7fe3",
"metadata": {},
"outputs": [],
"source": [
"# --- robust logging that works inside VS Code notebooks + Gradio threads ---\n",
"import sys, logging, threading\n",
"from collections import deque\n",
"from typing import Any\n",
"\n",
"DEBUG_LLM = True # toggle on/off noisy logs\n",
"CLEAR_LOG_ON_RUN = True # clear panel before each submit\n",
"\n",
"_LOG_BUFFER = deque(maxlen=2000) # keep ~2000 lines in memory\n",
"_LOG_LOCK = threading.Lock()\n",
"\n",
"class GradioBufferHandler(logging.Handler):\n",
" def emit(self, record: logging.LogRecord) -> None:\n",
" try:\n",
" msg = self.format(record)\n",
" except Exception:\n",
" msg = record.getMessage()\n",
" with _LOG_LOCK:\n",
" for line in (msg.splitlines() or [\"\"]):\n",
" _LOG_BUFFER.append(line)\n",
"\n",
"def get_log_text() -> str:\n",
" with _LOG_LOCK:\n",
" return \"\\n\".join(_LOG_BUFFER)\n",
"\n",
"def clear_log_buffer() -> None:\n",
" with _LOG_LOCK:\n",
" _LOG_BUFFER.clear()\n",
"\n",
"def _setup_logger() -> logging.Logger:\n",
" logger = logging.getLogger(\"aidf\")\n",
" logger.setLevel(logging.DEBUG if DEBUG_LLM else logging.INFO)\n",
" logger.handlers.clear()\n",
" fmt = logging.Formatter(\"%(asctime)s | %(levelname)s | %(message)s\", \"%H:%M:%S\")\n",
"\n",
" stream = logging.StreamHandler(stream=sys.stdout) # captured by VS Code notebook\n",
" stream.setFormatter(fmt)\n",
"\n",
" buf = GradioBufferHandler() # shown inside the Gradio panel\n",
" buf.setFormatter(fmt)\n",
"\n",
" logger.addHandler(stream)\n",
" logger.addHandler(buf)\n",
" logger.propagate = False\n",
" return logger\n",
"\n",
"logger = _setup_logger()\n",
"\n",
"def dbg_json(obj: Any, title: str = \"\") -> None:\n",
" \"\"\"Convenience: pretty-print JSON-ish objects to the logger.\"\"\"\n",
" try:\n",
" txt = json.dumps(obj, ensure_ascii=False, indent=2)\n",
" except Exception:\n",
" txt = str(obj)\n",
" if title:\n",
" logger.debug(\"%s\\n%s\", title, txt)\n",
" else:\n",
" logger.debug(\"%s\", txt)\n"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
@@ -56,12 +123,35 @@
"source": [ "source": [
"RDAP_URL = \"https://rdap.verisign.com/com/v1/domain/{}\"\n", "RDAP_URL = \"https://rdap.verisign.com/com/v1/domain/{}\"\n",
"\n", "\n",
"_ALPHA_RE = re.compile(r\"^[a-z]+$\", re.IGNORECASE)\n",
"\n",
"def _to_com(domain: str) -> str:\n", "def _to_com(domain: str) -> str:\n",
" d = domain.strip().lower()\n", " d = domain.strip().lower()\n",
" return d if d.endswith(\".com\") else f\"{d}.com\"\n", " return d if d.endswith(\".com\") else f\"{d}.com\"\n",
"\n", "\n",
"def _sld_is_english_alpha(fqdn: str) -> bool:\n",
" \"\"\"\n",
" True only if the second-level label (just before .com) is made up\n",
" exclusively of English letters (a-z).\n",
" Examples:\n",
" foo.com -> True\n",
" foo-bar.com -> False\n",
" foo1.com -> False\n",
" café.com -> False\n",
" xn--cafe.com -> False\n",
" www.foo.com -> True (checks 'foo')\n",
" \"\"\"\n",
" if not fqdn.endswith(\".com\"):\n",
" return False\n",
" sld = fqdn[:-4].split(\".\")[-1] # take label immediately before .com\n",
" return bool(sld) and bool(_ALPHA_RE.fullmatch(sld))\n",
"\n",
"def check_com_availability(domain: str) -> Dict:\n", "def check_com_availability(domain: str) -> Dict:\n",
" fqdn = _to_com(domain)\n", " fqdn = _to_com(domain)\n",
" # Skip API if not strictly English letters\n",
" if not _sld_is_english_alpha(fqdn):\n",
" return {\"domain\": fqdn, \"available\": False, \"status\": 0}\n",
"\n",
" try:\n", " try:\n",
" r = requests.get(RDAP_URL.format(fqdn), timeout=6)\n", " r = requests.get(RDAP_URL.format(fqdn), timeout=6)\n",
" return {\"domain\": fqdn, \"available\": (r.status_code == 404), \"status\": r.status_code}\n", " return {\"domain\": fqdn, \"available\": (r.status_code == 404), \"status\": r.status_code}\n",
@@ -80,8 +170,15 @@
" \"\"\"\n", " \"\"\"\n",
" session = requests.Session()\n", " session = requests.Session()\n",
" results: List[Dict] = []\n", " results: List[Dict] = []\n",
"\n",
" for d in domains:\n", " for d in domains:\n",
" fqdn = _to_com(d)\n", " fqdn = _to_com(d)\n",
"\n",
" # Skip API if not strictly English letters\n",
" if not _sld_is_english_alpha(fqdn):\n",
" results.append({\"domain\": fqdn, \"available\": False, \"status\": 0})\n",
" continue\n",
"\n",
" try:\n", " try:\n",
" r = session.get(RDAP_URL.format(fqdn), timeout=6)\n", " r = session.get(RDAP_URL.format(fqdn), timeout=6)\n",
" ok = (r.status_code == 404)\n", " ok = (r.status_code == 404)\n",
@@ -133,19 +230,31 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"def handle_tool_calls(message) -> List[Dict]:\n", "def handle_tool_calls(message) -> List[Dict]:\n",
" \"\"\"\n",
" Translates model tool_calls into tool results for follow-up completion.\n",
" \"\"\"\n",
" results = []\n", " results = []\n",
" for call in (message.tool_calls or []):\n", " for call in (message.tool_calls or []):\n",
" if call.function.name == \"check_com_availability\":\n", " fn = getattr(call.function, \"name\", None)\n",
" args = json.loads(call.function.arguments or \"{}\")\n", " args_raw = getattr(call.function, \"arguments\", \"\") or \"{}\"\n",
" try:\n",
" args = json.loads(args_raw)\n",
" except Exception:\n",
" args = {}\n",
"\n",
" logger.debug(\"TOOL CALL -> %s | args=%s\", fn, json.dumps(args, ensure_ascii=False))\n",
"\n",
" if fn == \"check_com_availability_bulk\":\n",
" payload = check_com_availability_bulk(args.get(\"domains\", []))\n",
" elif fn == \"check_com_availability\":\n",
" payload = check_com_availability(args.get(\"domain\", \"\"))\n", " payload = check_com_availability(args.get(\"domain\", \"\"))\n",
" results.append({\n", " else:\n",
" \"role\": \"tool\",\n", " payload = {\"error\": f\"unknown tool {fn}\"}\n",
" \"tool_call_id\": call.id,\n", "\n",
" \"content\": json.dumps(payload)\n", " logger.debug(\"TOOL RESULT <- %s | %s\", fn, json.dumps(payload, ensure_ascii=False))\n",
" })\n", "\n",
" results.append({\n",
" \"role\": \"tool\",\n",
" \"tool_call_id\": call.id,\n",
" \"content\": json.dumps(payload),\n",
" })\n",
" return results\n" " return results\n"
] ]
}, },
@@ -159,26 +268,67 @@
"SYSTEM_PROMPT = \"\"\"You are the Agent for project \"AI Domain Finder\".\n", "SYSTEM_PROMPT = \"\"\"You are the Agent for project \"AI Domain Finder\".\n",
"Goal: suggest .com domains and verify availability using the tool ONLY (no guessing).\n", "Goal: suggest .com domains and verify availability using the tool ONLY (no guessing).\n",
"\n", "\n",
"Instructions:\n", "Do this each interaction:\n",
"- Always propose 5-12 brandable .com candidates based on:\n", "- Generate up to ~20 short, brandable .com candidates from:\n",
" (1) Industry, (2) Target Customers, (3) Description.\n", " (1) Industry, (2) Target Customers, (3) Description.\n",
"- For each candidate, CALL the tool check_com_availability.\n", "- Use the BULK tool `check_com_availability_bulk` with a list of candidates\n",
"- Respond ONLY after checking all candidates.\n", " (roots or FQDNs). Prefer a single call or very few batched calls.\n",
"- Output Markdown with three sections and these exact headings:\n", "- If >= 5 available .coms are found, STOP checking and finalize the answer.\n",
" 1) Available .com domains:\n", "\n",
" - itemized list (root + .com)\n", "Output Markdown with EXACT section headings:\n",
" 2) Preferred domain:\n", "1) Available .com domains:\n",
" - a single best pick\n", " - itemized list of available .coms only (root + .com)\n",
" 3) Audio explanation:\n", "2) Preferred domain:\n",
" - 1-2 concise sentences explaining the preference\n", " - a single best pick\n",
"3) Audio explanation:\n",
" - 12 concise sentences explaining the preference\n",
"\n", "\n",
"Constraints:\n", "Constraints:\n",
"- Use customer-familiar words where helpful.\n", "- Use customer-familiar words where helpful.\n",
"- Keep names short, simple, pronounceable; avoid hyphens/numbers unless meaningful.\n", "- Keep names short, simple, pronounceable; avoid hyphens/numbers unless meaningful.\n",
"- Never include TLDs other than .com.\n", "- Never include TLDs other than .com.\n",
"- domain is made up of english alphabets in lower case only no symbols or spaces to use\n",
"\"\"\"\n" "\"\"\"\n"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"id": "72e9d8c2",
"metadata": {},
"outputs": [],
"source": [
"def _asdict_tool_call(tc: Any) -> dict:\n",
" try:\n",
" return {\n",
" \"id\": getattr(tc, \"id\", None),\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": getattr(tc.function, \"name\", None),\n",
" \"arguments\": getattr(tc.function, \"arguments\", None),\n",
" },\n",
" }\n",
" except Exception:\n",
" return {\"type\": \"function\", \"function\": {\"name\": None, \"arguments\": None}}\n",
"\n",
"def _asdict_message(msg: Any) -> dict:\n",
" if isinstance(msg, dict):\n",
" return msg\n",
" role = getattr(msg, \"role\", None)\n",
" content = getattr(msg, \"content\", None)\n",
" tool_calls = getattr(msg, \"tool_calls\", None)\n",
" out = {\"role\": role, \"content\": content}\n",
" if tool_calls:\n",
" out[\"tool_calls\"] = [_asdict_tool_call(tc) for tc in tool_calls]\n",
" return out\n",
"\n",
"def _sanitized_messages_for_log(messages: list[dict | Any]) -> list[dict]:\n",
" return [_asdict_message(m) for m in messages]\n",
"\n",
"def _limit_text(s: str, limit: int = 40000) -> str:\n",
" return s if len(s) <= limit else (s[:limit] + \"\\n... [truncated]\")\n"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
@@ -186,22 +336,58 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"def run_agent_with_tools(history: List[Dict]) -> str:\n", "def run_agent_with_tools(history: List[Dict]) -> Tuple[str, List[str], str]:\n",
" \"\"\"\n", " \"\"\"\n",
" history: list of {\"role\": \"...\", \"content\": \"...\"} messages\n", " Returns:\n",
" returns assistant markdown string (includes sections required by SYSTEM_PROMPT)\n", " reply_md: final assistant markdown\n",
" tool_available: .coms marked available by RDAP tools (order-preserving, deduped)\n",
" dbg_text: concatenated log buffer (for the UI panel)\n",
" \"\"\"\n", " \"\"\"\n",
" messages = [{\"role\": \"system\", \"content\": SYSTEM_PROMPT}] + history\n", " messages: List[Dict] = [{\"role\": \"system\", \"content\": SYSTEM_PROMPT}] + history\n",
" tool_available: List[str] = []\n",
"\n",
" dbg_json(_sanitized_messages_for_log(messages), \"=== LLM REQUEST (initial messages) ===\")\n",
" resp = openai.chat.completions.create(model=OPENAI_MODEL, messages=messages, tools=TOOLS)\n", " resp = openai.chat.completions.create(model=OPENAI_MODEL, messages=messages, tools=TOOLS)\n",
"\n", "\n",
" while resp.choices[0].finish_reason == \"tool_calls\":\n", " while resp.choices[0].finish_reason == \"tool_calls\":\n",
" tool_msg = resp.choices[0].message\n", " tool_msg_sdk = resp.choices[0].message\n",
" tool_results = handle_tool_calls(tool_msg)\n", " tool_msg = _asdict_message(tool_msg_sdk)\n",
" dbg_json(tool_msg, \"=== ASSISTANT (tool_calls) ===\")\n",
"\n",
" tool_results = handle_tool_calls(tool_msg_sdk)\n",
"\n",
" # Accumulate authoritative availability directly from tool outputs\n",
" for tr in tool_results:\n",
" try:\n",
" data = json.loads(tr[\"content\"])\n",
" if isinstance(data, dict) and isinstance(data.get(\"available\"), list):\n",
" for d in data[\"available\"]:\n",
" tool_available.append(_to_com(d))\n",
" except Exception:\n",
" pass\n",
"\n",
" dbg_json([json.loads(tr[\"content\"]) for tr in tool_results], \"=== TOOL RESULTS ===\")\n",
"\n",
" messages.append(tool_msg)\n", " messages.append(tool_msg)\n",
" messages.extend(tool_results)\n", " messages.extend(tool_results)\n",
" dbg_json(_sanitized_messages_for_log(messages), \"=== LLM REQUEST (messages + tools) ===\")\n",
"\n",
" resp = openai.chat.completions.create(model=OPENAI_MODEL, messages=messages, tools=TOOLS)\n", " resp = openai.chat.completions.create(model=OPENAI_MODEL, messages=messages, tools=TOOLS)\n",
"\n", "\n",
" return resp.choices[0].message.content" " # Dedup preserve order\n",
" seen, uniq = set(), []\n",
" for d in tool_available:\n",
" if d not in seen:\n",
" seen.add(d)\n",
" uniq.append(d)\n",
"\n",
" reply_md = resp.choices[0].message.content\n",
" logger.debug(\"=== FINAL ASSISTANT ===\\n%s\", _limit_text(reply_md))\n",
" dbg_json(uniq, \"=== AVAILABLE FROM TOOLS (authoritative) ===\")\n",
"\n",
" # Return current buffer text for the UI panel\n",
" dbg_text = _limit_text(get_log_text(), 40000)\n",
" return reply_md, uniq, dbg_text\n"
] ]
}, },
{ {
@@ -233,33 +419,6 @@
" return audio.content\n" " return audio.content\n"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"id": "7bdf7c67",
"metadata": {},
"outputs": [],
"source": [
"def chat(message: str, history_ui: List[Dict]) -> Tuple[List[Dict], bytes]:\n",
" \"\"\"\n",
" Gradio ChatInterface callback.\n",
" - message: latest user text (free-form)\n",
" - history_ui: [{\"role\": \"user\"/\"assistant\", \"content\": \"...\"}]\n",
" Returns: updated history, audio bytes for the 'Audio explanation'.\n",
" \"\"\"\n",
" # Convert Gradio UI history to OpenAI-format history\n",
" history = [{\"role\": h[\"role\"], \"content\": h[\"content\"]} for h in history_ui]\n",
" history.append({\"role\": \"user\", \"content\": message})\n",
"\n",
" reply_md = run_agent_with_tools(history)\n",
" history.append({\"role\": \"assistant\", \"content\": reply_md})\n",
"\n",
" audio_text = extract_audio_text(reply_md)\n",
" audio_bytes = synth_audio(audio_text)\n",
"\n",
" return history, audio_bytes\n"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
@@ -269,6 +428,8 @@
"source": [ "source": [
"\n", "\n",
"_DOMAIN_RE = re.compile(r\"\\b[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?\\.com\\b\", re.I)\n", "_DOMAIN_RE = re.compile(r\"\\b[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?\\.com\\b\", re.I)\n",
"_HDR_AVAIL = re.compile(r\"^\\s*[\\d\\.\\)\\-]*\\s*available\\s+.*\\.com\\s+domains\", re.I)\n",
"_HDR_PREF = re.compile(r\"^\\s*[\\d\\.\\)\\-]*\\s*preferred\\s+domain\", re.I)\n",
"\n", "\n",
"def _norm_domain(s: str) -> str:\n", "def _norm_domain(s: str) -> str:\n",
" s = s.strip().lower()\n", " s = s.strip().lower()\n",
@@ -279,16 +440,19 @@
" out = []\n", " out = []\n",
" in_section = False\n", " in_section = False\n",
" for ln in lines:\n", " for ln in lines:\n",
" if ln.strip().lower().startswith(\"1) available .com domains\"):\n", " if _HDR_AVAIL.search(ln):\n",
" in_section = True\n", " in_section = True\n",
" continue\n", " continue\n",
" if in_section and ln.strip().lower().startswith(\"2) preferred\"):\n", " if in_section and _HDR_PREF.search(ln):\n",
" break\n", " break\n",
" if in_section:\n", " if in_section:\n",
" if ln.strip().startswith((\"-\", \"*\")) or _DOMAIN_RE.search(ln):\n", " for m in _DOMAIN_RE.findall(ln):\n",
" for m in _DOMAIN_RE.findall(ln):\n", " out.append(_norm_domain(m))\n",
" out.append(_norm_domain(m))\n", " # Fallback: if the header wasn't found, collect all .coms then we'll still\n",
" # dedupe while preserving order\n", " # rely on agent instruction to list only available, which should be safe.\n",
" if not out:\n",
" out = [_norm_domain(m) for m in _DOMAIN_RE.findall(md)]\n",
" # dedupe preserve order\n",
" seen, uniq = set(), []\n", " seen, uniq = set(), []\n",
" for d in out:\n", " for d in out:\n",
" if d not in seen:\n", " if d not in seen:\n",
@@ -297,14 +461,17 @@
" return uniq\n", " return uniq\n",
"\n", "\n",
"def parse_preferred(md: str) -> str:\n", "def parse_preferred(md: str) -> str:\n",
" # look in the preferred section; fallback to first domain anywhere\n", " # search the preferred section first\n",
" lower = md.lower()\n", " lines = md.splitlines()\n",
" idx = lower.find(\"2) preferred domain\")\n", " start = None\n",
" if idx != -1:\n", " for i, ln in enumerate(lines):\n",
" seg = md[idx: idx + 500]\n", " if _HDR_PREF.search(ln):\n",
" m = _DOMAIN_RE.search(seg)\n", " start = i\n",
" if m:\n", " break\n",
" return _norm_domain(m.group(0))\n", " segment = \"\\n\".join(lines[start:start+8]) if start is not None else md[:500]\n",
" m = _DOMAIN_RE.search(segment)\n",
" if m:\n",
" return _norm_domain(m.group(0))\n",
" m = _DOMAIN_RE.search(md)\n", " m = _DOMAIN_RE.search(md)\n",
" return _norm_domain(m.group(0)) if m else \"\"\n", " return _norm_domain(m.group(0)) if m else \"\"\n",
"\n", "\n",
@@ -323,13 +490,22 @@
" return \"### Preferred domain\\n\\n* not chosen yet *\"\n", " return \"### Preferred domain\\n\\n* not chosen yet *\"\n",
" return f\"### Preferred domain\\n\\n`{d}`\"\n", " return f\"### Preferred domain\\n\\n`{d}`\"\n",
"\n", "\n",
"def build_initial_message(industry: str, customers: str, desc: str) -> str:\n", "def build_context_msg(known_avail: Optional[List[str]], preferred_now: Optional[str]) -> str:\n",
" return (\n", " \"\"\"\n",
" \"Please propose .com domains based on:\\n\"\n", " Create a short 'state so far' block that we prepend to the next user turn\n",
" f\"Industry: {industry}\\n\"\n", " so the model always sees the preferred and cumulative available list.\n",
" f\"Target Customers: {customers}\\n\"\n", " \"\"\"\n",
" f\"Description: {desc}\"\n", " lines = []\n",
" )\n" " if (preferred_now or \"\").strip():\n",
" lines.append(f\"Preferred domain so far: {preferred_now.strip().lower()}\")\n",
" if known_avail:\n",
" lines.append(\"Available .com domains discovered so far:\")\n",
" for d in known_avail:\n",
" if d:\n",
" lines.append(f\"- {d.strip().lower()}\")\n",
" if not lines:\n",
" return \"\"\n",
" return \"STATE TO CARRY OVER FROM PREVIOUS TURNS:\\n\" + \"\\n\".join(lines)"
] ]
}, },
{ {
@@ -338,52 +514,102 @@
"id": "07f079d6", "id": "07f079d6",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [
"def run_and_extract(history: List[Dict]) -> Tuple[str, List[str], str, str, str]:\n",
" reply_md, avail_from_tools, dbg_text = run_agent_with_tools(history)\n",
" parsed_avail = parse_available(reply_md)\n",
" new_avail = merge_and_sort(avail_from_tools, parsed_avail)\n",
" preferred = parse_preferred(reply_md)\n",
" audio_text = extract_audio_text(reply_md)\n",
" return reply_md, new_avail, preferred, audio_text, dbg_text\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4cd5d8ef",
"metadata": {},
"outputs": [],
"source": [ "source": [
"def initial_submit(industry: str, customers: str, desc: str,\n", "def initial_submit(industry: str, customers: str, desc: str,\n",
" history: list[dict], known_avail: list[str], preferred_now: str):\n", " history: List[Dict], known_avail: List[str], preferred_now: str):\n",
" msg = build_initial_message(industry, customers, desc)\n", " if CLEAR_LOG_ON_RUN:\n",
" history = (history or []) + [{\"role\": \"user\", \"content\": msg}]\n", " clear_log_buffer()\n",
"\n", "\n",
" reply_md, new_avail, preferred, audio_text = run_and_extract(history)\n", " logger.info(\"Initial submit | industry=%r | customers=%r | desc_len=%d\",\n",
" industry, customers, len(desc or \"\"))\n",
"\n",
" # Build context (usually empty on the very first run, but future inits also work)\n",
" ctx = build_context_msg(known_avail or [], preferred_now or \"\")\n",
"\n",
" user_msg = (\n",
" \"Please propose .com domains based on:\\n\"\n",
" f\"Industry: {industry}\\n\"\n",
" f\"Target Customers: {customers}\\n\"\n",
" f\"Description: {desc}\"\n",
" )\n",
"\n",
" # Single user turn that includes state + prompt so the model always sees memory\n",
" full_content = (ctx + \"\\n\\n\" if ctx else \"\") + user_msg\n",
"\n",
" history = (history or []) + [{\"role\": \"user\", \"content\": full_content}]\n",
" reply_md, new_avail, preferred, audio_text, dbg_text = run_and_extract(history)\n",
" history += [{\"role\": \"assistant\", \"content\": reply_md}]\n", " history += [{\"role\": \"assistant\", \"content\": reply_md}]\n",
"\n", "\n",
" all_avail = merge_and_sort(known_avail or [], new_avail)\n", " all_avail = merge_and_sort(known_avail or [], new_avail or [])\n",
" preferred_final = preferred or preferred_now or \"\"\n", " preferred_final = preferred or preferred_now or \"\"\n",
" audio_bytes = synth_audio(audio_text)\n", " audio_bytes = synth_audio(audio_text)\n",
"\n", "\n",
" return (\n", " return (\n",
" history, # s_history\n", " history, # s_history\n",
" all_avail, # s_available\n", " all_avail, # s_available (cumulative)\n",
" preferred_final, # s_preferred\n", " preferred_final, # s_preferred\n",
" gr.update(value=fmt_preferred_md(preferred_final)), # preferred_md\n", " gr.update(value=fmt_preferred_md(preferred_final)),\n",
" gr.update(value=fmt_available_md(all_avail)), # available_md\n", " gr.update(value=fmt_available_md(all_avail)),\n",
" gr.update(value=\"\", visible=True), # reply_in -> now visible\n", " gr.update(value=\"\", visible=True), # reply_in: show after first run\n",
" gr.update(value=audio_bytes, visible=True), # audio_out\n", " gr.update(value=audio_bytes, visible=True), # audio_out\n",
" gr.update(value=dbg_text), # debug_box\n",
" gr.update(value=\"Find Domains (done)\", interactive=False), # NEW: disable Find\n",
" gr.update(visible=True), # NEW: show Send button\n",
" )\n", " )\n",
"\n", "\n",
"def refine_submit(reply: str,\n", "def refine_submit(reply: str,\n",
" history: list[dict], known_avail: list[str], preferred_now: str):\n", " history: List[Dict], known_avail: List[str], preferred_now: str):\n",
" if not reply.strip():\n", " # If empty, do nothing (keeps UI state untouched)\n",
" if not (reply or \"\").strip():\n",
" return (\"\", history, known_avail, preferred_now,\n", " return (\"\", history, known_avail, preferred_now,\n",
" gr.update(), gr.update(), gr.update())\n", " gr.update(), gr.update(), gr.update(), gr.update())\n",
"\n", "\n",
" history = (history or []) + [{\"role\": \"user\", \"content\": reply.strip()}]\n", " if CLEAR_LOG_ON_RUN:\n",
" reply_md, new_avail, preferred, audio_text = run_and_extract(history)\n", " clear_log_buffer()\n",
" logger.info(\"Refine submit | user_reply_len=%d\", len(reply))\n",
"\n",
" # Always prepend memory + the user's refinement so the model can iterate properly\n",
" ctx = build_context_msg(known_avail or [], preferred_now or \"\")\n",
" full_content = (ctx + \"\\n\\n\" if ctx else \"\") + reply.strip()\n",
"\n",
" history = (history or []) + [{\"role\": \"user\", \"content\": full_content}]\n",
" reply_md, new_avail, preferred, audio_text, dbg_text = run_and_extract(history)\n",
" history += [{\"role\": \"assistant\", \"content\": reply_md}]\n", " history += [{\"role\": \"assistant\", \"content\": reply_md}]\n",
"\n", "\n",
" all_avail = merge_and_sort(known_avail or [], new_avail)\n", " all_avail = merge_and_sort(known_avail or [], new_avail or [])\n",
" preferred_final = preferred or preferred_now or \"\"\n", " preferred_final = preferred or preferred_now or \"\"\n",
" audio_bytes = synth_audio(audio_text)\n", " audio_bytes = synth_audio(audio_text)\n",
"\n", "\n",
" return (\n", " return (\n",
" \"\", # clear Reply box\n", " \"\", # clear Reply box\n",
" history, # s_history\n", " history, # s_history\n",
" all_avail, # s_available\n", " all_avail, # s_available (cumulative)\n",
" preferred_final, # s_preferred\n", " preferred_final, # s_preferred\n",
" gr.update(value=fmt_preferred_md(preferred_final)), # preferred_md\n", " gr.update(value=fmt_preferred_md(preferred_final)),\n",
" gr.update(value=fmt_available_md(all_avail)), # available_md\n", " gr.update(value=fmt_available_md(all_avail)),\n",
" gr.update(value=audio_bytes, visible=True), # audio_out\n", " gr.update(value=audio_bytes, visible=True),\n",
" )\n" " gr.update(value=dbg_text), # debug_box\n",
" )\n",
"\n",
"def clear_debug():\n",
" clear_log_buffer()\n",
" return gr.update(value=\"\")\n"
] ]
}, },
{ {
@@ -412,38 +638,61 @@
"\n", "\n",
" audio_out = gr.Audio(label=\"Audio explanation\", autoplay=True, visible=False)\n", " audio_out = gr.Audio(label=\"Audio explanation\", autoplay=True, visible=False)\n",
"\n", "\n",
" reply_in = gr.Textbox(\n", " with gr.Row():\n",
" label=\"Reply\",\n", " reply_in = gr.Textbox(\n",
" placeholder=\"Chat with agent to refine the outputs\",\n", " label=\"Reply\",\n",
" lines=2,\n", " placeholder=\"Chat with the agent to refine the outputs\",\n",
" visible=False, # 👈 hidden for the first input\n", " lines=2,\n",
" )\n", " visible=False, # hidden for the first input\n",
" )\n",
" send_btn = gr.Button(\"Send\", variant=\"primary\", visible=False)\n",
"\n", "\n",
" with gr.Column(scale=3): # RIGHT 30%\n", " with gr.Column(scale=3): # RIGHT 30%\n",
" preferred_md = gr.Markdown(fmt_preferred_md(\"\"))\n", " preferred_md = gr.Markdown(fmt_preferred_md(\"\"))\n",
" available_md = gr.Markdown(fmt_available_md([]))\n", " available_md = gr.Markdown(fmt_available_md([]))\n",
"\n", "\n",
" with gr.Accordion(\"Debug log\", open=False):\n",
" debug_box = gr.Textbox(label=\"Log\", value=\"\", lines=16, interactive=False)\n",
" clear_btn = gr.Button(\"Clear log\", size=\"sm\")\n",
"\n",
" # Events\n", " # Events\n",
" # Initial run: also disables Find and shows Send\n",
" find_btn.click(\n", " find_btn.click(\n",
" initial_submit,\n", " initial_submit,\n",
" inputs=[industry_in, customers_in, desc_in, s_history, s_available, s_preferred],\n", " inputs=[industry_in, customers_in, desc_in, s_history, s_available, s_preferred],\n",
" outputs=[\n", " outputs=[\n",
" s_history, s_available, s_preferred,\n", " s_history, s_available, s_preferred,\n",
" preferred_md, available_md,\n", " preferred_md, available_md,\n",
" reply_in, # 👈 becomes visible after first run\n", " reply_in, # visible after first run\n",
" audio_out # 👈 becomes visible after first run\n", " audio_out, # visible after first run\n",
" debug_box,\n",
" find_btn, # NEW: disable + relabel\n",
" send_btn, # NEW: show the Send button\n",
" ],\n", " ],\n",
" )\n", " )\n",
"\n", "\n",
" # Multi-turn submit via Enter in the textbox\n",
" reply_in.submit(\n", " reply_in.submit(\n",
" refine_submit,\n", " refine_submit,\n",
" inputs=[reply_in, s_history, s_available, s_preferred],\n", " inputs=[reply_in, s_history, s_available, s_preferred],\n",
" outputs=[\n", " outputs=[\n",
" reply_in, s_history, s_available, s_preferred,\n", " reply_in, s_history, s_available, s_preferred,\n",
" preferred_md, available_md, audio_out\n", " preferred_md, available_md, audio_out, debug_box\n",
" ],\n", " ],\n",
" )\n", " )\n",
"\n", "\n",
" # Multi-turn submit via explicit Send button\n",
" send_btn.click(\n",
" refine_submit,\n",
" inputs=[reply_in, s_history, s_available, s_preferred],\n",
" outputs=[\n",
" reply_in, s_history, s_available, s_preferred,\n",
" preferred_md, available_md, audio_out, debug_box\n",
" ],\n",
" )\n",
"\n",
" clear_btn.click(clear_debug, inputs=[], outputs=[debug_box])\n",
"\n",
"ui.launch(inbrowser=True, show_error=True)\n" "ui.launch(inbrowser=True, show_error=True)\n"
] ]
} }