스프링 - LLM Response를 Redis Streams와 SSE로 스트리밍 해보자
작성일자 : 2024년 11월 25일
개요
OpenAI, Claude와 같은 LLM API 제공자들은 응답 형태로 메세지 스트림을 제공하기도 합니다.
서비스의 프론트엔드와 LLM API 사이에서, 스프링부트를 통해 요청을 핸들링하고 싶은 경우에 Redis Streams와 SSE를 활용하여 프론트엔드로 메세지를 스트리밍할 수 있습니다.
이번 포스팅에서는 LLM API의 응답을 Redis Streams와 SSE로 스트리밍하는 방법을 알아보겠습니다.
왜 Redis Streams를 사용하게 되었는가?
상황
프론트엔드로의 메세지 스트림은 일반적인 HTTP 요청-응답 방식, 또는 SSE(Server-Sent Events), WebSocket을 통해 구현할 수 있습니다.
이 중에서 SSE(Server-Sent Events) 방식으로 로직을 구성하는 경우에는 아래와 같은 프론트엔드 요청 흐름이 만들어집니다.
프론트엔드에서의 요청 흐름
- 프론트엔드에서 백엔드로 사용자의 메세지를 담아 요청을 보냅니다. (HTTP POST)
- 백엔드에서는
EventSource
로 SSE Connection을 맺을 프론트엔드를 위하여, 스트림될 메세지의 ID를 생성해 응답합니다. - 프론트엔드에서는 스트림될 메세지의 ID와
EventSource
로 백엔드에게 SSE Connection을 맺고 메세지를 수신합니다.
const url = `http://localhost:8080/messages/${streamMessageId}`;
const eventSource = new EventSource(url);
eventSource.onmessage = event => {
if (event.data === "[DONE]") {
eventSource.close();
return;
}
onMessage(event.data);
};
발생 가능한 문제점
Step 1
-> Step 2
와 같이 먼저 사용자의 메세지를 POST 요청을 통해 백엔드로 전달하고, Response로 돌아온 streamMessageId
를 이용하여 다시 SSE Connection을 맺는 방식은 Step 1
과 Step 2
사이에 요청 딜레이가 발생하여 메세지의 일부가 유실될 수 있다는 단점이 존재합니다.
때문에 이러한 문제를 해결하기 위해 Redis Streams를 사용하게 되었으며, Redis Streams는 Redis Pub/Sub과 다르게 메세지를 저장하고, 늦게 접속한 클라이언트에게도 메세지를 전달할 수 있는 장점이 있습니다.
전체적인 흐름
- 프론트엔드에서 백엔드로 사용자의 메세지를 담아 요청을 보냅니다.
- 백엔드에서는 스트림될 메세지의 ID를 생성해 응답하며, LLM API에 사용자의 메세지로 요청을 보냅니다.
- 백엔드에서 LLM API로 부터 전달받는 메세지 스트림을 Redis Streams에 저장합니다.
- 프론트엔드에서는
streamMessageId
를 이용하여 SSE Connection을 맺습니다. - 백엔드에서는
streamMessageId
키로 Redis Streams에 저장된 메세지를 읽어 프론트엔드로 전달합니다.
구현
해당 코드는 Anthropic의 Claude API를 사용하는 것을 전제로 작성되었습니다.
Redis Docker Compose
version: '3.8'
services:
redis:
image: redis:latest
container_name: redis
ports:
- "6379:6379"
volumes:
- redis_data:/data
command: redis-server --appendonly yes --requirepass password
restart: unless-stopped
networks:
- redis_network
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 10s
timeout: 5s
retries: 3
redis-commander:
image: rediscommander/redis-commander:latest
container_name: redis-commander
environment:
- REDIS_HOSTS=local:redis:6379:0:password
ports:
- "8081:8081"
depends_on:
- redis
restart: unless-stopped
networks:
- redis_network
volumes:
redis_data:
driver: local
networks:
redis_network:
driver: bridge
Configuration
application.yml
llm:
claude:
api:
key: <your-anthropic-api-key>
endpoint: https://api.anthropic.com/v1/messages
spring:
data:
redis:
host: localhost
port: 6379
password: password
RedisConfig
@Configuration
public class RedisConfig {
@Bean
public ReactiveRedisTemplate<String, String> reactiveRedisTemplate(
ReactiveRedisConnectionFactory connectionFactory) {
RedisSerializationContext<String, String> serializationContext =
RedisSerializationContext
.<String, String>newSerializationContext(new StringRedisSerializer())
.key(new StringRedisSerializer())
.value(new StringRedisSerializer())
.hashKey(new StringRedisSerializer())
.hashValue(new StringRedisSerializer())
.build();
return new ReactiveRedisTemplate<>(connectionFactory, serializationContext);
}
@Bean
public StreamReceiver<String, MapRecord<String, String, String>> streamReceiver(
ReactiveRedisConnectionFactory connectionFactory) {
return StreamReceiver.create(connectionFactory);
}
}
WebClientConfig
@Configuration
public class WebClientConfig {
@Value("${llm.claude.api.key}")
private String claudeApiKey;
@Value("${llm.claude.api.endpoint}")
private String claudeApiEndpoint;
@Bean
public WebClient claudeWebClient() {
return WebClient.builder()
.baseUrl(claudeApiEndpoint)
.defaultHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.defaultHeader("anthropic-version", "2023-06-01")
.defaultHeader("x-api-key", claudeApiKey)
.build();
}
}
Controller 레이어 및 DTO
LlmController
@RestController
@RequiredArgsConstructor
public class LlmController {
private final LlmService llmService;
@PostMapping("/messages")
public Mono<ResponseEntity<SendMessageResponseDto>> sendMessage(
@RequestBody SendMessageRequestDto requestDto) {
SendMessageResponseDto responseDto = llmService.sendMessage(requestDto);
return Mono.just(ResponseEntity.ok(responseDto));
}
@GetMapping(value = "/messages/{streamMessageId}", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<ServerSentEvent<String>> streamMessage(
@PathVariable("streamMessageId") String streamMessageId) {
return llmService.streamMessage(streamMessageId)
.map(data -> ServerSentEvent.builder(data).build());
}
}
SendMessageRequestDto
@Getter
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public class SendMessageRequestDto {
private String message;
}
SendMessageResponseDto
@Getter
public class SendMessageResponseDto {
private String streamMessageId;
@Builder
public SendMessageResponseDto(String streamMessageId) {
this.streamMessageId = streamMessageId;
}
}
Service 레이어
sendMessage
public SendMessageResponseDto sendMessage(SendMessageRequestDto requestDto) {
// 스트림될 메시지의 고유 ID 생성
String streamMessageId = UUID.randomUUID().toString();
// Claude API 요청 생성
ClaudeRequestBody claudeRequestBody = createClaudeRequestBody(requestDto);
// Claude API 요청 보내고 응답되는 메세지 스트림 처리
claudeWebClient.post()
.bodyValue(claudeRequestBody)
.retrieve()
.bodyToFlux(String.class)
.map(this::parseJson)
.filter(json -> json != null)
.takeUntil(node -> "message_stop".equals(node.path("type").asText()))
.filter(this::isContentDelta)
.map(this::extractDeltaContent)
.flatMap(content -> reactiveRedisTemplate.opsForStream()
.add("message-stream:" + streamMessageId,
Collections.singletonMap("content", content)))
.doOnComplete(() -> publishEndMessage(streamMessageId))
.doOnError(error -> {
log.error("Error sending message", error);
clearStream(streamMessageId);
})
.subscribe();
// streamMessageId 반환
return SendMessageResponseDto.builder()
.streamMessageId(streamMessageId)
.build();
}
map(this::parseJson)
: Claude API 응답을 JSON으로 파싱filter(json -> json != null)
: JSON이 null이 아닌 경우만 필터링takeUntil(node -> "message_stop".equals(node.path("type").asText()))
:type
이message_stop
인 경우까지만 스트림 처리(API 제공자에 따라 상이함)filter(this::isContentDelta)
:content
필드가delta
인 경우만 필터링(Claude의 경우 메세지 응답 본문이type
이content_block_delta
인 경우에 들어있음)map(this::extractDeltaContent)
: 메세지 응답 본문 추출flatMap(...)
: Redis Stream에 메세지 추가doOnComplete(() -> publishEndMessage(streamMessageId))
: 메세지 전송 완료 시, Redis Streams에[DONE]
메세지 발행
streamMessage
private final StreamReceiver<String, MapRecord<String, String, String>> streamReceiver;
public Flux<String> streamMessage(String streamMessageId) {
return streamReceiver.receive(StreamOffset.fromStart("message-stream:" + streamMessageId))
.map(record -> record.getValue().get("content"))
.takeUntil(content -> "[DONE]".equals(content))
.timeout(Duration.ofMinutes(1))
.doFinally(signalType -> clearStream(streamMessageId));
}
private void clearStream(String streamMessageId) {
reactiveRedisTemplate
.delete("message-stream:" + streamMessageId)
.subscribe();
}
streamReceiver.receive(...)
:streamReceiver
를 통해 Redis Stream에서 메세지 수신StreamOffset.fromStart("message-stream:" + streamMessageId)
: 스트림의 시작부터 메세지 수신doFinally(signalType -> clearStream(streamMessageId))
: 스트림 종료 시 Redis에 저장된 Stream 삭제
Service 전체 코드
@Slf4j
@Service
@RequiredArgsConstructor
public class LlmService {
private final WebClient claudeWebClient;
private final ReactiveRedisTemplate<String, String> reactiveRedisTemplate;
private final StreamReceiver<String, MapRecord<String, String, String>> streamReceiver;
public SendMessageResponseDto sendMessage(SendMessageRequestDto requestDto) {
String streamMessageId = UUID.randomUUID().toString();
ClaudeRequestBody claudeRequestBody = createClaudeRequestBody(requestDto);
claudeWebClient.post()
.bodyValue(claudeRequestBody)
.retrieve()
.bodyToFlux(String.class)
.map(this::parseJson)
.filter(json -> json != null)
.takeUntil(node -> "message_stop".equals(node.path("type").asText()))
.filter(this::isContentDelta)
.map(this::extractDeltaContent)
.flatMap(content -> reactiveRedisTemplate.opsForStream()
.add("message-stream:" + streamMessageId,
Collections.singletonMap("content", content)))
.doOnComplete(() -> publishEndMessage(streamMessageId))
.doOnError(error -> {
log.error("Error sending message", error);
clearStream(streamMessageId);
})
.subscribe();
return SendMessageResponseDto.builder()
.streamMessageId(streamMessageId)
.build();
}
private ClaudeRequestBody createClaudeRequestBody(SendMessageRequestDto requestDto) {
ClaudeRequestBody.Message message = ClaudeRequestBody.Message.builder()
.role("user")
.content(requestDto.getMessage())
.build();
return ClaudeRequestBody.builder()
.model("claude-3-5-sonnet-20241022")
.stream(true)
.maxTokens(1024)
.messages(Collections.singletonList(message))
.build();
}
public Flux<String> streamMessage(String streamMessageId) {
return streamReceiver.receive(StreamOffset.fromStart("message-stream:" + streamMessageId))
.map(record -> record.getValue().get("content"))
.takeUntil(content -> "[DONE]".equals(content))
.timeout(Duration.ofMinutes(1))
.doFinally(signalType -> clearStream(streamMessageId));
}
private JsonNode parseJson(String line) {
ObjectMapper objectMapper = new ObjectMapper();
try {
return objectMapper.readTree(line);
} catch (JsonMappingException e) {
return null;
} catch (JsonProcessingException e) {
return null;
}
}
private boolean isContentDelta(JsonNode node) {
return "content_block_delta".equals(node.path("type").asText());
}
private String extractDeltaContent(JsonNode node) {
return node.path("delta")
.path("text")
.asText("");
}
private void publishEndMessage(String streamMessageId) {
reactiveRedisTemplate.opsForStream()
.add("message-stream:" + streamMessageId,
Collections.singletonMap("content", "[DONE]"))
.subscribe();
}
private void clearStream(String streamMessageId) {
reactiveRedisTemplate
.delete("message-stream:" + streamMessageId)
.subscribe();
}
}
ClaudeRequestBody
@Getter
public class ClaudeRequestBody {
private String model;
private boolean stream;
@JsonProperty("max_tokens")
private int maxTokens;
private List<Message> messages;
@Builder
public ClaudeRequestBody(String model, boolean stream, int maxTokens, List<Message> messages) {
this.model = model;
this.stream = stream;
this.maxTokens = maxTokens;
this.messages = messages;
}
@Getter
public static class Message {
private String role;
private String content;
@Builder
public Message(String role, String content) {
this.role = role;
this.content = content;
}
}
}
테스트
주의!
Postman을 통해 테스트를 진행할 경우, 모든 메세지 스트림을 수신해도 메세지 수신 요청을 보낸 시점부터의 메세지만 출력되거나, 마지막 메세지만 출력될 수 있습니다.
따라서, 테스트를 진행할 때는 curl
명령어를 사용하는 것을 권장합니다.
sendMessage
요청
curl -X POST http://localhost:8080/messages \
-H "Content-Type: application/json" \
-d '{
"message": "hello!"
}'
응답
{
"streamMessageId":"f9ef8fbe-fd91-4218-855f-e65ff28df7a1"
}
streamMessage
요청
curl --location 'http://localhost:8080/messages/f9ef8fbe-fd91-4218-855f-e65ff28df7a1'
응답
data:Hi
data: there! How can I help you today?
data:[DONE]