Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question About Perplexity Calculation in "3.2 Rationale-Guided Filtering" and Request for Relevant Code #1

Closed
jmycsu opened this issue Jan 25, 2025 · 5 comments

Comments

@jmycsu
Copy link

jmycsu commented Jan 25, 2025

Dear authors,

First of all, congratulations on having your paper accepted to NAACL 2025!

I have a question regarding the perplexity calculation method mentioned in Section 3.2 ("Rationale-Guided Filtering") of your paper. I’ve tried to manually implement this calculation, but the perplexity values I obtained seem to be incorrect. Could you please help me identify if there’s anything wrong with my implementation? Also, would it be possible for you to release the code for your own implementation of the perplexity calculation? It would be incredibly helpful in understanding and reproducing the results from your paper.

Thank you so much for your time and assistance!

Below is the code I’ve implemented for this purpose:

def perplexity(self, input_text):
    text = self.tokenizer.apply_chat_template(
        input_text,
        tokenize=False,
        add_generation_prompt=True,
    )
    model_inputs = self.tokenizer(text, return_tensors="pt")
    input_ids = model_inputs.input_ids

    input_ids = input_ids.to(self.chat_model.device)
    input_length = input_ids.shape[1]
    attention_mask = torch.ones_like(input_ids)

    outputs = self.chat_model.generate(
        input_ids = input_ids, 
        attention_mask = attention_mask,
        max_new_tokens = 4096, 
        return_dict_in_generate = True, 
        output_scores = True,
    )

    generated_tokens = outputs.sequences[:, input_length:]
    text = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
        
    useful_tokens_index = torch.tensor([i for i, token_ids in enumerate(generated_tokens[0]) if token_ids not in self.tokenizer.all_special_ids])

    transition_scores = chat_model.compute_transition_scores(
        outputs.sequences, outputs.scores, normalize_logits=True
    )
    logprobs = transition_scores[0]  

    seq_logprobs = logprobs[useful_tokens_index]

    avg_nll = -torch.mean(seq_logprobs)

    perplexity = torch.exp(avg_nll).item()

    return perplexity
@jw-sohn
Copy link
Collaborator

jw-sohn commented Jan 25, 2025

Thank you for your interest in RAG^2 !

I will share the exact code I used

from vllm import LLM, SamplingParams
import numpy as np

# Initialize the model
llm = LLM(
    model="meta-llama/Meta-Llama-3-8B-Instruct",
    dtype="bfloat16",
    gpu_memory_utilization=0.9,
    max_model_len=8192,
)

# Define inputs
inputs = ["Your first input prompt here.", "Your second input prompt here."]

# Get tokenizer
tokenizer = llm.get_tokenizer()

# Generate outputs
generated = llm.generate(
    inputs,
    SamplingParams(
        temperature=0.0,
        top_k=1,
        stop_token_ids=[tokenizer.vocab.get("<|eot_id|>", None)],  # Use .get to avoid KeyError
        max_tokens=8192,
        logprobs=1
    ),
)

# Calculate perplexity
perplexity_list = []
for g in generated:
    cumulative_logprob = g.outputs[0].cumulative_logprob
    num_tokens = len(g.outputs[0].logprobs)
    avg_logprob = cumulative_logprob / num_tokens
    perplexity = np.exp(-avg_logprob)
    perplexity_list.append(perplexity)

print("Perplexity list:", perplexity_list)

Let me know any specific errors you encounter!

@jmycsu
Copy link
Author

jmycsu commented Jan 25, 2025

Thank you for your interest in RAG^2 !

I will share the exact code I used

from vllm import LLM, SamplingParams
import numpy as np

# Initialize the model
llm = LLM(
    model="meta-llama/Meta-Llama-3-8B-Instruct",
    dtype="bfloat16",
    gpu_memory_utilization=0.9,
    max_model_len=8192,
)

# Define inputs
inputs = ["Your first input prompt here.", "Your second input prompt here."]

# Get tokenizer
tokenizer = llm.get_tokenizer()

# Generate outputs
generated = llm.generate(
    inputs,
    SamplingParams(
        temperature=0.0,
        top_k=1,
        stop_token_ids=[tokenizer.vocab.get("<|eot_id|>", None)],  # Use .get to avoid KeyError
        max_tokens=8192,
        logprobs=1
    ),
)

# Calculate perplexity
perplexity_list = []
for g in generated:
    cumulative_logprob = g.outputs[0].cumulative_logprob
    num_tokens = len(g.outputs[0].logprobs)
    avg_logprob = cumulative_logprob / num_tokens
    perplexity = np.exp(-avg_logprob)
    perplexity_list.append(perplexity)

print("Perplexity list:", perplexity_list)

Let me know any specific errors you encounter!

Thank you so much for your response and for providing the code—it’s been very helpful and has greatly clarified my understanding.

I do have a follow-up question regarding the perplexity values. Could you please share the approximate range of perplexity values you obtained? In my own implementation, the perplexity values are quite small, typically between 1.x and 2.x (for example, 1.48). I wonder if this range is correct or if it suggests an error in my calculation.

Thanks again for your time and kind explanation.

@jw-sohn
Copy link
Collaborator

jw-sohn commented Jan 25, 2025

Yes, the perplexity values can indeed be small, but still able to calculate thresholds based on the top percentages (e.g., top 5%, 10%, 25%) across the entire training set.

Let me know if you have further questions or need clarification!

@jmycsu
Copy link
Author

jmycsu commented Jan 25, 2025

Yes, the perplexity values can indeed be small, but still able to calculate thresholds based on the top percentages (e.g., top 5%, 10%, 25%) across the entire training set.

Let me know if you have further questions or need clarification!

Thank you so much for your timely responses and for addressing all my questions. Wishing you all the best in your future work!

@jw-sohn jw-sohn closed this as completed Jan 25, 2025
@jw-sohn
Copy link
Collaborator

jw-sohn commented Jan 25, 2025

Feel free to reach out whenever you need help :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants