fixed mime types
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
maubot: 0.1.0
|
maubot: 0.1.0
|
||||||
id: nigzu.com.maubot-stt
|
id: nigzu.com.maubot-stt
|
||||||
version: 0.3.1
|
version: 0.3.2
|
||||||
license: MIT
|
license: MIT
|
||||||
modules:
|
modules:
|
||||||
- openai-whisper
|
- openai-whisper
|
||||||
|
|||||||
161
openai-whisper copy.py
Normal file
161
openai-whisper copy.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import aiohttp
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Type
|
||||||
|
from mautrix.client import Client
|
||||||
|
from maubot.handlers import event
|
||||||
|
from maubot import Plugin, MessageEvent
|
||||||
|
from mautrix.errors import MatrixRequestError
|
||||||
|
from mautrix.types import EventType, MessageType, RelationType, TextMessageEventContent, Format,RelatesTo,InReplyTo
|
||||||
|
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
|
||||||
|
|
||||||
|
ALLOWED_EXTENSIONS = ['flac', 'm4a', 'mp3', 'mp4', 'mpeg', 'mpga', 'oga', 'ogg', 'wav', 'webm']
|
||||||
|
ALLOWED_MIME_TYPES = ['audio/flac','audio/mp4','video/mpeg','audio/ogg','audio/wav','video/webm']
|
||||||
|
|
||||||
|
class Config(BaseProxyConfig):
|
||||||
|
def do_update(self, helper: ConfigUpdateHelper) -> None:
|
||||||
|
helper.copy("whisper_endpoint")
|
||||||
|
helper.copy("openai_api_key")
|
||||||
|
helper.copy("allowed_users")
|
||||||
|
helper.copy("allowed_rooms")
|
||||||
|
helper.copy("prompt")
|
||||||
|
helper.copy("search_reminders_and_events")
|
||||||
|
helper.copy("language")
|
||||||
|
|
||||||
|
class WhisperPlugin(Plugin):
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
await super().start()
|
||||||
|
self.config.load_and_update()
|
||||||
|
self.whisper_endpoint = self.config['whisper_endpoint']
|
||||||
|
self.api_key = self.config['openai_api_key']
|
||||||
|
self.prompt = self.config['prompt']
|
||||||
|
self.language = self.config['language']
|
||||||
|
self.allowed_users = self.config['allowed_users']
|
||||||
|
self.allowed_rooms = self.config['allowed_rooms']
|
||||||
|
self.search_reminders_and_events = self.config['search_reminders_and_events']
|
||||||
|
self.log.debug("Whisper plugin started")
|
||||||
|
|
||||||
|
async def should_respond(self, event: MessageEvent) -> bool:
|
||||||
|
if event.sender == self.client.mxid:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self.allowed_users and event.sender not in self.allowed_users:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self.allowed_rooms and event.room_id not in self.allowed_rooms:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if event.content.info.mimetype not in ALLOWED_MIME_TYPES:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return event.content.msgtype == MessageType.AUDIO or event.content.msgtype == MessageType.FILE
|
||||||
|
|
||||||
|
@event.on(EventType.ROOM_MESSAGE)
|
||||||
|
async def on_message(self, event: MessageEvent) -> None:
|
||||||
|
if not await self.should_respond(event):
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await event.mark_read()
|
||||||
|
await self.client.set_typing(event.room_id, timeout=99999)
|
||||||
|
|
||||||
|
audio_bytes = await self.client.download_media(url=event.content.url)
|
||||||
|
transcription = await self.transcribe_audio(audio_bytes)
|
||||||
|
|
||||||
|
if self.search_reminders_and_events:
|
||||||
|
transcription = await self.study_transcribe(transcription)
|
||||||
|
|
||||||
|
await self.client.set_typing(event.room_id, timeout=0)
|
||||||
|
content = TextMessageEventContent(
|
||||||
|
msgtype=MessageType.TEXT,
|
||||||
|
body=transcription,
|
||||||
|
format=Format.HTML,
|
||||||
|
formatted_body=transcription
|
||||||
|
)
|
||||||
|
in_reply_to = InReplyTo(event_id=event.event_id)
|
||||||
|
if event.content.relates_to and event.content.relates_to.rel_type == RelationType.THREAD:
|
||||||
|
await event.respond(content, in_thread=True)
|
||||||
|
else:
|
||||||
|
content.relates_to = RelatesTo(
|
||||||
|
in_reply_to=in_reply_to
|
||||||
|
)
|
||||||
|
await event.respond(content)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.log.exception(f"Something went wrong: {e}")
|
||||||
|
await event.respond(f"Something went wrong: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def transcribe_audio(self, audio_bytes: bytes) -> str:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}"
|
||||||
|
}
|
||||||
|
data = aiohttp.FormData()
|
||||||
|
data.add_field('file', audio_bytes, filename='audio.mp3', content_type='audio/mpeg')
|
||||||
|
data.add_field('model', 'whisper-1')
|
||||||
|
if self.prompt:
|
||||||
|
data.add_field('prompt', f"{self.prompt}")
|
||||||
|
if self.language:
|
||||||
|
data.add_field('language', f"{self.language}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(self.whisper_endpoint, headers=headers, data=data) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
self.log.error(f"Error response from API: {await response.text()}")
|
||||||
|
return f"Error: {await response.text()}"
|
||||||
|
response_json = await response.json()
|
||||||
|
return response_json.get("text", "Sorry, I can't transcribe the audio.")
|
||||||
|
except Exception as e:
|
||||||
|
self.log.exception(f"Failed to transcribe audio, msg: {e}")
|
||||||
|
return "Sorry, an error occurred while transcribing the audio."
|
||||||
|
|
||||||
|
async def study_transcribe(self, transcription: str) -> str:
|
||||||
|
prompt = f"""
|
||||||
|
Voici la transcription du message vocal :
|
||||||
|
|
||||||
|
{transcription}
|
||||||
|
Ton objectif est d'analyser cette transcription afin de déterminer si l'utilisateur tente de créer un rappel ou un évènement.
|
||||||
|
|
||||||
|
- Si l'utilisateur essaie de créer un rappel, la sortie doit prendre la forme :
|
||||||
|
!rappel <date> <message>
|
||||||
|
- Si l'utilisateur essaie de créer un évènement, la sortie doit prendre la forme :
|
||||||
|
!agenda ##ROOM## <message>
|
||||||
|
- Si l'utilisateur ne cherche ni à créer un rappel ni un évènement, renvoie seulement la transcription telle quelle, sans ajout d'explication, de texte supplémentaire ou de ponctuation superflue.
|
||||||
|
Ne fournis aucun autre texte ni explication dans ta réponse, uniquement la sortie demandée.
|
||||||
|
"""
|
||||||
|
|
||||||
|
url = "https://api.openai.com/v1/chat/completions"
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self.api_key}"
|
||||||
|
}
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": "gpt-4",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": prompt}
|
||||||
|
],
|
||||||
|
"temperature": 0.7
|
||||||
|
}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(url, headers=headers, json=data) as response:
|
||||||
|
response_json = await response.json()
|
||||||
|
return response_json.get('choices', [])[0].get('message', {}).get('content', transcription)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config_class(cls) -> Type[BaseProxyConfig]:
|
||||||
|
return Config
|
||||||
|
|
||||||
|
def save_config(self) -> None:
|
||||||
|
self.config.save()
|
||||||
|
|
||||||
|
async def update_config(self, new_config: dict) -> None:
|
||||||
|
self.config.update(new_config)
|
||||||
|
self.save_config()
|
||||||
|
self.log.debug("Configuration updated and saved")
|
||||||
@@ -9,11 +9,14 @@ from mautrix.client import Client
|
|||||||
from maubot.handlers import event
|
from maubot.handlers import event
|
||||||
from maubot import Plugin, MessageEvent
|
from maubot import Plugin, MessageEvent
|
||||||
from mautrix.errors import MatrixRequestError
|
from mautrix.errors import MatrixRequestError
|
||||||
from mautrix.types import EventType, MessageType, RelationType, TextMessageEventContent, Format,RelatesTo,InReplyTo
|
from mautrix.types import (
|
||||||
|
EventType, MessageType, RelationType,
|
||||||
|
TextMessageEventContent, Format, RelatesTo, InReplyTo
|
||||||
|
)
|
||||||
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
|
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
|
||||||
|
|
||||||
ALLOWED_EXTENSIONS = ['flac', 'm4a', 'mp3', 'mp4', 'mpeg', 'mpga', 'oga', 'ogg', 'wav', 'webm']
|
ALLOWED_EXTENSIONS = ['flac', 'm4a', 'mp3', 'mp4', 'mpeg', 'mpga', 'oga', 'ogg', 'wav', 'webm']
|
||||||
ALLOWED_MIME_TYPES = ['audio/flac','audio/mp4','video/mpeg','audio/ogg','audio/wav','video/webm']
|
ALLOWED_MIME_TYPES = ['audio/flac', 'audio/mp4', 'video/mpeg', 'audio/ogg', 'audio/wav', 'video/webm']
|
||||||
|
|
||||||
class Config(BaseProxyConfig):
|
class Config(BaseProxyConfig):
|
||||||
def do_update(self, helper: ConfigUpdateHelper) -> None:
|
def do_update(self, helper: ConfigUpdateHelper) -> None:
|
||||||
@@ -48,11 +51,15 @@ class WhisperPlugin(Plugin):
|
|||||||
|
|
||||||
if self.allowed_rooms and event.room_id not in self.allowed_rooms:
|
if self.allowed_rooms and event.room_id not in self.allowed_rooms:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if event.content.info.mimetype not in ALLOWED_MIME_TYPES:
|
# Extraction de la partie principale du MIME type (avant les éventuels paramètres)
|
||||||
|
mime_type = ""
|
||||||
|
if event.content.info and event.content.info.mimetype:
|
||||||
|
mime_type = event.content.info.mimetype.split(";")[0]
|
||||||
|
if mime_type not in ALLOWED_MIME_TYPES:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return event.content.msgtype == MessageType.AUDIO or event.content.msgtype == MessageType.FILE
|
return event.content.msgtype in (MessageType.AUDIO, MessageType.FILE)
|
||||||
|
|
||||||
@event.on(EventType.ROOM_MESSAGE)
|
@event.on(EventType.ROOM_MESSAGE)
|
||||||
async def on_message(self, event: MessageEvent) -> None:
|
async def on_message(self, event: MessageEvent) -> None:
|
||||||
@@ -64,7 +71,16 @@ class WhisperPlugin(Plugin):
|
|||||||
await self.client.set_typing(event.room_id, timeout=99999)
|
await self.client.set_typing(event.room_id, timeout=99999)
|
||||||
|
|
||||||
audio_bytes = await self.client.download_media(url=event.content.url)
|
audio_bytes = await self.client.download_media(url=event.content.url)
|
||||||
transcription = await self.transcribe_audio(audio_bytes)
|
if not audio_bytes:
|
||||||
|
await event.respond("Erreur lors du téléchargement du fichier audio.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Récupère le nom de fichier s'il est défini, sinon utilise une valeur par défaut
|
||||||
|
filename = getattr(event.content, "filename", "audio.mp3")
|
||||||
|
# Utilise le MIME type tel quel, ou une valeur par défaut
|
||||||
|
mime_type = event.content.info.mimetype if event.content.info and event.content.info.mimetype else "audio/mpeg"
|
||||||
|
|
||||||
|
transcription = await self.transcribe_audio(audio_bytes, filename, mime_type)
|
||||||
|
|
||||||
if self.search_reminders_and_events:
|
if self.search_reminders_and_events:
|
||||||
transcription = await self.study_transcribe(transcription)
|
transcription = await self.study_transcribe(transcription)
|
||||||
@@ -89,25 +105,25 @@ class WhisperPlugin(Plugin):
|
|||||||
self.log.exception(f"Something went wrong: {e}")
|
self.log.exception(f"Something went wrong: {e}")
|
||||||
await event.respond(f"Something went wrong: {e}")
|
await event.respond(f"Something went wrong: {e}")
|
||||||
|
|
||||||
|
async def transcribe_audio(self, audio_bytes: bytes, filename: str = "audio.mp3", mime_type: str = "audio/mpeg") -> str:
|
||||||
async def transcribe_audio(self, audio_bytes: bytes) -> str:
|
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {self.api_key}"
|
"Authorization": f"Bearer {self.api_key}"
|
||||||
}
|
}
|
||||||
data = aiohttp.FormData()
|
data = aiohttp.FormData()
|
||||||
data.add_field('file', audio_bytes, filename='audio.mp3', content_type='audio/mpeg')
|
data.add_field('file', audio_bytes, filename=filename, content_type=mime_type)
|
||||||
data.add_field('model', 'whisper-1')
|
data.add_field('model', 'whisper-1')
|
||||||
if self.prompt:
|
if self.prompt:
|
||||||
data.add_field('prompt', f"{self.prompt}")
|
data.add_field('prompt', self.prompt)
|
||||||
if self.language:
|
if self.language:
|
||||||
data.add_field('language', f"{self.language}")
|
data.add_field('language', self.language)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(self.whisper_endpoint, headers=headers, data=data) as response:
|
async with session.post(self.whisper_endpoint, headers=headers, data=data) as response:
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
self.log.error(f"Error response from API: {await response.text()}")
|
error_text = await response.text()
|
||||||
return f"Error: {await response.text()}"
|
self.log.error(f"Error response from API: {error_text}")
|
||||||
|
return f"Error: {error_text}"
|
||||||
response_json = await response.json()
|
response_json = await response.json()
|
||||||
return response_json.get("text", "Sorry, I can't transcribe the audio.")
|
return response_json.get("text", "Sorry, I can't transcribe the audio.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
Reference in New Issue
Block a user