Skip to content

Commit

Permalink
start with few_shot testing
Browse files Browse the repository at this point in the history
  • Loading branch information
maxzirps committed Dec 24, 2024
1 parent 0aadd4e commit 6f8f359
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 24 deletions.
58 changes: 58 additions & 0 deletions backend/data/test_output_few_shot.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
[
{
"id": 62157,
"artist": "Chris Brown",
"title": "Say Goodbye",
"attachment_style": "avoidant",
"predicted": "unknown"
},
{
"id": 115478,
"artist": "*NSYNC",
"title": "Bye Bye Bye",
"attachment_style": "avoidant",
"predicted": "avoidant"
},
{
"id": 1644,
"artist": "Michael Jackson",
"title": "Billie Jean",
"attachment_style": "avoidant",
"predicted": "avoidant"
},
{
"id": 6591973,
"artist": "TLC",
"title": "No Scrubs",
"attachment_style": "avoidant",
"predicted": "unknown"
},
{
"id": 58062,
"artist": "Rihanna",
"title": "Take a Bow",
"attachment_style": "avoidant",
"predicted": "secure"
},
{
"id": 5053007,
"artist": "The Weeknd",
"title": "Heartless",
"attachment_style": "avoidant",
"predicted": "secure"
},
{
"id": 129806,
"artist": "The Police",
"title": "Every Breath You Take",
"attachment_style": "anxious",
"predicted": "secure"
},
{
"id": 209747,
"artist": "Miley Cyrus",
"title": "Wrecking Ball",
"attachment_style": "anxious",
"predicted": "unknown"
}
]
37 changes: 21 additions & 16 deletions backend/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,33 @@ class AttachmentStyleProbabilities(TypedDict):
class TextGenerationModel:
def __init__(self, model_id: str = "meta-llama/Llama-3.2-1B-Instruct"):
logging.set_verbosity_error()
self.pipe = pipeline(
"text-generation",
model=model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
# TODO: make it more deterministic
# temperature=0.01
)

def classify_attachment_style(self, lyrics: str, max_new_tokens: int = 256) -> AttachmentStyleProbabilities:
system_message = """
self.system_message = """
You are an assistant that classifies attachment styles based on song lyrics.
Analyze the following song lyrics and classify the attachment style it reflects.
Use the four main attachment styles: secure, anxious, avoidant, and disorganized.
Provide the confidence/probability for each style. The total of all probabilities should sum to 1.
Do not include any commentary, explanations, or additional text. Return the probabilites as float.
"""

messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": lyrics}
]
self.pipe = pipeline(
"text-generation",
model=model_id,
torch_dtype=torch.bfloat16,
device_map="auto"
)

def classify_attachment_style(self, lyrics: str, max_new_tokens: int = 256, system_message: dict = {}) -> AttachmentStyleProbabilities:

if system_message["role"] != "system":
raise ValueError("System messages have role 'system'.")
if system_message["content"] == None:
raise ValueError("System messages must have content.")

messages = [{"role": "system", "content": self.system_message}]

if system_message:
messages.append(system_message)

messages.append({"role": "user", "content": lyrics})

outputs = self.pipe(messages, max_new_tokens=max_new_tokens)
answer = outputs[0]["generated_text"][-1]["content"]
Expand Down
69 changes: 61 additions & 8 deletions backend/src/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,12 @@
from env import load_env
from model import TextGenerationModel
import json
from tqdm import tqdm
import numpy as np

load_env()





async def main():
token = os.getenv("TOKEN")
genius = API_Client(token)
async def test(genius: API_Client, model: TextGenerationModel):
if os.path.exists("./data/test_output.json"):
print("Loading existing data")
with open("./data/test_output.json", "r") as json_file:
Expand All @@ -29,7 +23,7 @@ async def main():
with open("./data/train.json", "r") as json_file:
songs = json.load(json_file)

model = TextGenerationModel()


async def get_lyrics_and_classify(genius, song: dict) -> dict:
lyrics = await genius.get_lyrics(song["id"])
Expand All @@ -45,6 +39,65 @@ async def get_lyrics_and_classify(genius, song: dict) -> dict:
with open("./data/test_output.json", "w") as json_file:
json.dump(songs, json_file, indent=4)


async def test_few_shot(genius: API_Client, model: TextGenerationModel):
if os.path.exists("./data/test_output_few_shot.json"):
print("Loading existing data")
with open("./data/test_output_few_shot.json", "r") as json_file:
predictions = json.load(json_file)
correct = np.sum([song["predicted"] == song["attachment_style"] for song in predictions])
print(f"Accuracy: {correct/len(predictions)}")

else:
print("Fetching new data")
with open("./data/train.json", "r") as json_file:
songs = json.load(json_file)[:10]


attachment_styles = set(song["attachment_style"] for song in songs)
few_shot_examples = {style: '' for style in attachment_styles}
few_shot_ids = []
for style in attachment_styles:
example = next((song for song in songs if song["attachment_style"] == style), None)
few_shot_examples[style] = await genius.get_lyrics(example["id"])
few_shot_ids.append(example["id"])


few_shot_message = {
"role": "system",
"content": f"""
Here are some examples:
{"\n\n".join([f"{style}: {example}" for style, example in few_shot_examples.items()])}
"""
}

songs = [song for song in songs if song["id"] not in few_shot_ids]

async def get_lyrics_and_classify(genius, song: dict) -> dict:
lyrics = await genius.get_lyrics(song["id"])
attachment_style = model.classify_attachment_style(lyrics, system_message=few_shot_message)
if attachment_style:
song["predicted"] = max(attachment_style, key=attachment_style.get)
else:
song["predicted"] = "unknown"
return song

songs = await tqdm_asyncio.gather(*[get_lyrics_and_classify(genius, song) for song in songs])
if len(songs) > 0:
with open("./data/test_output_few_shot.json", "w") as json_file:
json.dump(songs, json_file, indent=4)



async def main():
token = os.getenv("TOKEN")
genius = API_Client(token)
model = TextGenerationModel()

await test_few_shot(genius, model)


await genius.close()


Expand Down

0 comments on commit 6f8f359

Please sign in to comment.