Skip to content

Commit

Permalink
openai upgrade
Browse files Browse the repository at this point in the history
  • Loading branch information
jonfleming committed Feb 3, 2024
1 parent 059844a commit 4e3689f
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 71 deletions.
4 changes: 3 additions & 1 deletion Amy/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,6 @@
CELERY_TIMEZONE = 'UTC'
CELERY_TASK_TRACK_STARTED = True
CELERY_TASK_TIME_LIMIT = 30 * 60
CELERY_BROKER_URL = 'redis://localhost:6379/0'
CELERY_BROKER_URL = 'redis://127.0.0.1:6379/0'
CELERYD_POOL = 'solo'
CELERYD_CONCURRENCY = 1
34 changes: 22 additions & 12 deletions chat/celery.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import os
import json
import logging
import random

from celery import Celery

logger = logging.getLogger(__name__)

# Set the default Django settings module for the 'celery' program.
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'Amy.settings')
app = Celery('Amy', broker='redis://localhost:6379/0')
Expand All @@ -12,19 +16,25 @@

@app.task(bind=True)
def classify_user_input(self, id):
import chat.lang as lang
import chat.models as models

user_input = models.UserInput.objects.filter(pk=id)[0]
categories = lang.get_categories()
category_list = ','.join(f"'{x}'" for x in categories )
args = {'<<TEXT>>': user_input.user_text, '<<CATEGORIES>>': category_list, '<<USER>>': user_input.user}
prompt = render_template('classify.txt', args)
result = lang.completion(prompt).strip()
print(f'Classifying user input {id}')
logger.info(f'Classifying user input {id}')

# import chat.lang as lang
# import chat.models as models

# user_input = models.UserInput.objects.filter(pk=id)[0]
# categories = lang.get_categories()
# category_list = ','.join(f"'{x}'" for x in categories )
# print(f"Category list: {category_list}")

# args = {'<<TEXT>>': user_input.user_text, '<<CATEGORIES>>': category_list, '<<USER>>': user_input.user}
# prompt = render_template('classify.txt', args)
# result = lang.completion(prompt).strip()
# print(f"Result: {result}")

if result in categories:
user_input.category = result
user_input.save()
# if result in categories:
# user_input.category = result
# user_input.save()

def render_template(template_name, args):
import chat.lang as lang
Expand Down
8 changes: 4 additions & 4 deletions chat/lang.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

load_dotenv()
CHAT_MODEL = 'gpt-3.5-turbo'
COMPLETIONS_MODEL = 'text-davinci-003'
COMPLETIONS_MODEL = 'gpt-3.5-turbo-instruct'
EMBEDDING_MODEL = 'text-embedding-ada-002'
CATEGORIES = ['Childhood', 'Education', 'Career', 'Family', 'Spiritual', 'Story']

Expand Down Expand Up @@ -54,7 +54,7 @@ def chat_completion(messages):

logger.info(f'Chat messages: {json.dumps(messages, indent=4)}')
logger.info('get_completion_from_open_ai::starting::')
response = openai.ChatCompletion.create(
response = client.chat.completions.create(
messages=messages,
temperature=0,
max_tokens=300,
Expand Down Expand Up @@ -91,7 +91,7 @@ def conversation_history(exchanges, prompt_text, user_text, chat_mode):
return messages

def first_chat_completion_choice(completion_response):
if 'choices' not in completion_response or len(completion_response['choices']) == 0:
if not hasattr(completion_response, 'choices') or len(completion_response.choices) == 0:
logger.warning('get_chat_completion_from_open_ai_failed')
response = completion_response
else:
Expand All @@ -100,7 +100,7 @@ def first_chat_completion_choice(completion_response):
return response

def first_completion_choice(completion_response):
if 'choices' not in completion_response or len(completion_response['choices']) == 0:
if not hasattr(completion_response, 'choices') or len(completion_response.choices) == 0:
logger.warning('get_completion_from_open_ai_failed')
response = completion_response
else:
Expand Down
125 changes: 72 additions & 53 deletions chat/static/chat/js/streaming-client-api.js
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ const signalingStatusLabel = document.getElementById("signaling-status-label")
const dIdKey = document.getElementById("d-id-key")
const dIdImage = document.getElementById("d-id-image")
const stats = document.getElementById("stats-box")
const statusMessage = document.getElementById("stats-box")

window.dragable(document.getElementById("dragable"))

Expand Down Expand Up @@ -74,20 +75,24 @@ window.connect = async () => {
return
}

await fetch(
`${DID_API.url}/talks/streams/${streamId}/sdp`,
{
method: "POST",
headers: {
Authorization: `Basic ${DID_API.key}`,
"Content-Type": "application/json",
},
body: JSON.stringify({
answer: sessionClientAnswer,
session_id: sessionId,
}),
}
)
try {
await fetch(
`${DID_API.url}/talks/streams/${streamId}/sdp`,
{
method: "POST",
headers: {
Authorization: `Basic ${DID_API.key}`,
"Content-Type": "application/json",
},
body: JSON.stringify({
answer: sessionClientAnswer,
session_id: sessionId,
}),
}
)
} catch (e) {
window.stat(e)
}
}

window.talk = async (text) => {
Expand All @@ -96,50 +101,60 @@ window.talk = async (text) => {
peerConnection?.iceConnectionState === "connected"
) {
isPlaying = true
await fetch(`${DID_API.url}/talks/streams/${streamId}`, {
method: "POST",
headers: {
Authorization: `Basic ${DID_API.key}`,
"Content-Type": "application/json",
},
body: JSON.stringify({
script: {
type: "text",
subtitles: "false",
provider: {
type: "microsoft",
voice_id: "en-US-SaraNeural",
voice_config: {
style: "Cheerful",
rate: "1.25"
},
},
ssml: true,
input: text, // Use the user input as the input value
},
config: {
fluent: true,
pad_audio: 0,
driver_expressions: {
expressions: [
{ expression: "neutral", start_frame: 0, intensity: 0 },
],
transition_frames: 0,
},
align_driver: true,
align_expand_factor: 0,
auto_match: true,
motion_factor: 0,
normalization_factor: 0,
sharpen: true,
stitch: true,
result_format: "mp4",
const script = {
type: "text",
subtitles: "false",
provider: {
type: "microsoft",
voice_id: "en-US-SaraNeural",
voice_config: {
style: "Cheerful",
rate: "1.25"
},
},
ssml: true,
input: text, // Use the user input as the input value
}
const config ={
fluent: true,
pad_audio: 0,
driver_expressions: {
expressions: [
{ expression: "neutral", start_frame: 0, intensity: 0 },
],
transition_frames: 0,
},
align_driver: true,
align_expand_factor: 0,
auto_match: true,
motion_factor: 0,
normalization_factor: 0,
sharpen: true,
stitch: true,
result_format: "mp4",
}

const body = JSON.stringify({
script: script,
config: config,
driver_url: "bank://lively/",
config: { stich: true },
session_id: sessionId,
}),
})

try {
await fetch(`${DID_API.url}/talks/streams/${streamId}`, {
method: "POST",
headers: {
Authorization: `Basic ${DID_API.key}`,
"Content-Type": "application/json",
},
body: body,
})
} catch (e) {
window.stat(e)
}


return "OK"
}
Expand Down Expand Up @@ -189,6 +204,10 @@ window.startStats = (callback) => {
}, 3000)
}

function setStatusMessage(msg) {
statusMessage.innerText = msg
}

function setLabelStatus(label, status) {
label.className = status
const title = label.getAttribute('data-title')
Expand Down
1 change: 1 addition & 0 deletions chat/templates/chat/navbar.html
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
<label id="peer-status-label" data-title="Peer Status">&#x25AC;</span></label>
<label id="signaling-status-label" data-title="Signaling Status">&#x25AC;</label>
<label><a id="connect-btn" class="nav-link" onclick="window.connect()" style="display: block;"><i class="fa-solid fa-link"></i></a></label>
<label id="status-message"></label>
</div>
{% endif %}
<button class="navbar-toggler" type="button" data-toggle="collapse" data-target="#navbarText" aria-controls="navbarText" aria-expanded="False" aria-label="Toggle navigation">
Expand Down
2 changes: 1 addition & 1 deletion chat/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def __init__(self, category, summarization, text):

def summary(request):
if request.method == 'GET':
proto = 'ws' if request.META['HTTP_X_FORWARDED_PROTO'] == 'http' else 'wss'
proto = 'wss' if hasattr(request.META, 'HTTP_X_FORWARDED_PROTO') and request.META['HTTP_X_FORWARDED_PROTO'] == 'https' else 'ws'
ws_url = f'{proto}://{request.get_host()}/ws/summary/'
return render(request, 'chat/summary.html', {'summary': None, 'wsUrl': ws_url })

Expand Down
Binary file modified db.sqlite3
Binary file not shown.

0 comments on commit 4e3689f

Please sign in to comment.