|
@@ -0,0 +1,262 @@
|
|
|
+package com.fdkankan.fusion.xfyun;
|
|
|
+
|
|
|
+import com.alibaba.fastjson.JSON;
|
|
|
+import com.alibaba.fastjson.JSONArray;
|
|
|
+import com.alibaba.fastjson.JSONObject;
|
|
|
+import com.fdkankan.fusion.common.ResultCode;
|
|
|
+import com.fdkankan.fusion.entity.XfyunConfig;
|
|
|
+import com.fdkankan.fusion.exception.BusinessException;
|
|
|
+import com.fdkankan.fusion.mapper.IXfyunConfigMapper;
|
|
|
+import com.fdkankan.fusion.service.IXfyunConfigService;
|
|
|
+import com.google.gson.Gson;
|
|
|
+import lombok.AllArgsConstructor;
|
|
|
+import lombok.Data;
|
|
|
+import lombok.extern.slf4j.Slf4j;
|
|
|
+import okhttp3.*;
|
|
|
+import org.springframework.beans.factory.annotation.Autowired;
|
|
|
+import org.springframework.scheduling.annotation.Async;
|
|
|
+import org.springframework.stereotype.Component;
|
|
|
+import org.springframework.stereotype.Service;
|
|
|
+
|
|
|
+import javax.annotation.PostConstruct;
|
|
|
+import javax.crypto.Mac;
|
|
|
+import javax.crypto.spec.SecretKeySpec;
|
|
|
+import java.io.IOException;
|
|
|
+import java.net.URL;
|
|
|
+import java.nio.charset.StandardCharsets;
|
|
|
+import java.text.SimpleDateFormat;
|
|
|
+import java.util.*;
|
|
|
+
|
|
|
+@Slf4j
|
|
|
+public class XfyunWebSocketListener extends WebSocketListener {
|
|
|
+ private Boolean wsCloseFlag ;
|
|
|
+ private String totalAnswer ;
|
|
|
+ private String appid;
|
|
|
+ private String newQuestion;
|
|
|
+ private String imagePath;
|
|
|
+
|
|
|
+ public XfyunWebSocketListener( String appid, String newQuestion, String imagePath) {
|
|
|
+ this.wsCloseFlag = false;
|
|
|
+ this.totalAnswer = "";
|
|
|
+ this.appid = appid;
|
|
|
+ this.newQuestion = newQuestion;
|
|
|
+ this.imagePath = imagePath;
|
|
|
+ }
|
|
|
+
|
|
|
+ public String createXfyunResult(XfyunWebSocketListener xfyunWebSocketListener, XfyunConfig xfyunConfig) {
|
|
|
+ try {
|
|
|
+ String authUrl = getAuthUrl(xfyunConfig.getHostUrl(), xfyunConfig.getApiKey(), xfyunConfig.getApiSecret());
|
|
|
+ OkHttpClient client = new OkHttpClient.Builder().build();
|
|
|
+ String url = authUrl.toString().replace("http://", "ws://").replace("https://", "wss://");
|
|
|
+ Request request = new Request.Builder().url(url).build();
|
|
|
+ WebSocket webSocket = client.newWebSocket(request, xfyunWebSocketListener);
|
|
|
+ while (!wsCloseFlag){
|
|
|
+ Thread.sleep(200);
|
|
|
+ }
|
|
|
+
|
|
|
+ }catch (Exception e){
|
|
|
+ log.info("创建XfyunWebSocket失败",e);
|
|
|
+ }
|
|
|
+ return totalAnswer;
|
|
|
+ }
|
|
|
+ class MyThread extends Thread {
|
|
|
+ private WebSocket webSocket;
|
|
|
+
|
|
|
+
|
|
|
+ public MyThread(WebSocket webSocket) {
|
|
|
+ this.webSocket = webSocket;
|
|
|
+ }
|
|
|
+ public void run() {
|
|
|
+ try {
|
|
|
+ JSONObject requestJson = new JSONObject();
|
|
|
+
|
|
|
+ JSONObject header = new JSONObject(); // header参数
|
|
|
+ header.put("app_id", appid);
|
|
|
+ header.put("uid", UUID.randomUUID().toString().substring(0, 10));
|
|
|
+
|
|
|
+ JSONObject parameter = new JSONObject(); // parameter参数
|
|
|
+ JSONObject chat = new JSONObject();
|
|
|
+ chat.put("domain", "image");
|
|
|
+ chat.put("temperature", 0.5);
|
|
|
+ chat.put("max_tokens", 4096);
|
|
|
+ chat.put("auditing", "default");
|
|
|
+ parameter.put("chat", chat);
|
|
|
+
|
|
|
+ JSONObject payload = new JSONObject(); // payload参数
|
|
|
+ JSONObject message = new JSONObject();
|
|
|
+ JSONArray text = new JSONArray();
|
|
|
+
|
|
|
+ RoleContent roleContent = new RoleContent();
|
|
|
+ // 添加图片信息
|
|
|
+ roleContent.role = "user";
|
|
|
+ roleContent.content = Base64.getEncoder().encodeToString(ImageUtil.read(imagePath));
|
|
|
+ roleContent.content_type = "image";
|
|
|
+ text.add(JSON.toJSON(roleContent));
|
|
|
+ // 添加对图片提出要求的信息
|
|
|
+ RoleContent roleContent1 = new RoleContent();
|
|
|
+ roleContent1.role = "user";
|
|
|
+ roleContent1.content = newQuestion;
|
|
|
+ roleContent1.content_type = "text";
|
|
|
+ text.add(JSON.toJSON(roleContent1));
|
|
|
+
|
|
|
+
|
|
|
+ message.put("text", text);
|
|
|
+ payload.put("message", message);
|
|
|
+
|
|
|
+
|
|
|
+ requestJson.put("header", header);
|
|
|
+ requestJson.put("parameter", parameter);
|
|
|
+ requestJson.put("payload", payload);
|
|
|
+ // System.err.println(requestJson); // 可以打印看每次的传参明细
|
|
|
+ webSocket.send(requestJson.toString());
|
|
|
+ // 等待服务端返回完毕后关闭
|
|
|
+ while (true) {
|
|
|
+ // System.err.println(wsCloseFlag + "---");
|
|
|
+ Thread.sleep(200);
|
|
|
+ if (wsCloseFlag) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ webSocket.close(1000, "");
|
|
|
+ } catch (Exception e) {
|
|
|
+ e.printStackTrace();
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ }
|
|
|
+ @Override
|
|
|
+ public void onOpen(WebSocket webSocket, Response response) {
|
|
|
+ super.onOpen(webSocket, response);
|
|
|
+ MyThread myThread = new MyThread(webSocket);
|
|
|
+ myThread.start();
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void onMessage(WebSocket webSocket, String text) {
|
|
|
+ JsonParse myJsonParse = new Gson().fromJson(text, JsonParse.class);
|
|
|
+ if (myJsonParse.header.code != 0) {
|
|
|
+ log.info("发生错误,错误码为:" + myJsonParse.header.code);
|
|
|
+ log.info("本次请求的sid为:" + myJsonParse.header.sid);
|
|
|
+ webSocket.close(1000, "");
|
|
|
+ }
|
|
|
+ List<Text> textList = myJsonParse.payload.choices.text;
|
|
|
+ for (Text temp : textList) {
|
|
|
+ totalAnswer+=temp.content;
|
|
|
+ }
|
|
|
+ if (myJsonParse.header.status == 2) {
|
|
|
+ wsCloseFlag = true;
|
|
|
+ log.info("讯飞回复:"+totalAnswer.toString());
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void onFailure(WebSocket webSocket, Throwable t, Response response) {
|
|
|
+ super.onFailure(webSocket, t, response);
|
|
|
+ try {
|
|
|
+ if (null != response) {
|
|
|
+ int code = response.code();
|
|
|
+ System.out.println("onFailure code:" + code);
|
|
|
+ System.out.println("onFailure body:" + response.body().string());
|
|
|
+ if (101 != code) {
|
|
|
+ System.out.println("connection failed");
|
|
|
+ System.exit(0);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } catch (IOException e) {
|
|
|
+ // TODO Auto-generated catch block
|
|
|
+ e.printStackTrace();
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ // 鉴权方法
|
|
|
+ public static String getAuthUrl(String hostUrl, String apiKey, String apiSecret) throws Exception {
|
|
|
+ URL url = new URL(hostUrl);
|
|
|
+ // 时间
|
|
|
+ SimpleDateFormat format = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US);
|
|
|
+ format.setTimeZone(TimeZone.getTimeZone("GMT"));
|
|
|
+ String date = format.format(new Date());
|
|
|
+ // 拼接
|
|
|
+ String preStr = "host: " + url.getHost() + "\n" +
|
|
|
+ "date: " + date + "\n" +
|
|
|
+ "GET " + url.getPath() + " HTTP/1.1";
|
|
|
+ // System.err.println(preStr);
|
|
|
+ // SHA256加密
|
|
|
+ Mac mac = Mac.getInstance("hmacsha256");
|
|
|
+ SecretKeySpec spec = new SecretKeySpec(apiSecret.getBytes(StandardCharsets.UTF_8), "hmacsha256");
|
|
|
+ mac.init(spec);
|
|
|
+
|
|
|
+ byte[] hexDigits = mac.doFinal(preStr.getBytes(StandardCharsets.UTF_8));
|
|
|
+ // Base64加密
|
|
|
+ String sha = Base64.getEncoder().encodeToString(hexDigits);
|
|
|
+ // System.err.println(sha);
|
|
|
+ // 拼接
|
|
|
+ String authorization = String.format("api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, "hmac-sha256", "host date request-line", sha);
|
|
|
+ // 拼接地址
|
|
|
+ HttpUrl httpUrl = Objects.requireNonNull(HttpUrl.parse("https://" + url.getHost() + url.getPath())).newBuilder().//
|
|
|
+ addQueryParameter("authorization", Base64.getEncoder().encodeToString(authorization.getBytes(StandardCharsets.UTF_8))).//
|
|
|
+ addQueryParameter("date", date).//
|
|
|
+ addQueryParameter("host", url.getHost()).//
|
|
|
+ build();
|
|
|
+
|
|
|
+ // System.err.println(httpUrl.toString());
|
|
|
+ return httpUrl.toString();
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ //返回的json结果拆解
|
|
|
+ class JsonParse {
|
|
|
+ Header header;
|
|
|
+ Payload payload;
|
|
|
+ }
|
|
|
+
|
|
|
+ class Header {
|
|
|
+ int code;
|
|
|
+ int status;
|
|
|
+ String sid;
|
|
|
+ }
|
|
|
+
|
|
|
+ class Payload {
|
|
|
+ Choices choices;
|
|
|
+ }
|
|
|
+
|
|
|
+ class Choices {
|
|
|
+ List<Text> text;
|
|
|
+ }
|
|
|
+
|
|
|
+ class Text {
|
|
|
+ String role;
|
|
|
+ String content;
|
|
|
+ }
|
|
|
+ class RoleContent{
|
|
|
+ String role;
|
|
|
+ String content;
|
|
|
+
|
|
|
+ String content_type;
|
|
|
+
|
|
|
+ public String getContent_type() {
|
|
|
+ return content_type;
|
|
|
+ }
|
|
|
+
|
|
|
+ public void setContent_type(String content_type) {
|
|
|
+ this.content_type = content_type;
|
|
|
+ }
|
|
|
+
|
|
|
+ public String getRole() {
|
|
|
+ return role;
|
|
|
+ }
|
|
|
+
|
|
|
+ public void setRole(String role) {
|
|
|
+ this.role = role;
|
|
|
+ }
|
|
|
+
|
|
|
+ public String getContent() {
|
|
|
+ return content;
|
|
|
+ }
|
|
|
+
|
|
|
+ public void setContent(String content) {
|
|
|
+ this.content = content;
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|