Skip to content

Commit

Permalink
Add support for Assist pipeline, update Wear implementation (home-ass…
Browse files Browse the repository at this point in the history
…istant#3526)

* Group incoming messages by subscription to prevent out-of-order delivery

 - Messages received on the websocket are processed asynchronously, which is usually fine but can cause issues if messages need to be received in a specific order for a subscription. To fix this, process messages in order for the same subscription.

* Implement Assist pipeline API

 - Add basic support for the Assist pipeline API
 - Update conversation function to use the Assist pipeline when on the minimum required version
 - Update UI to refer to Assist pipeline requirement
  • Loading branch information
jpelgrom authored May 13, 2023
1 parent 57024e1 commit 7d6f11a
Show file tree
Hide file tree
Showing 9 changed files with 161 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ interface IntegrationRepository {

suspend fun shouldNotifySecurityWarning(): Boolean

suspend fun getConversation(speech: String): String?
suspend fun getAssistResponse(speech: String): String?
}

@AssistedFactory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,21 @@ import io.homeassistant.companion.android.common.data.integration.impl.entities.
import io.homeassistant.companion.android.common.data.integration.impl.entities.Template
import io.homeassistant.companion.android.common.data.integration.impl.entities.UpdateLocationRequest
import io.homeassistant.companion.android.common.data.servers.ServerManager
import io.homeassistant.companion.android.common.data.websocket.impl.entities.AssistPipelineEventType
import io.homeassistant.companion.android.common.data.websocket.impl.entities.AssistPipelineIntentEnd
import io.homeassistant.companion.android.common.data.websocket.impl.entities.GetConfigResponse
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.filter
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.launch
import kotlinx.coroutines.suspendCancellableCoroutine
import okhttp3.HttpUrl.Companion.toHttpUrlOrNull
import java.util.concurrent.TimeUnit
import javax.inject.Named
import kotlin.coroutines.resume

class IntegrationRepositoryImpl @AssistedInject constructor(
private val integrationService: IntegrationService,
Expand Down Expand Up @@ -64,6 +72,8 @@ class IntegrationRepositoryImpl @AssistedInject constructor(
private const val APPLOCK_TIMEOUT_GRACE_MS = 1000
}

private val ioScope = CoroutineScope(Dispatchers.IO + Job())

private val server get() = serverManager.getServer(serverId)!!

private val webSocketRepository get() = serverManager.webSocketRepository(serverId)
Expand Down Expand Up @@ -523,11 +533,29 @@ class IntegrationRepositoryImpl @AssistedInject constructor(
}?.toList()
}

override suspend fun getConversation(speech: String): String? {
// TODO: Also send back conversation ID for dialogue
val response = webSocketRepository.getConversation(speech)

return response?.response?.speech?.plain?.get("speech")
override suspend fun getAssistResponse(speech: String): String? {
return if (server.version?.isAtLeast(2023, 5, 0) == true) {
var job: Job? = null
val response = suspendCancellableCoroutine { cont ->
job = ioScope.launch {
webSocketRepository.runAssistPipeline(speech)?.collect {
if (!cont.isActive) return@collect
when (it.type) {
AssistPipelineEventType.INTENT_END ->
cont.resume((it.data as AssistPipelineIntentEnd).intentOutput.response.speech.plain["speech"])
AssistPipelineEventType.ERROR,
AssistPipelineEventType.RUN_END -> cont.resume(null)
else -> { /* Do nothing */ }
}
} ?: cont.resume(null)
}
}
job?.cancel()
response
} else {
val response = webSocketRepository.getConversation(speech)
response?.response?.speech?.plain?.get("speech")
}
}

override suspend fun getEntities(): List<Entity<Any>>? {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import io.homeassistant.companion.android.common.data.integration.impl.entities.
import io.homeassistant.companion.android.common.data.websocket.impl.WebSocketRepositoryImpl
import io.homeassistant.companion.android.common.data.websocket.impl.entities.AreaRegistryResponse
import io.homeassistant.companion.android.common.data.websocket.impl.entities.AreaRegistryUpdatedEvent
import io.homeassistant.companion.android.common.data.websocket.impl.entities.AssistPipelineEvent
import io.homeassistant.companion.android.common.data.websocket.impl.entities.CompressedStateChangedEvent
import io.homeassistant.companion.android.common.data.websocket.impl.entities.ConversationResponse
import io.homeassistant.companion.android.common.data.websocket.impl.entities.CurrentUserResponse
Expand Down Expand Up @@ -48,7 +49,18 @@ interface WebSocketRepository {
suspend fun getThreadDatasets(): List<ThreadDatasetResponse>?
suspend fun getThreadDatasetTlv(datasetId: String): ThreadDatasetTlvResponse?
suspend fun addThreadDataset(tlv: ByteArray): Boolean

/**
* Get an Assist response for the given text input. For core >= 2023.5, use [runAssistPipeline]
* instead.
*/
suspend fun getConversation(speech: String): ConversationResponse?

/**
* Run the Assist pipeline for the given text input
* @return a Flow that will emit all events for the pipeline
*/
suspend fun runAssistPipeline(text: String): Flow<AssistPipelineEvent>?
}

@AssistedFactory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ import io.homeassistant.companion.android.common.data.websocket.WebSocketRequest
import io.homeassistant.companion.android.common.data.websocket.WebSocketState
import io.homeassistant.companion.android.common.data.websocket.impl.entities.AreaRegistryResponse
import io.homeassistant.companion.android.common.data.websocket.impl.entities.AreaRegistryUpdatedEvent
import io.homeassistant.companion.android.common.data.websocket.impl.entities.AssistPipelineEvent
import io.homeassistant.companion.android.common.data.websocket.impl.entities.AssistPipelineEventType
import io.homeassistant.companion.android.common.data.websocket.impl.entities.AssistPipelineIntentEnd
import io.homeassistant.companion.android.common.data.websocket.impl.entities.AssistPipelineIntentStart
import io.homeassistant.companion.android.common.data.websocket.impl.entities.AssistPipelineRunStart
import io.homeassistant.companion.android.common.data.websocket.impl.entities.CompressedStateChangedEvent
import io.homeassistant.companion.android.common.data.websocket.impl.entities.ConversationResponse
import io.homeassistant.companion.android.common.data.websocket.impl.entities.CurrentUserResponse
Expand Down Expand Up @@ -81,6 +86,7 @@ class WebSocketRepositoryImpl @AssistedInject constructor(
companion object {
private const val TAG = "WebSocketRepository"

private const val SUBSCRIBE_TYPE_ASSIST_PIPELINE_RUN = "assist_pipeline/run"
private const val SUBSCRIBE_TYPE_SUBSCRIBE_EVENTS = "subscribe_events"
private const val SUBSCRIBE_TYPE_SUBSCRIBE_ENTITIES = "subscribe_entities"
private const val SUBSCRIBE_TYPE_SUBSCRIBE_TRIGGER = "subscribe_trigger"
Expand Down Expand Up @@ -209,6 +215,18 @@ class WebSocketRepositoryImpl @AssistedInject constructor(
return mapResponse(socketResponse)
}

override suspend fun runAssistPipeline(text: String): Flow<AssistPipelineEvent>? =
subscribeTo(
SUBSCRIBE_TYPE_ASSIST_PIPELINE_RUN,
mapOf(
"start_stage" to "intent",
"end_stage" to "intent",
"input" to mapOf(
"text" to text
)
)
)

override suspend fun getStateChanges(): Flow<StateChangedEvent>? =
subscribeToEventsForType(EVENT_STATE_CHANGED)

Expand Down Expand Up @@ -629,6 +647,21 @@ class WebSocketRepositoryImpl @AssistedInject constructor(
Log.w(TAG, "Received no trigger value for trigger subscription, skipping")
return
}
} else if (subscriptionType == SUBSCRIBE_TYPE_ASSIST_PIPELINE_RUN) {
val eventType = response.event?.get("type")
if (eventType?.isTextual == true) {
val eventDataMap = response.event.get("data")
val eventData = when (eventType.textValue()) {
AssistPipelineEventType.RUN_START -> mapper.convertValue(eventDataMap, AssistPipelineRunStart::class.java)
AssistPipelineEventType.INTENT_START -> mapper.convertValue(eventDataMap, AssistPipelineIntentStart::class.java)
AssistPipelineEventType.INTENT_END -> mapper.convertValue(eventDataMap, AssistPipelineIntentEnd::class.java)
else -> null
}
AssistPipelineEvent(eventType.textValue(), eventData)
} else {
Log.w(TAG, "Received Assist pipeline event without type, skipping")
return
}
} else if (eventResponseType != null && eventResponseType.isTextual) {
val eventResponseClass = when (eventResponseType.textValue()) {
EVENT_STATE_CHANGED ->
Expand Down Expand Up @@ -737,17 +770,18 @@ class WebSocketRepositoryImpl @AssistedInject constructor(
listOf(mapper.readValue(text))
}

messages.forEach { message ->
Log.d(TAG, "Message number ${message.id} received")

messages.groupBy { it.id }.values.forEach { messagesForId ->
ioScope.launch {
when (message.type) {
"auth_required" -> Log.d(TAG, "Auth Requested")
"auth_ok" -> handleAuthComplete(true, message.haVersion)
"auth_invalid" -> handleAuthComplete(false, message.haVersion)
"pong", "result" -> handleMessage(message)
"event" -> handleEvent(message)
else -> Log.d(TAG, "Unknown message type: ${message.type}")
messagesForId.forEach { message ->
Log.d(TAG, "Message number ${message.id} received")
when (message.type) {
"auth_required" -> Log.d(TAG, "Auth Requested")
"auth_ok" -> handleAuthComplete(true, message.haVersion)
"auth_invalid" -> handleAuthComplete(false, message.haVersion)
"pong", "result" -> handleMessage(message)
"event" -> handleEvent(message)
else -> Log.d(TAG, "Unknown message type: ${message.type}")
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package io.homeassistant.companion.android.common.data.websocket.impl.entities

import com.fasterxml.jackson.annotation.JsonIgnoreProperties

data class AssistPipelineEvent(
val type: String,
val data: AssistPipelineEventData?
)

object AssistPipelineEventType {
const val RUN_START = "run-start"
const val RUN_END = "run-end"
const val STT_START = "stt-start"
const val STT_END = "stt-end"
const val INTENT_START = "intent-start"
const val INTENT_END = "intent-end"
const val TTS_START = "tts-start"
const val TTS_END = "tts-end"
const val ERROR = "error"
}

interface AssistPipelineEventData

@JsonIgnoreProperties(ignoreUnknown = true)
data class AssistPipelineRunStart(
val pipeline: String,
val language: String,
val runnerData: Map<String, Any?>
) : AssistPipelineEventData

@JsonIgnoreProperties(ignoreUnknown = true)
data class AssistPipelineIntentStart(
val engine: String,
val language: String,
val intentInput: String
) : AssistPipelineEventData

@JsonIgnoreProperties(ignoreUnknown = true)
data class AssistPipelineIntentEnd(
val intentOutput: ConversationResponse
) : AssistPipelineEventData
4 changes: 3 additions & 1 deletion common/src/main/res/values/strings.xml
Original file line number Diff line number Diff line change
Expand Up @@ -1052,7 +1052,9 @@
<string name="tile_vibrate">Vibrate when clicked</string>
<string name="tile_auth_required">Requires unlocked device</string>
<string name="no_results">No results yet</string>
<string name="no_conversation_support">You must be at least on Home Assistant 2023.1 and have the conversation integration enabled</string>
<string name="no_assist_support">You must be at least on Home Assistant %1$s and have the %2$s integration enabled</string>
<string name="no_assist_support_conversation">conversation</string>
<string name="no_assist_support_assist_pipeline">Assist pipeline</string>
<string name="conversation">Conversation</string>
<string name="assist">Assist</string>
<string name="assist_log_in">Log in to Home Assistant to start using Assist</string>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ class ConversationActivity : ComponentActivity() {
super.onCreate(savedInstanceState)

lifecycleScope.launch {
conversationViewModel.isSupportConversation()
if (conversationViewModel.supportsConversation) {
conversationViewModel.checkAssistSupport()
if (conversationViewModel.supportsAssist) {
val searchIntent = Intent(RecognizerIntent.ACTION_RECOGNIZE_SPEECH).apply {
putExtra(
RecognizerIntent.EXTRA_LANGUAGE_MODEL,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ class ConversationViewModel @Inject constructor(
var conversationResult by mutableStateOf("")
private set

var supportsConversation by mutableStateOf(false)
var supportsAssist by mutableStateOf(false)
private set

var useAssistPipeline by mutableStateOf(false)
private set

var isHapticEnabled = mutableStateOf(false)
Expand All @@ -41,20 +44,28 @@ class ConversationViewModel @Inject constructor(
viewModelScope.launch {
conversationResult =
if (serverManager.isRegistered()) {
serverManager.integrationRepository().getConversation(speechResult) ?: ""
serverManager.integrationRepository().getAssistResponse(speechResult) ?: ""
} else {
""
}
}
}

suspend fun isSupportConversation() {
suspend fun checkAssistSupport() {
checkSupportProgress = true
isRegistered = serverManager.isRegistered()
supportsConversation =
serverManager.isRegistered() &&
serverManager.integrationRepository().isHomeAssistantVersionAtLeast(2023, 1, 0) &&
serverManager.webSocketRepository().getConfig()?.components?.contains("conversation") == true

if (serverManager.isRegistered()) {
val config = serverManager.webSocketRepository().getConfig()
val onConversationVersion = serverManager.integrationRepository().isHomeAssistantVersionAtLeast(2023, 1, 0)
val onPipelineVersion = serverManager.integrationRepository().isHomeAssistantVersionAtLeast(2023, 5, 0)

supportsAssist =
(onConversationVersion && !onPipelineVersion && config?.components?.contains("conversation") == true) ||
(onPipelineVersion && config?.components?.contains("assist_pipeline") == true)
useAssistPipeline = onPipelineVersion
}

isHapticEnabled.value = wearPrefsRepository.getWearHapticFeedback()
checkSupportProgress = false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,13 @@ fun ConversationResultView(
SpeechBubble(
text = conversationViewModel.speechResult.ifEmpty {
when {
(conversationViewModel.supportsConversation) -> stringResource(R.string.no_results)
(!conversationViewModel.supportsConversation && !conversationViewModel.checkSupportProgress) -> stringResource(R.string.no_conversation_support)
conversationViewModel.supportsAssist -> stringResource(R.string.no_results)
(!conversationViewModel.supportsAssist && !conversationViewModel.checkSupportProgress) ->
if (conversationViewModel.useAssistPipeline) {
stringResource(R.string.no_assist_support, "2023.5", stringResource(R.string.no_assist_support_assist_pipeline))
} else {
stringResource(R.string.no_assist_support, "2023.1", stringResource(R.string.no_assist_support_conversation))
}
(!conversationViewModel.isRegistered) -> stringResource(R.string.not_registered)
else -> "..."
}
Expand Down

0 comments on commit 7d6f11a

Please sign in to comment.