스프링 - LLM Response를 Redis Streams와 SSE로 스트리밍 해보자

2024. 11. 25. 19:34· Spring
목차
  1. 개요
  2. 왜 Redis Streams를 사용하게 되었는가?
  3. 상황
  4. 프론트엔드에서의 요청 흐름
  5. 발생 가능한 문제점
  6. 전체적인 흐름
  7. 구현
  8. Redis Docker Compose
  9. Configuration
  10. Controller 레이어 및 DTO
  11. Service 레이어
  12. Service 전체 코드
  13. 테스트
  14. 주의!
  15. sendMessage
  16. streamMessage

스프링 - LLM Response를 Redis Streams와 SSE로 스트리밍 해보자

 

작성일자 : 2024년 11월 25일


 

Redis Streams illustrated by Dalle3

 

 

개요

 

OpenAI, Claude와 같은 LLM API 제공자들은 응답 형태로 메세지 스트림을 제공하기도 합니다.

 

서비스의 프론트엔드와 LLM API 사이에서, 스프링부트를 통해 요청을 핸들링하고 싶은 경우에 Redis Streams와 SSE를 활용하여 프론트엔드로 메세지를 스트리밍할 수 있습니다.

 

이번 포스팅에서는 LLM API의 응답을 Redis Streams와 SSE로 스트리밍하는 방법을 알아보겠습니다.

 


 

 

왜 Redis Streams를 사용하게 되었는가?

 

상황

 

프론트엔드로의 메세지 스트림은 일반적인 HTTP 요청-응답 방식, 또는 SSE(Server-Sent Events), WebSocket을 통해 구현할 수 있습니다.

 

이 중에서 SSE(Server-Sent Events) 방식으로 로직을 구성하는 경우에는 아래와 같은 프론트엔드 요청 흐름이 만들어집니다.

 


 

프론트엔드에서의 요청 흐름

 

  1. 프론트엔드에서 백엔드로 사용자의 메세지를 담아 요청을 보냅니다. (HTTP POST)
  2. 백엔드에서는 EventSource로 SSE Connection을 맺을 프론트엔드를 위하여, 스트림될 메세지의 ID를 생성해 응답합니다.
  3. 프론트엔드에서는 스트림될 메세지의 ID와 EventSource로 백엔드에게 SSE Connection을 맺고 메세지를 수신합니다.

 

const url = `http://localhost:8080/messages/${streamMessageId}`;
const eventSource = new EventSource(url);

eventSource.onmessage = event => {
if (event.data === "[DONE]") {
  eventSource.close();
  return;
}

onMessage(event.data);
};

 


 

발생 가능한 문제점

 

Step 1 -> Step 2와 같이 먼저 사용자의 메세지를 POST 요청을 통해 백엔드로 전달하고, Response로 돌아온 streamMessageId를 이용하여 다시 SSE Connection을 맺는 방식은 Step 1과 Step 2 사이에 요청 딜레이가 발생하여 메세지의 일부가 유실될 수 있다는 단점이 존재합니다.


때문에 이러한 문제를 해결하기 위해 Redis Streams를 사용하게 되었으며, Redis Streams는 Redis Pub/Sub과 다르게 메세지를 저장하고, 늦게 접속한 클라이언트에게도 메세지를 전달할 수 있는 장점이 있습니다.

 


 

 

전체적인 흐름

 

  1. 프론트엔드에서 백엔드로 사용자의 메세지를 담아 요청을 보냅니다.
  2. 백엔드에서는 스트림될 메세지의 ID를 생성해 응답하며, LLM API에 사용자의 메세지로 요청을 보냅니다.
  3. 백엔드에서 LLM API로 부터 전달받는 메세지 스트림을 Redis Streams에 저장합니다.
  4. 프론트엔드에서는 streamMessageId를 이용하여 SSE Connection을 맺습니다.
  5. 백엔드에서는 streamMessageId 키로 Redis Streams에 저장된 메세지를 읽어 프론트엔드로 전달합니다.

 


 

 

구현

 

해당 코드는 Anthropic의 Claude API를 사용하는 것을 전제로 작성되었습니다.

 


 

Redis Docker Compose

version: '3.8'

services:
  redis:
    image: redis:latest
    container_name: redis
    ports:
      - "6379:6379"
    volumes:
      - redis_data:/data
    command: redis-server --appendonly yes --requirepass password
    restart: unless-stopped
    networks:
      - redis_network
    healthcheck:
      test: ["CMD", "redis-cli", "ping"]
      interval: 10s
      timeout: 5s
      retries: 3

  redis-commander:
    image: rediscommander/redis-commander:latest
    container_name: redis-commander
    environment:
      - REDIS_HOSTS=local:redis:6379:0:password
    ports:
      - "8081:8081"
    depends_on:
      - redis
    restart: unless-stopped
    networks:
      - redis_network

volumes:
  redis_data:
    driver: local

networks:
  redis_network:
    driver: bridge

 


 

Configuration

 

application.yml

llm:
  claude:
    api:
      key: <your-anthropic-api-key>
      endpoint: https://api.anthropic.com/v1/messages

spring:
  data:
    redis:
      host: localhost
      port: 6379
      password: password

 


 

RedisConfig

@Configuration
public class RedisConfig {

    @Bean
    public ReactiveRedisTemplate<String, String> reactiveRedisTemplate(
            ReactiveRedisConnectionFactory connectionFactory) {
        RedisSerializationContext<String, String> serializationContext =
                RedisSerializationContext
                        .<String, String>newSerializationContext(new StringRedisSerializer())
                        .key(new StringRedisSerializer())
                        .value(new StringRedisSerializer())
                        .hashKey(new StringRedisSerializer())
                        .hashValue(new StringRedisSerializer())
                        .build();
        return new ReactiveRedisTemplate<>(connectionFactory, serializationContext);
    }

    @Bean
    public StreamReceiver<String, MapRecord<String, String, String>> streamReceiver(
            ReactiveRedisConnectionFactory connectionFactory) {
        return StreamReceiver.create(connectionFactory);
    }

}

 


 

WebClientConfig

@Configuration
public class WebClientConfig {

    @Value("${llm.claude.api.key}")
    private String claudeApiKey;

    @Value("${llm.claude.api.endpoint}")
    private String claudeApiEndpoint;

    @Bean
    public WebClient claudeWebClient() {
        return WebClient.builder()
                .baseUrl(claudeApiEndpoint)
                .defaultHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
                .defaultHeader("anthropic-version", "2023-06-01")
                .defaultHeader("x-api-key", claudeApiKey)
                .build();
    }

}

 


 

Controller 레이어 및 DTO

 

LlmController

@RestController
@RequiredArgsConstructor
public class LlmController {

    private final LlmService llmService;

    @PostMapping("/messages")
    public Mono<ResponseEntity<SendMessageResponseDto>> sendMessage(
            @RequestBody SendMessageRequestDto requestDto) {
        SendMessageResponseDto responseDto = llmService.sendMessage(requestDto);
        return Mono.just(ResponseEntity.ok(responseDto));
    }

    @GetMapping(value = "/messages/{streamMessageId}", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
    public Flux<ServerSentEvent<String>> streamMessage(
            @PathVariable("streamMessageId") String streamMessageId) {
        return llmService.streamMessage(streamMessageId)
                .map(data -> ServerSentEvent.builder(data).build());
    }

}

 


 

SendMessageRequestDto

@Getter
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public class SendMessageRequestDto {

    private String message;

}

 


 

SendMessageResponseDto

@Getter
public class SendMessageResponseDto {

    private String streamMessageId;

    @Builder
    public SendMessageResponseDto(String streamMessageId) {
        this.streamMessageId = streamMessageId;
    }

}

 


 

Service 레이어

 

sendMessage

public SendMessageResponseDto sendMessage(SendMessageRequestDto requestDto) {
    // 스트림될 메시지의 고유 ID 생성
    String streamMessageId = UUID.randomUUID().toString();

    // Claude API 요청 생성
    ClaudeRequestBody claudeRequestBody = createClaudeRequestBody(requestDto);

    // Claude API 요청 보내고 응답되는 메세지 스트림 처리
    claudeWebClient.post()
            .bodyValue(claudeRequestBody)
            .retrieve()
            .bodyToFlux(String.class)
            .map(this::parseJson)
            .filter(json -> json != null)
            .takeUntil(node -> "message_stop".equals(node.path("type").asText()))
            .filter(this::isContentDelta)
            .map(this::extractDeltaContent)
            .flatMap(content -> reactiveRedisTemplate.opsForStream()
                    .add("message-stream:" + streamMessageId,
                            Collections.singletonMap("content", content)))
            .doOnComplete(() -> publishEndMessage(streamMessageId))
            .doOnError(error -> {
                log.error("Error sending message", error);
                clearStream(streamMessageId);
            })
            .subscribe();

    // streamMessageId 반환
    return SendMessageResponseDto.builder()
            .streamMessageId(streamMessageId)
            .build();
}

 

  • map(this::parseJson): Claude API 응답을 JSON으로 파싱
  • filter(json -> json != null): JSON이 null이 아닌 경우만 필터링
  • takeUntil(node -> "message_stop".equals(node.path("type").asText())): type이 message_stop인 경우까지만 스트림 처리(API 제공자에 따라 상이함)
  • filter(this::isContentDelta): content 필드가 delta인 경우만 필터링(Claude의 경우 메세지 응답 본문이 type이 content_block_delta인 경우에 들어있음)
  • map(this::extractDeltaContent): 메세지 응답 본문 추출
  • flatMap(...): Redis Stream에 메세지 추가
  • doOnComplete(() -> publishEndMessage(streamMessageId)): 메세지 전송 완료 시, Redis Streams에 [DONE] 메세지 발행

 


 

streamMessage

private final StreamReceiver<String, MapRecord<String, String, String>> streamReceiver;

public Flux<String> streamMessage(String streamMessageId) {
    return streamReceiver.receive(StreamOffset.fromStart("message-stream:" + streamMessageId))
            .map(record -> record.getValue().get("content"))
            .takeUntil(content -> "[DONE]".equals(content))
            .timeout(Duration.ofMinutes(1))
            .doFinally(signalType -> clearStream(streamMessageId));
}

private void clearStream(String streamMessageId) {
    reactiveRedisTemplate
            .delete("message-stream:" + streamMessageId)
            .subscribe();
}

 

  • streamReceiver.receive(...): streamReceiver를 통해 Redis Stream에서 메세지 수신
  • StreamOffset.fromStart("message-stream:" + streamMessageId): 스트림의 시작부터 메세지 수신
  • doFinally(signalType -> clearStream(streamMessageId)): 스트림 종료 시 Redis에 저장된 Stream 삭제

 


 

Service 전체 코드

@Slf4j
@Service
@RequiredArgsConstructor
public class LlmService {

    private final WebClient claudeWebClient;
    private final ReactiveRedisTemplate<String, String> reactiveRedisTemplate;
    private final StreamReceiver<String, MapRecord<String, String, String>> streamReceiver;

    public SendMessageResponseDto sendMessage(SendMessageRequestDto requestDto) {
        String streamMessageId = UUID.randomUUID().toString();

        ClaudeRequestBody claudeRequestBody = createClaudeRequestBody(requestDto);

        claudeWebClient.post()
                .bodyValue(claudeRequestBody)
                .retrieve()
                .bodyToFlux(String.class)
                .map(this::parseJson)
                .filter(json -> json != null)
                .takeUntil(node -> "message_stop".equals(node.path("type").asText()))
                .filter(this::isContentDelta)
                .map(this::extractDeltaContent)
                .flatMap(content -> reactiveRedisTemplate.opsForStream()
                        .add("message-stream:" + streamMessageId,
                                Collections.singletonMap("content", content)))
                .doOnComplete(() -> publishEndMessage(streamMessageId))
                .doOnError(error -> {
                    log.error("Error sending message", error);
                    clearStream(streamMessageId);
                })
                .subscribe();

        return SendMessageResponseDto.builder()
                .streamMessageId(streamMessageId)
                .build();
    }

    private ClaudeRequestBody createClaudeRequestBody(SendMessageRequestDto requestDto) {
        ClaudeRequestBody.Message message = ClaudeRequestBody.Message.builder()
                .role("user")
                .content(requestDto.getMessage())
                .build();

        return ClaudeRequestBody.builder()
                .model("claude-3-5-sonnet-20241022")
                .stream(true)
                .maxTokens(1024)
                .messages(Collections.singletonList(message))
                .build();
    }

    public Flux<String> streamMessage(String streamMessageId) {
        return streamReceiver.receive(StreamOffset.fromStart("message-stream:" + streamMessageId))
                .map(record -> record.getValue().get("content"))
                .takeUntil(content -> "[DONE]".equals(content))
                .timeout(Duration.ofMinutes(1))
                .doFinally(signalType -> clearStream(streamMessageId));
    }

    private JsonNode parseJson(String line) {
        ObjectMapper objectMapper = new ObjectMapper();

        try {
            return objectMapper.readTree(line);
        } catch (JsonMappingException e) {
            return null;
        } catch (JsonProcessingException e) {
            return null;
        }
    }

    private boolean isContentDelta(JsonNode node) {
        return "content_block_delta".equals(node.path("type").asText());
    }

    private String extractDeltaContent(JsonNode node) {
        return node.path("delta")
                .path("text")
                .asText("");
    }

    private void publishEndMessage(String streamMessageId) {
        reactiveRedisTemplate.opsForStream()
                .add("message-stream:" + streamMessageId,
                        Collections.singletonMap("content", "[DONE]"))
                .subscribe();
    }

    private void clearStream(String streamMessageId) {
        reactiveRedisTemplate
                .delete("message-stream:" + streamMessageId)
                .subscribe();
    }

}

 


 

ClaudeRequestBody

@Getter
public class ClaudeRequestBody {

    private String model;
    private boolean stream;
    @JsonProperty("max_tokens")
    private int maxTokens;
    private List<Message> messages;

    @Builder
    public ClaudeRequestBody(String model, boolean stream, int maxTokens, List<Message> messages) {
        this.model = model;
        this.stream = stream;
        this.maxTokens = maxTokens;
        this.messages = messages;
    }

    @Getter
    public static class Message {

        private String role;
        private String content;

        @Builder
        public Message(String role, String content) {
            this.role = role;
            this.content = content;
        }
    }

}

 


 

 

테스트

 

주의!

Postman을 통해 테스트를 진행할 경우, 모든 메세지 스트림을 수신해도 메세지 수신 요청을 보낸 시점부터의 메세지만 출력되거나, 마지막 메세지만 출력될 수 있습니다.

 

따라서, 테스트를 진행할 때는 curl 명령어를 사용하는 것을 권장합니다.

 


 

sendMessage

 

요청

curl -X POST http://localhost:8080/messages \
-H "Content-Type: application/json" \
-d '{
    "message": "hello!"
}'

 

응답

{
  "streamMessageId":"f9ef8fbe-fd91-4218-855f-e65ff28df7a1"
}

 


 

streamMessage

 

요청

curl --location 'http://localhost:8080/messages/f9ef8fbe-fd91-4218-855f-e65ff28df7a1'

 

응답

data:Hi

data: there! How can I help you today?

data:[DONE]

 

저작자표시 (새창열림)
  1. 개요
  2. 왜 Redis Streams를 사용하게 되었는가?
  3. 상황
  4. 프론트엔드에서의 요청 흐름
  5. 발생 가능한 문제점
  6. 전체적인 흐름
  7. 구현
  8. Redis Docker Compose
  9. Configuration
  10. Controller 레이어 및 DTO
  11. Service 레이어
  12. Service 전체 코드
  13. 테스트
  14. 주의!
  15. sendMessage
  16. streamMessage
'Spring' 카테고리의 다른 글
  • 스프링 - Redission 분산락으로 동시성 문제 해결하기 예시
  • 스프링 - Redis CacheManager ClassCastException 해결하기
  • 스프링 RabbitMQ 연동하기
  • 스프링에서 각기 다른 Base URL의 WebClient 인스턴스 사용하기
gerrymandering
gerrymandering
gerrymandering
gerrymandering
gerrymandering
전체
오늘
어제
  • 분류 전체보기 (81)
    • SOLID 원칙 (6)
    • 번역 (4)
    • Nginx (1)
    • Tailwind CSS (1)
    • AWS (7)
      • DMS를 사용한 RDS to OpenSearch .. (3)
      • ECS를 이용한 Blue-Green 무중단 배포 .. (7)
    • NextJS (5)
    • 기타 (12)
    • Prompt Engineering (6)
    • 읽어볼만한 글 (3)
      • 기술 (0)
      • 쓸만한 툴 (0)
      • 아이템 (0)
      • 웹 디자인 (0)
      • 기타 (3)
    • Cloud Architecture (4)
    • Trouble Shooting (9)
    • Spring (11)

블로그 메뉴

  • 홈
  • 태그
  • 방명록

공지사항

인기 글

최근 댓글

최근 글

글쓰기 / 관리자
hELLO · Designed By 정상우.v4.2.1
gerrymandering
스프링 - LLM Response를 Redis Streams와 SSE로 스트리밍 해보자
상단으로

티스토리툴바

단축키

내 블로그

내 블로그 - 관리자 홈 전환
Q
Q
새 글 쓰기
W
W

블로그 게시글

글 수정 (권한 있는 경우)
E
E
댓글 영역으로 이동
C
C

모든 영역

이 페이지의 URL 복사
S
S
맨 위로 이동
T
T
티스토리 홈 이동
H
H
단축키 안내
Shift + /
⇧ + /

* 단축키는 한글/영문 대소문자로 이용 가능하며, 티스토리 기본 도메인에서만 동작합니다.