File size: 2,276 Bytes
f5ec828
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b56870fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q git+https://github.com/srush/MiniChain\n",
    "!git clone git+https://github.com/srush/MiniChain; cp -fr MiniChain/examples/* . "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bef4e04e",
   "metadata": {},
   "source": [
    "Answer a math problem with code.\n",
    "Adapted from Dust [maths-generate-code](https://dust.tt/spolu/a/d12ac33169)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bcf8e75",
   "metadata": {},
   "outputs": [],
   "source": [
    "from minichain import Backend, JinjaPrompt, Prompt, start_chain, SimplePrompt, show_log"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "715fd76e",
   "metadata": {},
   "source": [
    "Prompt that asks LLM for code from math."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72e3d463",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "class ColorPrompt(Prompt[str, bool]):\n",
    "    def parse(inp: str) -> str:\n",
    "        return f\"Answer 'Yes' if this is a color, {inp}. Answer:\"\n",
    "\n",
    "    def parse(out: str, inp) -> bool:\n",
    "        # Encode the parsing logic\n",
    "        return out.strip() == \"Yes\"  \n",
    "ColorPrompt().show({\"inp\": \"dog\"}, \"No\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77cbf319",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "with start_chain(\"color\") as backend:\n",
    "    question = 'What is the sum of the powers of 3 (3^i) that are smaller than 100?'\n",
    "    prompt = MathPrompt(backend.OpenAI()).chain(SimplePrompt(backend.Python()))\n",
    "    result = prompt({\"question\": question})\n",
    "    print(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72cdff4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_log(\"math.log\")"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "main_language": "python",
   "notebook_metadata_filter": "-all"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}