|
|
|
@@ -12,6 +12,9 @@ from mautrix.errors import MatrixRequestError
|
|
|
|
from mautrix.types import EventType, MessageType, RelationType, TextMessageEventContent, Format,RelatesTo,InReplyTo
|
|
|
|
from mautrix.types import EventType, MessageType, RelationType, TextMessageEventContent, Format,RelatesTo,InReplyTo
|
|
|
|
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
|
|
|
|
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ALLOWED_EXTENSIONS = ['flac', 'm4a', 'mp3', 'mp4', 'mpeg', 'mpga', 'oga', 'ogg', 'wav', 'webm']
|
|
|
|
|
|
|
|
ALLOWED_MIME_TYPES = ['audio/flac','audio/mp4','video/mpeg','audio/ogg','audio/wav','video/webm']
|
|
|
|
|
|
|
|
|
|
|
|
class Config(BaseProxyConfig):
|
|
|
|
class Config(BaseProxyConfig):
|
|
|
|
def do_update(self, helper: ConfigUpdateHelper) -> None:
|
|
|
|
def do_update(self, helper: ConfigUpdateHelper) -> None:
|
|
|
|
helper.copy("whisper_endpoint")
|
|
|
|
helper.copy("whisper_endpoint")
|
|
|
|
@@ -19,6 +22,7 @@ class Config(BaseProxyConfig):
|
|
|
|
helper.copy("allowed_users")
|
|
|
|
helper.copy("allowed_users")
|
|
|
|
helper.copy("allowed_rooms")
|
|
|
|
helper.copy("allowed_rooms")
|
|
|
|
helper.copy("prompt")
|
|
|
|
helper.copy("prompt")
|
|
|
|
|
|
|
|
helper.copy("search_reminders_and_events")
|
|
|
|
helper.copy("language")
|
|
|
|
helper.copy("language")
|
|
|
|
|
|
|
|
|
|
|
|
class WhisperPlugin(Plugin):
|
|
|
|
class WhisperPlugin(Plugin):
|
|
|
|
@@ -32,6 +36,7 @@ class WhisperPlugin(Plugin):
|
|
|
|
self.language = self.config['language']
|
|
|
|
self.language = self.config['language']
|
|
|
|
self.allowed_users = self.config['allowed_users']
|
|
|
|
self.allowed_users = self.config['allowed_users']
|
|
|
|
self.allowed_rooms = self.config['allowed_rooms']
|
|
|
|
self.allowed_rooms = self.config['allowed_rooms']
|
|
|
|
|
|
|
|
self.search_reminders_and_events = self.config['search_reminders_and_events']
|
|
|
|
self.log.debug("Whisper plugin started")
|
|
|
|
self.log.debug("Whisper plugin started")
|
|
|
|
|
|
|
|
|
|
|
|
async def should_respond(self, event: MessageEvent) -> bool:
|
|
|
|
async def should_respond(self, event: MessageEvent) -> bool:
|
|
|
|
@@ -44,6 +49,9 @@ class WhisperPlugin(Plugin):
|
|
|
|
if self.allowed_rooms and event.room_id not in self.allowed_rooms:
|
|
|
|
if self.allowed_rooms and event.room_id not in self.allowed_rooms:
|
|
|
|
return False
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if event.content.info.mimetype not in ALLOWED_MIME_TYPES:
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
return event.content.msgtype == MessageType.AUDIO or event.content.msgtype == MessageType.FILE
|
|
|
|
return event.content.msgtype == MessageType.AUDIO or event.content.msgtype == MessageType.FILE
|
|
|
|
|
|
|
|
|
|
|
|
@event.on(EventType.ROOM_MESSAGE)
|
|
|
|
@event.on(EventType.ROOM_MESSAGE)
|
|
|
|
@@ -58,8 +66,10 @@ class WhisperPlugin(Plugin):
|
|
|
|
audio_bytes = await self.client.download_media(url=event.content.url)
|
|
|
|
audio_bytes = await self.client.download_media(url=event.content.url)
|
|
|
|
transcription = await self.transcribe_audio(audio_bytes)
|
|
|
|
transcription = await self.transcribe_audio(audio_bytes)
|
|
|
|
|
|
|
|
|
|
|
|
await self.client.set_typing(event.room_id, timeout=0)
|
|
|
|
if self.search_reminders_and_events:
|
|
|
|
|
|
|
|
transcription = await self.study_transcribe(transcription)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
await self.client.set_typing(event.room_id, timeout=0)
|
|
|
|
content = TextMessageEventContent(
|
|
|
|
content = TextMessageEventContent(
|
|
|
|
msgtype=MessageType.TEXT,
|
|
|
|
msgtype=MessageType.TEXT,
|
|
|
|
body=transcription,
|
|
|
|
body=transcription,
|
|
|
|
@@ -104,6 +114,40 @@ class WhisperPlugin(Plugin):
|
|
|
|
self.log.exception(f"Failed to transcribe audio, msg: {e}")
|
|
|
|
self.log.exception(f"Failed to transcribe audio, msg: {e}")
|
|
|
|
return "Sorry, an error occurred while transcribing the audio."
|
|
|
|
return "Sorry, an error occurred while transcribing the audio."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def study_transcribe(self, transcription: str) -> str:
|
|
|
|
|
|
|
|
prompt = f"""
|
|
|
|
|
|
|
|
Voici la transcription du message vocal :
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
{transcription}
|
|
|
|
|
|
|
|
Ton objectif est d'analyser cette transcription afin de déterminer si l'utilisateur tente de créer un rappel ou un évènement.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- Si l'utilisateur essaie de créer un rappel, la sortie doit prendre la forme :
|
|
|
|
|
|
|
|
!rappel <date> <message>
|
|
|
|
|
|
|
|
- Si l'utilisateur essaie de créer un évènement, la sortie doit prendre la forme :
|
|
|
|
|
|
|
|
!agenda ##ROOM## <message>
|
|
|
|
|
|
|
|
- Si l'utilisateur ne cherche ni à créer un rappel ni un évènement, renvoie seulement la transcription telle quelle, sans ajout d'explication, de texte supplémentaire ou de ponctuation superflue.
|
|
|
|
|
|
|
|
Ne fournis aucun autre texte ni explication dans ta réponse, uniquement la sortie demandée.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
url = "https://api.openai.com/v1/chat/completions"
|
|
|
|
|
|
|
|
headers = {
|
|
|
|
|
|
|
|
"Content-Type": "application/json",
|
|
|
|
|
|
|
|
"Authorization": f"Bearer {self.api_key}"
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data = {
|
|
|
|
|
|
|
|
"model": "gpt-4",
|
|
|
|
|
|
|
|
"messages": [
|
|
|
|
|
|
|
|
{"role": "user", "content": prompt}
|
|
|
|
|
|
|
|
],
|
|
|
|
|
|
|
|
"temperature": 0.7
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
|
|
|
|
|
|
async with session.post(url, headers=headers, json=data) as response:
|
|
|
|
|
|
|
|
response_json = await response.json()
|
|
|
|
|
|
|
|
return response_json.get('choices', [])[0].get('message', {}).get('content', transcription)
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def get_config_class(cls) -> Type[BaseProxyConfig]:
|
|
|
|
def get_config_class(cls) -> Type[BaseProxyConfig]:
|
|
|
|
return Config
|
|
|
|
return Config
|
|
|
|
|