diff --git a/backend/data/test_output_few_shot.json b/backend/data/test_output_few_shot.json new file mode 100644 index 0000000..385c4af --- /dev/null +++ b/backend/data/test_output_few_shot.json @@ -0,0 +1,58 @@ +[ + { + "id": 62157, + "artist": "Chris Brown", + "title": "Say Goodbye", + "attachment_style": "avoidant", + "predicted": "unknown" + }, + { + "id": 115478, + "artist": "*NSYNC", + "title": "Bye Bye Bye", + "attachment_style": "avoidant", + "predicted": "avoidant" + }, + { + "id": 1644, + "artist": "Michael Jackson", + "title": "Billie Jean", + "attachment_style": "avoidant", + "predicted": "avoidant" + }, + { + "id": 6591973, + "artist": "TLC", + "title": "No Scrubs", + "attachment_style": "avoidant", + "predicted": "unknown" + }, + { + "id": 58062, + "artist": "Rihanna", + "title": "Take a Bow", + "attachment_style": "avoidant", + "predicted": "secure" + }, + { + "id": 5053007, + "artist": "The Weeknd", + "title": "Heartless", + "attachment_style": "avoidant", + "predicted": "secure" + }, + { + "id": 129806, + "artist": "The Police", + "title": "Every Breath You Take", + "attachment_style": "anxious", + "predicted": "secure" + }, + { + "id": 209747, + "artist": "Miley Cyrus", + "title": "Wrecking Ball", + "attachment_style": "anxious", + "predicted": "unknown" + } +] \ No newline at end of file diff --git a/backend/src/model.py b/backend/src/model.py index e8b15bd..797a1da 100644 --- a/backend/src/model.py +++ b/backend/src/model.py @@ -13,28 +13,33 @@ class AttachmentStyleProbabilities(TypedDict): class TextGenerationModel: def __init__(self, model_id: str = "meta-llama/Llama-3.2-1B-Instruct"): logging.set_verbosity_error() - self.pipe = pipeline( - "text-generation", - model=model_id, - torch_dtype=torch.bfloat16, - device_map="auto", - # TODO: make it more deterministic - # temperature=0.01 - ) - - def classify_attachment_style(self, lyrics: str, max_new_tokens: int = 256) -> AttachmentStyleProbabilities: - system_message = """ + self.system_message = """ You are an assistant that classifies attachment styles based on song lyrics. Analyze the following song lyrics and classify the attachment style it reflects. Use the four main attachment styles: secure, anxious, avoidant, and disorganized. Provide the confidence/probability for each style. The total of all probabilities should sum to 1. Do not include any commentary, explanations, or additional text. Return the probabilites as float. """ - - messages = [ - {"role": "system", "content": system_message}, - {"role": "user", "content": lyrics} - ] + self.pipe = pipeline( + "text-generation", + model=model_id, + torch_dtype=torch.bfloat16, + device_map="auto" + ) + + def classify_attachment_style(self, lyrics: str, max_new_tokens: int = 256, system_message: dict = {}) -> AttachmentStyleProbabilities: + + if system_message["role"] != "system": + raise ValueError("System messages have role 'system'.") + if system_message["content"] == None: + raise ValueError("System messages must have content.") + + messages = [{"role": "system", "content": self.system_message}] + + if system_message: + messages.append(system_message) + + messages.append({"role": "user", "content": lyrics}) outputs = self.pipe(messages, max_new_tokens=max_new_tokens) answer = outputs[0]["generated_text"][-1]["content"] diff --git a/backend/src/test.py b/backend/src/test.py index 89fdc8e..f37162f 100644 --- a/backend/src/test.py +++ b/backend/src/test.py @@ -5,18 +5,12 @@ from env import load_env from model import TextGenerationModel import json -from tqdm import tqdm import numpy as np load_env() - - - -async def main(): - token = os.getenv("TOKEN") - genius = API_Client(token) +async def test(genius: API_Client, model: TextGenerationModel): if os.path.exists("./data/test_output.json"): print("Loading existing data") with open("./data/test_output.json", "r") as json_file: @@ -29,7 +23,7 @@ async def main(): with open("./data/train.json", "r") as json_file: songs = json.load(json_file) - model = TextGenerationModel() + async def get_lyrics_and_classify(genius, song: dict) -> dict: lyrics = await genius.get_lyrics(song["id"]) @@ -45,6 +39,65 @@ async def get_lyrics_and_classify(genius, song: dict) -> dict: with open("./data/test_output.json", "w") as json_file: json.dump(songs, json_file, indent=4) + +async def test_few_shot(genius: API_Client, model: TextGenerationModel): + if os.path.exists("./data/test_output_few_shot.json"): + print("Loading existing data") + with open("./data/test_output_few_shot.json", "r") as json_file: + predictions = json.load(json_file) + correct = np.sum([song["predicted"] == song["attachment_style"] for song in predictions]) + print(f"Accuracy: {correct/len(predictions)}") + + else: + print("Fetching new data") + with open("./data/train.json", "r") as json_file: + songs = json.load(json_file)[:10] + + + attachment_styles = set(song["attachment_style"] for song in songs) + few_shot_examples = {style: '' for style in attachment_styles} + few_shot_ids = [] + for style in attachment_styles: + example = next((song for song in songs if song["attachment_style"] == style), None) + few_shot_examples[style] = await genius.get_lyrics(example["id"]) + few_shot_ids.append(example["id"]) + + + few_shot_message = { + "role": "system", + "content": f""" + Here are some examples: + + {"\n\n".join([f"{style}: {example}" for style, example in few_shot_examples.items()])} + """ + } + + songs = [song for song in songs if song["id"] not in few_shot_ids] + + async def get_lyrics_and_classify(genius, song: dict) -> dict: + lyrics = await genius.get_lyrics(song["id"]) + attachment_style = model.classify_attachment_style(lyrics, system_message=few_shot_message) + if attachment_style: + song["predicted"] = max(attachment_style, key=attachment_style.get) + else: + song["predicted"] = "unknown" + return song + + songs = await tqdm_asyncio.gather(*[get_lyrics_and_classify(genius, song) for song in songs]) + if len(songs) > 0: + with open("./data/test_output_few_shot.json", "w") as json_file: + json.dump(songs, json_file, indent=4) + + + +async def main(): + token = os.getenv("TOKEN") + genius = API_Client(token) + model = TextGenerationModel() + + await test_few_shot(genius, model) + + await genius.close()