{ "cells": [ { "attachments": { "67b615fe-0c25-4410-9d58-835982547001.png": { "image/png": "" } }, "cell_type": "markdown", "id": "16dc0e41-80bd-4453-b421-dcf315741bf4", "metadata": {}, "source": [ "# Code generation with RAG and self-correction\n", "\n", "AlphaCodium presented an approach for code generation that uses control flow.\n", "\n", "Main idea: [construct an answer to a coding question iteratively.](https://x.com/karpathy/status/1748043513156272416?s=20). \n", "\n", "[AlphaCodium](https://github.com/Codium-ai/AlphaCodium) iteravely tests and improves an answer on public and AI-generated tests for a particular question. \n", "\n", "We will implement some of these ideas from scratch using [LangGraph](https://langchain-ai.github.io/langgraph/):\n", "\n", "1. We start with a set of documentation specified by a user\n", "2. We use a long context LLM to ingest it and perform RAG to answer a question based upon it\n", "3. We will invoke a tool to produce a structured output\n", "4. We will perform two unit tests (check imports and code execution) prior returning the solution to the user \n", "\n", "![Screenshot 2024-05-23 at 2.17.42 PM.png](attachment:67b615fe-0c25-4410-9d58-835982547001.png)" ] }, { "cell_type": "code", "execution_count": 6, "id": "e3900420", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Defaulting to user installation because normal site-packages is not writeable\n", "Requirement already satisfied: langchain_community in /Users/mark/Library/Python/3.9/lib/python/site-packages (0.2.7)\n", "Requirement already satisfied: langchain-openai in /Users/mark/Library/Python/3.9/lib/python/site-packages (0.1.15)\n", "Requirement already satisfied: langchain-anthropic in /Users/mark/Library/Python/3.9/lib/python/site-packages (0.1.19)\n", "Requirement already satisfied: langchain in /Users/mark/Library/Python/3.9/lib/python/site-packages (0.2.7)\n", "Requirement already satisfied: langgraph in /Users/mark/Library/Python/3.9/lib/python/site-packages (0.1.8)\n", "Requirement already satisfied: bs4 in /Users/mark/Library/Python/3.9/lib/python/site-packages (0.0.2)\n", "Requirement already satisfied: PyYAML>=5.3 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from langchain_community) (6.0.1)\n", "Requirement already satisfied: SQLAlchemy<3,>=1.4 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from langchain_community) (2.0.27)\n", "Requirement already satisfied: aiohttp<4.0.0,>=3.8.3 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from langchain_community) (3.9.1)\n", "Requirement already satisfied: dataclasses-json<0.7,>=0.5.7 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from langchain_community) (0.6.4)\n", "Requirement already satisfied: langchain-core<0.3.0,>=0.2.12 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from langchain_community) (0.2.16)\n", "Requirement already satisfied: langsmith<0.2.0,>=0.1.0 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from langchain_community) (0.1.82)\n", "Requirement already satisfied: numpy<2,>=1 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from langchain_community) (1.26.4)\n", "Requirement already satisfied: requests<3,>=2 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from langchain_community) (2.31.0)\n", "Requirement already satisfied: tenacity!=8.4.0,<9.0.0,>=8.1.0 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from langchain_community) (8.2.3)\n", "Requirement already satisfied: openai<2.0.0,>=1.32.0 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from langchain-openai) (1.35.13)\n", "Requirement already satisfied: tiktoken<1,>=0.7 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from langchain-openai) (0.7.0)\n", "Requirement already satisfied: anthropic<1,>=0.28.0 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from langchain-anthropic) (0.31.0)\n", "Requirement already satisfied: defusedxml<0.8.0,>=0.7.1 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from langchain-anthropic) (0.7.1)\n", "Requirement already satisfied: async-timeout<5.0.0,>=4.0.0 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from langchain) (4.0.3)\n", "Requirement already satisfied: langchain-text-splitters<0.3.0,>=0.2.0 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from langchain) (0.2.0)\n", "Requirement already satisfied: pydantic<3,>=1 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from langchain) (2.5.2)\n", "Requirement already satisfied: beautifulsoup4 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from bs4) (4.12.2)\n", "Requirement already satisfied: attrs>=17.3.0 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain_community) (23.1.0)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain_community) (6.0.4)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain_community) (1.9.4)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain_community) (1.4.1)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain_community) (1.3.1)\n", "Requirement already satisfied: anyio<5,>=3.5.0 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from anthropic<1,>=0.28.0->langchain-anthropic) (3.7.1)\n", "Requirement already satisfied: distro<2,>=1.7.0 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from anthropic<1,>=0.28.0->langchain-anthropic) (1.8.0)\n", "Requirement already satisfied: httpx<1,>=0.23.0 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from anthropic<1,>=0.28.0->langchain-anthropic) (0.25.2)\n", "Requirement already satisfied: jiter<1,>=0.4.0 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from anthropic<1,>=0.28.0->langchain-anthropic) (0.5.0)\n", "Requirement already satisfied: sniffio in /Users/mark/Library/Python/3.9/lib/python/site-packages (from anthropic<1,>=0.28.0->langchain-anthropic) (1.3.0)\n", "Requirement already satisfied: tokenizers>=0.13.0 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from anthropic<1,>=0.28.0->langchain-anthropic) (0.15.0)\n", "Requirement already satisfied: typing-extensions<5,>=4.7 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from anthropic<1,>=0.28.0->langchain-anthropic) (4.8.0)\n", "Requirement already satisfied: marshmallow<4.0.0,>=3.18.0 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from dataclasses-json<0.7,>=0.5.7->langchain_community) (3.21.0)\n", "Requirement already satisfied: typing-inspect<1,>=0.4.0 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from dataclasses-json<0.7,>=0.5.7->langchain_community) (0.9.0)\n", "Requirement already satisfied: jsonpatch<2.0,>=1.33 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from langchain-core<0.3.0,>=0.2.12->langchain_community) (1.33)\n", "Requirement already satisfied: packaging<25,>=23.2 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from langchain-core<0.3.0,>=0.2.12->langchain_community) (23.2)\n", "Requirement already satisfied: orjson<4.0.0,>=3.9.14 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from langsmith<0.2.0,>=0.1.0->langchain_community) (3.9.15)\n", "Requirement already satisfied: tqdm>4 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from openai<2.0.0,>=1.32.0->langchain-openai) (4.66.1)\n", "Requirement already satisfied: annotated-types>=0.4.0 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from pydantic<3,>=1->langchain) (0.6.0)\n", "Requirement already satisfied: pydantic-core==2.14.5 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from pydantic<3,>=1->langchain) (2.14.5)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from requests<3,>=2->langchain_community) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from requests<3,>=2->langchain_community) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from requests<3,>=2->langchain_community) (2.1.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from requests<3,>=2->langchain_community) (2018.8.24)\n", "Requirement already satisfied: regex>=2022.1.18 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from tiktoken<1,>=0.7->langchain-openai) (2023.12.25)\n", "Requirement already satisfied: soupsieve>1.2 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from beautifulsoup4->bs4) (2.5)\n", "Requirement already satisfied: exceptiongroup in /Users/mark/Library/Python/3.9/lib/python/site-packages (from anyio<5,>=3.5.0->anthropic<1,>=0.28.0->langchain-anthropic) (1.2.0)\n", "Requirement already satisfied: httpcore==1.* in /Users/mark/Library/Python/3.9/lib/python/site-packages (from httpx<1,>=0.23.0->anthropic<1,>=0.28.0->langchain-anthropic) (1.0.2)\n", "Requirement already satisfied: h11<0.15,>=0.13 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from httpcore==1.*->httpx<1,>=0.23.0->anthropic<1,>=0.28.0->langchain-anthropic) (0.14.0)\n", "Requirement already satisfied: jsonpointer>=1.9 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from jsonpatch<2.0,>=1.33->langchain-core<0.3.0,>=0.2.12->langchain_community) (2.4)\n", "Requirement already satisfied: huggingface_hub<1.0,>=0.16.4 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from tokenizers>=0.13.0->anthropic<1,>=0.28.0->langchain-anthropic) (0.20.1)\n", "Requirement already satisfied: mypy-extensions>=0.3.0 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from typing-inspect<1,>=0.4.0->dataclasses-json<0.7,>=0.5.7->langchain_community) (1.0.0)\n", "Requirement already satisfied: filelock in /Users/mark/Library/Python/3.9/lib/python/site-packages (from huggingface_hub<1.0,>=0.16.4->tokenizers>=0.13.0->anthropic<1,>=0.28.0->langchain-anthropic) (3.13.1)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /Users/mark/Library/Python/3.9/lib/python/site-packages (from huggingface_hub<1.0,>=0.16.4->tokenizers>=0.13.0->anthropic<1,>=0.28.0->langchain-anthropic) (2023.10.0)\n", "\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.1.2\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49m/Library/Developer/CommandLineTools/usr/bin/python3 -m pip install --upgrade pip\u001b[0m\n" ] } ], "source": [ "! pip install -U langchain_community langchain-openai langchain-anthropic langchain langgraph bs4" ] }, { "cell_type": "markdown", "id": "38330223-d8c8-4156-82b6-93e63343bc01", "metadata": {}, "source": [ "## Docs\n", "\n", "Load [LangChain Expression Language](https://python.langchain.com/v0.2/docs/concepts/#langchain-expression-language-lcel) (LCEL) docs as an example." ] }, { "cell_type": "code", "execution_count": 7, "id": "c2eb35d1-4990-47dc-a5c4-208bae588a82", "metadata": {}, "outputs": [], "source": [ "from bs4 import BeautifulSoup as Soup\n", "from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader\n", "\n", "# LCEL docs\n", "url = \"https://python.langchain.com/v0.2/docs/concepts/#langchain-expression-language-lcel\"\n", "loader = RecursiveUrlLoader(\n", " url=url, max_depth=20, extractor=lambda x: Soup(x, \"html.parser\").text\n", ")\n", "docs = loader.load()\n", "\n", "# Sort the list based on the URLs and get the text\n", "d_sorted = sorted(docs, key=lambda x: x.metadata[\"source\"])\n", "d_reversed = list(reversed(d_sorted))\n", "concatenated_content = \"\\n\\n\\n --- \\n\\n\\n\".join(\n", " [doc.page_content for doc in d_reversed]\n", ")" ] }, { "cell_type": "markdown", "id": "662d4ff4-1709-412f-bfed-5eb2b8d3d3dc", "metadata": {}, "source": [ "## LLMs\n", "\n", "### Code solution\n", "\n", "Try OpenAI and [Claude3](https://docs.anthropic.com/en/docs/about-claude/models) with function calling.\n", "\n", "Create `code_gen_chain` w/ either OpenAI or Claude and test here." ] }, { "cell_type": "code", "execution_count": 8, "id": "3ba3df70-f6b4-4ea5-a210-e10944960bc6", "metadata": {}, "outputs": [], "source": [ "from langchain_core.prompts import ChatPromptTemplate\n", "from langchain_core.pydantic_v1 import BaseModel, Field\n", "from langchain_openai import ChatOpenAI\n", "\n", "### OpenAI\n", "\n", "# Grader prompt\n", "code_gen_prompt = ChatPromptTemplate.from_messages(\n", " [\n", " (\n", " \"system\",\n", " \"\"\"You are a coding assistant with expertise in LCEL, LangChain expression language. \\n \n", " Here is a full set of LCEL documentation: \\n ------- \\n {context} \\n ------- \\n Answer the user \n", " question based on the above provided documentation. Ensure any code you provide can be executed \\n \n", " with all required imports and variables defined. Structure your answer with a description of the code solution. \\n\n", " Then list the imports. And finally list the functioning code block. Here is the user question:\"\"\",\n", " ),\n", " (\"placeholder\", \"{messages}\"),\n", " ]\n", ")\n", "\n", "\n", "# Data model\n", "class code(BaseModel):\n", " \"\"\"Code output\"\"\"\n", "\n", " prefix: str = Field(description=\"Description of the problem and approach\")\n", " imports: str = Field(description=\"Code block import statements\")\n", " code: str = Field(description=\"Code block not including import statements\")\n", " description = \"Schema for code solutions to questions about LCEL.\"\n", "\n", "\n", "expt_llm = \"gpt-4-0125-preview\"\n", "llm = ChatOpenAI(temperature=0, model=expt_llm)\n", "code_gen_chain = code_gen_prompt | llm.with_structured_output(code)\n", "question = \"How do I build a RAG chain in LCEL?\"\n", "#solution = code_gen_chain_oai.invoke({\"context\":concatenated_content,\"messages\":[(\"user\",question)]})" ] }, { "cell_type": "code", "execution_count": 9, "id": "9f14750f-dddc-485b-ba29-5392cdf4ba43", "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "code(prefix='Build a RAG Chain in LCEL', imports='from langchain import LCEL\\nfrom langchain.retrievers import YourRetriever\\nfrom langchain.llms import YourLLM\\nfrom langchain.output_parsers import YourOutputParser', code='# Define your retriever\\nretriever = YourRetriever(...)\\n\\n# Define your LLM\\nllm = YourLLM(...)\\n\\n# Define your output parser (optional)\\noutput_parser = YourOutputParser(...)\\n\\n# Build the RAG chain\\nrag_chain = LCEL.chain(retriever | llm | output_parser)\\n\\n# Example usage\\nresult = rag_chain.invoke(\"Your query here\")\\nprint(result)', description=\"This code snippet demonstrates how to build a Retrieval Augmented Generation (RAG) chain using the LangChain Expression Language (LCEL). The process involves defining a retriever, a language model (LLM), and optionally an output parser. These components are then chained together using the `LCEL.chain` method to create the RAG chain. The `invoke` method is used to execute the chain with a query, and the result is printed out. Note that `YourRetriever`, `YourLLM`, and `YourOutputParser` are placeholders for specific implementations of these components that you would need to replace based on your application's requirements.\")" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Test\n", "question = \"How do I build a RAG chain in LCEL?\"\n", "solution = code_gen_chain.invoke(\n", " {\"context\": concatenated_content, \"messages\": [(\"user\", question)]}\n", ")\n", "solution" ] }, { "cell_type": "markdown", "id": "131f2055-2f64-4d19-a3d1-2d3cb8b42894", "metadata": {}, "source": [ "## State \n", "\n", "Our state is a dict that will contain keys (errors, question, code generation) relevant to code generation." ] }, { "cell_type": "code", "execution_count": 10, "id": "c185f1a2-e943-4bed-b833-4243c9c64092", "metadata": {}, "outputs": [], "source": [ "from typing import List, TypedDict\n", "\n", "\n", "class GraphState(TypedDict):\n", " \"\"\"\n", " Represents the state of our graph.\n", "\n", " Attributes:\n", " error : Binary flag for control flow to indicate whether test error was tripped\n", " messages : With user question, error messages, reasoning\n", " generation : Code solution\n", " iterations : Number of tries\n", " \"\"\"\n", "\n", " error: str\n", " messages: List\n", " generation: str\n", " iterations: int" ] }, { "cell_type": "markdown", "id": "64454465-26a3-40de-ad85-bcf59a2c3086", "metadata": {}, "source": [ "## Graph \n", "\n", "Our graph lays out the logical flow shown in the figure above." ] }, { "cell_type": "code", "execution_count": 11, "id": "b70e8301-63ae-4f7e-ad8f-c9a052fe3566", "metadata": {}, "outputs": [], "source": [ "from langchain_core.pydantic_v1 import BaseModel, Field\n", "\n", "### Parameter\n", "\n", "# Max tries\n", "max_iterations = 3\n", "# Reflect\n", "# flag = 'reflect'\n", "flag = \"do not reflect\"\n", "\n", "### Nodes\n", "\n", "\n", "def generate(state: GraphState):\n", " \"\"\"\n", " Generate a code solution\n", "\n", " Args:\n", " state (dict): The current graph state\n", "\n", " Returns:\n", " state (dict): New key added to state, generation\n", " \"\"\"\n", "\n", " print(\"---GENERATING CODE SOLUTION---\")\n", "\n", " # State\n", " messages = state[\"messages\"]\n", " iterations = state[\"iterations\"]\n", " error = state[\"error\"]\n", "\n", " # We have been routed back to generation with an error\n", " if error == \"yes\":\n", " messages += [\n", " (\n", " \"user\",\n", " \"Now, try again. Invoke the code tool to structure the output with a prefix, imports, and code block:\",\n", " )\n", " ]\n", "\n", " # Solution\n", " code_solution = code_gen_chain.invoke(\n", " {\"context\": concatenated_content, \"messages\": messages}\n", " )\n", " messages += [\n", " (\n", " \"assistant\",\n", " f\"{code_solution.prefix} \\n Imports: {code_solution.imports} \\n Code: {code_solution.code}\",\n", " )\n", " ]\n", "\n", " # Increment\n", " iterations = iterations + 1\n", " return {\"generation\": code_solution, \"messages\": messages, \"iterations\": iterations}\n", "\n", "\n", "def code_check(state: GraphState):\n", " \"\"\"\n", " Check code\n", "\n", " Args:\n", " state (dict): The current graph state\n", "\n", " Returns:\n", " state (dict): New key added to state, error\n", " \"\"\"\n", "\n", " print(\"---CHECKING CODE---\")\n", "\n", " # State\n", " messages = state[\"messages\"]\n", " code_solution = state[\"generation\"]\n", " iterations = state[\"iterations\"]\n", "\n", " # Get solution components\n", " imports = code_solution.imports\n", " code = code_solution.code\n", "\n", " # Check imports\n", " try:\n", " exec(imports)\n", " except Exception as e:\n", " print(\"---CODE IMPORT CHECK: FAILED---\")\n", " error_message = [(\"user\", f\"Your solution failed the import test: {e}\")]\n", " messages += error_message\n", " return {\n", " \"generation\": code_solution,\n", " \"messages\": messages,\n", " \"iterations\": iterations,\n", " \"error\": \"yes\",\n", " }\n", "\n", " # Check execution\n", " try:\n", " exec(imports + \"\\n\" + code)\n", " except Exception as e:\n", " print(\"---CODE BLOCK CHECK: FAILED---\")\n", " error_message = [(\"user\", f\"Your solution failed the code execution test: {e}\")]\n", " messages += error_message\n", " return {\n", " \"generation\": code_solution,\n", " \"messages\": messages,\n", " \"iterations\": iterations,\n", " \"error\": \"yes\",\n", " }\n", "\n", " # No errors\n", " print(\"---NO CODE TEST FAILURES---\")\n", " return {\n", " \"generation\": code_solution,\n", " \"messages\": messages,\n", " \"iterations\": iterations,\n", " \"error\": \"no\",\n", " }\n", "\n", "\n", "def reflect(state: GraphState):\n", " \"\"\"\n", " Reflect on errors\n", "\n", " Args:\n", " state (dict): The current graph state\n", "\n", " Returns:\n", " state (dict): New key added to state, generation\n", " \"\"\"\n", "\n", " print(\"---GENERATING CODE SOLUTION---\")\n", "\n", " # State\n", " messages = state[\"messages\"]\n", " iterations = state[\"iterations\"]\n", " code_solution = state[\"generation\"]\n", "\n", " # Prompt reflection\n", "\n", " # Add reflection\n", " reflections = code_gen_chain.invoke(\n", " {\"context\": concatenated_content, \"messages\": messages}\n", " )\n", " messages += [(\"assistant\", f\"Here are reflections on the error: {reflections}\")]\n", " return {\"generation\": code_solution, \"messages\": messages, \"iterations\": iterations}\n", "\n", "\n", "### Edges\n", "\n", "\n", "def decide_to_finish(state: GraphState):\n", " \"\"\"\n", " Determines whether to finish.\n", "\n", " Args:\n", " state (dict): The current graph state\n", "\n", " Returns:\n", " str: Next node to call\n", " \"\"\"\n", " error = state[\"error\"]\n", " iterations = state[\"iterations\"]\n", "\n", " if error == \"no\" or iterations == max_iterations:\n", " print(\"---DECISION: FINISH---\")\n", " return \"end\"\n", " else:\n", " print(\"---DECISION: RE-TRY SOLUTION---\")\n", " if flag == \"reflect\":\n", " return \"reflect\"\n", " else:\n", " return \"generate\"" ] }, { "cell_type": "code", "execution_count": 12, "id": "f66b4e00-4731-42c8-bc38-72dd0ff7c92c", "metadata": {}, "outputs": [], "source": [ "from langgraph.graph import END, StateGraph, START\n", "\n", "workflow = StateGraph(GraphState)\n", "\n", "# Define the nodes\n", "workflow.add_node(\"generate\", generate) # generation solution\n", "workflow.add_node(\"check_code\", code_check) # check code\n", "workflow.add_node(\"reflect\", reflect) # reflect\n", "\n", "# Build graph\n", "workflow.add_edge(START, \"generate\")\n", "workflow.add_edge(\"generate\", \"check_code\")\n", "workflow.add_conditional_edges(\n", " \"check_code\",\n", " decide_to_finish,\n", " {\n", " \"end\": END,\n", " \"reflect\": \"reflect\",\n", " \"generate\": \"generate\",\n", " },\n", ")\n", "workflow.add_edge(\"reflect\", \"generate\")\n", "app = workflow.compile()" ] }, { "cell_type": "code", "execution_count": 13, "id": "9bcaafe4-ddcf-4fab-8620-2d9b6c508f98", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "---GENERATING CODE SOLUTION---\n", "---CHECKING CODE---\n", "---CODE IMPORT CHECK: FAILED---\n", "---DECISION: RE-TRY SOLUTION---\n", "---GENERATING CODE SOLUTION---\n", "---CHECKING CODE---\n", "---CODE IMPORT CHECK: FAILED---\n", "---DECISION: RE-TRY SOLUTION---\n", "---GENERATING CODE SOLUTION---\n", "---CHECKING CODE---\n", "---CODE BLOCK CHECK: FAILED---\n", "---DECISION: FINISH---\n" ] }, { "data": { "text/plain": [ "{'error': 'yes',\n", " 'messages': [('user',\n", " 'How can I directly pass a string to a runnable and use it to construct the input needed for my prompt?'),\n", " ('assistant',\n", " 'Passing a string directly to a runnable in LCEL \\n Imports: from langchain_core.prompts import PromptTemplate\\nfrom langchain_core import Runnable \\n Code: # Define a custom runnable class that accepts a string and constructs the input for a prompt\\nclass StringInputRunnable(Runnable):\\n def __init__(self, prompt_template):\\n self.prompt_template = prompt_template\\n\\n async def invoke(self, input_data):\\n # Assuming input_data is a string, use it to construct the prompt input\\n prompt_input = {\\'text\\': input_data}\\n # Generate the prompt using the provided template\\n generated_prompt = self.prompt_template.invoke(prompt_input)\\n return generated_prompt\\n\\n# Example usage\\nprompt_template = PromptTemplate.from_template(\"Your prompt here with {text}\")\\nstring_input_runnable = StringInputRunnable(prompt_template)\\n\\n# Example string to pass\\ninput_string = \\'example string\\'\\n# Invoke the runnable with the string\\nresult = await string_input_runnable.invoke(input_string)\\nprint(result)'),\n", " ('user',\n", " \"Your solution failed the import test: cannot import name 'Runnable' from 'langchain_core' (/Users/mark/Library/Python/3.9/lib/python/site-packages/langchain_core/__init__.py)\"),\n", " ('user',\n", " 'Now, try again. Invoke the code tool to structure the output with a prefix, imports, and code block:'),\n", " ('assistant',\n", " 'To directly pass a string to a runnable and use it to construct the input needed for a prompt in LCEL, you can create a custom runnable class that wraps around a prompt template. This class will take a string as input, use it to fill in the template, and then proceed with the execution. Here\\'s how you can implement this: \\n Imports: from langchain.runnables import BaseRunnable\\nfrom langchain.prompts import PromptTemplate \\n Code: class StringInputRunnable(BaseRunnable):\\n def __init__(self, prompt_template_str):\\n self.prompt_template = PromptTemplate.from_template(prompt_template_str)\\n\\n async def invoke(self, input_data, **kwargs):\\n # Construct the prompt input using the input string\\n prompt_input = {\\'text\\': input_data}\\n # Generate the prompt\\n generated_prompt = self.prompt_template.invoke(prompt_input)\\n # Here you would typically pass generated_prompt to a model or another process\\n # For demonstration, we\\'ll just return the generated prompt\\n return generated_prompt\\n\\n# Example usage\\nprompt_template_str = \"This is a test prompt with input: {text}\"\\nrunnable = StringInputRunnable(prompt_template_str)\\n\\n# Simulate invoking the runnable with a string\\ninput_string = \\'example input\\'\\n# Since invoke is an async method, it should be run in an async context\\nimport asyncio\\n\\nasync def run_example():\\n result = await runnable.invoke(input_string)\\n print(f\\'Result: {result}\\')\\n\\n# Run the example\\nasyncio.run(run_example())'),\n", " ('user',\n", " \"Your solution failed the import test: cannot import name 'BaseRunnable' from 'langchain.runnables' (/Users/mark/Library/Python/3.9/lib/python/site-packages/langchain/runnables/__init__.py)\"),\n", " ('user',\n", " 'Now, try again. Invoke the code tool to structure the output with a prefix, imports, and code block:'),\n", " ('assistant',\n", " 'To directly pass a string to a runnable and construct the input for a prompt in LCEL, you can create a custom class that implements the Runnable interface. This class will accept a string, use it to fill a prompt template, and then execute the necessary logic. Here\\'s a structured solution: \\n Imports: from langchain_core.prompts import PromptTemplate\\nfrom langchain_core.runnables import Runnable \\n Code: class StringInputRunnable(Runnable):\\n def __init__(self, template):\\n self.template = PromptTemplate.from_template(template)\\n\\n async def invoke(self, input_data):\\n # Construct the prompt input using the input string\\n prompt_input = {\\'text\\': input_data}\\n # Generate the prompt\\n generated_prompt = self.template.invoke(prompt_input)\\n # Here, you would typically pass generated_prompt to a model or another process\\n # For demonstration, we\\'ll just return the generated prompt\\n return generated_prompt\\n\\n# Example usage\\nasync def run_example():\\n runnable = StringInputRunnable(\"This is a test prompt with input: {text}\")\\n input_string = \\'example input\\'\\n result = await runnable.invoke(input_string)\\n print(f\\'Result: {result}\\')\\n\\n# Since invoke is an async method, it should be run in an async context\\nimport asyncio\\nasyncio.run(run_example())'),\n", " ('user',\n", " \"Your solution failed the code execution test: name 'StringInputRunnable' is not defined\")],\n", " 'generation': code(prefix=\"To directly pass a string to a runnable and construct the input for a prompt in LCEL, you can create a custom class that implements the Runnable interface. This class will accept a string, use it to fill a prompt template, and then execute the necessary logic. Here's a structured solution:\", imports='from langchain_core.prompts import PromptTemplate\\nfrom langchain_core.runnables import Runnable', code='class StringInputRunnable(Runnable):\\n def __init__(self, template):\\n self.template = PromptTemplate.from_template(template)\\n\\n async def invoke(self, input_data):\\n # Construct the prompt input using the input string\\n prompt_input = {\\'text\\': input_data}\\n # Generate the prompt\\n generated_prompt = self.template.invoke(prompt_input)\\n # Here, you would typically pass generated_prompt to a model or another process\\n # For demonstration, we\\'ll just return the generated prompt\\n return generated_prompt\\n\\n# Example usage\\nasync def run_example():\\n runnable = StringInputRunnable(\"This is a test prompt with input: {text}\")\\n input_string = \\'example input\\'\\n result = await runnable.invoke(input_string)\\n print(f\\'Result: {result}\\')\\n\\n# Since invoke is an async method, it should be run in an async context\\nimport asyncio\\nasyncio.run(run_example())', description='How to create a custom class in LCEL that accepts a string to construct prompt input and execute logic.'),\n", " 'iterations': 3}" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "question = \"How can I directly pass a string to a runnable and use it to construct the input needed for my prompt?\"\n", "app.invoke({\"messages\": [(\"user\", question)], \"iterations\": 0})" ] }, { "cell_type": "markdown", "id": "744f48a5-9ad3-4342-899f-7dd4266a9a15", "metadata": {}, "source": [ "## Eval" ] }, { "cell_type": "markdown", "id": "89852874-b538-4c8d-a4c3-1d68302db492", "metadata": {}, "source": [ "[Here](https://smith.langchain.com/public/326674a6-62bd-462d-88ae-eea49d503f9d/d) is a public dataset of LCEL questions. \n", "\n", "I saved this as `test-LCEL-code-gen`.\n", "\n", "You can also find the csv [here](https://github.com/langchain-ai/lcel-teacher/blob/main/eval/eval.csv)." ] }, { "cell_type": "code", "execution_count": 14, "id": "678e8954-56b5-4cc6-be26-f7f2a060b242", "metadata": {}, "outputs": [ { "ename": "LangSmithUserError", "evalue": "API key must be provided when using hosted LangSmith API", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mLangSmithUserError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[14], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mlangsmith\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m client \u001b[38;5;241m=\u001b[39m \u001b[43mlangsmith\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mClient\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/langsmith/client.py:559\u001b[0m, in \u001b[0;36mClient.__init__\u001b[0;34m(self, api_url, api_key, retry_config, timeout_ms, web_url, session, auto_batch_tracing, anonymizer, hide_inputs, hide_outputs, info, api_urls)\u001b[0m\n\u001b[1;32m 557\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mapi_url \u001b[38;5;241m=\u001b[39m _get_api_url(api_url)\n\u001b[1;32m 558\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mapi_key \u001b[38;5;241m=\u001b[39m _get_api_key(api_key)\n\u001b[0;32m--> 559\u001b[0m \u001b[43m_validate_api_key_if_hosted\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapi_url\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapi_key\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 560\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_write_api_urls \u001b[38;5;241m=\u001b[39m {\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mapi_url: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mapi_key}\n\u001b[1;32m 561\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mretry_config \u001b[38;5;241m=\u001b[39m retry_config \u001b[38;5;129;01mor\u001b[39;00m _default_retry_config()\n", "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/langsmith/client.py:325\u001b[0m, in \u001b[0;36m_validate_api_key_if_hosted\u001b[0;34m(api_url, api_key)\u001b[0m\n\u001b[1;32m 323\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m api_key:\n\u001b[1;32m 324\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _is_langchain_hosted(api_url):\n\u001b[0;32m--> 325\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m ls_utils\u001b[38;5;241m.\u001b[39mLangSmithUserError(\n\u001b[1;32m 326\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAPI key must be provided when using hosted LangSmith API\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 327\u001b[0m )\n", "\u001b[0;31mLangSmithUserError\u001b[0m: API key must be provided when using hosted LangSmith API" ] } ], "source": [ "import langsmith\n", "\n", "client = langsmith.Client()" ] }, { "cell_type": "code", "execution_count": null, "id": "ef7cf662-7a6f-4dee-965c-6309d4045feb", "metadata": {}, "outputs": [], "source": [ "# Clone the dataset to your tenant to use it\n", "public_dataset = (\n", " \"https://smith.langchain.com/public/326674a6-62bd-462d-88ae-eea49d503f9d/d\"\n", ")\n", "client.clone_public_dataset(public_dataset)" ] }, { "cell_type": "markdown", "id": "9d171396-022b-47ec-a741-c782aff9fdae", "metadata": {}, "source": [ "Custom evals." ] }, { "cell_type": "code", "execution_count": null, "id": "455a34ea-52cb-4ae5-9f4a-7e4a08cd0c09", "metadata": {}, "outputs": [], "source": [ "from langsmith.schemas import Example, Run\n", "\n", "\n", "def check_import(run: Run, example: Example) -> dict:\n", " imports = run.outputs.get(\"imports\")\n", " try:\n", " exec(imports)\n", " return {\"key\": \"import_check\", \"score\": 1}\n", " except Exception:\n", " return {\"key\": \"import_check\", \"score\": 0}\n", "\n", "\n", "def check_execution(run: Run, example: Example) -> dict:\n", " imports = run.outputs.get(\"imports\")\n", " code = run.outputs.get(\"code\")\n", " try:\n", " exec(imports + \"\\n\" + code)\n", " return {\"key\": \"code_execution_check\", \"score\": 1}\n", " except Exception:\n", " return {\"key\": \"code_execution_check\", \"score\": 0}" ] }, { "cell_type": "markdown", "id": "c90bf261-0d94-4779-bbde-c76adeefe3d7", "metadata": {}, "source": [ "Compare LangGraph to Context Stuffing." ] }, { "cell_type": "code", "execution_count": null, "id": "c8fa6bcb-b245-4422-b79a-582cd8a7d7ea", "metadata": {}, "outputs": [], "source": [ "def predict_base_case(example: dict):\n", " \"\"\"Context stuffing\"\"\"\n", " solution = code_gen_chain.invoke(\n", " {\"context\": concatenated_content, \"messages\": [(\"user\", example[\"question\"])]}\n", " )\n", " solution_structured = code_gen_chain.invoke([(\"code\", solution)])\n", " return {\"imports\": solution_structured.imports, \"code\": solution_structured.code}\n", "\n", "\n", "def predict_langgraph(example: dict):\n", " \"\"\"LangGraph\"\"\"\n", " graph = app.invoke({\"messages\": [(\"user\", example[\"question\"])], \"iterations\": 0})\n", " solution = graph[\"generation\"]\n", " return {\"imports\": solution.imports, \"code\": solution.code}" ] }, { "cell_type": "code", "execution_count": null, "id": "d9c57468-97f6-47d6-a5e9-c09b53bfdd83", "metadata": {}, "outputs": [], "source": [ "from langsmith.evaluation import evaluate\n", "\n", "# Evaluator\n", "code_evalulator = [check_import, check_execution]\n", "\n", "# Dataset\n", "dataset_name = \"test-LCEL-code-gen\"" ] }, { "cell_type": "code", "execution_count": null, "id": "2dacccf0-d73f-4017-aaf0-9806ffe5bd2c", "metadata": {}, "outputs": [], "source": [ "# Run base case\n", "experiment_results_ = evaluate(\n", " predict_base_case,\n", " data=dataset_name,\n", " evaluators=code_evalulator,\n", " experiment_prefix=f\"test-without-langgraph-{expt_llm}\",\n", " max_concurrency=2,\n", " metadata={\n", " \"llm\": expt_llm,\n", " },\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "71d90f9e-9dad-410c-a709-093d275029ae", "metadata": {}, "outputs": [], "source": [ "# Run with langgraph\n", "experiment_results = evaluate(\n", " predict_langgraph,\n", " data=dataset_name,\n", " evaluators=code_evalulator,\n", " experiment_prefix=f\"test-with-langgraph-{expt_llm}-{flag}\",\n", " max_concurrency=2,\n", " metadata={\n", " \"llm\": expt_llm,\n", " \"feedback\": flag,\n", " },\n", ")" ] }, { "cell_type": "markdown", "id": "d69da747-b4ea-455d-9314-60c3d9d30549", "metadata": {}, "source": [ "`Results:`\n", "\n", "* `LangGraph outperforms base case`: adding re-try loop improve performance\n", "* `Reflection did not help`: reflection prior to re-try regression vs just passing errors directly back to the LLM\n", "* `GPT-4 outperforms Claude3`: Claude3 had 3 and 1 run fail due to tool-use error for Opus and Haiku, respectively\n", "\n", "https://smith.langchain.com/public/78a3d858-c811-4e46-91cb-0f10ef56260b/d" ] }, { "cell_type": "code", "execution_count": null, "id": "a42333c3-c098-4576-ae2a-0258de64ece2", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.6" } }, "nbformat": 4, "nbformat_minor": 5 }