Skip to content

Commit

Permalink
update downloads
Browse files Browse the repository at this point in the history
  • Loading branch information
fionabos committed Aug 3, 2024
1 parent 70ace2d commit c2c857f
Showing 1 changed file with 73 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public class MainActivity extends AppCompatActivity implements Consumer<String>
private ImageButton sendMsgIB;
private TextView generatedTV;
private TextView promptTV;
private TextView progressText;
private static final String TAG = "genai.demo.MainActivity";

private static boolean fileExists(Context context, String fileName) {
Expand Down Expand Up @@ -92,7 +93,7 @@ public void onClick(View v) {
}

String promptQuestion = userMsgEdt.getText().toString();
String promptQuestion_formatted = "<system>You are a helpful AI assistant. Answer in one paragraph or less<|end|><|user|>"+promptQuestion+"<|end|>\n<assistant|>";
String promptQuestion_formatted = "<system>You are a helpful AI assistant. Answer in two paragraphs or less<|end|><|user|>"+promptQuestion+"<|end|>\n<assistant|>";
Log.i("GenAI: prompt question", promptQuestion_formatted);
setVisibility();

Expand All @@ -112,7 +113,7 @@ public void run() {

GeneratorParams generatorParams = model.createGeneratorParams();
//generatorParams.setSearchOption("length_penalty", 1000);
generatorParams.setSearchOption("max_length", 500);
//generatorParams.setSearchOption("max_length", 500);

Sequences encodedPrompt = tokenizer.encode(promptQuestion_formatted);
generatorParams.setInput(encodedPrompt);
Expand All @@ -128,7 +129,9 @@ public void run() {
tokenListener.accept(stream.decode(token));
}

//generator.close();
generator.close();
generatorParams.close();

}
catch (GenAIException e) {
throw new RuntimeException(e);
Expand All @@ -153,89 +156,81 @@ protected void onDestroy() {
}

private void downloadModels(Context context) throws GenAIException {
List<Pair<String, String>> urlFilePairs = Arrays.asList(
new Pair<>(
"https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/added_tokens.json?download=true",
"added_tokens.json"),
new Pair<>(
"https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/config.json?download=true",
"config.json"),
new Pair<>(
"https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/configuration_phi3.py?download=true",
"configuration_phi3.py"),
new Pair<>(
"https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/genai_config.json?download=true",
"genai_config.json"),
new Pair<>(
"https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/phi3-mini-4k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx?download=true",
"phi3-mini-4k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx"),
new Pair<>(
"https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/phi3-mini-4k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx.data?download=true",
"phi3-mini-4k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx.data"),
new Pair<>(
"https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/special_tokens_map.json?download=true",
"special_tokens_map.json"),
new Pair<>(
"https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/tokenizer.json?download=true",
"tokenizer.json"),
new Pair<>(
"https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/tokenizer.model?download=true",
"tokenizer.model"),
new Pair<>(
"https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/tokenizer_config.json?download=true",
"tokenizer_config.json"));

final String baseUrl = "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/";
List<String> files = Arrays.asList(
"added_tokens.json",
"config.json",
"configuration_phi3.py",
"genai_config.json",
"phi3-mini-4k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx",
"phi3-mini-4k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx.data",
"special_tokens_map.json",
"tokenizer.json",
"tokenizer.model",
"tokenizer_config.json");

List<Pair<String, String>> urlFilePairs = new ArrayList<>();
for (String file : files) {
if (/*file.endsWith(".data") ||*/ !fileExists(context, file)) {
urlFilePairs.add(new Pair<>(
baseUrl + file,// + "?download=true",
file));
}
}
if (urlFilePairs.isEmpty()) {
// Display a message using Toast
Toast.makeText(this, "All files already exist. Skipping download.", Toast.LENGTH_SHORT).show();
Log.d(TAG, "All files already exist. Skipping download.");
model = new Model(getFilesDir().getPath());
tokenizer = model.createTokenizer();
return;
}

progressText.setText("Downloading...");
progressText.setVisibility(View.VISIBLE);

Toast.makeText(this,
"Downloading model for the app... Model Size greater than 2GB, please allow a few minutes to download.",
Toast.LENGTH_SHORT).show();

ExecutorService executor = Executors.newSingleThreadExecutor();
for (int i = 0; i < urlFilePairs.size(); i++) {
final int index = i;
String url = urlFilePairs.get(index).first;
String fileName = urlFilePairs.get(index).second;
if (fileExists(context, fileName)) {
// Display a message using Toast
Toast.makeText(this, "File already exists. Skipping Download.", Toast.LENGTH_SHORT).show();

Log.d(TAG, "File " + fileName + " already exists. Skipping download.");
// note: since we always download the files lists together for once,
// so assuming if one filename exists, then the download model step has already
// be
// done.
if (index == urlFilePairs.size() - 1) {
model = new Model(getFilesDir().getPath());
tokenizer = model.createTokenizer();
break;
}
continue;
}
executor.execute(() -> {
ModelDownloader.downloadModel(context, url, fileName, new ModelDownloader.DownloadCallback() {
private long pctDone = 0;
@Override
public void onDownloadProgress(long bytesDone, long bytesTotal) {
if (bytesTotal > 0) {
long newPctDone = bytesDone * 100 / bytesTotal;
if (newPctDone > pctDone) {
pctDone = newPctDone;
Log.d(TAG, "Download" + fileName + ": " + pctDone
+ "% of " + (bytesTotal/1024) + " KB");
}
}
executor.execute(() -> {
ModelDownloader.downloadModel(context, urlFilePairs, new ModelDownloader.DownloadCallback() {
@Override
public void onProgress(long lastBytesRead, long bytesRead, long bytesTotal) {
long lastPctDone = 100 * lastBytesRead / bytesTotal;
long pctDone = 100 * bytesRead / bytesTotal;
if (pctDone > lastPctDone) {
Log.d(TAG, "Downloading files: " + pctDone + "%");
//if (pctDone / 10 > lastPctDone / 10) {
runOnUiThread(() -> {
progressText.setText("Downloading: " + pctDone + "%");
//Toast.makeText(context, "Downloading: " + pctDone + "%", Toast.LENGTH_SHORT).show();
});
//}
}
@Override
public void onDownloadComplete() throws GenAIException {
Log.d(TAG, "Download complete for " + fileName);
if (index == urlFilePairs.size() - 1) {
// Last download completed, create GenAIWrapper
model = new Model(getFilesDir().getPath());
tokenizer = model.createTokenizer();
Log.d(TAG, "All downloads completed");
}
}
@Override
public void onDownloadComplete() {
Log.d(TAG, "All downloads completed.");

// Last download completed, create SimpleGenAI
try {
model = new Model(getFilesDir().getPath());
tokenizer = model.createTokenizer();
runOnUiThread(() -> {
Toast.makeText(context, "All downloads completed", Toast.LENGTH_SHORT).show();
progressText.setVisibility(View.INVISIBLE);
});
} catch (GenAIException e) {
e.printStackTrace();
throw new RuntimeException(e);
}
});

}
});
}
});
executor.shutdown();
}

Expand Down

0 comments on commit c2c857f

Please sign in to comment.