Skip to content

Commit

Permalink
Implement server mode.
Browse files Browse the repository at this point in the history
This new mode works by first loading the model then listening for TCP
connections on a port. When a connection is received, arguments will be
parsed using a simple protocol:

- First the number of arguments will be read followed by a newline
  character.
- Then each argument will be read, separated by the 0 byte.
- With this we build an argument vector, similar to what is passed to
  the program entry point. We pass this to gpt_params_parse.

Finally `run` will be executed with the input/output streams connected
to the socket.

Signed-off-by: Thiago Padilha <[email protected]>
  • Loading branch information
tarruda committed Mar 22, 2023
1 parent 6b67f63 commit f877c93
Show file tree
Hide file tree
Showing 9 changed files with 331 additions and 2 deletions.
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,10 @@ add_executable(main
run.cpp)
target_link_libraries(main PRIVATE llama ggml utils)

if(NOT WIN32)
target_sources(main PRIVATE tcp_server.cpp)
endif()

add_executable(quantize quantize.cpp)
target_link_libraries(quantize PRIVATE llama ggml utils)

Expand Down
7 changes: 5 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,14 @@ utils.o: utils.cpp utils.h
run.o: run.cpp run.h
$(CXX) $(CXXFLAGS) -c run.cpp -o run.o

tcp_server.o: tcp_server.cpp tcp_server.h
$(CXX) $(CXXFLAGS) -c tcp_server.cpp -o tcp_server.o

clean:
rm -f *.o main quantize

main: main.cpp ggml.o llama.o utils.o run.o
$(CXX) $(CXXFLAGS) main.cpp ggml.o llama.o utils.o run.o -o main $(LDFLAGS)
main: main.cpp ggml.o llama.o utils.o run.o tcp_server.o
$(CXX) $(CXXFLAGS) main.cpp ggml.o llama.o utils.o run.o tcp_server.o -o main $(LDFLAGS)
@echo "\x1b[36mrun ./main -h for help\x1b[0m"

quantize: quantize.cpp ggml.o llama.o utils.o
Expand Down
45 changes: 45 additions & 0 deletions chat_tcp_client.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#!/usr/bin/env bash

PORT=${PORT:-8080}
PROMPT="${PROMPT:-"Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
User:Hello, Bob.
Bob:Hello. How may I help you today?
User:Please tell me the largest city in Europe.
Bob:Sure. The largest city in Europe is Moscow, the capital of Russia.
User:"}"
RPROMPT="${RPROMPT:-"User:"}"
N_PREDICT="${N_PREDICT:-"4096"}"
REPEAT_PENALTY="${REPEAT_PENALTY:-"1.0"}"
N_THREADS="${N_THREADS:-"4"}"

# Open connection to the chat server
exec 3<>/dev/tcp/127.0.0.1/${PORT}

# Pass the arguments. The protocol is really simple:
# 1. Pass the number of arguments followed by a linefeed
# 2. Pass the arguments, with each being followed by "0"
(
echo -en "12\n"
echo -en "-t\x00"
echo -en "$N_THREADS\x00"
echo -en "-n\x00"
echo -en "$N_PREDICT\x00"
echo -en "--repeat_penalty\x00"
echo -en "$REPEAT_PENALTY\x00"
echo -en "--color\x00"
echo -en "-i\x00"
echo -en "-r\x00"
echo -en "$RPROMPT\x00"
echo -en "-p\x00"
echo -en "$PROMPT\x00"
) >&3

trap exit TERM

# When we have passed the arguments, start printing socket data to the screen.
# This is done in a background job because we also want to send data when
# running in interactive mode.
cat <&3 && echo "(disconnected, press \"enter\" twice to exit)" &
cat >&3
wait
6 changes: 6 additions & 0 deletions chat_tcp_server.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/usr/bin/env bash

PORT=${PORT:-8080}
MODEL=${MODEL:-models/7B/ggml-model-q4_0.bin}

./main -l ${PORT} -m $MODEL
7 changes: 7 additions & 0 deletions main.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "run.h"
#include "ggml.h"
#include "tcp_server.h"

#include <iostream>

Expand Down Expand Up @@ -125,5 +126,11 @@ int main(int argc, char ** argv) {
exit(0);
}

#ifndef _WIN32
if (params.listen_port != "") {
return listen_tcp(ctx, params);
}
#endif

return run(ctx, params, std::cin, stdout, stderr);
}
245 changes: 245 additions & 0 deletions tcp_server.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
#include "tcp_server.h"
#include "llama.h"
#include "utils.h"

#include <iostream>

#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <stdbool.h>
#include <string.h>
#include <errno.h>

#include <signal.h>
#include <unistd.h>
#include <sys/wait.h>

#include <sys/types.h>
#include <sys/socket.h>
#include <arpa/inet.h>
#include <netdb.h>

class PosixStream : public std::istream {
public:
PosixStream(int fd) : std::istream(&buf), buf(fd) {}
~PosixStream() { close(buf.get_fd()); }

private:
class PosixStreamBuf : public std::streambuf {
public:
PosixStreamBuf(int fd) : fd(fd) {}
int get_fd() const { return fd; }

protected:
virtual int_type underflow() {
if (gptr() < egptr()) {
return traits_type::to_int_type(*gptr());
}

ssize_t num_read = ::read(fd, buffer, BUFFER_SIZE);
if (num_read <= 0) {
return traits_type::eof();
}

setg(buffer, buffer, buffer + num_read);
return traits_type::to_int_type(*gptr());
}

private:
static const int BUFFER_SIZE = 1024;
int fd;
char buffer[BUFFER_SIZE];
};

PosixStreamBuf buf;
};

void die(const char *msg, ...)
{
va_list ap;

va_start(ap, msg);
vfprintf(stderr, msg, ap);
va_end(ap);
fputc('\n', stderr);
exit(1);
}

static char *read_argument(uint8_t **param_buf, size_t *param_buf_size, FILE *instream) {
bool done = false;
uint8_t *buf = *param_buf;
size_t bufsize = *param_buf_size;
size_t bufpos = 0;
while (!done) {
if (bufpos == bufsize) {
bufsize += 1024;
buf = (uint8_t *)realloc(buf, bufsize);
if (!buf) {
die("failed to allocate memory");
}
}

int c = fgetc(instream);
if (c == EOF) {
die("unexpected EOF client socket");
}
buf[bufpos++] = (uint8_t)c;
if (c == 0) {
// done reading argument
break;
}
}
*param_buf = buf;
*param_buf_size = bufsize;
return strdup((char *)buf);
}

static int read_arguments(int argc, char **argv, FILE *instream) {
int i = 1;
size_t param_buf_size = 0;
uint8_t *param_buf = nullptr;

for (i = 1; i < argc; i++) {
argv[i] = read_argument(&param_buf, &param_buf_size, instream);
}

free(param_buf);
return i;
}

static int serve_model(llama_context * ctx,
gpt_params params,
int sock_fd)
{
int argc;
char **argv;
FILE *instream = fdopen(sock_fd, "r");
FILE *outstream = fdopen(sock_fd, "w");
setvbuf(instream, NULL, _IONBF, 0);

// start by reading the parameter count
if (fscanf(instream, "%d\n", &argc) != 1) {
fprintf(outstream, "Error: First line must be character count\n");
fflush(outstream);
return 1;
}

argc += 1; // add one extra argument to emulate the program command line
argv = (char **)malloc(argc * sizeof *argv);
argv[0] = nullptr;
if (read_arguments(argc, argv, instream) != argc) {
fprintf(outstream, "Error: Failed to read arguments\n");
fflush(outstream);
}

if (gpt_params_parse(argc, argv, params) == false) {
fprintf(outstream, "Error: Failed to parse parameters\n");
fflush(outstream);
return 1;
}

for (int i = 1; i < argc; i++) {
free(argv[i]);
}
free(argv);

PosixStream tcp_instream(sock_fd);

return run(ctx, params, tcp_instream, outstream, outstream);
}

int listen_tcp(llama_context * ctx, gpt_params params) {
int listen_fd;
int status;
pid_t child;
struct addrinfo hints;
struct addrinfo *servinfo, *p;
int yes = 1;

memset(&hints, 0, sizeof hints);
hints.ai_family = AF_INET;
hints.ai_socktype = SOCK_STREAM;
hints.ai_flags = AI_PASSIVE;

// This should only ever listen on a loopback address. Access from outside
// should be proxied via socat or similar software
status = getaddrinfo("127.0.0.1", params.listen_port.c_str(), &hints, &servinfo);
if (status) {
die("getaddrinfo error: %s", gai_strerror(status));
}

// bind to the first addrinfo we can from the getaddrinfo results
for (p = servinfo; p != NULL; p = p->ai_next) {
listen_fd = socket(p->ai_family, p->ai_socktype, p->ai_protocol);
if (listen_fd == -1) {
perror("server: socket");
continue;
}

if (setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &yes, sizeof yes)) {
die("setsockopt error: %s", params.listen_port.c_str(), strerror(errno));
}

if (bind(listen_fd, p->ai_addr, p->ai_addrlen) == 0) {
struct sockaddr_in addr_in;
socklen_t addr_in_len = sizeof(addr_in);
memset(&addr_in, 0, addr_in_len);
getsockname(listen_fd, (struct sockaddr*)&addr_in, &addr_in_len);

printf("Listening on %s:%d\n", inet_ntoa(addr_in.sin_addr), ntohs(addr_in.sin_port));
break;
}

close(listen_fd);
perror("server: bind");
}

freeaddrinfo(servinfo);

if (p == NULL) {
die("failed to bind: %s", strerror(errno));
}

if (listen(listen_fd, 20)) {
die("listen error: %s", strerror(errno));
}
// Don't track child processes, so ignore SIGCHLD to prevent zombies
signal(SIGCHLD, SIG_IGN);

for (;;) {
struct sockaddr_in client_addr;
socklen_t client_addr_len = 0;
memset(&client_addr, 0, sizeof(client_addr));

int sock_fd = accept(listen_fd,
(struct sockaddr *)&client_addr,
&client_addr_len);
if (sock_fd < 0) {
fprintf(stderr, "accept error: %s\n", strerror(errno));
break;
}

child = fork();
if (child == 0) {
// close the listen_fd since we won't use it in the child
close(listen_fd);
int ret = serve_model(ctx, params, sock_fd);
close(sock_fd);
return ret;
} else {
// close the client since we won't use it in the server
close(sock_fd);
sock_fd = 0;
}
}
close(listen_fd);

// ignore SIGTERM since we'll send it to the group
signal(SIGTERM, SIG_IGN);
// tell children to exit
kill(0, SIGTERM);
// wait for children to terminate
wait(&status);
return 0;
}
7 changes: 7 additions & 0 deletions tcp_server.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#pragma once

#include "utils.h"
#include "llama.h"
#include "run.h"

int listen_tcp(llama_context * ctx, gpt_params params);
8 changes: 8 additions & 0 deletions utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.ignore_eos = true;
} else if (arg == "--n_parts") {
params.n_parts = std::stoi(argv[++i]);
#ifndef _WIN32
} else if (arg == "-l" || arg == "--listen") {
params.listen_port = argv[++i];
#endif
} else if (arg == "-h" || arg == "--help") {
gpt_print_usage(argc, argv, params);
exit(0);
Expand Down Expand Up @@ -122,6 +126,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " --perplexity compute perplexity over the prompt\n");
fprintf(stderr, " -m FNAME, --model FNAME\n");
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
#ifndef _WIN32
fprintf(stderr, " -l PORT, --listen PORT\n");
fprintf(stderr, " Run in TCP mode, listening on PORT\n");
#endif
fprintf(stderr, "\n");
}

Expand Down
4 changes: 4 additions & 0 deletions utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ struct gpt_params {
bool instruct = false; // instruction mode (used for Alpaca models)
bool ignore_eos = false; // do not stop generating after eos
bool perplexity = false; // compute perplexity over the prompt

#ifndef _WIN32
std::string listen_port = ""; // TCP port for when running in server mode
#endif
};

bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
Expand Down

0 comments on commit f877c93

Please sign in to comment.