Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ interface DownloadRepository {
task: Task?,
model: Model,
onStatusUpdated: (model: Model, status: ModelDownloadStatus) -> Unit,
customUrl: String? = null,
)

fun cancelDownloadModel(model: Model)
Expand Down Expand Up @@ -97,14 +98,21 @@ class DefaultDownloadRepository(
task: Task?,
model: Model,
onStatusUpdated: (model: Model, status: ModelDownloadStatus) -> Unit,
customUrl: String?,
) {
val downloadUrl = customUrl ?: model.url
if (downloadUrl.isEmpty()) {
Log.e(TAG, "Cannot download a model without url. Name: ${model.name}")
return
}

// Create input data.
val builder = Data.Builder()
val totalBytes = model.totalBytes + model.extraDataFiles.sumOf { it.sizeInBytes }
val inputDataBuilder =
builder
.putString(KEY_MODEL_NAME, model.name)
.putString(KEY_MODEL_URL, model.url)
.putString(KEY_MODEL_URL, downloadUrl)
.putString(KEY_MODEL_COMMIT_HASH, model.version)
.putString(KEY_MODEL_DOWNLOAD_MODEL_DIR, model.normalizedName)
.putString(KEY_MODEL_DOWNLOAD_FILE_NAME, model.downloadFileName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,22 @@ import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.width
import android.net.Uri
import android.provider.OpenableColumns
import android.widget.Toast
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.contract.ActivityResultContracts
import androidx.compose.material3.AlertDialog
import androidx.compose.material3.OutlinedTextField
import androidx.compose.material3.TextButton
import androidx.compose.material3.OutlinedButton
import androidx.compose.material.icons.rounded.Link
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.getValue
import androidx.compose.runtime.setValue
import androidx.compose.foundation.layout.Column
import androidx.compose.ui.platform.LocalContext
import androidx.compose.foundation.text.TextAutoSize
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.BarChart
Expand Down Expand Up @@ -120,6 +136,36 @@ fun DownloadModelPanel(
return !downloadFailed || isLitertLm
}

if (downloadStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED) {
var showFulfillDialog by remember { mutableStateOf(false) }

OutlinedButton(
onClick = { showFulfillDialog = true },
modifier = Modifier.height(42.dp),
contentPadding = PaddingValues(horizontal = 12.dp)
) {
Icon(Icons.Rounded.Link, contentDescription = null, tint = MaterialTheme.colorScheme.primary)
Spacer(modifier = Modifier.width(6.dp))
Text("Import Custom", color = MaterialTheme.colorScheme.primary)
}
Spacer(modifier = Modifier.width(8.dp))

if (showFulfillDialog) {
ModelFulfillmentDialog(
model = model,
onDismiss = { showFulfillDialog = false },
onUrlPicked = { url ->
showFulfillDialog = false
modelManagerViewModel.downloadModel(task = task, model = model, customUrl = url)
},
onLocalUriPicked = { uri ->
showFulfillDialog = false
modelManagerViewModel.fulfillModelWithLocalUri(model, uri)
}
)
}
}

DownloadAndTryButton(
task = task,
model = model,
Expand All @@ -138,3 +184,90 @@ fun DownloadModelPanel(
}
}
}

@Composable
fun ModelFulfillmentDialog(
model: Model,
onDismiss: () -> Unit,
onUrlPicked: (String) -> Unit,
onLocalUriPicked: (Uri) -> Unit,
) {
var url by remember { mutableStateOf("") }
var errorMsg by remember { mutableStateOf("") }
val context = LocalContext.current

val filePickerLauncher = rememberLauncherForActivityResult(
contract = ActivityResultContracts.OpenDocument(),
onResult = { uri ->
if (uri != null) {
context.contentResolver.query(uri, arrayOf(OpenableColumns.DISPLAY_NAME), null, null, null)?.use { cursor ->
if (cursor.moveToFirst()) {
val nameIndex = cursor.getColumnIndexOrThrow(OpenableColumns.DISPLAY_NAME)
val displayName = cursor.getString(nameIndex)
if (displayName != model.downloadFileName) {
Toast.makeText(context, "File name '$displayName' does not match expected model file '${model.downloadFileName}'", Toast.LENGTH_LONG).show()
} else {
onLocalUriPicked(uri)
return@use
}
}
}
onDismiss()
}
}
)

AlertDialog(
onDismissRequest = onDismiss,
title = { Text("Import Custom Source") },
text = {
Column(verticalArrangement = Arrangement.spacedBy(8.dp)) {
Text("Expected file name: ${model.downloadFileName}", style = MaterialTheme.typography.bodySmall)
OutlinedTextField(
value = url,
onValueChange = {
url = it
errorMsg = ""
},
label = { Text("Model URL") },
isError = errorMsg.isNotEmpty(),
singleLine = true,
modifier = Modifier.fillMaxWidth()
)
if (errorMsg.isNotEmpty()) {
Text(errorMsg, style = MaterialTheme.typography.labelSmall, color = MaterialTheme.colorScheme.error)
}

Text("OR", modifier = Modifier.align(Alignment.CenterHorizontally), style = MaterialTheme.typography.bodyMedium)
OutlinedButton(
onClick = {
filePickerLauncher.launch(arrayOf("*/*"))
},
modifier = Modifier.fillMaxWidth()
) {
Text("Select Local File")
}
}
},
confirmButton = {
Button(
onClick = {
if (url.isNotBlank() && android.webkit.URLUtil.isValidUrl(url) && (url.startsWith("http://") || url.startsWith("https://"))) {
val parsedUrl = Uri.parse(url)
val fileName = parsedUrl.lastPathSegment ?: ""
if (fileName == model.downloadFileName) {
onUrlPicked(url)
} else {
errorMsg = "URL string must end with ${model.downloadFileName}. Found: '$fileName'"
}
} else {
errorMsg = "Invalid HTTP/HTTPS URL"
}
}
) { Text("Set URL") }
},
dismissButton = {
TextButton(onClick = onDismiss) { Text("Cancel") }
}
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.google.ai.edge.gallery.ui.modelmanager

import android.content.Context
import android.util.Log
import android.net.Uri
import androidx.activity.result.ActivityResult
import androidx.core.net.toUri
import androidx.lifecycle.ViewModel
Expand Down Expand Up @@ -276,7 +277,7 @@ constructor(
}
}

fun downloadModel(task: Task?, model: Model) {
fun downloadModel(task: Task?, model: Model, customUrl: String? = null) {
// Update status.
setDownloadStatus(
curModel = model,
Expand All @@ -291,9 +292,42 @@ constructor(
task = task,
model = model,
onStatusUpdated = this::setDownloadStatus,
customUrl = customUrl,
)
}

fun fulfillModelWithLocalUri(model: Model, uri: Uri) {
setDownloadStatus(
curModel = model,
status = ModelDownloadStatus(status = ModelDownloadStatusType.IN_PROGRESS),
)
deleteModel(model = model)

viewModelScope.launch(Dispatchers.IO) {
try {
val destFile = File(model.getPath(context))
if (destFile.parentFile?.exists() == false) {
destFile.parentFile?.mkdirs()
}
context.contentResolver.openInputStream(uri)?.use { input ->
destFile.outputStream().use { output ->
input.copyTo(output)
}
}
val status = ModelDownloadStatus(
status = ModelDownloadStatusType.SUCCEEDED,
receivedBytes = destFile.length(),
totalBytes = destFile.length()
)
setDownloadStatus(model, status)
} catch (e: Exception) {
Log.e(TAG, "Failed to copy local file for fulfillment", e)
val status = ModelDownloadStatus(status = ModelDownloadStatusType.FAILED)
setDownloadStatus(model, status)
}
}
}

fun cancelDownloadModel(model: Model) {
downloadRepository.cancelDownloadModel(model)
deleteModel(model = model)
Expand Down