例子中是接了一个硅基流动的deepseek R1 distill Qwen 7B的模型然后把流式输出的chunk推到前端。
采用SSE协议,具体代码如下:
环境配置
- Java 11
- SpringBoot 2.7.x
稳定性:
在内网稳定运行了大半个月,公网在阿里云ECS上稳定跑了一个多月
比较坑的点
controller 没用异步导致流式失效
package cn.nolaurene.cms.controller.tc;
import cn.nolaurene.cms.common.dto.tc.MessageRequest;
import cn.nolaurene.cms.common.dto.tc.StreamChatRequest;
import cn.nolaurene.cms.common.dto.tc.WhaleMessage;
import com.alibaba.fastjson2.JSONObject;
import lombok.extern.slf4j.Slf4j;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.springframework.scheduling.annotation.Async;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.List;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
/**
* @author nolaurence
* @date 2025/3/21.
*/
@Slf4j
@RestController
@RequestMapping("/v1")
public class StreamOutputController {
private static final String MODEL = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B";
private static final String CHAT_HOST = "https://api.siliconflow.cn";
private static final String CHAT_ENDPOINT = "/v1/chat/completions";
@Value("${silicon.flow.ak}")
private String apiKey;
private final ThreadPoolExecutor executor = new ThreadPoolExecutor(
5, // 核心线程数(固定大小)
5, // 最大线程数(与核心线程数相同)
0L, // 保持空闲线程的时间
TimeUnit.SECONDS,
new LinkedBlockingQueue<Runnable>() // 任务队列
);
@CrossOrigin
@PostMapping("/chat")
public SseEmitter chat(@RequestBody StreamChatRequest request, HttpServletResponse httpServletResponse) {
httpServletResponse.setContentType("text/event-stream");
SseEmitter sseEmitter = new SseEmitter(60000L);
executor.submit(() -> {
try {
question(request.getMessages(), sseEmitter, httpServletResponse);
} catch (Exception e) {
sseEmitter.completeWithError(e);
}
});
return sseEmitter;
}
/**
* 调用whale接口
*
* @param messages 用户询问的问题
* @param sseEmitter 用于发送消息的sseEmitter
*/
public void question(List<MessageRequest> messages, SseEmitter sseEmitter, HttpServletResponse httpServletResponse) {
try {
long startTimeMs = System.currentTimeMillis();
// 组装请求参数
JSONObject body = new JSONObject();
body.put("model", MODEL);
body.put("stream", true);
List<JSONObject> whaleMessageList = messages.stream().map(message ->
new WhaleMessage(message.getRole(), message.getContent()).getJsonObject()).collect(Collectors.toList());
// JSONObject message = WhaleMessage.ofUser(question);
body.put("messages", whaleMessageList);
// openapi协议请求
CloseableHttpClient httpClient = HttpClients.createDefault();
HttpPost httpPost = new HttpPost(CHAT_HOST + CHAT_ENDPOINT);
httpPost.addHeader("Content-Type", "application/json");
httpPost.addHeader("Authorization", "Bearer " + apiKey);
StringEntity stringEntity = new StringEntity(body.toString(), "UTF-8");
httpPost.setEntity(stringEntity);
// 发送请求
try (CloseableHttpResponse response = httpClient.execute(httpPost);
InputStream inputStream = response.getEntity().getContent();
BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream), 512)) {
log.info("建立链接时间:{}ms", System.currentTimeMillis() - startTimeMs);
String line;
long count = 0;
while ((line = reader.readLine()) != null) {
if (line.startsWith("data: ")) {
String data = line.substring(6).trim();
if ("[DONE]".equals(data)) {
break;
}
// 接收到chunk后使用sseEmitter发送出去
sseEmitter.send(SseEmitter.event()
.data(data)
.id(String.valueOf(System.currentTimeMillis())));
httpServletResponse.getOutputStream().flush();
if (count == 0) {
log.info("首次发送耗时:{}ms", System.currentTimeMillis() - startTimeMs);
}
count++;
}
}
sseEmitter.complete();
log.info("完成时间:{}ms", System.currentTimeMillis() - startTimeMs);
} catch (IOException e) {
sseEmitter.completeWithError(e);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}