diff --git a/cmakepc.cmake b/cmakepc.cmake index 25ed690..34901a3 100644 --- a/cmakepc.cmake +++ b/cmakepc.cmake @@ -59,5 +59,8 @@ zadd_executable_simple(TARGET alsaplayer_main.out SRC dep/zlinuxcomponents/alsaplayer/alsaplayer_main.cpp) zadd_executable_simple(TARGET audio_recorder_main.out SRC dep/zlinuxcomponents/audio/audio_recorder_main.cpp) -zadd_executable_simple(TARGET test_asr_main.out SRC src/test_asr_main.cpp) -zadd_executable_simple(TARGET aiui_main.out SRC dep/zlinuxcomponents/aiui_ws/aiui_main.c) \ No newline at end of file +zadd_executable_simple( + TARGET test_asr_main.out SRC src/service/voiceprocess/asr_service.cpp + src/test_asr_main.cpp) +zadd_executable_simple(TARGET aiui_main.out SRC + dep/zlinuxcomponents/aiui_ws/aiui_main.c) diff --git a/dep/iflytopcpp b/dep/iflytopcpp index e3fb64e..59531ea 160000 --- a/dep/iflytopcpp +++ b/dep/iflytopcpp @@ -1 +1 @@ -Subproject commit e3fb64e8cb81f56ca1bff28657946a5453cf325d +Subproject commit 59531ea26be163327f7cbc6bc31b9845521afff5 diff --git a/src/service/voiceprocess/asr_service.cpp b/src/service/voiceprocess/asr_service.cpp index baf3e5a..eeca9d4 100644 --- a/src/service/voiceprocess/asr_service.cpp +++ b/src/service/voiceprocess/asr_service.cpp @@ -1,12 +1,30 @@ #include "asr_service.hpp" +#include "iflytopcpp/core/components/audio/wavheader.hpp" #include "zlinuxcomponents/aiui_ws/aiui.h" using namespace iflytop; using namespace core; using namespace std; static AiuiService *thisClass = nullptr; +static string getTTSFileName() { + static int i = 0; + i++; + if (i > 10) i = 0; + return fmt::format("/tmp/aiui_service_tts{}.wav", i); +} +/** + * aiui的识别结果返回大致如下,先返回一个nlp识别结果的包,然后再返回多个包含tts音频数据的json包 + * nlp->tts0->tts1->...->ttsn + * 1. 通过json["data"]["sub"]来判断是什么包,sub的值有:nlp,tts + * 2. 通过json["data"]["is_last"]判断是否是当前种类的最后一包. + * 3. 通过json["data"]["is_finish"]判断是否是本次传输的最后一包 + * + */ +extern "C" { +extern size_t aiui_base64_decode(char *source, unsigned char *target, size_t targetlen); +} void aiui_message_cb(const char *data, int len) { if (thisClass) { thisClass->call_aiui_message_cb(data, len); @@ -28,11 +46,78 @@ void AiuiService::aiuiDestroy() { aiui_destroy(); thisClass = nullptr; } +void AiuiService::parseTTSContent(json &rxjson) { + if (!oneAsrResultTransmition) { + logger->error("oneAsrResultTransmition is null"); + return; + } + /** + * @brief 解码base64数据,并保存起来 + */ + string base64Data = rxjson["data"]["content"].get(); + size_t size = base64Data.size() * 3 / 4 + 50; + shared_ptr binary = make_shared(size); + size_t realsize = aiui_base64_decode((char *)base64Data.c_str(), binary->data(), size); + binary->resize(realsize); + oneAsrResultTransmition->ttsRawData.push_back(binary); +} + +void AiuiService::endOneAsrResultTransmition() { + /** + * @brief 如果是最后一包,则将接收到的tts保存成文件,然后同nlp结果打包在一起,然后上报。 + */ + string ttsFileName = getTTSFileName(); + FILE *fp = fopen(ttsFileName.c_str(), "wb"); + size_t ttsframetotalsize = 0; + for (auto &binary : oneAsrResultTransmition->ttsRawData) { + ttsframetotalsize += binary->size(); + } + WAVHeader wavHeader(16000, 16, 1, ttsframetotalsize / 2); + if (fp) { + fwrite(wavHeader.data(), 1, wavHeader.size(), fp); + for (auto &binary : oneAsrResultTransmition->ttsRawData) { + fwrite(binary->data(), 1, binary->size(), fp); + ttsframetotalsize += binary->size(); + } + fclose(fp); + } + oneAsrResultTransmition->asrResult["data"]["intent"]["answer"]["ttsurl"] = ttsFileName; + /** + * @brief 上报识别结果 + */ + onAsrResult(oneAsrResultTransmition->asrResult); + oneAsrResultTransmition.reset(); +} void AiuiService::call_aiui_message_cb(const char *data, int len) { try { - json j = json::parse(data); - onAsrResult(j); + if (!oneAsrResultTransmition) { + oneAsrResultTransmition = make_unique(); + logger->info("start rx aiui data"); + } + json rxjson = json::parse(data); + if (rxjson["action"].get() != "result") { + return; + } + + /** + * @brief 处理接收到的数据 + */ + if (rxjson["data"]["sub"].get() == "nlp") { + // 解析nlpcontent + oneAsrResultTransmition->asrResult = rxjson; + } else if (rxjson["data"]["sub"].get() == "tts") { + // 解析ttscontent + parseTTSContent(rxjson); + } + /** + * @brief 结束本次接收 + */ + if (rxjson["data"]["is_finish"].get()) { + logger->info("end rx aiui data"); + // will call onAsrResult in this function + endOneAsrResultTransmition(); + } } catch (const std::exception &e) { logger->error("AiuiService::call_aiui_message_cb error: {}", e.what()); } diff --git a/src/service/voiceprocess/asr_service.hpp b/src/service/voiceprocess/asr_service.hpp index 2092b58..4e0e206 100644 --- a/src/service/voiceprocess/asr_service.hpp +++ b/src/service/voiceprocess/asr_service.hpp @@ -13,6 +13,7 @@ #include #include +#include "iflytopcpp/core/basic/ds/binary.hpp" #include "iflytopcpp/core/basic/nlohmann/json.hpp" #include "iflytopcpp/core/basic/nod/nod.hpp" #include "iflytopcpp/core/spdlogfactory/logger.hpp" @@ -35,9 +36,19 @@ using namespace std; using namespace core; using namespace nlohmann; -class AiuiService : public enable_shared_from_this { +class OneAsrResultTransmition { + public: + json asrResult; + list ttsResult; + list> ttsRawData; +}; + +class AiuiService : public enable_shared_from_this { ENABLE_LOGGER(AiuiService); + bool transFinished = false; + unique_ptr oneAsrResultTransmition; + public: nod::signal onAsrResult; nod::signal onError; @@ -51,6 +62,10 @@ class AiuiService : public enable_shared_from_this { void call_aiui_message_cb(const char *data, int len); void call_aiui_error_cb(int code, const char *str); + + private: + void parseTTSContent(json &rxjson); + void endOneAsrResultTransmition(); }; class AsrService : public enable_shared_from_this { diff --git a/src/test_asr_main.cpp b/src/test_asr_main.cpp index 2ea5012..8d76012 100644 --- a/src/test_asr_main.cpp +++ b/src/test_asr_main.cpp @@ -14,6 +14,7 @@ #include "service/light_control_service.hpp" #include "service/main_control_service.hpp" #include "service/report_service.hpp" +#include "service/voiceprocess/asr_service.hpp" // #include "zlinuxcomponents/alsaplayer/AudioPlayerAlsaImpl.hpp" // #include "zlinuxcomponents/audio/audio_recoder.hpp" @@ -46,12 +47,22 @@ int Main::main(int argc, char *argv[]) { const char *appid = "5938b7c7"; // 应用ID,在AIUI开放平台创建并设置 const char *key = "19c1f7becc78eedc7826b485aabe30de"; // 接口密钥,在AIUI开放平台查看 - const char *param = "{\"result_level\":\"plain\",\"auth_id\":\"ac30105366ea460f9ff08ddac0c4f71e\",\"data_" "type\":\"text\"," "\"scene\":\"main_box\",\"sample_rate\":\"16000\", " "\"context\":\"{\\\"sdk_support\\\":[\\\"nlp\\\",\\\"tts\\\"]}\"}"; + shared_ptr aiuiService(new AiuiService()); + logger->info("test_asr_main.cpp"); + aiuiService->aiuiInit(appid, key, param); + aiuiService->onAsrResult.connect([this](json j) { + logger->info("onAsrResult: {}", j.dump(2)); + }); + aiuiService->aiuiWrite("今天天气怎么样", strlen("今天天气怎么样")); +// aiuiService->aiuiFinished(); +// aiuiService->aiuiDestroy(); +#if 0 + aiui_init(appid, key, param, onmessage, onerror); @@ -62,6 +73,7 @@ int Main::main(int argc, char *argv[]) { aiui_finished(); aiui_destroy(); +#endif while (true) sleep(1000); }