From c9ab9104e554e4f24b55f69cf24b784091c7bfb1 Mon Sep 17 00:00:00 2001 From: Dimitri Kennedy Date: Wed, 31 Jan 2024 23:07:14 -0500 Subject: [PATCH] providers + prep (#99) --- .changeset/selfish-birds-appear.md | 5 ++ .example.env | 3 +- .github/workflows/test.yml | 1 + bun.lockb | Bin 274869 -> 274805 bytes package.json | 3 +- src/constants/modes.ts | 35 --------- src/constants/providers.ts | 67 ++++++++++++++++++ src/dsl/validator.ts | 6 +- src/instructor.ts | 57 ++++++++++++--- src/oai/params.ts | 100 -------------------------- src/oai/parser.ts | 110 ----------------------------- src/types/index.ts | 14 ++-- tests/mode.test.ts | 93 ++++++++++++++++++------ 13 files changed, 201 insertions(+), 293 deletions(-) create mode 100644 .changeset/selfish-birds-appear.md delete mode 100644 src/constants/modes.ts create mode 100644 src/constants/providers.ts delete mode 100644 src/oai/params.ts delete mode 100644 src/oai/parser.ts diff --git a/.changeset/selfish-birds-appear.md b/.changeset/selfish-birds-appear.md new file mode 100644 index 00000000..7cd50353 --- /dev/null +++ b/.changeset/selfish-birds-appear.md @@ -0,0 +1,5 @@ +--- +"@instructor-ai/instructor": patch +--- + +Adding explicit support for non-oai providers - currently anyscale and together ai - will do explicit checks on mode selected vs provider and model diff --git a/.example.env b/.example.env index 43340ada..4fa45c0a 100644 --- a/.example.env +++ b/.example.env @@ -1,2 +1,3 @@ OPENAI_API_KEY= -ANYSCALE_API_KEY= \ No newline at end of file +ANYSCALE_API_KEY= +TOGETHER_API_KEY= \ No newline at end of file diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c5ad0ed2..a94e4e86 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -14,6 +14,7 @@ jobs: env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} ANYSCALE_API_KEY: ${{ secrets.ANYSCALE_API_KEY }} + TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }} steps: - uses: actions/checkout@v3 diff --git a/bun.lockb b/bun.lockb index 0d79edfd0f1206f80bc26a17d8f0b728d1a9669d..10705c0082b5b71fc7616ba9a70092295a09d819 100755 GIT binary patch delta 17428 zcmeI4d3Y36y2h)j38Vu7LRbGlS7RjZPB2{us848u=InxuP+~YbHxuQT)z05 zjqe}MT6#rx`b5g5?Kf6%oFvNo@uj=bDrn)Tc53dluIM;n$B9;RQzP4{t4iI~l^mzC+Dr|a zIP3ANQ%;UAiSdxddYUqQiz@Je8={as{G@Oy+j*aGJ zqpO=24QJ-KGopE!Io{q1(K|A`x|z{kffdfm@wUfB`($+uhuAzBCCkd{;C71c%F1!4 zMZ>Lg;vS)W`0h_QZo*PZd=-{r1=3SH%{sUfbyijy^ zW`}TNHc}r-jENQ#qTx0<;f;8m@G7W0CleYoAsMqI!}oc<-1fwD)OtF0WXTHFv&?`ur2MU~3ZIQ)Fja7TbrtSZh( zZ5L%q-ryyw4%dT{ltalG4;sKkQ2R-sB$aRP%Acs(>Fp)G36!L2KON-Hu>C@&q+%AT zRF-OBw+6K(sz7rf31Edq7;mDNk-FSc^5a$<4r zHUC^Sa;^=}rJdeOD{KH|sdnXRt2gcX6GI2hS{L;Un?-x>SI;bH1kU}qKwwyW>mKX51=|ztwQ;89@dMgva5N~0@k1^ zKdzU5P+3oyFID_?sQNvPs%m|SA90>ji6m8n4X6h25~`{#rdv(7q1rM#th^Idl4?P( zn!gKGoHtR$-DCCdnC=af{lAZll2idcG(BK?2-O0QpbGR?R7tA#pPQaAU#fQhX=SOh zr%b;#UmDVB28a>GoIk306)-b!Sr(<@6?cM_V>CFGmsJsTm}ng* zp&HOM3lOVX>DlPEFt9Mg9J0{bE(VK7Fg|ZL62c|C5zV zl>M)aHud{f5v%I(XY-{h??V&OPtBL=pgL;x$ILHF)&FyAe?ska{3)aSmwGW(!>`Og zW&XLSntsEJc7>i_D9+zZ�VtHJyK2Wt?lhFWNKlu!k#U>c7qMgpq9l~Edm32V1>VA%|F5s{RF5mMS~S%4O&^dZOb@ zum-WJz!UMMlTfYW6jV)bvhvMVz6DjgnW{8(qy0Xw6I^DB%#BrT+4<&6t5J?xxhz%H zZ+X!czYo>Y?l*k^)#h1c<%dxvG!pE>$1vnQVHN+1YW3IBzXAFNs%y-jO!rybpR1kX zRiQQ?@}iOd#q=XoYjgzViJ5Z})q=i4mHsEHey8ZC_+OiTqwO!_TMH;v_O$8W&6ld( z4_20H7ln9{UqOX)R1Kse?G-8#wg3sHl`W7|4HK;#tFn@;J{eWHj$TaFFO{+ecAoiC zZNVm1HjM;^IyAEm&8fS$Tv^l2ma=TKPxi{h)X17^{d?owz1h$H`VN)j+12KMmCzZ>E)J zna)O)q+0N;R+egc^HBA>BVw7QR`FXjIHjQKct5H$@#CmcS*kT&YxPnM@F^>+j~D~GNKm1X{qH@MEm?3UM9b+6+Cp7`Q~ZMcRm-2w2AiV% z>Fl{YebuHqTRXj}Rs`F#Y}K24wbh%JKb^Jh@vlEzwcDD1JwV9_&emV&j`F9o+D_$9 zXZ4g;Z}jq~v*k}`%b(6R<#tME$MUDMY=lW{44oRL=^pbRwEXF;wn+KY+485edU~rj zNcq#*@~5*UPif1a&N};oFJXCCls}yz+Kbe5;J|KCq%2XiN<$F%?Ur?bT~;@!sX zC3lXvZ$`?=LEGZi-1bflXXoi-3)2qgu1n6}5p}jNt1R$YK9TL+-ut}M z?vJx39DC;BuUZ|Nzvqk8$)m1*&fgz)Q~Z7jZeKUfU!LF&^6yD-6a6zXD*1VpF?v_V zSX&t*?0+ZYTN%TuU{v@F`QfS<@l`Pjt70ViFUWXa zM!jkn$^O`C7-Onoydk5OpIRLwr8>s!>KJwWT{2#kk)DK+>d#2Rn4W~OUq)TOMGcJR zH82*}z^LcHFJrHaF3A}6{YA+bwU2XV-M^ITGE{szULn-@9YH=#o}7 zi?2B2uG~n7lh1p}T^h3Ul=8XR&_@*Q%yoKRPSQjUd3w&j!MqXX=}vv4pZfwKUbm{z zR;4d>O26OK+zy4P0+(958_d&F`#gWCLbSB1(y1T%Zt*gj$=37_edAdj`jJT|$G}}Wfq%Rb%HgBP|t0&*ztq}FCYLQjdr>cQ@ znh?UD(*R1p;QN)eOT&A^Puoa{KicNJ+o~E-^|pCS%xjFd&^%2K;V+nQe$t<;5Z_YN z4*HE%H6#C>dA@ngDW3tQ<#<}*7Ap8#6`~?#g}B$MG)=u7OVs?}`^;-aULSX~!|z8m zpmg}PpR|b(_gK|~D1S~SD4d$FLbXJFU$Rn_q=!*$Dt%S5$}dm|SxRfHc^mSp$t%R; zsH)q-W9B_+p60f%^?$Drzf&91Q>X^m0S1!S5&5)v+2o%G9W&43@fS>kzuk8?6Jh{G z9VRc}YtWrx3L`xa)m$2R7ehXtj+u?-={uOg{s4v0*Pl8THd~dB))rb>lBU%t#3hho z-b?1?;wAYHDulk{Z9s9WRdpw?36^Q-HuHLr?`Gb1^LpYP_D?FrgKU7t6gBULR?el+ zpG*_{M=eK@|5Zamiyd;?NL)IAt8kvrd53KoM@}(2lv`B>=0#SdCLTJ5p z=={a1t|c#&KEmVA84AOo6KI4Aq4IElw?b&ffexKRRyCZwJbvU73|`IiK@X*sj##?_ z@~M8>7DCkIxX7ibxk+j@5~f0T_{_Xf}F^jA= zU+JgkiR3kKC4G|CcT1B%+ou3MV<9G!|K2Z92vuplN@qD~z*_3{pm)EPI1RMtCV*?+ zP2{ilf3FZ)cP%+&RX3B@y34Cz-gNRenx|icDeMdw=eyeoq1hf9RM@I!lGjK?>3k}! ztVX78np33^vtg<~fDrCffAlub57)!F0P6c^wt01G=H$H{gDbFR+r`X!Uz#kvOR#L)tpfMx?|?!yXL1*>5RXj;N_(Di&G=o+qTwx&-_hZ&(p zoMABfKwr2FG|S^s=nb0PQrADT!)q4#hW3dRt3oxX4ju7y0gr<>$-f1D4Jb zL-q&I0px-QA&7%8RE8>$2vwmP)PR~0WrzM2bW5OH0A1Htf-cw(fv&!fz@wlm@Z)~& z%U<1Dx@z`^D_{Uz30L`(U-lYBt|L1f@}VcJW?+v&ck;hO+oA2D17t%-XafyFGu`#W z3{9<{0-B1i>G{{fC-mJ1AHaw35gdSnupM^5PF5uO6>mVh7bt9ijj#ze!;7#5R=`Sl z819Ap;4Zk^U;c`h8M%V&K)4cg{m|K~GqeL;EOc?u9gZ$bna~WH!xNxMGMZ2`1ZJ|M zG)Ye<4bA%L2i>6u^aM>w>Iyk<320VR2Td2!JwX#_3XNHD-47%|CxTT%;~|(dhPvRv z0qPEdX1i&M+&l0tya#*Xeb5ZT2Ve!P1Rw5!c?@tqEPw&}l|~*5{XK=Jp(puXa4FnM zS=V&^B%~9xfHcrgY|e+;PzTO~#q7<-1aAUOp&2xUWT*+XKo|TvkODdj>$feL`jy5- zu#W&Af-bBdfi9dMgRYXFfTr>4GOx?}d$1Swz&6lrkM3J`z_ai?tcL}#5N?M>FdycD zzR9==ehI^F=4A%l0yAM2=mwwwM!;GIryn%D5Bep?beI7%VHV5=twa&bfq5_=7QjM3 z?KLksX))Qm;BHt7%lw|NapY|w+n6OZf%@bdz>{qAL$DQ!VG&#lL*Y6Y4!V}=Hl-uz zUPM#0X0bi9(T;E-q=V+;=_aBws>e@Z2*FX>9@E6l<79Lnpew8{UYa`hFzAQFkHDj# zdxB@FTMznjxFuW&t>AhFJ{E3-aWEdL(5D)N;Ty`|Lwnll;;;LPFF;eR-`3pxm&v>c z`j`}fBA5enp%TPH9GoV=F?b7#;RRUBfJVV+7y?~bGtCUlg4VEx0X&D_fwo@}L-V)w z1DK7l3D&XvV9H^03Ny%`5BYSujD`hxr-Gqb2B5LoR*0!C?mSXLt}Eg8q0{0rgyv$!IGSL09tG&=5vYzLNpyX6|li zL`U6u{T)t$?v?k$VK@Ssi?5F;Z@`=I3cLn8;Td=ibnhJ-GomGH!FOn3Fbot(_j9G1 z!J23qs_u8IfbOQcGr%LX`4ql@RiHs%4W;#SX_rHJDe;y-1KRw6eg_(aR_eG8xX;We zT~mFu+X(735k4cxT!PJl*)Rg~;S>Rde$H1qVBKVXkEgP3^9E5TPq%;nNi4Ztdb3d%Wa0bn$y$=Y&Kp^xFd^tMOv- zI|sbsUN|cwD?=COU;2X%d2P<`NiAb6Dr$H%?as-aS9$Jr7iF~0$kKbu-*U)H@xMIe z^$IVi8AC35$j?3OrF!cD@Mi{yGl*CA4U-gPrv$ z%CE7o$+~5?%=g^Z867e%VgSc&$wS^v-8ufH1#5!7ZB%vI^FO3--8l9h+p%cKV@F@< z@!;z>(}H?MU+pIy@%nj3Lw^1dFD0Y&|4U5i^TP2~d()rLLRe20^79a{SMVP>;!UV) z|8t^f`N7K${-*uw?L7C^)~q<2=4wCl6R!7*aF}B4(*G(^ zbNa1y8?XH6zMyxOKmL%H>aRHD#T8%lskf(+yV{RC5sb6=suSL?#x*+4I&(A>E%|!& z$_@2GbAzqjP9u%i#}86(KQuIyTiknS=n^+!x;~GP_Gc9@8x~sW`aOn+*7-U4p_In$ zGp5~e%hZu$Mi&;OPn|YpbU`8iv?G1m#Pl1dPMnZlQsM7j6gt27w){}jc>nnDP`~0Q J#)J}f{TI*x-^&01 delta 17600 zcmeI4cX$-l`p0J`iDVN3gouEEB%vb-gkk~#0@6ZL5DTF6VvuGiA}DNvAYH*#1}mT< zR;(BhdjavY$U*4_#stqGo-;+P7aLtWtQt!B7QlqR@19F0k zUB@}raf;?l9qXuEK)pKkNoWE(_55*DMon;>(_%{7<)iX{ZJHBX+J45gwkN57F8(Cq zk@oR%j+0EiIezI0Xk~Qr*mi2}IrBoUTYo||+*NSjRLOA?6`rmr1%EZ_BjZcEf8~|x zQR)Hw19{PKCQq0=E@SLy=Va&2!g=mbA^g zKfc|?8-ul&T)`>h)CLu)>c<0n>eL6e;xq*nDf8(xH*JBcM62S=L%5WsiZi!Cog_}1 z(st#l2Gka$;nUAl_3+tis>C^brfs&^`_qQO$@H%IB7-}C{j)=X9{Qlg`oD+ zKt(G5OjRfm4a!i%3ruIBDpGZr1Bzb+DpL99sX{4BHLxNDDMg_6mw@st02L|qptCSs zb;?wATntLM)Y_G$O1uoz{xWMPmH%?9M{5;)*N6Sgf2r}m)mS#ZUIQ2qF@C2ymK#B( zEY&%4Cun9ifr``(joVXyg7*79)!=?mk*f7(kpDqYk;=aXy%qCcHSJG$Gc<2dmcAA=nGKt<|?YJ5;XsM$Rr z7o|MaF+T*2p{7!q%3sgwQst|U%9U#Vax@rjm}VUtSckGyp@tSOPt~N6wUf$h%!>xr z)cjJJ&0LO;jOOMzQ5A|*W(!{QzHMpsR;c#)Bn3@nW?Ma4wZZo0m&)v5^RPK)uq}{SM#52eyPlEyeQto>OJM8NcBDs_OU=;3zVlC$SKw?T6Ot7 z%luMpaER5VGKX3{T4@(_&LN-vKI$~D6L(WiB|2vZ2V2o zb!eO}u^TYd_7=UE%G}6{HgE^3`ki|Doyv8W`K8*xJ*fIUfGV`v{99C^NELq=jnn7D z6Bvp;Y5J7u)2L=`Jb>a9^t zakgnkRA+q;R2%Gxsz_z#TRmFk>SOVK7VoFeUnS^|YQqCjo!w(l?a_Etr7Tsbz~WNP z%v7sOmG^wB|0fMb2h^d^I!;41pbIQPv}(^aaD|F_QNQ`tt}NA#EVTIVlv}K_1LgX) z!Aq^5RP{xsi%|`Hxy8#;wOQ`F*k7ye0H(S34t)Eoa zkcZLm`S++bc+47WDtci}rN8^(#-+4--{QtCiv}&f`(feNo<5yLOcdaf}eK*Q?f9EsvOLbCxVev1`UzV!> zSJwV(YbTZYugLrVU)0p_TMHakl=4SZO@H7;hvEpTJinM8Rge;`a=DCN?Y!{kuMF*u zsS+JSU7CQZLuKobWbLH#SGBrSW_7DeW!B(Dhp-N+_VujZ!0N$97#c}qR7I){>#>ST zS*nDsh$~@hYbVv6W(QRH^teVvs(MG$PUe@=Kj`FP=-Bo{mEd%15Un~{hTvDCp%#~_ zezw)6>OaQnQkmneUX~{4{5#(oM5_`P;+IZCwWl*sHJxSk*;c<8RlB(gnyNm}{ETTm6LwsSiw*Im4X z-+wYx{)}oy?xtTV`Vy*(&p%CfTfXvCP4>jFeT{rCHR%Va_UIFohic9@s5bO1s`6)4 z{SMPldA~RP!SqMV_hStEFZ&2DVV~wl1yx7a>Qb3;=0C>#e?&Eqc-kv}qV=zAniRA| zNvIlDwR*J5m16N4sOq)#Vyb?1sB6G=%`epqHnDoNs(mwSKdiYmh*pg>!~9a2Syq?I zY-e?;HhhxRr7}D4;)UiOEd7Sr#iC`Yw${z!<*A06rXzOspzSdBx;r6ox zr&@i0#ieej|Dl>eZElcxqgChU5c~=a<3+<5Y4x#kQl#3-IIBmi-kIl_KU($KHO=}> zk1c&;IvC;8Vx|QyK=n?WYxQ}ik<&=34Huy*Qf+V{s(#B&L*`$DhCf13^;?VT!}!kF z@CSnoYGKE3sy)8j{8G)-y;fJNT80w-<)^l`{lD_qHhf6_+o!g{zy8=(a}fO)w{(W` zO5IMh>Q4VZJhct}+v8hJ`5%64OIscD(V#;!7F41&d~)KGj#rQ(Rr>-}bpFZE$uR{q zfIs@!_HTJ=tK%R2_*Nq;0+k}rfGz>;!2(c`s=iPaN?EFYi$VFydRmJ+><=o@GEj%h zKt(El^y6CjRj&enFt$F`+!v&Kw@-lr9hZF}mpph)#TuXt=*@z}QFv2Dd;+wk*UeKu4)w!K1E3W`)S zQt{Zf;<0Uck8gFlR6Mq=cx+qo*tX)aZCQ_T_4!co*p^A4bkTQ?->II9l=rw+?G@KN z>T#`#RQ<|%TpMmjyyCGfGx6si+qMkP-yeT$TQWP|ZRq-6q_}I5A5i8zkP5%7B_ z;`m;UyAp9E`aj5VYh@h6D&t7x-Ba1vHh5{|0=kR%-ORd75aN3tKNf@6yuldIrJ z@wduRP!&hLsyJ%+6RP56__LF7G_8(f zw;c8SX4P@LDaVTHIO_Xv$+0*EM<-^_P4kzexM!8*9P6&x7h2ik*wEuGVoP58(Jk;w z1{`tMZF5uoLtDI*PK`>BkTcTmQ#E7v=uJTxWj_5D@UZztn@@`_9`U$I^KMh@%6TL=bBGX_WSy)l%j=&CR<3qH=cqm_Y^b%WqG`FDz?gnsLW)@_xCDA zXA50~>V`W7^2~R!`S=3>XPknRx#p{ZJ=M>8h!kDJ0Ul>qNQ*AJo3F@x`Zb}C`67!= z^s7Q&zgQ_I+egbK)|>^8&Kc%gfKNMA7Y3UzvP`5Nz7zaiN|9!vC6=N-p>*>t#m7G< z6(Uy&ttinxrNK*n+BQbsASVi2ZIldbqQ z=vJAI_!X8L-ojQ?~ zJo(szl#8N8JTiMjH^1wnq)_(8RR3Y2zSzZ}Q|x*3^~27y6fc;sKfcTTRZ5|gSf`G@ zQ>xXeP)mn_^0N7~TwI4H6NxMp(yv>G{k=-@KFxIML{NAUd<7}HQKc9HSNX+CQH-fmN25{bY#0VA zd+@2}IWPh`zz5b&ZASUKltQZzbn0mEN};Cm@uxA)$L1S}9eHfM&wQis)$!9FC&h7` z7hS1-VxckE`q``-d}_Y2ij!97f)+e!Q{!a&g-Y>0A(hW9bS`#e5pHB%Q~^Gv)w!Ux zQA#lZ68$HXLaUxs4q9j;wx&}ve#m^2u=NHw6aCtJld%i^pXfA&k%d!JvGqfu zioSj8N2l{a^P~MgVkz`n;?MplrBF!wb<{%Bv8O;d4KWRR1{9k=vUW;4R^Z>O6xt_k zImVjL!amP@vF5vwx^_nG;>OQX0LG%|HdG~Zn8 znf^dhxHJ86PkMf)K9Eu&4bq_j=nbqltlous*Xd{VyI@nSpZS#6JZ>x59`c7h<@HK? z0aF)vUC_^gVK5v@t=9hs~_x%*$Z@y?+Yh_F6#Z6MP0N9Kt6N=U8(b+3v`8( zLCdK}!YCLGV?kHybKyLg2$NtkOoaoC`g72SU>E&7+!do6bO)WHJwcyMI@ff))`v_# z&m+{x&bKZM@IUDqN zNGtsG*`UQir-Hu5_Jm%buekbp+!apNvK=jeY7e?oXaY?^KV<4IAq6^;EPy6}3m(Kk zEPP1pBiI9aD)R=s32(vM@D6DC;W}6kH^S9$4J>4UmpWrgS2yTppfB8y^8x6Eoe#aC zh`O%y`qPq*a3VAS{q@cXPz!29UC{k(W0E(4rqB!;`dK@@lwfVl;~@t0XE9mO7IxER z59|e9Jav`SRZ&+#UD|Y6e*@lxw?NnO*WhV*26U6M9ZEnqYMWs(ECJn(FM~_9aCjlc z1uzTr^<*|&1p4kV7v_O(1jfKvxSPT0Pa)m`{Sn9-#!6V__kWHv?*`1Z@FX;5drcq>I~_JL=lkF(*b2*F7z~GzFbW34P|&SK zE@*9+){tr?T`pumOVH9i-C87}l_3%0bmo0Q<1gVL==NzT2D*BcKnY!fhl;N2ihKee8!|DGUOAGZ_V=VGIm~ z?w~({*Vm91pzjjh=-VCkGmv-TCb$^}x*Y#A2}BwclXO0G!PfyA!dU7n7=Ui^RzV{= z>ZZ#<52NqF`>-EA0o|zU3yJReUxuBauM0ciLD&MiBQMR#@mJEwhL>w&FanfFcX<(Q z#hz%vs_ufTz)afoV1T;Y{R|Gl2GF1fLnOX{c3r85$ae*#(?)lLKVh#3@}dD8gadLw zWKSQ)egxF15Dt(;$#nP1xWh9u8a+%h(TsRxz!%{E8K3I9`5Qt^zRJ|Uvz(8jNz-|m z0mTI8`^7JMDQ8T=)XlPT>IPG)ozb>+GpXC!IWP;9Tlc2ATUHxwBcih7br>km`+HyX znz(U(;3aR6ZjyBP90uhmF47_5i!)yudOXyE8c-Fg03TycB{&A+;V2~m4L}l9hGU@` zq(F5@hU1_n*NgCXY$dM)sc-_+g?doG&&c7F6ia-Co{WJbk)$;SNH2S`qN+g_}6~q^$O_v#9kMD*sFEY3AA8(iaz_~^LsCIHl)!aE3=JSG$+C&6^*RE zsL`e?<}C2stjz38-DTv)_*3XxC!ZEfbJ5E0HmrN7Ud;UPSWe2!=Dx*W|FKuw{o3F5 zvDZIe?|;q4&CTCiHnVrLKEI73^50)f>$COC7H?+Up>2eR_v?SCG33?*JG!rbewOFf zj`IiX^QO4-{Kxi@`zHUbeO?=Pq@T1OIqG*s>O}s3i&JMcU)yfi-M@tQrd1}lO9TDO zX;3@zA731M;k-Hx*X_B^bLW+AvESE1A5$=H@+kky;W795=>;*p{p&`=)J<%cIpf?p7TWz{OnS-T5iw2TOC}b?B<%b@ DMCREe diff --git a/package.json b/package.json index 3d3c35ea..a60a991d 100644 --- a/package.json +++ b/package.json @@ -51,8 +51,7 @@ }, "homepage": "https://github.com/instructor-ai/instructor-js#readme", "dependencies": { - "zod-stream": "^0.0.5", - "zod-to-json-schema": "^3.22.3", + "zod-stream": "0.0.6", "zod-validation-error": "^2.1.0" }, "peerDependencies": { diff --git a/src/constants/modes.ts b/src/constants/modes.ts deleted file mode 100644 index ffdf69eb..00000000 --- a/src/constants/modes.ts +++ /dev/null @@ -1,35 +0,0 @@ -import { - OAIBuildFunctionParams, - OAIBuildMessageBasedParams, - OAIBuildToolFunctionParams -} from "@/oai/params" -import { - OAIResponseFnArgsParser, - OAIResponseJSONParser, - OAIResponseJSONStringParser, - OAIResponseToolArgsParser -} from "@/oai/parser" - -export const MODE = { - FUNCTIONS: "FUNCTIONS", - TOOLS: "TOOLS", - JSON: "JSON", - MD_JSON: "MD_JSON", - JSON_SCHEMA: "JSON_SCHEMA" -} as const - -export const MODE_TO_PARSER = { - [MODE.FUNCTIONS]: OAIResponseFnArgsParser, - [MODE.TOOLS]: OAIResponseToolArgsParser, - [MODE.JSON]: OAIResponseJSONStringParser, - [MODE.MD_JSON]: OAIResponseJSONParser, - [MODE.JSON_SCHEMA]: OAIResponseJSONStringParser -} - -export const MODE_TO_PARAMS = { - [MODE.FUNCTIONS]: OAIBuildFunctionParams, - [MODE.TOOLS]: OAIBuildToolFunctionParams, - [MODE.JSON]: OAIBuildMessageBasedParams, - [MODE.MD_JSON]: OAIBuildMessageBasedParams, - [MODE.JSON_SCHEMA]: OAIBuildMessageBasedParams -} diff --git a/src/constants/providers.ts b/src/constants/providers.ts new file mode 100644 index 00000000..55065af1 --- /dev/null +++ b/src/constants/providers.ts @@ -0,0 +1,67 @@ +import { MODE, type Mode } from "zod-stream" + +export const PROVIDERS = { + OAI: "OAI", + ANYSCALE: "ANYSCALE", + TOGETHER: "TOGETHER", + OTHER: "OTHER" +} as const + +export type Provider = keyof typeof PROVIDERS + +export const PROVIDER_SUPPORTED_MODES: { + [key in Provider]: Mode[] +} = { + [PROVIDERS.OTHER]: [MODE.FUNCTIONS, MODE.TOOLS, MODE.JSON, MODE.MD_JSON, MODE.JSON_SCHEMA], + [PROVIDERS.OAI]: [MODE.FUNCTIONS, MODE.TOOLS, MODE.JSON, MODE.MD_JSON], + [PROVIDERS.ANYSCALE]: [MODE.TOOLS, MODE.JSON, MODE.JSON_SCHEMA], + [PROVIDERS.TOGETHER]: [MODE.TOOLS, MODE.JSON, MODE.MD_JSON, MODE.JSON_SCHEMA] +} as const + +export const NON_OAI_PROVIDER_URLS = { + [PROVIDERS.ANYSCALE]: "api.endpoints.anyscale", + [PROVIDERS.TOGETHER]: "api.together.xyz", + [PROVIDERS.OAI]: "api.openai.com" +} as const + +export const PROVIDER_SUPPORTED_MODES_BY_MODEL = { + [PROVIDERS.OTHER]: { + [MODE.FUNCTIONS]: ["*"], + [MODE.TOOLS]: ["*"], + [MODE.JSON]: ["*"], + [MODE.MD_JSON]: ["*"], + [MODE.JSON_SCHEMA]: ["*"] + }, + [PROVIDERS.OAI]: { + [MODE.FUNCTIONS]: ["*"], + [MODE.TOOLS]: ["*"], + [MODE.JSON]: [ + "gpt-3.5-turbo-1106", + "gpt-4-1106-preview", + "gpt-4-0125-preview", + "gpt-4-turbo-preview" + ], + [MODE.MD_JSON]: ["*"] + }, + [PROVIDERS.TOGETHER]: { + [MODE.JSON_SCHEMA]: [ + "mistralai/Mixtral-8x7B-Instruct-v0.1", + "mistralai/Mistral-7B-Instruct-v0.1", + "togethercomputer/CodeLlama-34b-Instruct" + ], + [MODE.MD_JSON]: ["*"], + [MODE.TOOLS]: [ + "mistralai/Mixtral-8x7B-Instruct-v0.1", + "mistralai/Mistral-7B-Instruct-v0.1", + "togethercomputer/CodeLlama-34b-Instruct" + ] + }, + [PROVIDERS.ANYSCALE]: { + [MODE.JSON_SCHEMA]: [ + "mistralai/Mistral-7B-Instruct-v0.1", + "mistralai/Mixtral-8x7B-Instruct-v0.1" + ], + [MODE.MD_JSON]: ["*"], + [MODE.TOOLS]: ["mistralai/Mistral-7B-Instruct-v0.1", "mistralai/Mixtral-8x7B-Instruct-v0.1"] + } +} diff --git a/src/dsl/validator.ts b/src/dsl/validator.ts index 4967548a..d89b3587 100644 --- a/src/dsl/validator.ts +++ b/src/dsl/validator.ts @@ -2,8 +2,7 @@ import { OAIClientExtended } from "@/instructor" import type { ChatCompletionCreateParams } from "openai/resources/chat/completions.mjs" import { RefinementCtx, z } from "zod" -// eslint-disable-next-line @typescript-eslint/no-explicit-any -type AsyncSuperRefineFunction = (data: any, ctx: RefinementCtx) => Promise +type AsyncSuperRefineFunction = (data: string, ctx: RefinementCtx) => Promise export const LLMValidator = ( instructor: OAIClientExtended, @@ -15,7 +14,7 @@ export const LLMValidator = ( reason: z.string().optional() }) - const fn = async (value, ctx) => { + return async (value, ctx) => { const validated = await instructor.chat.completions.create({ max_retries: 0, ...params, @@ -41,5 +40,4 @@ export const LLMValidator = ( }) } } - return fn } diff --git a/src/instructor.ts b/src/instructor.ts index 279ff3a8..3c5587f7 100644 --- a/src/instructor.ts +++ b/src/instructor.ts @@ -2,21 +2,27 @@ import { ChatCompletionCreateParamsWithModel, InstructorConfig, LogLevel, - Mode, ReturnTypeBasedOnParams } from "@/types" import OpenAI from "openai" import { z } from "zod" -import ZodStream, { OAIStream, withResponseModel } from "zod-stream" +import ZodStream, { OAIResponseParser, OAIStream, withResponseModel, type Mode } from "zod-stream" import { fromZodError } from "zod-validation-error" -import { MODE, MODE_TO_PARSER } from "@/constants/modes" +import { + NON_OAI_PROVIDER_URLS, + Provider, + PROVIDER_SUPPORTED_MODES, + PROVIDER_SUPPORTED_MODES_BY_MODEL, + PROVIDERS +} from "./constants/providers" const MAX_RETRIES_DEFAULT = 0 class Instructor { readonly client: OpenAI readonly mode: Mode + readonly provider: Provider readonly debug: boolean = false /** @@ -29,11 +35,39 @@ class Instructor { this.mode = mode this.debug = debug - //TODO: probably some more sophisticated validation we can do here re: modes and otherwise. - // but just throwing quick here for now. - if (mode === MODE.JSON_SCHEMA) { - if (!this.client.baseURL.includes("anyscale")) { - throw new Error("JSON_SCHEMA mode is only support on Anyscale.") + const provider = + this.client?.baseURL.includes(NON_OAI_PROVIDER_URLS.ANYSCALE) ? PROVIDERS.ANYSCALE + : this.client?.baseURL.includes(NON_OAI_PROVIDER_URLS.TOGETHER) ? PROVIDERS.TOGETHER + : this.client?.baseURL.includes(NON_OAI_PROVIDER_URLS.TOGETHER) ? PROVIDERS.OAI + : PROVIDERS.OTHER + + this.provider = provider + + this.validateOptions() + } + + private validateOptions() { + const isModeSupported = PROVIDER_SUPPORTED_MODES[this.provider].includes(this.mode) + + if (this.provider === PROVIDERS.OTHER) { + this.log("debug", "Unknown provider - cant validate options.") + } + + if (!isModeSupported) { + throw new Error(`Mode ${this.mode} is not supported by provider ${this.provider}`) + } + } + + private validateModelModeSupport( + params: ChatCompletionCreateParamsWithModel + ) { + if (this.provider !== PROVIDERS.OAI) { + const modelSupport = PROVIDER_SUPPORTED_MODES_BY_MODEL[this.provider][this.mode] + + if (!modelSupport.includes("*") && !modelSupport.includes(params.model)) { + throw new Error( + `Model ${params.model} is not supported by provider ${this.provider} in mode ${this.mode}` + ) } } } @@ -98,9 +132,10 @@ class Instructor { this.log("debug", response_model.name, "making completion call with params: ", resolvedParams) const completion = await this.client.chat.completions.create(resolvedParams) - const parser = MODE_TO_PARSER[this.mode] - const parsedCompletion = parser(completion as OpenAI.Chat.Completions.ChatCompletion) + const parsedCompletion = OAIResponseParser( + completion as OpenAI.Chat.Completions.ChatCompletion + ) try { return JSON.parse(parsedCompletion) as z.infer } catch (error) { @@ -200,6 +235,8 @@ class Instructor { >( params: P ): Promise> => { + this.validateModelModeSupport(params) + if (this.isChatCompletionCreateParamsWithModel(params)) { if (params.stream) { return this.chatCompletionStream(params) as ReturnTypeBasedOnParams< diff --git a/src/oai/params.ts b/src/oai/params.ts deleted file mode 100644 index 19afb3b6..00000000 --- a/src/oai/params.ts +++ /dev/null @@ -1,100 +0,0 @@ -import { omit } from "@/lib" -import { ChatCompletionCreateParamsWithModel, Mode } from "@/types" -import { ChatCompletionCreateParams } from "openai/resources/index.mjs" -import { z } from "zod" -import { JsonSchema7Type } from "zod-to-json-schema" - -import { MODE } from "@/constants/modes" - -type ParseParams = { - name: string - description?: string -} & JsonSchema7Type - -export function OAIBuildFunctionParams( - definition: ParseParams, - params: Omit, "response_model"> -): ChatCompletionCreateParams { - const { name, description, ...definitionParams } = definition - - return { - ...params, - function_call: { - name: name - }, - functions: [ - ...(params?.functions ?? []), - { - name: name, - description: description ?? undefined, - parameters: definitionParams - } - ] - } -} - -export function OAIBuildToolFunctionParams( - definition: ParseParams, - params: Omit, "response_model"> -): ChatCompletionCreateParams { - const { name, description, ...definitionParams } = definition - - return { - ...params, - tool_choice: { - type: "function", - function: { name } - }, - tools: [ - { - type: "function", - function: { - name: name, - description: description, - parameters: definitionParams - } - }, - ...(params?.tools ?? []) - ] - } -} - -export function OAIBuildMessageBasedParams( - definition: ParseParams, - params: Omit, "response_model">, - mode: Mode -): ChatCompletionCreateParams { - const MODE_SPECIFIC_CONFIGS = { - [MODE.JSON]: { - response_format: { type: "json_object" } - }, - [MODE.JSON_SCHEMA]: { - response_format: { - type: "json_object", - schema: omit(["name", "description"], definition) - } - } - } - - const modeConfig = MODE_SPECIFIC_CONFIGS[mode] - - const t = { - ...params, - ...modeConfig, - messages: [ - ...params.messages, - { - role: "system", - content: ` - Given a user prompt, you will return fully valid JSON based on the following description and schema. - You will return no other prose. You will take into account any descriptions or required parameters within the schema - and return a valid and fully escaped JSON object that matches the schema and those instructions. - - description: ${definition.description} - json schema: ${JSON.stringify(definition)} - ` - } - ] - } - return t -} diff --git a/src/oai/parser.ts b/src/oai/parser.ts deleted file mode 100644 index 1f6776c8..00000000 --- a/src/oai/parser.ts +++ /dev/null @@ -1,110 +0,0 @@ -import OpenAI from "openai" - -/** - * `OAIResponseTextParser` parses a JSON string and extracts the text content. - * - * @param {string} data - The JSON string to parse. - * @returns {string} - The extracted text content. - * - */ -export function OAIResponseTextParser( - data: - | string - | OpenAI.Chat.Completions.ChatCompletionChunk - | OpenAI.Chat.Completions.ChatCompletion -) { - const parsedData = typeof data === "string" ? JSON.parse(data) : data - - const text = parsedData?.choices[0]?.message?.content ?? null - - return text -} - -/** - * `OAIResponseFnArgsParser` parses a JSON string and extracts the function call arguments. - * - * @param {string} data - The JSON string to parse. - * @returns {Object} - The extracted function call arguments. - * - */ -export function OAIResponseFnArgsParser( - data: - | string - | OpenAI.Chat.Completions.ChatCompletionChunk - | OpenAI.Chat.Completions.ChatCompletion -) { - const parsedData = typeof data === "string" ? JSON.parse(data) : data - const text = - parsedData.choices?.[0].delta?.function_call?.arguments ?? - parsedData.choices?.[0]?.message?.function_call?.arguments ?? - null - - return text -} - -/** - * `OAIResponseToolArgsParser` parses a JSON string and extracts the tool call arguments. - * - * @param {string} data - The JSON string to parse. - * @returns {Object} - The extracted tool call arguments. - * - */ -export function OAIResponseToolArgsParser( - data: - | string - | OpenAI.Chat.Completions.ChatCompletionChunk - | OpenAI.Chat.Completions.ChatCompletion -) { - const parsedData = typeof data === "string" ? JSON.parse(data) : data - - const text = - parsedData.choices?.[0].delta?.tool_calls?.[0]?.function?.arguments ?? - parsedData.choices?.[0]?.message?.tool_calls?.[0]?.function?.arguments ?? - null - - return text -} - -/** - * `OAIResponseJSONParser` parses a JSON string and extracts the JSON content. - * - * @param {string} data - The JSON string to parse. - * @returns {Object} - The extracted JSON content. - * - */ -export function OAIResponseJSONStringParser( - data: - | string - | OpenAI.Chat.Completions.ChatCompletionChunk - | OpenAI.Chat.Completions.ChatCompletion -) { - const parsedData = typeof data === "string" ? JSON.parse(data) : data - const text = - parsedData.choices?.[0].delta?.content ?? parsedData?.choices[0]?.message?.content ?? null - - return text -} - -/** - * `OAIResponseJSONParser` parses a JSON string and extracts the JSON content. - * - * @param {string} data - The JSON string to parse. - * @returns {Object} - The extracted JSON content. - * - * - */ -export function OAIResponseJSONParser( - data: - | string - | OpenAI.Chat.Completions.ChatCompletionChunk - | OpenAI.Chat.Completions.ChatCompletion -) { - const parsedData = typeof data === "string" ? JSON.parse(data) : data - const text = - parsedData.choices?.[0].delta?.content ?? parsedData?.choices[0]?.message?.content ?? null - - const jsonRegex = /```json\n([\s\S]*?)\n```/ - const match = text.match(jsonRegex) - - return match ? match[1] : text -} diff --git a/src/types/index.ts b/src/types/index.ts index 32e83bd1..3cf212e1 100644 --- a/src/types/index.ts +++ b/src/types/index.ts @@ -1,25 +1,19 @@ import OpenAI from "openai" import { Stream } from "openai/streaming" import { z } from "zod" - -import { MODE } from "@/constants/modes" - -export type Mode = keyof typeof MODE +import { type Mode as ZMode, type ResponseModel as ZResponseModel } from "zod-stream" export type LogLevel = "debug" | "info" | "warn" | "error" +export type Mode = ZMode +export type ResponseModel = ZResponseModel + export type InstructorConfig = { client: OpenAI mode: Mode debug?: boolean } -export type ResponseModel = { - schema: T - name: string - description?: string -} - export type InstructorChatCompletionParams = { response_model: ResponseModel max_retries?: number diff --git a/tests/mode.test.ts b/tests/mode.test.ts index 3ceeff68..9d5bea6d 100644 --- a/tests/mode.test.ts +++ b/tests/mode.test.ts @@ -1,24 +1,76 @@ import Instructor from "@/instructor" -import { Mode } from "@/types" import { describe, expect, test } from "bun:test" import OpenAI from "openai" import { z } from "zod" +import { type Mode } from "zod-stream" -import { MODE } from "@/constants/modes" +import { Provider, PROVIDER_SUPPORTED_MODES_BY_MODEL, PROVIDERS } from "@/constants/providers" -const models_latest = ["gpt-3.5-turbo-1106", "gpt-4-1106-preview"] -const models_old = ["gpt-3.5-turbo", "gpt-4"] -const models_anyscale = ["mistralai/Mistral-7B-Instruct-v0.1"] +const default_oai_model = "gpt-4-1106-preview" +const default_anyscale_model = "mistralai/Mixtral-8x7B-Instruct-v0.1" +const default_together_model = "mistralai/Mixtral-8x7B-Instruct-v0.1" -const createTestCases = (): { model: string; mode: Mode }[] => { - const { FUNCTIONS, JSON_SCHEMA, ...rest } = MODE - const modes = Object.values(rest) +const provider_config = { + [PROVIDERS.OAI]: { + baseURL: "https://api.openai.com/v1", + apiKey: process.env.OPENAI_API_KEY + }, + [PROVIDERS.ANYSCALE]: { + baseURL: "https://api.endpoints.anyscale.com/v1", + apiKey: process.env.ANYSCALE_API_KEY + }, + [PROVIDERS.TOGETHER]: { + baseURL: "https://api.together.xyz", + apiKey: process.env.TOGETHER_API_KEY + } +} + +const createTestCases = (): { model: string; mode: Mode; provider: Provider }[] => { + const testCases: { model: string; mode: Mode; provider: Provider }[] = [] + + Object.entries(PROVIDER_SUPPORTED_MODES_BY_MODEL[PROVIDERS.OAI]).forEach( + ([mode, models]: [Mode, string[]]) => { + if (models.includes("*")) { + testCases.push({ model: default_oai_model, mode, provider: PROVIDERS.OAI }) + } else { + models.forEach(model => testCases.push({ model, mode, provider: PROVIDERS.OAI })) + } + } + ) + + Object.entries(PROVIDER_SUPPORTED_MODES_BY_MODEL).forEach( + ([provider, modesByModel]: [Provider, Record]) => { + if (provider === PROVIDERS.ANYSCALE) { + Object.entries(modesByModel).forEach(([mode, models]: [Mode, string[]]) => { + if (models.includes("*")) { + testCases.push({ + model: default_anyscale_model, + mode, + provider + }) + } else { + models.forEach(model => testCases.push({ model, mode, provider })) + } + }) + } + + if (provider === PROVIDERS.TOGETHER) { + Object.entries(modesByModel).forEach(([mode, models]: [Mode, string[]]) => { + if (models.includes("*")) { + testCases.push({ + model: default_together_model, + mode, + provider + }) + } else { + models.forEach(model => testCases.push({ model, mode, provider })) + } + }) + } + } + ) - return [ - ...models_anyscale.flatMap(model => ({ model, mode: JSON_SCHEMA })), - ...models_latest.flatMap(model => modes.map(mode => ({ model, mode }))), - ...models_old.flatMap(model => ({ model, mode: FUNCTIONS })) - ] + return testCases } const UserSchema = z.object({ @@ -28,11 +80,10 @@ const UserSchema = z.object({ }) }) -async function extractUser(model: string, mode: Mode) { - const anyscale = mode === MODE.JSON_SCHEMA +async function extractUser(model: string, mode: Mode, provider: Provider) { + const config = provider_config[provider] const oai = new OpenAI({ - baseURL: anyscale ? "https://api.endpoints.anyscale.com/v1" : undefined, - apiKey: anyscale ? process.env.ANYSCALE_API_KEY : process.env.OPENAI_API_KEY ?? undefined, + ...config, organization: process.env.OPENAI_ORG_ID ?? undefined }) @@ -46,7 +97,7 @@ async function extractUser(model: string, mode: Mode) { model: model, response_model: { schema: UserSchema, name: "User" }, max_retries: 4, - seed: !anyscale ? 1 : undefined + seed: provider === PROVIDERS.OAI ? 1 : undefined }) return user @@ -55,9 +106,9 @@ async function extractUser(model: string, mode: Mode) { describe("Modes", async () => { const testCases = createTestCases() - for await (const { model, mode } of testCases) { - test(`Should return extracted name and age for model ${model} and mode ${mode}`, async () => { - const user = await extractUser(model, mode) + for await (const { model, mode, provider } of testCases) { + test(`${provider}: Should return extracted name and age for model ${model} and mode ${mode}`, async () => { + const user = await extractUser(model, mode, provider) expect(user.name).toEqual("Jason Liu") expect(user.age).toEqual(30)